Files
openlan/pkg/switch/ztrust.go
2024-03-27 13:17:41 +08:00

304 lines
5.7 KiB
Go

package cswitch
import (
"fmt"
"sync"
"time"
"github.com/luscis/openlan/pkg/libol"
cn "github.com/luscis/openlan/pkg/network"
"github.com/luscis/openlan/pkg/schema"
)
type KnockRule struct {
createAt time.Time
age int64
protocol string
destination string
port string
rule *cn.IPRule
}
func (r *KnockRule) Id() string {
return fmt.Sprintf("%s:%s:%s", r.protocol, r.destination, r.port)
}
func (r *KnockRule) Expire() bool {
now := time.Now()
if r.createAt.Unix()+int64(r.age) < now.Unix() {
return true
}
return false
}
func (r *KnockRule) Rule() cn.IPRule {
if r.rule == nil {
r.rule = &cn.IPRule{
Dest: r.destination,
DstPort: r.port,
Proto: r.protocol,
Comment: "Knock at " + r.createAt.UTC().String(),
}
}
return *r.rule
}
type ZGuest struct {
network string
username string
source string
rules map[string]*KnockRule
chain *cn.FireWallChain
out *libol.SubLogger
lock sync.Mutex
}
func NewZGuest(network, name, source string) *ZGuest {
return &ZGuest{
network: network,
username: name,
source: source,
rules: make(map[string]*KnockRule, 1024),
out: libol.NewSubLogger(name + "@" + network),
}
}
func (g *ZGuest) Chain() string {
return "ZTT_" + g.network + "-" + g.username
}
func (g *ZGuest) Start() {
g.chain = cn.NewFireWallChain(g.Chain(), cn.TMangle, "")
g.chain.Install()
}
func (g *ZGuest) addRuleX(rule cn.IPRule) {
if err := g.chain.AddRuleX(rule); err != nil {
g.out.Warn("ZTrust.AddRuleX: %s", err)
}
}
func (g *ZGuest) delRuleX(rule cn.IPRule) {
if err := g.chain.DelRuleX(rule); err != nil {
g.out.Warn("ZTrust.DelRuleX: %s", err)
}
}
func (g *ZGuest) AddRule(rule *KnockRule) {
g.lock.Lock()
defer g.lock.Unlock()
if dst, ok := g.rules[rule.Id()]; !ok {
g.addRuleX(rule.Rule())
g.rules[rule.Id()] = rule
} else {
dst.age = rule.age
dst.createAt = rule.createAt
}
}
func (g *ZGuest) DelRule(rule *KnockRule) {
g.lock.Lock()
defer g.lock.Unlock()
if _, ok := g.rules[rule.Id()]; ok {
g.delRuleX(rule.Rule())
delete(g.rules, rule.Id())
}
}
func (g *ZGuest) Stop() {
g.lock.Lock()
defer g.lock.Unlock()
g.flush()
g.chain.Cancel()
}
func (g *ZGuest) Clear() {
g.lock.Lock()
defer g.lock.Unlock()
removed := make([]*KnockRule, 0, 32)
for _, rule := range g.rules {
if rule.Expire() {
removed = append(removed, rule)
}
}
for _, rule := range removed {
g.out.Info("ZTrust.Clear: %s", rule.Id())
delete(g.rules, rule.Id())
g.delRuleX(rule.Rule())
}
}
func (g *ZGuest) flush() {
for _, rule := range g.rules {
g.delRuleX(rule.Rule())
}
}
func (g *ZGuest) Flush() {
g.lock.Lock()
defer g.lock.Unlock()
g.flush()
}
type ZTrust struct {
network string
expire int
guests map[string]*ZGuest
chain *cn.FireWallChain
out *libol.SubLogger
}
func NewZTrust(network string, expire int) *ZTrust {
return &ZTrust{
network: network,
expire: expire,
out: libol.NewSubLogger(network),
guests: make(map[string]*ZGuest, 32),
}
}
func (z *ZTrust) Chain() string {
return "ZTT_" + z.network
}
func (z *ZTrust) Initialize() {
z.chain = cn.NewFireWallChain(z.Chain(), cn.TMangle, "")
z.chain.AddRule(cn.IPRule{
Comment: "ZTrust Deny All",
Jump: "DROP",
})
}
func (z *ZTrust) Knock(name string, protocol, dest, port string, age int) error {
guest, ok := z.guests[name]
if !ok {
return libol.NewErr("Knock: not found %s", name)
}
rule := &KnockRule{
protocol: protocol,
destination: dest,
port: port,
createAt: time.Now(),
age: int64(age),
}
z.out.Info("Knock: %s %s", name, rule.Id())
guest.AddRule(rule)
return nil
}
func (z *ZTrust) Update() {
for {
for _, guest := range z.guests {
guest.Clear()
}
time.Sleep(time.Second * 3)
}
}
func (z *ZTrust) addRuleX(rule cn.IPRule) {
if err := z.chain.AddRuleX(rule); err != nil {
z.out.Warn("ZTrust.AddRuleX: %s", err)
}
}
func (z *ZTrust) delRuleX(rule cn.IPRule) {
if err := z.chain.DelRuleX(rule); err != nil {
z.out.Warn("ZTrust.DelRuleX: %s", err)
}
}
func (z *ZTrust) AddGuest(name, source string) error {
z.out.Info("ZTrust.AddGuest: %s %s", name, source)
if source == "" {
return libol.NewErr("AddGuest: invalid source")
}
guest, ok := z.guests[name]
if ok {
return nil
}
guest = NewZGuest(z.network, name, source)
guest.Start()
z.addRuleX(cn.IPRule{
Source: guest.source,
Comment: "User " + guest.username + "@" + guest.network,
Jump: guest.Chain(),
Order: "-I",
})
z.guests[name] = guest
return nil
}
func (z *ZTrust) DelGuest(name, source string) error {
guest, ok := z.guests[name]
if !ok {
return nil
}
z.out.Info("ZTrust.DelGuest: %s %s", name, source)
z.delRuleX(cn.IPRule{
Source: guest.source,
Comment: "User " + guest.username + "@" + guest.network,
Jump: guest.Chain(),
})
guest.Stop()
delete(z.guests, name)
return nil
}
func (z *ZTrust) Start() {
z.out.Info("ZTrust.Start")
z.chain.Install()
libol.Go(z.Update)
}
func (z *ZTrust) Stop() {
z.out.Info("ZTrust.Stop")
z.chain.Cancel()
for _, guest := range z.guests {
guest.Stop()
}
}
func (z *ZTrust) ListGuest(call func(obj schema.ZGuest)) {
for _, guest := range z.guests {
obj := schema.ZGuest{
Name: guest.username,
Network: guest.network,
Address: guest.source,
}
call(obj)
}
}
func (z *ZTrust) ListKnock(name string, call func(obj schema.KnockRule)) {
guest, ok := z.guests[name]
if !ok {
return
}
now := time.Now()
for _, rule := range guest.rules {
createAt := rule.createAt
obj := schema.KnockRule{
Name: name,
Network: z.network,
Protocol: rule.protocol,
Dest: rule.destination,
Port: rule.port,
CreateAt: createAt.Unix(),
Age: int(rule.age + createAt.Unix() - now.Unix()),
}
call(obj)
}
}