fea: support age timer for knock

This commit is contained in:
Daniel Ding
2024-01-02 15:53:19 +08:00
parent 1af91f2f65
commit 77fa149380
8 changed files with 156 additions and 56 deletions

View File

@@ -44,6 +44,6 @@ func Commands(app *api.App) {
Policy{}.Commands(app) Policy{}.Commands(app)
Version{}.Commands(app) Version{}.Commands(app)
Log{}.Commands(app) Log{}.Commands(app)
ZGuest{}.Commands(app) Guest{}.Commands(app)
Knock{}.Commands(app) Knock{}.Commands(app)
} }

View File

@@ -9,11 +9,11 @@ import (
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
) )
type ZGuest struct { type Guest struct {
Cmd Cmd
} }
func (u ZGuest) Url(prefix, name string) string { func (u Guest) Url(prefix, name string) string {
name, network := api.SplitName(name) name, network := api.SplitName(name)
if name == "" { if name == "" {
return prefix + "/api/network/" + network + "/guest" return prefix + "/api/network/" + network + "/guest"
@@ -21,7 +21,7 @@ func (u ZGuest) Url(prefix, name string) string {
return prefix + "/api/network/" + network + "/guest/" + name return prefix + "/api/network/" + network + "/guest/" + name
} }
func (u ZGuest) Add(c *cli.Context) error { func (u Guest) Add(c *cli.Context) error {
username := c.String("name") username := c.String("name")
if !strings.Contains(username, "@") { if !strings.Contains(username, "@") {
return libol.NewErr("invalid username") return libol.NewErr("invalid username")
@@ -39,7 +39,7 @@ func (u ZGuest) Add(c *cli.Context) error {
return nil return nil
} }
func (u ZGuest) Remove(c *cli.Context) error { func (u Guest) Remove(c *cli.Context) error {
username := c.String("name") username := c.String("name")
if !strings.Contains(username, "@") { if !strings.Contains(username, "@") {
return libol.NewErr("invalid username") return libol.NewErr("invalid username")
@@ -57,7 +57,7 @@ func (u ZGuest) Remove(c *cli.Context) error {
return nil return nil
} }
func (u ZGuest) Tmpl() string { func (u Guest) Tmpl() string {
return `# total {{ len . }} return `# total {{ len . }}
{{ps -24 "username"}} {{ps -24 "address"}} {{ps -24 "username"}} {{ps -24 "address"}}
{{- range . }} {{- range . }}
@@ -66,7 +66,7 @@ func (u ZGuest) Tmpl() string {
` `
} }
func (u ZGuest) List(c *cli.Context) error { func (u Guest) List(c *cli.Context) error {
network := c.String("network") network := c.String("network")
url := u.Url(c.String("url"), "@"+network) url := u.Url(c.String("url"), "@"+network)
@@ -80,11 +80,11 @@ func (u ZGuest) List(c *cli.Context) error {
return u.Out(items, c.String("format"), u.Tmpl()) return u.Out(items, c.String("format"), u.Tmpl())
} }
func (u ZGuest) Commands(app *api.App) { func (u Guest) Commands(app *api.App) {
app.Command(&cli.Command{ app.Command(&cli.Command{
Name: "zguest", Name: "guest",
Aliases: []string{"zg"}, Aliases: []string{"gu"},
Usage: "zGuest configuration", Usage: "ZTrust Guest configuration",
Subcommands: []*cli.Command{ Subcommands: []*cli.Command{
{ {
Name: "add", Name: "add",

View File

@@ -26,6 +26,7 @@ func (u Knock) Add(c *cli.Context) error {
socket := c.String("socket") socket := c.String("socket")
knock := &schema.KnockRule{ knock := &schema.KnockRule{
Protocol: c.String("protocol"), Protocol: c.String("protocol"),
Age: c.Int("age"),
} }
knock.Name, knock.Network = api.SplitName(username) knock.Name, knock.Network = api.SplitName(username)
knock.Dest, knock.Port = api.SplitSocket(socket) knock.Dest, knock.Port = api.SplitSocket(socket)
@@ -60,9 +61,9 @@ func (u Knock) Remove(c *cli.Context) error {
func (u Knock) Tmpl() string { func (u Knock) Tmpl() string {
return `# total {{ len . }} return `# total {{ len . }}
{{ps -24 "username"}} {{ps -8 "protocol"}} {{ps -24 "socket"}} {{ps -24 "createAt"}} {{ps -24 "username"}} {{ps -8 "protocol"}} {{ps -24 "socket"}} {{ps -4 "age"}} {{ps -24 "createAt"}}
{{- range . }} {{- range . }}
{{p2 -24 "%s@%s" .Name .Network}} {{ps -8 .Protocol}} {{p2 -24 "%s:%s" .Dest .Port}} {{ut .CreateAt}} {{p2 -24 "%s@%s" .Name .Network}} {{ps -8 .Protocol}} {{p2 -24 "%s:%s" .Dest .Port}} {{pi -4 .Age}} {{ut .CreateAt}}
{{- end }} {{- end }}
` `
} }
@@ -94,6 +95,7 @@ func (u Knock) Commands(app *api.App) {
&cli.StringFlag{Name: "name"}, &cli.StringFlag{Name: "name"},
&cli.StringFlag{Name: "protocol"}, &cli.StringFlag{Name: "protocol"},
&cli.StringFlag{Name: "socket"}, &cli.StringFlag{Name: "socket"},
&cli.IntFlag{Name: "age", Value: 60},
}, },
Action: u.Add, Action: u.Add,
}, },

View File

@@ -4,6 +4,7 @@ import (
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/luscis/openlan/pkg/cache"
"github.com/luscis/openlan/pkg/libol" "github.com/luscis/openlan/pkg/libol"
"github.com/luscis/openlan/pkg/schema" "github.com/luscis/openlan/pkg/schema"
) )
@@ -77,6 +78,17 @@ func (h ZTrust) AddGuest(w http.ResponseWriter, r *http.Request) {
guest.Name = vars["user"] guest.Name = vars["user"]
libol.Info("ZTrust.AddGuest %s@%s", guest.Name, id) libol.Info("ZTrust.AddGuest %s@%s", guest.Name, id)
if guest.Address == "" {
client := cache.VPNClient.Get(id, guest.Name)
if client != nil {
guest.Address = client.Address
guest.Device = client.Device
}
}
if guest.Address == "" {
http.Error(w, "invalid address", http.StatusBadRequest)
return
}
if err := ztrust.AddGuest(guest.Name, guest.Address); err == nil { if err := ztrust.AddGuest(guest.Name, guest.Address); err == nil {
ResponseJson(w, "success") ResponseJson(w, "success")
@@ -164,7 +176,7 @@ func (h ZTrust) AddKnock(w http.ResponseWriter, r *http.Request) {
name := vars["user"] name := vars["user"]
libol.Info("ZTrust.AddKnock %s@%s", rule.Name, id) libol.Info("ZTrust.AddKnock %s@%s", rule.Name, id)
if err := ztrust.Knock(name, rule.Protocol, rule.Dest, rule.Port, 0); err == nil { if err := ztrust.Knock(name, rule.Protocol, rule.Dest, rule.Port, rule.Age); err == nil {
ResponseJson(w, "success") ResponseJson(w, "success")
} else { } else {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)

11
pkg/cache/openvpn.go vendored
View File

@@ -157,6 +157,17 @@ func (o *vpnClient) List(name string) <-chan *schema.VPNClient {
return c return c
} }
func (o *vpnClient) Get(name, user string) *schema.VPNClient {
username := user + "@" + name
clients := o.readStatus(name)
for _, client := range clients {
if client.Name == username {
return client
}
}
return nil
}
func (o *vpnClient) clientFile(name string) string { func (o *vpnClient) clientFile(name string) string {
files, _ := filepath.Glob(o.Dir(name, "*client.ovpn")) files, _ := filepath.Glob(o.Dir(name, "*client.ovpn"))
if len(files) > 0 { if len(files) > 0 {

View File

@@ -225,9 +225,15 @@ func (w *WorkerImpl) Start(v api.Switcher) {
if !(w.vpn == nil || w.ztrust == nil) { if !(w.vpn == nil || w.ztrust == nil) {
w.ztrust.Start() w.ztrust.Start()
fire.Mangle.Pre.AddRule(cn.IpRule{
Input: vpn.Device,
CtState: "RELATED,ESTABLISHED",
Comment: "Forwarding Accpted",
})
fire.Mangle.Pre.AddRule(cn.IpRule{ fire.Mangle.Pre.AddRule(cn.IpRule{
Input: vpn.Device, Input: vpn.Device,
Jump: w.ztrust.Chain(), Jump: w.ztrust.Chain(),
Comment: "Goto Zero Trust",
}) })
} }
fire.Start() fire.Start()

View File

@@ -66,7 +66,7 @@ push "{{ . }}"
ifconfig-pool-persist {{ .Protocol }}{{ .Port }}ipp ifconfig-pool-persist {{ .Protocol }}{{ .Port }}ipp
tls-auth {{ .TlsAuth }} 0 tls-auth {{ .TlsAuth }} 0
cipher {{ .Cipher }} cipher {{ .Cipher }}
status {{ .Protocol }}{{ .Port }}server.status 5 status {{ .Protocol }}{{ .Port }}server.status 2
{{- if .CertNot }} {{- if .CertNot }}
client-cert-not-required client-cert-not-required
{{- else }} {{- else }}
@@ -98,7 +98,7 @@ push "route {{ . }}"
ifconfig-pool-persist {{ .Protocol }}{{ .Port }}ipp ifconfig-pool-persist {{ .Protocol }}{{ .Port }}ipp
tls-auth {{ .TlsAuth }} 0 tls-auth {{ .TlsAuth }} 0
cipher {{ .Cipher }} cipher {{ .Cipher }}
status {{ .Protocol }}{{ .Port }}server.status 5 status {{ .Protocol }}{{ .Port }}server.status 2
client-config-dir {{ .ClientConfigDir }} client-config-dir {{ .ClientConfigDir }}
verb 3 verb 3
` `
@@ -313,7 +313,7 @@ func (o *OpenVPN) Clean() {
} }
} }
} }
files := []string{o.FileStats(true), o.FileIpp(true)} files := []string{o.FileStats(true), o.FileIpp(true), o.FileClient(true)}
for _, file := range files { for _, file := range files {
if err := libol.FileExist(file); err == nil { if err := libol.FileExist(file); err == nil {
if err := os.Remove(file); err != nil { if err := os.Remove(file); err != nil {

View File

@@ -2,6 +2,7 @@ package cswitch
import ( import (
"fmt" "fmt"
"sync"
"time" "time"
"github.com/luscis/openlan/pkg/libol" "github.com/luscis/openlan/pkg/libol"
@@ -11,8 +12,8 @@ import (
) )
type KnockRule struct { type KnockRule struct {
createAt int64 createAt time.Time
age int age int64
protocol string protocol string
destination string destination string
port string port string
@@ -23,6 +24,27 @@ func (r *KnockRule) Id() string {
return fmt.Sprintf("%s:%s:%s", r.protocol, r.destination, r.port) return fmt.Sprintf("%s:%s:%s", r.protocol, r.destination, r.port)
} }
func (r *KnockRule) Expire() bool {
now := time.Now()
if r.createAt.Unix()+int64(r.age) < now.Unix() {
return true
}
return false
}
func (r *KnockRule) Rule() cn.IpRule {
if r.rule == nil {
r.rule = &cn.IpRule{
Dest: r.destination,
DstPort: r.port,
Proto: r.protocol,
Jump: "ACCEPT",
Comment: "Knock at " + r.createAt.UTC().String(),
}
}
return *r.rule
}
type ZGuest struct { type ZGuest struct {
network string network string
username string username string
@@ -30,6 +52,7 @@ type ZGuest struct {
rules map[string]*KnockRule rules map[string]*KnockRule
chain *cn.FireWallChain chain *cn.FireWallChain
out *libol.SubLogger out *libol.SubLogger
lock sync.Mutex
} }
func NewZGuest(network, name string) *ZGuest { func NewZGuest(network, name string) *ZGuest {
@@ -55,55 +78,87 @@ func (g *ZGuest) Start() {
} }
func (g *ZGuest) AddSource(source string) { func (g *ZGuest) AddSource(source string) {
g.lock.Lock()
defer g.lock.Unlock()
g.sources[source] = source g.sources[source] = source
} }
func (g *ZGuest) HasSource(source string) bool {
g.lock.Lock()
defer g.lock.Unlock()
if _, ok := g.sources[source]; ok {
return true
}
return false
}
func (g *ZGuest) DelSource(source string) { func (g *ZGuest) DelSource(source string) {
g.lock.Lock()
defer g.lock.Unlock()
if _, ok := g.sources[source]; ok { if _, ok := g.sources[source]; ok {
delete(g.sources, source) delete(g.sources, source)
} }
} }
func (g *ZGuest) AddRuleX(rule cn.IpRule) { func (g *ZGuest) addRuleX(rule cn.IpRule) {
if err := g.chain.AddRuleX(rule); err != nil { if err := g.chain.AddRuleX(rule); err != nil {
g.out.Warn("ZTrust.AddRuleX: %s", err) g.out.Warn("ZTrust.AddRuleX: %s", err)
} }
} }
func (g *ZGuest) DelRuleX(rule cn.IpRule) { func (g *ZGuest) delRuleX(rule cn.IpRule) {
if err := g.chain.DelRuleX(rule); err != nil { if err := g.chain.DelRuleX(rule); err != nil {
g.out.Warn("ZTrust.DelRuleX: %s", err) g.out.Warn("ZTrust.DelRuleX: %s", err)
} }
} }
func (g *ZGuest) AddRule(rule *KnockRule) { func (g *ZGuest) AddRule(rule *KnockRule) {
g.lock.Lock()
defer g.lock.Unlock()
if dst, ok := g.rules[rule.Id()]; !ok {
g.addRuleX(rule.Rule())
g.rules[rule.Id()] = rule g.rules[rule.Id()] = rule
g.AddRuleX(cn.IpRule{ } else {
Dest: rule.destination, dst.age = rule.age
DstPort: rule.port, dst.createAt = rule.createAt
Proto: rule.protocol, }
Jump: "ACCEPT",
Comment: "Knock at " + time.Now().UTC().String(),
})
} }
func (g *ZGuest) DelRule(rule *KnockRule) { func (g *ZGuest) DelRule(rule *KnockRule) {
g.lock.Lock()
defer g.lock.Unlock()
if _, ok := g.rules[rule.Id()]; ok { if _, ok := g.rules[rule.Id()]; ok {
g.delRuleX(rule.Rule())
delete(g.rules, rule.Id()) delete(g.rules, rule.Id())
} }
g.DelRuleX(cn.IpRule{
Proto: rule.protocol,
Dest: rule.destination,
DstPort: rule.port,
Jump: "ACCEPT",
Comment: "Knock at " + time.Now().Local().String(),
})
} }
func (g *ZGuest) Stop() { func (g *ZGuest) Stop() {
g.chain.Cancel() g.chain.Cancel()
} }
func (g *ZGuest) Clear() {
g.lock.Lock()
defer g.lock.Unlock()
removed := make([]*KnockRule, 0, 32)
for _, rule := range g.rules {
if rule.Expire() {
removed = append(removed, rule)
}
}
for _, rule := range removed {
g.out.Info("ZTrust.Clear: %s", rule.Id())
delete(g.rules, rule.Id())
g.delRuleX(rule.Rule())
}
}
type ZTrust struct { type ZTrust struct {
network string network string
expire int expire int
@@ -128,7 +183,7 @@ func (z *ZTrust) Chain() string {
func (z *ZTrust) Initialize() { func (z *ZTrust) Initialize() {
z.chain = cn.NewFireWallChain(z.Chain(), network.TMangle, "") z.chain = cn.NewFireWallChain(z.Chain(), network.TMangle, "")
z.chain.AddRule(cn.IpRule{ z.chain.AddRule(cn.IpRule{
Comment: "ZTrust Default", Comment: "ZTrust Deny All",
Jump: "DROP", Jump: "DROP",
}) })
} }
@@ -142,23 +197,29 @@ func (z *ZTrust) Knock(name string, protocol, dest, port string, age int) error
protocol: protocol, protocol: protocol,
destination: dest, destination: dest,
port: port, port: port,
createAt: time.Now().Unix(), createAt: time.Now(),
age: age, age: int64(age),
}) })
return nil return nil
} }
func (z *ZTrust) Update() { func (z *ZTrust) Update() {
//TODO expire knock rules. for {
for _, guest := range z.guests {
guest.Clear()
}
time.Sleep(time.Second * 3)
}
} }
func (z *ZTrust) AddRuleX(rule cn.IpRule) { func (z *ZTrust) addRuleX(rule cn.IpRule) {
if err := z.chain.AddRuleX(rule); err != nil { if err := z.chain.AddRuleX(rule); err != nil {
z.out.Warn("ZTrust.AddRuleX: %s", err) z.out.Warn("ZTrust.AddRuleX: %s", err)
} }
} }
func (z *ZTrust) DelRuleX(rule cn.IpRule) { func (z *ZTrust) delRuleX(rule cn.IpRule) {
if err := z.chain.DelRuleX(rule); err != nil { if err := z.chain.DelRuleX(rule); err != nil {
z.out.Warn("ZTrust.DelRuleX: %s", err) z.out.Warn("ZTrust.DelRuleX: %s", err)
} }
@@ -176,13 +237,15 @@ func (z *ZTrust) AddGuest(name, source string) error {
guest.Start() guest.Start()
z.guests[name] = guest z.guests[name] = guest
} }
guest.AddSource(source) if !guest.HasSource(source) {
z.AddRuleX(cn.IpRule{ z.addRuleX(cn.IpRule{
Source: source, Source: source,
Comment: "User " + guest.username + "@" + guest.network, Comment: "User " + guest.username + "@" + guest.network,
Jump: guest.Chain(), Jump: guest.Chain(),
Order: "-I", Order: "-I",
}) })
}
guest.AddSource(source)
return nil return nil
} }
@@ -195,11 +258,13 @@ func (z *ZTrust) DelGuest(name, source string) error {
return libol.NewErr("DelGuest: not found %s", name) return libol.NewErr("DelGuest: not found %s", name)
} }
z.out.Info("ZTrust.DelGuest: %s %s", name, source) z.out.Info("ZTrust.DelGuest: %s %s", name, source)
z.DelRuleX(cn.IpRule{ if guest.HasSource(source) {
z.delRuleX(cn.IpRule{
Source: source, Source: source,
Comment: guest.username + "." + guest.network, Comment: guest.username + "." + guest.network,
Jump: guest.Chain(), Jump: guest.Chain(),
}) })
}
guest.DelSource(source) guest.DelSource(source)
return nil return nil
} }
@@ -207,6 +272,7 @@ func (z *ZTrust) DelGuest(name, source string) error {
func (z *ZTrust) Start() { func (z *ZTrust) Start() {
z.out.Info("ZTrust.Start") z.out.Info("ZTrust.Start")
z.chain.Install() z.chain.Install()
libol.Go(z.Update)
} }
func (z *ZTrust) Stop() { func (z *ZTrust) Stop() {
@@ -236,14 +302,17 @@ func (z *ZTrust) ListKnock(name string, call func(obj schema.KnockRule)) {
return return
} }
now := time.Now()
for _, rule := range guest.rules { for _, rule := range guest.rules {
createAt := rule.createAt
obj := schema.KnockRule{ obj := schema.KnockRule{
Name: name, Name: name,
Network: z.network, Network: z.network,
Protocol: rule.protocol, Protocol: rule.protocol,
Dest: rule.destination, Dest: rule.destination,
Port: rule.port, Port: rule.port,
CreateAt: rule.createAt, CreateAt: createAt.Unix(),
Age: int(rule.age + createAt.Unix() - now.Unix()),
} }
call(obj) call(obj)
} }