mirror of
https://github.com/MirageNetwork/MirageServer.git
synced 2025-09-26 20:41:34 +08:00

何川完成的autogroup:self以及autogroup:owner的处理 * 租户的aclPolicy的默认配置为nil * 修改:aclPolicy判空需要对Acls字段进行判断,空则浅复制之后,替换为全通再生成rules * 新增:ACLPolicy-autogroup:self * 使用map来判断autogroup * 处理autogroup self和owner * 修改:减少updateACLRules的无效次数 * 添加一点关于autogroup替换的注释 * 减少updateAclRules的引用处 * 在aclRules的生成函数中加入userID,以便于可以获取到和请求用户相关的信息&&调整autogroup:self的src acl生成 * autogroup:self 配置后,src只包含self解析出来的地址,并不会包含dest的所有地址 * 获取peers:添加peerCacheMap(同步HS修改)以及快速判断autogroup:self * 添加节点更新推送 * 租户内节点更新,通知其他节点进行更新netmap;获取LastStateChange不必排序,只需取最晚time * 新用户登录时候查询组织不存在的错误码替换为组织不存在,以便可以新建用户 * autogroup:self bug fix * merge main * 修改peerCache的生成和使用方式,不再遍历CIDR内所有ip * 将UpdateAclRule操作从getPeer中提出到getMapResponse中 * fix bug: updateAclRules之后没有同步更新到对应的machine上 * 抽取出关于autogroup:self的修改 * fix bug:self情况下peer加入要判断uid * acl expand alias: 调整autogroup到前面 * 租户建立时,默认添加一条全通ACL规则 * 租户初始化默认ACL添加一条全通 --------- Co-authored-by: chuanh <chuanh@opera.com> Co-authored-by: chuanhe <chuanhe.u90@gmail.com> Co-authored-by: Chenyang Gao <gps949@outlook.com>
792 lines
23 KiB
Go
792 lines
23 KiB
Go
package controller
|
||
|
||
import (
|
||
"context"
|
||
"embed"
|
||
"fmt"
|
||
"io"
|
||
"io/fs"
|
||
"net"
|
||
"net/http"
|
||
"sort"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/coreos/go-oidc/v3/oidc"
|
||
"github.com/dexidp/dex/server"
|
||
"github.com/gorilla/mux"
|
||
"github.com/patrickmn/go-cache"
|
||
"github.com/puzpuzpuz/xsync/v2"
|
||
"github.com/rs/zerolog/log"
|
||
"golang.org/x/oauth2"
|
||
"golang.org/x/sync/errgroup"
|
||
"gorm.io/gorm"
|
||
"tailscale.com/control/controlclient"
|
||
"tailscale.com/tailcfg"
|
||
"tailscale.com/types/key"
|
||
)
|
||
|
||
const (
|
||
// NoiseKeyPath = "noise.key"
|
||
DatabasePath = "db.sqlite"
|
||
// DexDBPath = "db.sqlite"
|
||
// DexDBType = "sqlite3"
|
||
AuthPrefix = "Bearer "
|
||
|
||
EphemeralNodeInactivityTimeout = 5 * time.Minute //不得低于65s
|
||
NodeUpdateCheckInterval = 10 * time.Second //不得大于60s
|
||
updateInterval = 5000
|
||
HTTPReadTimeout = 30 * time.Second
|
||
HTTPShutdownTimeout = 3 * time.Second
|
||
privateKeyFileMode = 0o600
|
||
|
||
smsCacheExpiration = time.Minute * 5
|
||
smsCacheCleanup = time.Minute * 5
|
||
)
|
||
|
||
// Mirage represents the base app of the service.
|
||
type Mirage struct {
|
||
cfg *Config
|
||
db *gorm.DB
|
||
ctx context.Context
|
||
cancel context.CancelFunc
|
||
|
||
noisePrivateKey *key.MachinePrivate
|
||
// DERPMap *tailcfg.DERPMap
|
||
DERPNCs map[string]*controlclient.NoiseClient
|
||
DERPseqnum map[string]int
|
||
|
||
aclPolicy *ACLPolicy
|
||
aclRules []tailcfg.FilterRule
|
||
sshPolicy *tailcfg.SSHPolicy
|
||
|
||
lastStateChange *xsync.MapOf[string, time.Time]
|
||
|
||
oidcProvider *oidc.Provider
|
||
oauth2Config *oauth2.Config
|
||
|
||
smsCodeCache *cache.Cache
|
||
|
||
aCodeCache *cache.Cache
|
||
stateCodeCache *cache.Cache
|
||
controlCodeCache *cache.Cache
|
||
machineControlCodeCache *cache.Cache
|
||
//organizationCache *cache.Cache
|
||
|
||
tcdCache *cache.Cache
|
||
|
||
longPollChanPool map[string]chan string
|
||
|
||
ipAllocationMutex sync.Mutex
|
||
|
||
shutdownChan chan struct{}
|
||
pollNetMapStreamWG sync.WaitGroup
|
||
}
|
||
|
||
func NewMirage(cfg *Config, db *gorm.DB) (*Mirage, error) {
|
||
// noisePrivateKey, err := readOrCreatePrivateKey(AbsolutePathFromConfigPath(NoiseKeyPath))
|
||
noisePrivateKey, err := getServerPrivateKey(db)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
|
||
}
|
||
|
||
//cgao6: 注册机制探索
|
||
aCodeCache := cache.New(0, 0)
|
||
stateCodeCache := cache.New(0, 0)
|
||
controlCodeCache := cache.New(0, 0)
|
||
machineControlCodeCache := cache.New(0, 0)
|
||
|
||
smsCodeCache := cache.New(
|
||
smsCacheExpiration,
|
||
smsCacheCleanup,
|
||
)
|
||
longPollChanPool := make(map[string]chan string, 0)
|
||
|
||
InitESLogger(cfg)
|
||
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
|
||
app := Mirage{
|
||
cfg: cfg,
|
||
db: db,
|
||
ctx: ctx,
|
||
cancel: cancel,
|
||
|
||
noisePrivateKey: noisePrivateKey,
|
||
DERPNCs: make(map[string]*controlclient.NoiseClient),
|
||
DERPseqnum: make(map[string]int),
|
||
aclRules: tailcfg.FilterAllowAll, // default allowall
|
||
|
||
aCodeCache: aCodeCache,
|
||
stateCodeCache: stateCodeCache,
|
||
controlCodeCache: controlCodeCache,
|
||
machineControlCodeCache: machineControlCodeCache,
|
||
tcdCache: cache.New(0, 0),
|
||
longPollChanPool: longPollChanPool,
|
||
smsCodeCache: smsCodeCache,
|
||
shutdownChan: make(chan struct{}),
|
||
pollNetMapStreamWG: sync.WaitGroup{},
|
||
lastStateChange: xsync.NewMapOf[time.Time](),
|
||
}
|
||
|
||
nrs := app.ListNaviRegions()
|
||
for _, nr := range nrs {
|
||
nns := app.ListNaviNodes(nr.ID)
|
||
for _, nn := range nns {
|
||
nckey := key.MachinePublic{}
|
||
nckey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(nn.NaviKey)))
|
||
nc, err := app.GetNaviNoiseClient(nckey, nn.HostName, nn.DERPPort)
|
||
if err != nil {
|
||
log.Error().Err(err).Msg("GetNaviNoiseClient Error: " + err.Error())
|
||
}
|
||
app.DERPNCs[nn.ID] = nc
|
||
app.DERPseqnum[nn.ID] = 0
|
||
}
|
||
}
|
||
|
||
/* 由于可能我们会使用内建的dex,所以这里可能并不能正确初始化OIDC
|
||
if cfg.OIDC.Issuer != "" {
|
||
err = app.initOIDC()
|
||
if err != nil {
|
||
log.Warn().Err(err).Msg("failed to set up OIDC provider, falling back to CLI based authentication")
|
||
}
|
||
}
|
||
*/
|
||
|
||
return &app, nil
|
||
}
|
||
|
||
// expireEphemeralNodes deletes ephemeral machine records that have not been
|
||
// seen for longer than h.cfg.EphemeralNodeInactivityTimeout.
|
||
func (h *Mirage) expireEphemeralNodes(ticker *time.Ticker) { //milliSeconds int64) {
|
||
//ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
||
for range ticker.C {
|
||
h.expireEphemeralNodesWorker()
|
||
}
|
||
}
|
||
|
||
// expireExpiredMachines expires machines that have an explicit expiry set
|
||
// after that expiry time has passed.
|
||
func (h *Mirage) expireExpiredMachines(ticker *time.Ticker) { //milliSeconds int64) {
|
||
//ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
||
for range ticker.C {
|
||
h.expireExpiredMachinesWorker()
|
||
}
|
||
}
|
||
|
||
func (h *Mirage) failoverSubnetRoutes(ticker *time.Ticker) { //milliSeconds int64) {
|
||
//ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
||
for range ticker.C {
|
||
err := h.handlePrimarySubnetFailover()
|
||
if err != nil {
|
||
log.Error().Err(err).Msg("failed to handle primary subnet failover")
|
||
}
|
||
}
|
||
}
|
||
|
||
func (h *Mirage) expireEphemeralNodesWorker() {
|
||
users, err := h.ListUsers()
|
||
if err != nil {
|
||
log.Error().Err(err).Msg("Error listing users")
|
||
|
||
return
|
||
}
|
||
|
||
for _, user := range users {
|
||
machines, err := h.ListMachinesByUser(user.ID)
|
||
if err != nil {
|
||
log.Error().
|
||
Err(err).
|
||
Str("user", user.Name).
|
||
Msg("Error listing machines in user")
|
||
|
||
return
|
||
}
|
||
|
||
expiredFound := false
|
||
for _, machine := range machines {
|
||
if machine.isEphemeral() && machine.LastSeen != nil &&
|
||
time.Now().
|
||
After(machine.LastSeen.Add(EphemeralNodeInactivityTimeout)) {
|
||
expiredFound = true
|
||
log.Info().
|
||
Str("machine", machine.Hostname).
|
||
Msg("Ephemeral client removed from database")
|
||
|
||
err = h.db.Unscoped().Delete(machine).Error
|
||
if err != nil {
|
||
log.Error().
|
||
Err(err).
|
||
Str("machine", machine.Hostname).
|
||
Msg("🤮 Cannot delete ephemeral machine from the database")
|
||
}
|
||
|
||
// TODO: 并不好的处理
|
||
h.NotifyNaviOrgNodesChange(machine.User.OrganizationID, "", machine.NodeKey)
|
||
}
|
||
}
|
||
|
||
if expiredFound {
|
||
h.setOrgLastStateChangeToNow(user.OrganizationID)
|
||
}
|
||
}
|
||
}
|
||
|
||
func (h *Mirage) expireExpiredMachinesWorker() {
|
||
orgs, err := h.ListOrgnaizations()
|
||
if err != nil {
|
||
log.Error().Err(err).Msg("Error listing organizations")
|
||
|
||
return
|
||
}
|
||
orgChangeSet := NewUtilsSet[int64]()
|
||
for _, org := range orgs {
|
||
users, err := h.ListOrgUsers(org.ID)
|
||
if err != nil {
|
||
log.Error().Err(err).Msg("Error listing users")
|
||
|
||
return
|
||
}
|
||
for _, user := range users {
|
||
machines, err := h.ListMachinesByUser(user.ID)
|
||
if err != nil {
|
||
log.Error().
|
||
Err(err).
|
||
Str("user", user.Name).
|
||
Msg("Error listing machines in user")
|
||
|
||
return
|
||
}
|
||
|
||
for index, machine := range machines {
|
||
if machine.isExpired() &&
|
||
machine.Expiry.After(h.getOrgLastStateChange(user.OrganizationID)) {
|
||
orgChangeSet.SetKey(org.ID)
|
||
|
||
err := h.ExpireMachine(&machines[index])
|
||
if err != nil {
|
||
log.Error().
|
||
Err(err).
|
||
Str("machine", machine.Hostname).
|
||
Str("name", machine.GivenName).
|
||
Msg("🤮 Cannot expire machine")
|
||
} else {
|
||
log.Info().
|
||
Str("machine", machine.Hostname).
|
||
Str("name", machine.GivenName).
|
||
Msg("Machine successfully expired")
|
||
}
|
||
|
||
h.NotifyNaviOrgNodesChange(machine.User.OrganizationID, "", machine.NodeKey)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
changedOrgList := orgChangeSet.GetKeys()
|
||
if len(changedOrgList) > 0 {
|
||
h.setOrgLastStateChangeToNow(changedOrgList...)
|
||
}
|
||
}
|
||
|
||
//go:embed console_html/admin
|
||
var adminFS embed.FS
|
||
|
||
//go:embed console_html
|
||
var mainpageFS embed.FS
|
||
|
||
//go:embed console_html/login
|
||
var loginFS embed.FS
|
||
|
||
////go:embed console_html/downloads
|
||
//var downloadsFS embed.FS
|
||
|
||
func (h *Mirage) initRouter(router *mux.Router) {
|
||
|
||
adminDir, err := fs.Sub(adminFS, "console_html/admin")
|
||
if err != nil {
|
||
log.Fatal().Msg(err.Error())
|
||
}
|
||
loginDir, err := fs.Sub(loginFS, "console_html/login")
|
||
if err != nil {
|
||
log.Fatal().Msg(err.Error())
|
||
}
|
||
mainpageDir, err := fs.Sub(mainpageFS, "console_html")
|
||
if err != nil {
|
||
log.Fatal().Msg(err.Error())
|
||
}
|
||
|
||
// downloadsDir, err := fs.Sub(downloadsFS, "console_html/downloads")
|
||
// if err != nil {
|
||
// log.Fatal().Msg(err.Error())
|
||
// }
|
||
// router.PathPrefix("/downloads").Handler(http.StripPrefix("/downloads", http.FileServer(http.FS(downloadsDir))))
|
||
router.PathPrefix("/downloads").HandlerFunc(h.sendDownloadsPage).Methods(http.MethodGet)
|
||
|
||
router.PathPrefix("/download").Handler(http.StripPrefix("/download", http.FileServer(http.Dir("download"))))
|
||
|
||
router.HandleFunc("/logout", h.ConsoleLogout).Methods(http.MethodGet)
|
||
//注册
|
||
router.PathPrefix("/api/register").HandlerFunc(h.RegisterUserAPI).Methods(http.MethodPost)
|
||
router.PathPrefix("/api/idps").HandlerFunc(h.ListIdps).Methods(http.MethodGet)
|
||
|
||
//登录
|
||
router.PathPrefix("/login").HandlerFunc(h.doLogin).Methods(http.MethodPost)
|
||
router.PathPrefix("/wxmini").HandlerFunc(h.checkWXMini).Methods(http.MethodPost)
|
||
login_router := router.PathPrefix("/login").Subrouter()
|
||
login_router.Use(h.loginMidware)
|
||
login_router.PathPrefix("").Handler(http.StripPrefix("/login", http.FileServer(http.FS(loginDir))))
|
||
|
||
//cgao6: APage也算是控制台中的一环,逻辑类似
|
||
//对于特殊路径"/a/oauth_response"是login到第三方后验证通过的回写token逻辑
|
||
router.HandleFunc("/a/oauth_response", h.oauthResponse).Methods(http.MethodGet)
|
||
router.HandleFunc("/a/{aCode}", h.deviceRegPortal).Methods(http.MethodGet)
|
||
router.HandleFunc("/a/oauth_response", h.selectOrgForLogin).Methods(http.MethodPost)
|
||
router.HandleFunc("/a/{aCode}", h.deviceReg).Methods(http.MethodPost)
|
||
|
||
// 控制台所需的全部API接口(由APIAuth身份验证放行)
|
||
api_router := router.PathPrefix("/admin/api").Subrouter()
|
||
api_router.Use(h.APIAuth)
|
||
|
||
// 控制台页面全部路由(由ConsoleAuth身份验证放行),与前一个和最后面的juanfont的API接口之后要进行统一
|
||
console_router := router.PathPrefix("/admin").Subrouter()
|
||
console_router.Use(h.ConsoleAuth)
|
||
|
||
// GET(查询类)API
|
||
console_router.HandleFunc("/api/self", h.ConsoleSelfAPI).Methods(http.MethodGet)
|
||
console_router.HandleFunc("/api/users", h.CAPIGetUsers).Methods(http.MethodGet)
|
||
console_router.HandleFunc("/api/machines", h.ConsoleMachinesAPI).Methods(http.MethodGet)
|
||
console_router.HandleFunc("/api/machine-debug", h.ConsoleMachineDebugAPI).Methods(http.MethodGet)
|
||
console_router.HandleFunc("/api/dns", h.CAPIGetDNS).Methods(http.MethodGet)
|
||
console_router.HandleFunc("/api/tcd/offers", h.CAPIGetTCDOffers).Methods(http.MethodGet)
|
||
console_router.HandleFunc("/api/netsettings", h.getNetSettingAPI).Methods(http.MethodGet)
|
||
console_router.HandleFunc("/api/keys", h.CAPIGetKeys).Methods(http.MethodGet)
|
||
console_router.HandleFunc("/api/acls/tags", h.CAPIGetTags).Methods(http.MethodGet)
|
||
console_router.HandleFunc("/api/subscription", h.CAPIGetSubscription).Methods(http.MethodGet)
|
||
console_router.HandleFunc("/api/derp/query", h.CAPIQueryDERP).Methods(http.MethodGet)
|
||
|
||
// POST(更新类)API
|
||
console_router.HandleFunc("/api/users", h.CAPIPostUsers).Methods(http.MethodPost)
|
||
console_router.HandleFunc("/api/machines", h.ConsoleMachinesUpdateAPI).Methods(http.MethodPost)
|
||
console_router.HandleFunc("/api/machine/remove", h.ConsoleRemoveMachineAPI).Methods(http.MethodPost)
|
||
console_router.HandleFunc("/api/netsetting/updatekeyexpiry", h.ConsoleUpdateKeyExpiryAPI).Methods(http.MethodPost)
|
||
console_router.HandleFunc("/api/keys", h.CAPIPostKeys).Methods(http.MethodPost)
|
||
console_router.HandleFunc("/api/acls/tags", h.CAPIPostTags).Methods(http.MethodPost)
|
||
console_router.HandleFunc("/api/dns", h.CAPIPostDNS).Methods(http.MethodPost)
|
||
console_router.HandleFunc("/api/tcd", h.CAPIPostTCD).Methods(http.MethodPost)
|
||
console_router.HandleFunc("/api/derp/add", h.CAPIAddDERP).Methods(http.MethodPost)
|
||
console_router.HandleFunc("/api/derp/ban/{id}", h.CAPISwitchRegionBan).Methods(http.MethodPost)
|
||
|
||
// DELETE(删除类)API
|
||
console_router.PathPrefix("/api/keys/").HandlerFunc(h.CAPIDelKeys).Methods(http.MethodDelete)
|
||
console_router.PathPrefix("/api/acls/tags/").HandlerFunc(h.CAPIDelTags).Methods(http.MethodDelete)
|
||
console_router.PathPrefix("/api/derp/{id}").HandlerFunc(h.CAPIDelNaviNode).Methods(http.MethodDelete)
|
||
|
||
// TODO: 登出及页面转至VUE,要考虑logout是否有必要发消息给服务端
|
||
//cgao6: 改成不需检查登录信息 console_router.HandleFunc("/logout", h.ConsoleLogout).Methods(http.MethodGet)
|
||
console_router.PathPrefix("").Handler(http.StripPrefix("/admin", http.FileServer(http.FS(adminDir))))
|
||
|
||
// 核心与客户端通信协议,不动
|
||
router.HandleFunc(ts2021UpgradePath, h.NoiseUpgradeHandler).Methods(http.MethodPost)
|
||
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
|
||
|
||
// dex错误前端处理页面
|
||
router.HandleFunc("/dexerr", h.DexErrHandler).Methods(http.MethodGet)
|
||
|
||
// 资源目录们
|
||
router.PathPrefix("/img/").Handler(http.StripPrefix("/", http.FileServer(http.FS(mainpageDir))))
|
||
router.PathPrefix("/assets/").Handler(http.StripPrefix("/", http.FileServer(http.FS(mainpageDir))))
|
||
// 其余全部默认返回主页
|
||
router.Path("/").Handler(http.FileServer(http.FS(mainpageDir)))
|
||
|
||
router.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
h.ErrMessage(w, r, 404, "你迷失在蜃境中了吗?这里什么都没有")
|
||
})
|
||
}
|
||
|
||
// Serve launches a GIN server with the Mirage API.
|
||
func (h *Mirage) Serve(ctrlChn chan CtrlMsg) error {
|
||
var err error
|
||
|
||
// Fetch an initial DERP Map before we start serving
|
||
// h.DERPMap, err = h.LoadDERPMapFromURL(h.cfg.DERPURL)
|
||
// if err != nil {
|
||
// return err
|
||
// }
|
||
|
||
ticker := time.NewTicker(time.Millisecond * updateInterval)
|
||
defer ticker.Stop()
|
||
longTicker := time.NewTicker(time.Millisecond * updateInterval * 6)
|
||
defer longTicker.Stop()
|
||
|
||
go h.expireEphemeralNodes(ticker) //updateInterval)
|
||
go h.expireExpiredMachines(ticker) //updateInterval)
|
||
go h.failoverSubnetRoutes(ticker) //updateInterval)
|
||
go h.refreshNaviStatusPoller(longTicker)
|
||
|
||
// Prepare group for running listeners
|
||
errorGroup := new(errgroup.Group)
|
||
|
||
//
|
||
//
|
||
// HTTP setup
|
||
//
|
||
// This is the regular router that we expose
|
||
// over our main Addr. It also serves the legacy Tailcale API
|
||
router := mux.NewRouter()
|
||
|
||
_, err = server.InitDexServer(h.ctx, *h.cfg.DexConfig, router) //cgao6: 这里是dex的初始化
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer h.cfg.DexConfig.Storage.Close()
|
||
|
||
h.initRouter(router)
|
||
|
||
httpServer := &http.Server{
|
||
Addr: h.cfg.Addr,
|
||
Handler: router,
|
||
ReadTimeout: HTTPReadTimeout,
|
||
// Go does not handle timeouts in HTTP very well, and there is
|
||
// no good way to handle streaming timeouts, therefore we need to
|
||
// keep this at unlimited and be careful to clean up connections
|
||
// https://blog.cloudflare.com/the-complete-guide-to-golang-net-http-timeouts/#aboutstreaming
|
||
WriteTimeout: 0,
|
||
}
|
||
|
||
var httpListener net.Listener
|
||
httpListener, err = net.Listen("tcp", h.cfg.Addr)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to bind to TCP address: %w", err)
|
||
}
|
||
|
||
errorGroup.Go(func() error { return httpServer.Serve(httpListener) })
|
||
|
||
log.Info().
|
||
Msgf("listening and serving HTTP on: %s", h.cfg.Addr)
|
||
|
||
if err != nil {
|
||
return fmt.Errorf("failed to bind to TCP address: %w", err)
|
||
}
|
||
|
||
ctrlFunc := func(c chan CtrlMsg) {
|
||
for {
|
||
msg := <-c
|
||
switch msg.Msg {
|
||
case "stop":
|
||
log.Info().Msg("Received stop message, shutting down")
|
||
close(h.shutdownChan)
|
||
h.pollNetMapStreamWG.Wait()
|
||
// Gracefully shut down servers
|
||
ctx, cancel := context.WithTimeout(
|
||
context.Background(),
|
||
HTTPShutdownTimeout,
|
||
)
|
||
// Shutdown http server
|
||
if err := httpServer.Shutdown(ctx); err != nil {
|
||
log.Error().Err(err).Msg("Failed to shutdown http")
|
||
}
|
||
// Close network listeners
|
||
err = httpListener.Close()
|
||
if err != nil {
|
||
log.Error().Err(err).Msg("Failed to close http listener")
|
||
}
|
||
|
||
h.cancel() // ??
|
||
/*
|
||
// Close db connections
|
||
db, err := h.db.DB()
|
||
if err != nil {
|
||
log.Error().Err(err).Msg("Failed to get db handle")
|
||
}
|
||
err = db.Close()
|
||
if err != nil {
|
||
log.Error().Err(err).Msg("Failed to close db")
|
||
}
|
||
*/
|
||
log.Info().Msg("Mirage stopped")
|
||
cancel()
|
||
return
|
||
case "update-config":
|
||
log.Info().Msg("Received update-config message, updating config")
|
||
h.cfg = msg.SysCfg
|
||
case "set-last-update":
|
||
log.Info().Msg("Received set-last-update message, updating last update time")
|
||
h.setLastStateChangeToNow()
|
||
}
|
||
}
|
||
}
|
||
errorGroup.Go(func() error {
|
||
ctrlFunc(ctrlChn)
|
||
return nil
|
||
})
|
||
|
||
/*
|
||
// Handle common process-killing signals so we can gracefully shut down:
|
||
h.shutdownChan = make(chan struct{})
|
||
sigc := make(chan os.Signal, 1)
|
||
signal.Notify(sigc,
|
||
syscall.SIGHUP,
|
||
syscall.SIGINT,
|
||
syscall.SIGTERM,
|
||
syscall.SIGQUIT,
|
||
syscall.SIGHUP)
|
||
sigFunc := func(c chan os.Signal) {
|
||
// Wait for a SIGINT or SIGKILL:
|
||
for {
|
||
sig := <-c
|
||
switch sig {
|
||
case syscall.SIGHUP:
|
||
log.Info().
|
||
Str("signal", sig.String()).
|
||
Msg("Received SIGHUP, reloading ACL and Config")
|
||
|
||
// TODO(kradalby): Reload config on SIGHUP
|
||
|
||
aclPath := AbsolutePathFromConfigPath(ACLPath)
|
||
err := h.LoadACLPolicy(aclPath)
|
||
if err != nil {
|
||
log.Error().Err(err).Msg("Failed to reload ACL policy")
|
||
}
|
||
log.Info().
|
||
Str("path", aclPath).
|
||
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
||
|
||
h.setLastStateChangeToNow()
|
||
default:
|
||
log.Info().
|
||
Str("signal", sig.String()).
|
||
Msg("Received signal to stop, shutting down gracefully")
|
||
|
||
close(h.shutdownChan)
|
||
h.pollNetMapStreamWG.Wait()
|
||
|
||
// Gracefully shut down servers
|
||
ctx, cancel := context.WithTimeout(
|
||
context.Background(),
|
||
HTTPShutdownTimeout,
|
||
)
|
||
if err := httpServer.Shutdown(ctx); err != nil {
|
||
log.Error().Err(err).Msg("Failed to shutdown http")
|
||
}
|
||
|
||
// Close network listeners
|
||
httpListener.Close()
|
||
|
||
// Close db connections
|
||
db, err := h.db.DB()
|
||
if err != nil {
|
||
log.Error().Err(err).Msg("Failed to get db handle")
|
||
}
|
||
err = db.Close()
|
||
if err != nil {
|
||
log.Error().Err(err).Msg("Failed to close db")
|
||
}
|
||
|
||
log.Info().
|
||
Msg("Mirage stopped")
|
||
|
||
// And we're done:
|
||
cancel()
|
||
os.Exit(0)
|
||
}
|
||
}
|
||
}
|
||
errorGroup.Go(func() error {
|
||
sigFunc(sigc)
|
||
|
||
return nil
|
||
})
|
||
*/
|
||
|
||
return errorGroup.Wait()
|
||
}
|
||
|
||
func (h *Mirage) setLastStateChangeToNow() {
|
||
var err error
|
||
|
||
now := time.Now().UTC()
|
||
|
||
users, err := h.ListUsers()
|
||
if err != nil {
|
||
log.Error().
|
||
Caller().
|
||
Err(err).
|
||
Msg("failed to fetch all users, failing to update last changed state.")
|
||
}
|
||
|
||
for _, user := range users {
|
||
if h.lastStateChange == nil {
|
||
h.lastStateChange = xsync.NewMapOf[time.Time]()
|
||
}
|
||
h.lastStateChange.Store(user.StableID, now)
|
||
}
|
||
}
|
||
|
||
func (h *Mirage) setOrgLastStateChangeToNow(orgId ...int64) {
|
||
var err error
|
||
var users []User
|
||
now := time.Now().UTC()
|
||
|
||
if len(orgId) == 1 {
|
||
users, err = h.ListOrgUsers(orgId[0])
|
||
} else {
|
||
users, err = h.ListUsersInOrgs(orgId)
|
||
}
|
||
if err != nil {
|
||
log.Error().
|
||
Caller().
|
||
Err(err).
|
||
Msg("failed to fetch organization users, failing to update last changed state.")
|
||
}
|
||
|
||
for _, user := range users {
|
||
if h.lastStateChange == nil {
|
||
h.lastStateChange = xsync.NewMapOf[time.Time]()
|
||
}
|
||
h.lastStateChange.Store(user.StableID, now)
|
||
}
|
||
}
|
||
|
||
func (h *Mirage) getLastStateChange(users ...User) time.Time {
|
||
times := []time.Time{}
|
||
|
||
// getLastStateChange takes a list of users as a "filter", if no users
|
||
// are past, then use the entier list of users and look for the last update
|
||
if len(users) > 0 {
|
||
for _, user := range users {
|
||
if lastChange, ok := h.lastStateChange.Load(user.StableID); ok {
|
||
times = append(times, lastChange)
|
||
}
|
||
}
|
||
} else {
|
||
h.lastStateChange.Range(func(key string, value time.Time) bool {
|
||
times = append(times, value)
|
||
|
||
return true
|
||
})
|
||
}
|
||
|
||
sort.Slice(times, func(i, j int) bool {
|
||
return times[i].After(times[j])
|
||
})
|
||
|
||
if len(times) == 0 {
|
||
return time.Now().UTC()
|
||
} else {
|
||
return times[0]
|
||
}
|
||
}
|
||
|
||
func (h *Mirage) getOrgLastStateChange(orgId int64) time.Time {
|
||
times := []time.Time{}
|
||
|
||
// getLastStateChange takes a list of users as a "filter", if no users
|
||
// are past, then use the entier list of users and look for the last update
|
||
users, err := h.ListOrgUsers(orgId)
|
||
if err != nil {
|
||
log.Error().
|
||
Caller().
|
||
Err(err).
|
||
Msg("failed to fetch organization users, failing to get last changed state.")
|
||
}
|
||
for _, user := range users {
|
||
if lastChange, ok := h.lastStateChange.Load(user.StableID); ok {
|
||
if len(times) != 0 {
|
||
if times[0].Before(lastChange) {
|
||
times[0] = lastChange
|
||
}
|
||
} else {
|
||
times = append(times, lastChange)
|
||
}
|
||
}
|
||
}
|
||
|
||
if len(times) == 0 {
|
||
return time.Now().UTC()
|
||
}
|
||
/*
|
||
sort.Slice(times, func(i, j int) bool {
|
||
return times[i].After(times[j])
|
||
})
|
||
*/
|
||
|
||
return times[0]
|
||
}
|
||
|
||
func stdoutHandler(
|
||
writer http.ResponseWriter,
|
||
req *http.Request,
|
||
) {
|
||
body, _ := io.ReadAll(req.Body)
|
||
|
||
log.Trace().
|
||
Interface("header", req.Header).
|
||
Interface("proto", req.Proto).
|
||
Interface("url", req.URL).
|
||
Bytes("body", body).
|
||
Msg("Request did not match")
|
||
}
|
||
|
||
func getServerPrivateKey(db *gorm.DB) (*key.MachinePrivate, error) {
|
||
var sysCfg SysConfig
|
||
err := db.First(&sysCfg).Error
|
||
if err != nil || sysCfg.ServerKey == "" {
|
||
return nil, fmt.Errorf("failed to get server private key: %w", err)
|
||
}
|
||
|
||
var machineKey key.MachinePrivate
|
||
if err = machineKey.UnmarshalText([]byte(sysCfg.ServerKey)); err != nil {
|
||
log.Error().
|
||
Caller().
|
||
Msg("Convert db server key string to machinekey failed: " + err.Error())
|
||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||
}
|
||
return &machineKey, nil
|
||
}
|
||
|
||
/*
|
||
func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
||
privateKey, err := os.ReadFile(path)
|
||
if errors.Is(err, os.ErrNotExist) {
|
||
log.Info().Str("path", path).Msg("No private key file at path, creating...")
|
||
|
||
machineKey := key.NewMachine()
|
||
|
||
machineKeyStr, err := machineKey.MarshalText()
|
||
if err != nil {
|
||
return nil, fmt.Errorf(
|
||
"failed to convert private key to string for saving: %w",
|
||
err,
|
||
)
|
||
}
|
||
err = os.WriteFile(path, machineKeyStr, privateKeyFileMode)
|
||
if err != nil {
|
||
return nil, fmt.Errorf(
|
||
"failed to save private key to disk: %w",
|
||
err,
|
||
)
|
||
}
|
||
|
||
return &machineKey, nil
|
||
} else if err != nil {
|
||
return nil, fmt.Errorf("failed to read private key file: %w", err)
|
||
}
|
||
|
||
trimmedPrivateKey := strings.TrimSpace(string(privateKey))
|
||
privateKeyEnsurePrefix := PrivateKeyEnsurePrefix(trimmedPrivateKey)
|
||
|
||
var machineKey key.MachinePrivate
|
||
if err = machineKey.UnmarshalText([]byte(privateKeyEnsurePrefix)); err != nil {
|
||
log.Info().
|
||
Str("path", path).
|
||
Msg("This might be due to a legacy (mirage pre-0.12) private key. " +
|
||
"If the key is in WireGuard format, delete the key and restart mirage. " +
|
||
"A new key will automatically be generated. All Tailscale clients will have to be restarted")
|
||
|
||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||
}
|
||
|
||
return &machineKey, nil
|
||
}
|
||
*/
|