Files
MirageServer/controller/organization.go
ivloli 9bff601218 Autogroup self (#24)
何川完成的autogroup:self以及autogroup:owner的处理

* 租户的aclPolicy的默认配置为nil

* 修改:aclPolicy判空需要对Acls字段进行判断,空则浅复制之后,替换为全通再生成rules

* 新增:ACLPolicy-autogroup:self

* 使用map来判断autogroup

* 处理autogroup self和owner

* 修改:减少updateACLRules的无效次数

* 添加一点关于autogroup替换的注释

* 减少updateAclRules的引用处

* 在aclRules的生成函数中加入userID,以便于可以获取到和请求用户相关的信息&&调整autogroup:self的src acl生成

* autogroup:self 配置后,src只包含self解析出来的地址,并不会包含dest的所有地址

* 获取peers:添加peerCacheMap(同步HS修改)以及快速判断autogroup:self

* 添加节点更新推送

* 租户内节点更新,通知其他节点进行更新netmap;获取LastStateChange不必排序,只需取最晚time

* 新用户登录时候查询组织不存在的错误码替换为组织不存在,以便可以新建用户

* autogroup:self bug fix

* merge main

* 修改peerCache的生成和使用方式,不再遍历CIDR内所有ip

* 将UpdateAclRule操作从getPeer中提出到getMapResponse中

* fix bug: updateAclRules之后没有同步更新到对应的machine上

* 抽取出关于autogroup:self的修改

* fix bug:self情况下peer加入要判断uid

* acl expand alias: 调整autogroup到前面

* 租户建立时,默认添加一条全通ACL规则

* 租户初始化默认ACL添加一条全通

---------

Co-authored-by: chuanh <chuanh@opera.com>
Co-authored-by: chuanhe <chuanhe.u90@gmail.com>
Co-authored-by: Chenyang Gao <gps949@outlook.com>
2023-06-14 13:29:04 +08:00

307 lines
8.2 KiB
Go

package controller
import (
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/bwmarrin/snowflake"
"github.com/rs/zerolog/log"
"github.com/sethvargo/go-diceware/diceware"
"gorm.io/gorm"
"tailscale.com/tailcfg"
)
const (
ErrOrgExists = Error("Organization already exists")
ErrOrgNotFound = Error("Organization not found")
ErrCreateOrgParams = Error("Invalid create organization paramters")
ErrGetOrgParams = Error("Invalid get organization paramters")
ErrDeleteOrgFailed = Error("Delete Organization Failed")
ErrDeleteOrgCancelled = Error("Delete Organization Cancelled")
DefaultExpireTime = 180
)
type Organization struct {
ID int64 `gorm:"primary_key;unique;not null"`
StableID string `gorm:"unique"`
Name string `gorm:"uniqueIndex:idx_name_provider"`
Provider string `gorm:"uniqueIndex:idx_name_provider"`
ExpiryDuration uint `gorm:"default:180"`
EnableMagic bool `gorm:"default:false"`
MagicDnsDomain string
OverrideLocal bool `gorm:"default:false"`
Nameservers StringList
SplitDns SplitDNS
AclPolicy *ACLPolicy
AclRules []tailcfg.FilterRule `gorm:"-"`
SshPolicy *tailcfg.SSHPolicy `gorm:"-"`
NaviBanList NaviBanList
NaviDeployKey string
NaviDeployPub string
CreatedAt time.Time
UpdatedAt time.Time
}
type NaviBanList map[int]struct{}
func (nbl NaviBanList) Value() (driver.Value, error) {
b, err := json.Marshal(nbl)
return string(b), err
}
func (nbl *NaviBanList) Scan(value interface{}) error {
switch v := value.(type) {
case []byte:
return json.Unmarshal(v, nbl)
case string:
return json.Unmarshal([]byte(v), nbl)
default:
return fmt.Errorf("cannot parse admin credential: unexpected data type %T", value)
}
}
func (o *Organization) BeforeCreate(tx *gorm.DB) error {
if o.ID == 0 {
flakeID, err := snowflake.NewNode(1)
if err != nil {
return err
}
id := flakeID.Generate().Int64()
o.ID = id
}
o.StableID = GetShortId(o.ID)
return nil
}
func (m *Mirage) GenNewMagicDNSDomain(tx *gorm.DB) (string, error) {
list, err := diceware.Generate(2)
if err != nil {
log.Error().Err(err).Msg("Could not generate passphrase")
return "", err
}
tmpMagicDNSDomain := strings.Join(list, "-") + "." + m.cfg.BaseDomain
for {
if errors.Is(tx.First(&Organization{}, "magic_dns_domain = ?", tmpMagicDNSDomain).Error, gorm.ErrRecordNotFound) {
break
}
list, err = diceware.Generate(2)
if err != nil {
log.Error().Err(err).Msg("Could not generate passphrase")
return "", err
}
tmpMagicDNSDomain = strings.Join(list, "-") + "." + m.cfg.BaseDomain
}
return tmpMagicDNSDomain, nil
}
func (m *Mirage) UpdateMagicDNSDomain(orgID int64, netMagicDomain string) error {
org, err := m.GetOrgnaizationByID(orgID)
if err != nil {
return err
}
org.MagicDnsDomain = netMagicDomain
err = m.db.Save(org).Error
if err != nil {
return err
}
m.setOrgLastStateChangeToNow(orgID)
return nil
}
func (m *Mirage) CreateOrgnaizationInTx(tx *gorm.DB, name, provider string) (*Organization, error) {
if len(name) == 0 || len(provider) == 0 {
return nil, ErrCreateOrgParams
}
var count int64
if err := tx.Model(&Organization{}).Where("name = ? and provider = ?", name, provider).Count(&count).Error; err != nil {
return nil, err
}
if count > 0 {
return nil, ErrOrgExists
}
org := Organization{}
org.Name = name
org.Provider = provider
org.ExpiryDuration = DefaultExpireTime
org.AclPolicy = &ACLPolicy{
Groups: make(Groups, 0),
Hosts: make(Hosts, 0),
TagOwners: make(TagOwners, 0),
ACLs: []ACL{
{
Action: "accept",
Protocol: "",
Sources: []string{"*"},
Destinations: []string{"*:*"},
},
},
Tests: make([]ACLTest, 0),
AutoApprovers: AutoApprovers{
Routes: make(map[string][]string, 0),
ExitNode: make([]string, 0),
},
SSHs: make([]SSH, 0),
}
//cgao6: 添加组织幻域域名roll生成
newMagicDNSDomain, err := m.GenNewMagicDNSDomain(tx)
if err != nil {
log.Error().
Str("func", "CreateOrgnaization").
Err(err).
Msg("Could not generate magic dns domain")
return nil, err
}
org.MagicDnsDomain = newMagicDNSDomain
if err := tx.Create(&org).Error; err != nil {
log.Error().
Str("func", "CreateOrgnaization").
Err(err).
Msg("Could not create row")
return nil, err
}
return &org, nil
}
// GetOrgnaizationRecordByName get Organization Info only(not to update the AclRules)
func (m *Mirage) GetOrgnaizationRecordByName(name, provider string) (*Organization, error) {
var org Organization
err := m.db.Model(&Organization{}).Where(&Organization{
Name: name,
Provider: provider,
}).Take(&org).Error
return &org, err
}
// GetOrgnaizationIDByName get Organization id (the primary key of the db table)
func (m *Mirage) GetOrgnaizationIDByName(name, provider string) (int64, error) {
var id int64
err := m.db.Model(&Organization{}).Where(&Organization{
Name: name,
Provider: provider,
}).Take(&id).Error
return id, err
}
// GetOrgnaizationByID get Organization Info and update the AclRules
func (m *Mirage) GetOrgnaizationByID(id int64) (*Organization, error) {
org := &Organization{}
err := m.db.Where(&Organization{ID: id}).Take(org).Error
if err != nil {
return nil, err
}
//m.UpdateACLRulesOfOrg(org)
return org, err
}
// ListOrgnaizations List all the organizations in the database, but it not to generate acl rules
func (m *Mirage) ListOrgnaizations() ([]Organization, error) {
var orgs []Organization
err := m.db.Find(&orgs).Error
if err != nil {
return nil, err
}
return orgs, err
}
func GetOrgnaizationByNameInTx(tx *gorm.DB, name, provider string) (*Organization, error) {
if len(name) == 0 || len(provider) == 0 {
return nil, ErrGetOrgParams
}
var org Organization
if err := tx.Where("name = ? and provider = ?", name, provider).Take(&org).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
err = ErrOrgNotFound
}
log.Debug().
Str("func", "GetOrgnaizationByName").
Err(err).
Msg("Could not get row")
return nil, err
}
return &org, nil
}
// 删除组织
// 输入: 组织名 provider名 是否强行删除(删除用户下的machine)
// 返回: 删除前用户数 删除后用户数 错误
func (m *Mirage) DestroyOrgnaization(orgName, provider string, force bool) (int, int, error) {
var orgID int64
m.db.Model(&Organization{}).Select("ID").Where("name = ? and provider = ?", orgName, provider).Find(&orgID)
if orgID == 0 {
return 0, 0, ErrOrgNotFound
}
var deleteFn func(int64) error
if !force {
deleteFn = m.DestroyUserByID
} else {
deleteFn = m.ForceDestroyUserByID
}
//逐个删除用户,遇到删除失败的继续删除下一个
var userIDs []int64
before := len(userIDs)
m.db.Model(&User{}).Select("ID").Where(&User{OrganizationID: orgID}).Find(&userIDs)
for _, uid := range userIDs {
deleteFn(uid)
}
//检查是否还有用户
if before > 0 {
userIDs = []int64{}
m.db.Model(&User{}).Select("ID").Where(&User{OrganizationID: orgID}).Find(&userIDs)
}
after := len(userIDs)
//没有用户则删除组织
if after == 0 {
err := m.db.Unscoped().Delete(&Organization{}, orgID).Error
if err != nil {
return before, after, errors.Join(ErrDeleteOrgFailed, err)
}
} else {
return before, after, ErrDeleteOrgCancelled
}
return before, after, nil
}
func (m *Mirage) UpdateOrgExpiry(user *User, newDuration uint) error {
err := m.db.Select("expiry_duration").Updates(&Organization{
ID: user.OrganizationID,
ExpiryDuration: newDuration,
}).Error
return err
}
func (m *Mirage) UpdateOrgDNSConfig(org *Organization, newDNSCfg DNSData) error {
org.EnableMagic = newDNSCfg.MagicDNS
org.Nameservers = make([]string, 0)
if len(newDNSCfg.Resolvers) > 0 {
org.OverrideLocal = true
org.Nameservers = newDNSCfg.Resolvers
} else if len(newDNSCfg.FallbackResolvers) > 0 {
org.OverrideLocal = false
org.Nameservers = newDNSCfg.FallbackResolvers
}
newSplitDns := SplitDNS{}
for _, domain := range newDNSCfg.Domains {
if ns, ok := newDNSCfg.Routes[domain]; ok {
newSplitDns = append(newSplitDns, SplitDNSItem{
Domain: domain,
NS: ns,
})
}
}
org.SplitDns = newSplitDns
err := m.db.Select("EnableMagic", "Nameservers", "OverrideLocal", "Nameservers", "OverrideLocal", "SplitDns").Updates(org).Error
return err
}