diff --git a/auth/host_session.go b/auth/host_session.go index 4de8d663..d364ec5f 100644 --- a/auth/host_session.go +++ b/auth/host_session.go @@ -67,7 +67,7 @@ func SessionHandler(conn *websocket.Conn) { if len(registerMessage.User) > 0 { // handle basic auth logger.Log(0, "user registration attempted with host:", registerMessage.RegisterHost.Name, "user:", registerMessage.User) - if !servercfg.IsBasicAuthEnabled() { + if !logic.IsBasicAuthEnabled() { err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) if err != nil { logger.Log(0, "error during message writing:", err.Error()) @@ -207,7 +207,7 @@ func SessionHandler(conn *websocket.Conn) { netsToAdd = append(netsToAdd, newNet) } } - server := servercfg.GetServerInfo() + server := logic.GetServerInfo() server.TrafficKey = key result.Host.HostPass = "" response := models.RegisterResponse{ diff --git a/config/config.go b/config/config.go index 3a625a2d..0c42227b 100644 --- a/config/config.go +++ b/config/config.go @@ -75,7 +75,6 @@ type ServerConfig struct { NetmakerTenantID string `yaml:"netmaker_tenant_id"` IsPro string `yaml:"is_ee" json:"IsEE"` StunPort int `yaml:"stun_port"` - StunList string `yaml:"stun_list"` TurnServer string `yaml:"turn_server"` TurnApiServer string `yaml:"turn_api_server"` TurnPort int `yaml:"turn_port"` diff --git a/controllers/controller.go b/controllers/controller.go index 93e8bba3..5c97505c 100644 --- a/controllers/controller.go +++ b/controllers/controller.go @@ -3,6 +3,7 @@ package controller import ( "context" "fmt" + "github.com/gravitl/netmaker/db" "net/http" "os" "strings" @@ -18,6 +19,7 @@ import ( // HttpMiddlewares - middleware functions for REST interactions var HttpMiddlewares = []mux.MiddlewareFunc{ + db.Middleware, userMiddleWare, } diff --git a/controllers/dns.go b/controllers/dns.go index c6e08deb..257a683a 100644 --- a/controllers/dns.go +++ b/controllers/dns.go @@ -164,9 +164,9 @@ func createDNS(w http.ResponseWriter, r *http.Request) { return } // check if default domain is appended if not append - if servercfg.GetDefaultDomain() != "" && - !strings.HasSuffix(entry.Name, servercfg.GetDefaultDomain()) { - entry.Name += "." + servercfg.GetDefaultDomain() + if logic.GetDefaultDomain() != "" && + !strings.HasSuffix(entry.Name, logic.GetDefaultDomain()) { + entry.Name += "." + logic.GetDefaultDomain() } entry, err = logic.CreateDNS(entry) if err != nil { @@ -185,7 +185,7 @@ func createDNS(w http.ResponseWriter, r *http.Request) { } } - if servercfg.GetManageDNS() { + if logic.GetManageDNS() { mq.SendDNSSyncByNetwork(netID) } @@ -230,7 +230,7 @@ func deleteDNS(w http.ResponseWriter, r *http.Request) { } } - if servercfg.GetManageDNS() { + if logic.GetManageDNS() { mq.SendDNSSyncByNetwork(netID) } @@ -293,7 +293,7 @@ func pushDNS(w http.ResponseWriter, r *http.Request) { func syncDNS(w http.ResponseWriter, r *http.Request) { // Set header w.Header().Set("Content-Type", "application/json") - if !servercfg.GetManageDNS() { + if !logic.GetManageDNS() { logic.ReturnErrorResponse( w, r, diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index 1da667f3..bcf6fb7f 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -349,7 +349,7 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) { } } // ready the response - server := servercfg.GetServerInfo() + server := logic.GetServerInfo() server.TrafficKey = key response := models.RegisterResponse{ ServerConf: server, diff --git a/controllers/hosts.go b/controllers/hosts.go index e389a364..05bf27be 100644 --- a/controllers/hosts.go +++ b/controllers/hosts.go @@ -209,7 +209,7 @@ func pull(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - serverConf := servercfg.GetServerInfo() + serverConf := logic.GetServerInfo() key, keyErr := logic.RetrievePublicTrafficKey() if keyErr != nil { logger.Log(0, "error retrieving key:", keyErr.Error()) @@ -230,7 +230,7 @@ func pull(w http.ResponseWriter, r *http.Request) { ChangeDefaultGw: hPU.ChangeDefaultGw, DefaultGwIp: hPU.DefaultGwIp, IsInternetGw: hPU.IsInternetGw, - EndpointDetection: servercfg.IsEndpointDetectionEnabled(), + EndpointDetection: logic.IsEndpointDetectionEnabled(), } logger.Log(1, hostID, "completed a pull") diff --git a/controllers/migrate.go b/controllers/migrate.go index 53bbbc6d..4a836a39 100644 --- a/controllers/migrate.go +++ b/controllers/migrate.go @@ -70,7 +70,7 @@ func migrate(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } - server = servercfg.GetServerInfo() + server = logic.GetServerInfo() key, keyErr := logic.RetrievePublicTrafficKey() if keyErr != nil { slog.Error("retrieving traffickey", "error", err) @@ -134,7 +134,7 @@ func convertLegacyHostNode(legacy models.LegacyNode) (models.Host, models.Node) host := models.Host{} host.ID = uuid.New() host.IPForwarding = models.ParseBool(legacy.IPForwarding) - host.AutoUpdate = servercfg.AutoUpdateEnabled() + host.AutoUpdate = logic.AutoUpdateEnabled() host.Interface = "netmaker" host.ListenPort = int(legacy.ListenPort) if host.ListenPort == 0 { diff --git a/controllers/network.go b/controllers/network.go index 3e0c9fde..64109054 100644 --- a/controllers/network.go +++ b/controllers/network.go @@ -588,8 +588,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) { logic.CreateDefaultNetworkRolesAndGroups(models.NetworkID(network.NetID)) logic.CreateDefaultAclNetworkPolicies(models.NetworkID(network.NetID)) logic.CreateDefaultTags(models.NetworkID(network.NetID)) - - go logic.AddNetworkToAllocatedIpMap(network.NetID) + logic.AddNetworkToAllocatedIpMap(network.NetID) go func() { defaultHosts := logic.GetDefaultHosts() diff --git a/controllers/node.go b/controllers/node.go index 13c8874a..ee98165a 100644 --- a/controllers/node.go +++ b/controllers/node.go @@ -477,7 +477,7 @@ func getNode(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - server := servercfg.GetServerInfo() + server := logic.GetServerInfo() response := models.NodeGet{ Node: node, Host: *host, diff --git a/controllers/server.go b/controllers/server.go index d41937f5..2b5f7583 100644 --- a/controllers/server.go +++ b/controllers/server.go @@ -2,6 +2,7 @@ package controller import ( "encoding/json" + "errors" "net/http" "os" "strings" @@ -12,6 +13,7 @@ import ( "golang.org/x/exp/slog" "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" @@ -41,6 +43,10 @@ func serverHandlers(r *mux.Router) { ).Methods(http.MethodPost) r.HandleFunc("/api/server/getconfig", allowUsers(http.HandlerFunc(getConfig))). Methods(http.MethodGet) + r.HandleFunc("/api/server/settings", allowUsers(http.HandlerFunc(getSettings))). + Methods(http.MethodGet) + r.HandleFunc("/api/server/settings", logic.SecurityCheck(true, http.HandlerFunc(updateSettings))). + Methods(http.MethodPut) r.HandleFunc("/api/server/getserverinfo", logic.SecurityCheck(true, http.HandlerFunc(getServerInfo))). Methods(http.MethodGet) r.HandleFunc("/api/server/status", getStatus).Methods(http.MethodGet) @@ -207,7 +213,7 @@ func getServerInfo(w http.ResponseWriter, r *http.Request) { // get params - json.NewEncoder(w).Encode(servercfg.GetServerInfo()) + json.NewEncoder(w).Encode(logic.GetServerInfo()) // w.WriteHeader(http.StatusOK) } @@ -222,7 +228,7 @@ func getConfig(w http.ResponseWriter, r *http.Request) { // get params - scfg := servercfg.GetServerConfig() + scfg := logic.GetServerConfig() scfg.IsPro = "no" if servercfg.IsPro { scfg.IsPro = "yes" @@ -230,3 +236,66 @@ func getConfig(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(scfg) // w.WriteHeader(http.StatusOK) } + +// @Summary Get the server settings +// @Router /api/server/settings [get] +// @Tags Server +// @Security oauth2 +// @Success 200 {object} config.ServerSettings +func getSettings(w http.ResponseWriter, r *http.Request) { + scfg := logic.GetServerSettings() + scfg.ClientSecret = logic.Mask() + logic.ReturnSuccessResponseWithJson(w, r, scfg, "fetched server settings successfully") +} + +// @Summary Update the server settings +// @Router /api/server/settings [put] +// @Tags Server +// @Security oauth2 +// @Success 200 {object} config.ServerSettings +func updateSettings(w http.ResponseWriter, r *http.Request) { + var req models.ServerSettings + force := r.URL.Query().Get("force") + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + if !logic.ValidateNewSettings(req) { + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid settings"), "badrequest")) + return + } + currSettings := logic.GetServerSettings() + err := logic.UpsertServerSettings(req) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("failed to udpate server settings "+err.Error()), "internal")) + return + } + go reInit(currSettings, req, force == "true") + logic.ReturnSuccessResponseWithJson(w, r, req, "updated server settings successfully") +} + +func reInit(curr, new models.ServerSettings, force bool) { + logic.SettingsMutex.Lock() + defer logic.SettingsMutex.Unlock() + logic.InitializeAuthProvider() + logic.EmailInit() + logic.SetVerbosity(int(logic.GetServerSettings().Verbosity)) + // check if auto update is changed + if force { + if curr.NetclientAutoUpdate != new.NetclientAutoUpdate { + // update all hosts + hosts, _ := logic.GetAllHosts() + for _, host := range hosts { + host.AutoUpdate = new.NetclientAutoUpdate + logic.UpsertHost(&host) + mq.HostUpdate(&models.HostUpdate{ + Action: models.UpdateHost, + Host: host, + }) + } + } + } + go mq.PublishPeerUpdate(false) + +} diff --git a/controllers/user.go b/controllers/user.go index 55b79699..8ee3c08a 100644 --- a/controllers/user.go +++ b/controllers/user.go @@ -217,16 +217,6 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) { var errorResponse = models.ErrorResponse{ Code: http.StatusInternalServerError, Message: "W1R3: It's not you it's me.", } - - if !servercfg.IsBasicAuthEnabled() { - logic.ReturnErrorResponse( - response, - request, - logic.FormatError(fmt.Errorf("basic auth is disabled"), "badrequest"), - ) - return - } - decoder := json.NewDecoder(request.Body) decoderErr := decoder.Decode(&authRequest) defer request.Body.Close() @@ -236,15 +226,29 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) { logic.ReturnErrorResponse(response, request, errorResponse) return } + user, err := logic.GetUser(authRequest.UserName) + if err != nil { + logger.Log(0, authRequest.UserName, "user validation failed: ", + err.Error()) + logic.ReturnErrorResponse(response, request, logic.FormatError(err, "unauthorized")) + return + } + if logic.IsOauthUser(user) == nil { + logic.ReturnErrorResponse(response, request, logic.FormatError(errors.New("user is registered via SSO"), "badrequest")) + return + } + if !user.IsSuperAdmin && !logic.IsBasicAuthEnabled() { + logic.ReturnErrorResponse( + response, + request, + logic.FormatError(fmt.Errorf("basic auth is disabled"), "badrequest"), + ) + return + } + if val := request.Header.Get("From-Ui"); val == "true" { // request came from UI, if normal user block Login - user, err := logic.GetUser(authRequest.UserName) - if err != nil { - logger.Log(0, authRequest.UserName, "user validation failed: ", - err.Error()) - logic.ReturnErrorResponse(response, request, logic.FormatError(err, "unauthorized")) - return - } + role, err := logic.GetRole(user.PlatformRoleID) if err != nil { logic.ReturnErrorResponse(response, request, logic.FormatError(errors.New("access denied to dashboard"), "unauthorized")) @@ -255,15 +259,7 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) { return } } - user, err := logic.GetUser(authRequest.UserName) - if err != nil { - logic.ReturnErrorResponse(response, request, logic.FormatError(err, "unauthorized")) - return - } - if logic.IsOauthUser(user) == nil { - logic.ReturnErrorResponse(response, request, logic.FormatError(errors.New("user is registered via SSO"), "badrequest")) - return - } + username := authRequest.UserName jwt, err := logic.VerifyAuthRequest(authRequest) if err != nil { @@ -305,7 +301,7 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) { response.Write(successJSONResponse) go func() { - if servercfg.IsPro && servercfg.GetRacAutoDisable() { + if servercfg.IsPro && logic.GetRacAutoDisable() { // enable all associeated clients for the user clients, err := logic.GetAllExtClients() if err != nil { @@ -479,7 +475,7 @@ func createSuperAdmin(w http.ResponseWriter, r *http.Request) { return } - if !servercfg.IsBasicAuthEnabled() { + if !logic.IsBasicAuthEnabled() { logic.ReturnErrorResponse( w, r, @@ -527,7 +523,7 @@ func transferSuperAdmin(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("only admins can be promoted to superadmin role"), "forbidden")) return } - if !servercfg.IsBasicAuthEnabled() { + if !logic.IsBasicAuthEnabled() { logic.ReturnErrorResponse( w, r, diff --git a/database/database.go b/database/database.go index 86111b98..483eb35f 100644 --- a/database/database.go +++ b/database/database.go @@ -69,6 +69,8 @@ const ( TAG_TABLE_NAME = "tags" // PEER_ACK_TABLE - table for failover peer ack PEER_ACK_TABLE = "peer_ack" + // SERVER_SETTINGS - table for server settings + SERVER_SETTINGS = "server_settings" // == ERROR CONSTS == // NO_RECORD - no singular result found NO_RECORD = "no result found" @@ -125,7 +127,7 @@ var Tables = []string{ TAG_TABLE_NAME, ACLS_TABLE_NAME, PEER_ACK_TABLE, - // ACCESS_TOKENS_TABLE_NAME, + SERVER_SETTINGS, } func getCurrentDB() map[string]interface{} { diff --git a/go.mod b/go.mod index d53aa885..60e8a9c1 100644 --- a/go.mod +++ b/go.mod @@ -48,6 +48,9 @@ require ( github.com/olekukonko/tablewriter v0.0.5 github.com/spf13/cobra v1.8.1 gopkg.in/mail.v2 v2.3.1 + gorm.io/driver/postgres v1.5.11 + gorm.io/driver/sqlite v1.5.7 + gorm.io/gorm v1.25.12 ) require ( @@ -63,12 +66,10 @@ require ( github.com/jinzhu/now v1.1.5 // indirect github.com/kr/text v0.2.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/seancfoley/bintree v1.3.1 // indirect github.com/spf13/pflag v1.0.5 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect - gorm.io/driver/postgres v1.5.11 // indirect - gorm.io/driver/sqlite v1.5.7 // indirect - gorm.io/gorm v1.25.12 // indirect ) require ( diff --git a/go.sum b/go.sum index c99f9eb5..f086e123 100644 --- a/go.sum +++ b/go.sum @@ -61,9 +61,8 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= @@ -89,6 +88,8 @@ github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/rqlite/gorqlite v0.0.0-20240122221808-a8a425b1a6aa h1:hxMLFbj+F444JAS5nUQxTDZwUxwCRqg3WkNqhiDzXrM= github.com/rqlite/gorqlite v0.0.0-20240122221808-a8a425b1a6aa/go.mod h1:xF/KoXmrRyahPfo5L7Szb5cAAUl53dMWBh9cMruGEZg= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -115,8 +116,6 @@ go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwE golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= -golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= -golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= @@ -135,8 +134,6 @@ golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbht golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= -golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -147,8 +144,6 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= -golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= @@ -162,8 +157,6 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= -golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -176,8 +169,8 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20221104135756-97bc4ad4a1cb/go.mod h1:m gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk= gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/mail.v2 v2.3.1 h1:WYFn/oANrAGP2C0dcV6/pbkPzv8yGzqTjPmTeO7qoXk= gopkg.in/mail.v2 v2.3.1/go.mod h1:htwXN1Qh09vZJ1NVKxQqHPBaCBbzKhp5GzuJEA4VJWw= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/logic/auth.go b/logic/auth.go index 11d23ce0..611cb0f0 100644 --- a/logic/auth.go +++ b/logic/auth.go @@ -31,6 +31,8 @@ func ClearSuperUserCache() { superUser = models.User{} } +var InitializeAuthProvider = func() string { return "" } + // HasSuperAdmin - checks if server has an superadmin/owner func HasSuperAdmin() (bool, error) { diff --git a/logic/dns.go b/logic/dns.go index 7e0241d2..23bfed15 100644 --- a/logic/dns.go +++ b/logic/dns.go @@ -12,7 +12,6 @@ import ( "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/models" - "github.com/gravitl/netmaker/servercfg" "github.com/txn2/txeh" ) @@ -106,7 +105,7 @@ func GetNodeDNS(network string) ([]models.DNSEntry, error) { if err != nil { return dns, err } - defaultDomain := servercfg.GetDefaultDomain() + defaultDomain := GetDefaultDomain() for _, node := range nodes { if node.Network != network { continue diff --git a/logic/hosts.go b/logic/hosts.go index 4d1112c1..683bf06c 100644 --- a/logic/hosts.go +++ b/logic/hosts.go @@ -228,7 +228,7 @@ func CreateHost(h *models.Host) error { return err } h.HostPass = string(hash) - h.AutoUpdate = servercfg.AutoUpdateEnabled() + h.AutoUpdate = AutoUpdateEnabled() checkForZombieHosts(h) return UpsertHost(h) } diff --git a/logic/jwts.go b/logic/jwts.go index 4c8df6ea..2dc1cf82 100644 --- a/logic/jwts.go +++ b/logic/jwts.go @@ -61,8 +61,8 @@ func CreateUserAccessJwtToken(username string, role models.UserRoleID, d time.Ti UserName: username, Role: role, TokenType: models.AccessTokenType, - Api: servercfg.ServerInfo.APIHost, - RacAutoDisable: servercfg.GetRacAutoDisable() && (role != models.SuperAdminRole && role != models.AdminRole), + Api: servercfg.GetAPIHost(), + RacAutoDisable: GetRacAutoDisable() && (role != models.SuperAdminRole && role != models.AdminRole), RegisteredClaims: jwt.RegisteredClaims{ Issuer: "Netmaker", Subject: fmt.Sprintf("user|%s", username), @@ -82,12 +82,13 @@ func CreateUserAccessJwtToken(username string, role models.UserRoleID, d time.Ti // CreateUserJWT - creates a user jwt token func CreateUserJWT(username string, role models.UserRoleID) (response string, err error) { - expirationTime := time.Now().Add(servercfg.GetServerConfig().JwtValidityDuration) + settings := GetServerSettings() + expirationTime := time.Now().Add(time.Duration(settings.JwtValidityDuration) * time.Minute) claims := &models.UserClaims{ UserName: username, Role: role, TokenType: models.UserIDTokenType, - RacAutoDisable: servercfg.GetRacAutoDisable() && (role != models.SuperAdminRole && role != models.AdminRole), + RacAutoDisable: settings.RacAutoDisable && (role != models.SuperAdminRole && role != models.AdminRole), RegisteredClaims: jwt.RegisteredClaims{ Issuer: "Netmaker", Subject: fmt.Sprintf("user|%s", username), diff --git a/logic/peers.go b/logic/peers.go index e41d0592..b7d1c452 100644 --- a/logic/peers.go +++ b/logic/peers.go @@ -157,7 +157,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N Peers: []wgtypes.PeerConfig{}, NodePeers: []wgtypes.PeerConfig{}, HostNetworkInfo: models.HostInfoMap{}, - ServerConfig: servercfg.ServerInfo, + ServerConfig: GetServerInfo(), } defer func() { if !hostPeerUpdate.FwUpdate.AllowAll { diff --git a/logic/settings.go b/logic/settings.go new file mode 100644 index 00000000..4b062743 --- /dev/null +++ b/logic/settings.go @@ -0,0 +1,335 @@ +package logic + +import ( + "encoding/json" + "os" + "regexp" + "strconv" + "strings" + "sync" + "time" + + "github.com/gravitl/netmaker/config" + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/servercfg" +) + +var serverSettingsDBKey = "server_cfg" +var SettingsMutex = &sync.RWMutex{} + +func GetServerSettings() (s models.ServerSettings) { + data, err := database.FetchRecord(database.SERVER_SETTINGS, serverSettingsDBKey) + if err != nil { + return + } + json.Unmarshal([]byte(data), &s) + return +} + +func UpsertServerSettings(s models.ServerSettings) error { + // get curr settings + currSettings := GetServerSettings() + if s.ClientSecret == Mask() { + s.ClientSecret = currSettings.ClientSecret + } + data, err := json.Marshal(s) + if err != nil { + return err + } + err = database.Insert(serverSettingsDBKey, string(data), database.SERVER_SETTINGS) + if err != nil { + return err + } + return nil +} + +func ValidateNewSettings(req models.ServerSettings) bool { + // TODO: add checks for different fields + return true +} + +func GetServerSettingsFromEnv() (s models.ServerSettings) { + + s = models.ServerSettings{ + NetclientAutoUpdate: servercfg.AutoUpdateEnabled(), + Verbosity: servercfg.GetVerbosity(), + AuthProvider: os.Getenv("AUTH_PROVIDER"), + OIDCIssuer: os.Getenv("OIDC_ISSUER"), + ClientID: os.Getenv("CLIENT_ID"), + ClientSecret: os.Getenv("CLIENT_SECRET"), + AzureTenant: servercfg.GetAzureTenant(), + Telemetry: servercfg.Telemetry(), + BasicAuth: servercfg.IsBasicAuthEnabled(), + JwtValidityDuration: servercfg.GetJwtValidityDurationFromEnv() / 60, + RacAutoDisable: servercfg.GetRacAutoDisable(), + RacRestrictToSingleNetwork: servercfg.GetRacRestrictToSingleNetwork(), + EndpointDetection: servercfg.IsEndpointDetectionEnabled(), + AllowedEmailDomains: servercfg.GetAllowedEmailDomains(), + EmailSenderAddr: servercfg.GetSenderEmail(), + EmailSenderUser: servercfg.GetSenderUser(), + EmailSenderPassword: servercfg.GetEmaiSenderPassword(), + SmtpHost: servercfg.GetSmtpHost(), + SmtpPort: servercfg.GetSmtpPort(), + MetricInterval: servercfg.GetMetricInterval(), + MetricsPort: servercfg.GetMetricsPort(), + ManageDNS: servercfg.GetManageDNS(), + DefaultDomain: servercfg.GetDefaultDomain(), + Stun: servercfg.IsStunEnabled(), + StunServers: servercfg.GetStunServers(), + TextSize: "16", + Theme: models.Dark, + ReducedMotion: false, + } + + return +} + +// GetServerConfig - gets the server config into memory from file or env +func GetServerConfig() config.ServerConfig { + var cfg config.ServerConfig + settings := GetServerSettings() + cfg.APIConnString = servercfg.GetAPIConnString() + cfg.CoreDNSAddr = servercfg.GetCoreDNSAddr() + cfg.APIHost = servercfg.GetAPIHost() + cfg.APIPort = servercfg.GetAPIPort() + cfg.MasterKey = "(hidden)" + cfg.DNSKey = "(hidden)" + cfg.AllowedOrigin = servercfg.GetAllowedOrigin() + cfg.RestBackend = "off" + cfg.NodeID = servercfg.GetNodeID() + cfg.BrokerType = servercfg.GetBrokerType() + cfg.EmqxRestEndpoint = servercfg.GetEmqxRestEndpoint() + if settings.NetclientAutoUpdate { + cfg.NetclientAutoUpdate = "enabled" + } else { + cfg.NetclientAutoUpdate = "disabled" + } + if servercfg.IsRestBackend() { + cfg.RestBackend = "on" + } + cfg.DNSMode = "off" + if servercfg.IsDNSMode() { + cfg.DNSMode = "on" + } + cfg.DisplayKeys = "off" + if servercfg.IsDisplayKeys() { + cfg.DisplayKeys = "on" + } + cfg.DisableRemoteIPCheck = "off" + if servercfg.DisableRemoteIPCheck() { + cfg.DisableRemoteIPCheck = "on" + } + cfg.Database = servercfg.GetDB() + cfg.Platform = servercfg.GetPlatform() + cfg.Version = servercfg.GetVersion() + cfg.PublicIp = servercfg.GetServerHostIP() + + // == auth config == + var authInfo = GetAuthProviderInfo(settings) + cfg.AuthProvider = authInfo[0] + cfg.ClientID = authInfo[1] + cfg.ClientSecret = authInfo[2] + cfg.FrontendURL = servercfg.GetFrontendURL() + cfg.AzureTenant = settings.AzureTenant + cfg.Telemetry = settings.Telemetry + cfg.Server = servercfg.GetServer() + cfg.Verbosity = settings.Verbosity + cfg.IsPro = "no" + if servercfg.IsPro { + cfg.IsPro = "yes" + } + cfg.JwtValidityDuration = time.Duration(settings.JwtValidityDuration) * time.Minute + cfg.RacAutoDisable = settings.RacAutoDisable + cfg.RacRestrictToSingleNetwork = settings.RacRestrictToSingleNetwork + cfg.MetricInterval = settings.MetricInterval + cfg.ManageDNS = settings.ManageDNS + cfg.Stun = settings.Stun + cfg.StunServers = settings.StunServers + cfg.DefaultDomain = settings.DefaultDomain + return cfg +} + +// GetServerInfo - gets the server config into memory from file or env +func GetServerInfo() models.ServerConfig { + var cfg models.ServerConfig + serverSettings := GetServerSettings() + cfg.Server = servercfg.GetServer() + if servercfg.GetBrokerType() == servercfg.EmqxBrokerType { + cfg.MQUserName = "HOST_ID" + cfg.MQPassword = "HOST_PASS" + } else { + cfg.MQUserName = servercfg.GetMqUserName() + cfg.MQPassword = servercfg.GetMqPassword() + } + cfg.API = servercfg.GetAPIConnString() + cfg.CoreDNSAddr = servercfg.GetCoreDNSAddr() + cfg.APIPort = servercfg.GetAPIPort() + cfg.DNSMode = "off" + cfg.Broker = servercfg.GetPublicBrokerEndpoint() + cfg.BrokerType = servercfg.GetBrokerType() + if servercfg.IsDNSMode() { + cfg.DNSMode = "on" + } + cfg.Version = servercfg.GetVersion() + cfg.IsPro = servercfg.IsPro + cfg.MetricInterval = serverSettings.MetricInterval + cfg.MetricsPort = serverSettings.MetricsPort + cfg.ManageDNS = serverSettings.ManageDNS + cfg.Stun = serverSettings.Stun + cfg.StunServers = serverSettings.StunServers + cfg.DefaultDomain = serverSettings.DefaultDomain + cfg.EndpointDetection = serverSettings.EndpointDetection + return cfg +} + +// GetDefaultDomain - get the default domain +func GetDefaultDomain() string { + return GetServerSettings().DefaultDomain +} + +func ValidateDomain(domain string) bool { + domainPattern := `[a-zA-Z0-9][a-zA-Z0-9_-]{0,62}(\.[a-zA-Z0-9][a-zA-Z0-9_-]{0,62})*(\.[a-zA-Z][a-zA-Z0-9]{0,10}){1}` + + exp := regexp.MustCompile("^" + domainPattern + "$") + + return exp.MatchString(domain) +} + +// Telemetry - checks if telemetry data should be sent +func Telemetry() string { + return GetServerSettings().Telemetry +} + +// GetJwtValidityDuration - returns the JWT validity duration in minutes +func GetJwtValidityDuration() time.Duration { + return GetServerConfig().JwtValidityDuration +} + +// GetRacAutoDisable - returns whether the feature to autodisable RAC is enabled +func GetRacAutoDisable() bool { + return GetServerSettings().RacAutoDisable +} + +// GetRacRestrictToSingleNetwork - returns whether the feature to allow simultaneous network connections via RAC is enabled +func GetRacRestrictToSingleNetwork() bool { + return GetServerSettings().RacRestrictToSingleNetwork +} + +func GetSmtpHost() string { + return GetServerSettings().SmtpHost +} + +func GetSmtpPort() int { + return GetServerSettings().SmtpPort +} + +func GetSenderEmail() string { + return GetServerSettings().EmailSenderAddr +} + +func GetSenderUser() string { + return GetServerSettings().EmailSenderUser +} + +func GetEmaiSenderPassword() string { + return GetServerSettings().EmailSenderPassword +} + +// AutoUpdateEnabled returns a boolean indicating whether netclient auto update is enabled or disabled +// default is enabled +func AutoUpdateEnabled() bool { + return GetServerSettings().NetclientAutoUpdate +} + +// GetAuthProviderInfo = gets the oauth provider info +func GetAuthProviderInfo(settings models.ServerSettings) (pi []string) { + var authProvider = "" + + defer func() { + if authProvider == "oidc" { + if settings.OIDCIssuer != "" { + pi = append(pi, settings.OIDCIssuer) + } else { + pi = []string{"", "", ""} + } + } + }() + + if settings.AuthProvider != "" && settings.ClientID != "" && settings.ClientSecret != "" { + authProvider = strings.ToLower(settings.AuthProvider) + if authProvider == "google" || authProvider == "azure-ad" || authProvider == "github" || authProvider == "oidc" { + return []string{authProvider, settings.ClientID, settings.ClientSecret} + } else { + authProvider = "" + } + } + return []string{"", "", ""} +} + +// GetAzureTenant - retrieve the azure tenant ID from env variable or config file +func GetAzureTenant() string { + return GetServerSettings().AzureTenant +} + +// GetMetricsPort - get metrics port +func GetMetricsPort() int { + return GetServerSettings().MetricsPort +} + +// GetMetricInterval - get the publish metric interval +func GetMetricIntervalInMinutes() time.Duration { + //default 15 minutes + mi := "15" + if os.Getenv("PUBLISH_METRIC_INTERVAL") != "" { + mi = os.Getenv("PUBLISH_METRIC_INTERVAL") + } + interval, err := strconv.Atoi(mi) + if err != nil { + interval = 15 + } + + return time.Duration(interval) * time.Minute +} + +// GetMetricInterval - get the publish metric interval +func GetMetricInterval() string { + return GetServerSettings().MetricInterval +} + +// GetManageDNS - if manage DNS enabled or not +func GetManageDNS() bool { + return GetServerSettings().ManageDNS +} + +// IsBasicAuthEnabled - checks if basic auth has been configured to be turned off +func IsBasicAuthEnabled() bool { + return GetServerSettings().BasicAuth +} + +// IsEndpointDetectionEnabled - returns true if endpoint detection enabled +func IsEndpointDetectionEnabled() bool { + return GetServerSettings().EndpointDetection +} + +// IsStunEnabled - returns true if STUN set to on +func IsStunEnabled() bool { + return GetServerSettings().Stun +} + +func GetStunServers() string { + return GetServerSettings().StunServers +} + +// GetAllowedEmailDomains - gets the allowed email domains for oauth signup +func GetAllowedEmailDomains() string { + return GetServerSettings().AllowedEmailDomains +} + +func GetVerbosity() int32 { + return GetServerSettings().Verbosity +} + +func Mask() string { + return ("..................") +} diff --git a/logic/telemetry.go b/logic/telemetry.go index c0a41bab..de9f9088 100644 --- a/logic/telemetry.go +++ b/logic/telemetry.go @@ -7,6 +7,7 @@ import ( "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/servercfg" "github.com/posthog/posthog-go" @@ -33,7 +34,7 @@ func SetFreeTierForTelemetry(freeTierFlag bool) { // sendTelemetry - gathers telemetry data and sends to posthog func sendTelemetry() error { - if servercfg.Telemetry() == "off" { + if Telemetry() == "off" { return nil } diff --git a/logic/user_mgmt.go b/logic/user_mgmt.go index 3f7bea9a..7eb3de7b 100644 --- a/logic/user_mgmt.go +++ b/logic/user_mgmt.go @@ -62,6 +62,7 @@ var CreateDefaultUserPolicies = func(netID models.NetworkID) {} var GetUserGroupsInNetwork = func(netID models.NetworkID) (networkGrps map[models.UserGroupID]models.UserGroup) { return } var GetUserGroup = func(groupId models.UserGroupID) (userGrps models.UserGroup, err error) { return } var AddGlobalNetRolesToAdmins = func(u *models.User) {} +var EmailInit = func() {} // GetRole - fetches role template by id func GetRole(roleID models.UserRoleID) (models.UserRolePermissionTemplate, error) { diff --git a/migrate/migrate.go b/migrate/migrate.go index 94ee6d6d..1b1f8b65 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -20,6 +20,7 @@ import ( // Run - runs all migrations func Run() { + settings() updateEnrollmentKeys() assignSuperAdmin() createDefaultTagsAndPolicies() @@ -498,3 +499,10 @@ func migrateToGws() { logic.DeleteTag(models.TagID(fmt.Sprintf("%s.%s", netI.NetID, models.OldRemoteAccessTagName)), true) } } + +func settings() { + _, err := database.FetchRecords(database.SERVER_SETTINGS) + if database.IsEmptyRecord(err) { + logic.UpsertServerSettings(logic.GetServerSettingsFromEnv()) + } +} diff --git a/models/settings.go b/models/settings.go new file mode 100644 index 00000000..c7aa394c --- /dev/null +++ b/models/settings.go @@ -0,0 +1,40 @@ +package models + +type Theme string + +const ( + Dark Theme = "dark" + Light Theme = "light" + System Theme = "system" +) + +type ServerSettings struct { + NetclientAutoUpdate bool `json:"netclientautoupdate"` + Verbosity int32 `json:"verbosity"` + AuthProvider string `json:"authprovider"` + OIDCIssuer string `json:"oidcissuer"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + AzureTenant string `json:"azure_tenant"` + Telemetry string `json:"telemetry"` + BasicAuth bool `json:"basic_auth"` + JwtValidityDuration int `json:"jwt_validity_duration"` + RacAutoDisable bool `json:"rac_auto_disable"` + RacRestrictToSingleNetwork bool `json:"rac_restrict_to_single_network"` + EndpointDetection bool `json:"endpoint_detection"` + AllowedEmailDomains string `json:"allowed_email_domains"` + EmailSenderAddr string `json:"email_sender_addr"` + EmailSenderUser string `json:"email_sender_user"` + EmailSenderPassword string `json:"email_sender_password"` + SmtpHost string `json:"smtp_host"` + SmtpPort int `json:"smtp_port"` + MetricInterval string `json:"metric_interval"` + MetricsPort int `json:"metrics_port"` + ManageDNS bool `json:"manage_dns"` + DefaultDomain string `json:"default_domain"` + Stun bool `json:"stun"` + StunServers string `json:"stun_servers"` + Theme Theme `json:"theme"` + TextSize string `json:"text_size"` + ReducedMotion bool `json:"reduced_motion"` +} diff --git a/mq/publishers.go b/mq/publishers.go index cbebf60f..3542ef8c 100644 --- a/mq/publishers.go +++ b/mq/publishers.go @@ -21,7 +21,7 @@ func PublishPeerUpdate(replacePeers bool) error { return nil } - if servercfg.GetManageDNS() { + if logic.GetManageDNS() { sendDNSSync() } diff --git a/pro/auth/auth.go b/pro/auth/auth.go index 162a6ce2..215a6263 100644 --- a/pro/auth/auth.go +++ b/pro/auth/auth.go @@ -47,7 +47,7 @@ var ( ) func getCurrentAuthFunctions() map[string]interface{} { - var authInfo = servercfg.GetAuthProviderInfo() + var authInfo = logic.GetAuthProviderInfo(logic.GetServerSettings()) var authProvider = authInfo[0] switch authProvider { case google_provider_name: @@ -74,7 +74,7 @@ func InitializeAuthProvider() string { if err != nil { logger.FatalLog("failed to set auth_secret", err.Error()) } - var authInfo = servercfg.GetAuthProviderInfo() + var authInfo = logic.GetAuthProviderInfo(logic.GetServerSettings()) var serverConn = servercfg.GetAPIHost() if strings.Contains(serverConn, "localhost") || strings.Contains(serverConn, "127.0.0.1") { serverConn = "http://" + serverConn @@ -275,7 +275,7 @@ func isStateCached(state string) bool { // isEmailAllowed - checks if email is allowed to signup func isEmailAllowed(email string) bool { - allowedDomains := servercfg.GetAllowedEmailDomains() + allowedDomains := logic.GetAllowedEmailDomains() domains := strings.Split(allowedDomains, ",") if len(domains) == 1 && domains[0] == "*" { return true diff --git a/pro/auth/azure-ad.go b/pro/auth/azure-ad.go index b27b486d..dbcaae5a 100644 --- a/pro/auth/azure-ad.go +++ b/pro/auth/azure-ad.go @@ -35,7 +35,7 @@ func initAzureAD(redirectURL string, clientID string, clientSecret string) { ClientID: clientID, ClientSecret: clientSecret, Scopes: []string{"User.Read", "email", "profile", "openid"}, - Endpoint: microsoft.AzureADEndpoint(servercfg.GetAzureTenant()), + Endpoint: microsoft.AzureADEndpoint(logic.GetAzureTenant()), } } diff --git a/pro/email/email.go b/pro/email/email.go index cde69826..12870c6a 100644 --- a/pro/email/email.go +++ b/pro/email/email.go @@ -4,7 +4,7 @@ import ( "context" "regexp" - "github.com/gravitl/netmaker/servercfg" + "github.com/gravitl/netmaker/logic" ) type EmailSenderType string @@ -16,14 +16,14 @@ const ( Resend EmailSenderType = "resend" ) -func init() { +func Init() { smtpSender := &SmtpSender{ - SmtpHost: servercfg.GetSmtpHost(), - SmtpPort: servercfg.GetSmtpPort(), - SenderEmail: servercfg.GetSenderEmail(), - SendUser: servercfg.GetSenderUser(), - SenderPass: servercfg.GetEmaiSenderPassword(), + SmtpHost: logic.GetSmtpHost(), + SmtpPort: logic.GetSmtpPort(), + SenderEmail: logic.GetSenderEmail(), + SendUser: logic.GetSenderUser(), + SenderPass: logic.GetEmaiSenderPassword(), } if smtpSender.SendUser == "" { smtpSender.SendUser = smtpSender.SenderEmail diff --git a/pro/initialize.go b/pro/initialize.go index f54cde61..85ebeae5 100644 --- a/pro/initialize.go +++ b/pro/initialize.go @@ -13,6 +13,7 @@ import ( "github.com/gravitl/netmaker/mq" "github.com/gravitl/netmaker/pro/auth" proControllers "github.com/gravitl/netmaker/pro/controllers" + "github.com/gravitl/netmaker/pro/email" proLogic "github.com/gravitl/netmaker/pro/logic" "github.com/gravitl/netmaker/servercfg" "golang.org/x/exp/slog" @@ -79,7 +80,7 @@ func InitPro() { addTrialLicenseHook() } - if servercfg.GetServerConfig().RacAutoDisable { + if logic.GetRacAutoDisable() { AddRacHooks() } @@ -91,6 +92,7 @@ func InitPro() { } proLogic.LoadNodeMetricsToCache() proLogic.InitFailOverCache() + email.Init() }) logic.ResetFailOver = proLogic.ResetFailOver logic.ResetFailedOverPeer = proLogic.ResetFailedOverPeer @@ -135,6 +137,8 @@ func InitPro() { logic.GetUserGroupsInNetwork = proLogic.GetUserGroupsInNetwork logic.GetUserGroup = proLogic.GetUserGroup logic.GetNodeStatus = proLogic.GetNodeStatus + logic.InitializeAuthProvider = auth.InitializeAuthProvider + logic.EmailInit = email.Init } func retrieveProLogo() string { diff --git a/pro/license.go b/pro/license.go index c623a240..aaa16259 100644 --- a/pro/license.go +++ b/pro/license.go @@ -9,11 +9,12 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gravitl/netmaker/utils" "io" "net/http" "time" + "github.com/gravitl/netmaker/utils" + "golang.org/x/crypto/nacl/box" "golang.org/x/exp/slog" diff --git a/pro/remote_access_client.go b/pro/remote_access_client.go index 16a38a85..b9266c84 100644 --- a/pro/remote_access_client.go +++ b/pro/remote_access_client.go @@ -10,7 +10,6 @@ import ( "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" - "github.com/gravitl/netmaker/servercfg" "golang.org/x/exp/slog" ) @@ -41,7 +40,7 @@ func racAutoDisableHook() error { } currentTime := time.Now() - validityDuration := servercfg.GetJwtValidityDuration() + validityDuration := logic.GetJwtValidityDuration() for _, user := range users { if user.PlatformRoleID == models.AdminRole || user.PlatformRoleID == models.SuperAdminRole { diff --git a/schema/jobs.go b/schema/job.go similarity index 64% rename from schema/jobs.go rename to schema/job.go index 1e9e13f6..2731ef8d 100644 --- a/schema/jobs.go +++ b/schema/job.go @@ -16,21 +16,16 @@ import ( // that it is easier to prevent a task from // being executed again. type Job struct { - ID string `gorm:"id;primary_key"` - CreatedAt time.Time `gorm:"created_at"` -} - -// TableName returns the name of the jobs table. -func (j *Job) TableName() string { - return "jobs" + ID string `gorm:"primaryKey"` + CreatedAt time.Time } // Create creates a job record in the jobs table. func (j *Job) Create(ctx context.Context) error { - return db.FromContext(ctx).Table(j.TableName()).Create(j).Error + return db.FromContext(ctx).Model(&Job{}).Create(j).Error } // Get returns a job record with the given Job.ID. func (j *Job) Get(ctx context.Context) error { - return db.FromContext(ctx).Table(j.TableName()).Where("id = ?", j.ID).First(j).Error + return db.FromContext(ctx).Model(&Job{}).Where("id = ?", j.ID).First(j).Error } diff --git a/schema/accessToken.go b/schema/user_access_token.go similarity index 100% rename from schema/accessToken.go rename to schema/user_access_token.go diff --git a/servercfg/serverconf.go b/servercfg/serverconf.go index cc3fc814..e00dcf85 100644 --- a/servercfg/serverconf.go +++ b/servercfg/serverconf.go @@ -2,6 +2,7 @@ package servercfg import ( "errors" + "fmt" "io" "net/http" "os" @@ -11,11 +12,8 @@ import ( "time" "github.com/gravitl/netmaker/config" - "github.com/gravitl/netmaker/models" ) -var ServerInfo = GetServerInfo() - // EmqxBrokerType denotes the broker type for EMQX MQTT const EmqxBrokerType = "emqx" @@ -116,6 +114,18 @@ func GetJwtValidityDuration() time.Duration { return defaultDuration } +// GetJwtValidityDuration - returns the JWT validity duration in seconds +func GetJwtValidityDurationFromEnv() int { + var defaultDuration = 43200 + if os.Getenv("JWT_VALIDITY_DURATION") != "" { + t, err := strconv.Atoi(os.Getenv("JWT_VALIDITY_DURATION")) + if err == nil { + return t + } + } + return defaultDuration +} + // GetRacAutoDisable - returns whether the feature to autodisable RAC is enabled func GetRacAutoDisable() bool { return os.Getenv("RAC_AUTO_DISABLE") == "true" @@ -126,39 +136,6 @@ func GetRacRestrictToSingleNetwork() bool { return os.Getenv("RAC_RESTRICT_TO_SINGLE_NETWORK") == "true" } -// GetServerInfo - gets the server config into memory from file or env -func GetServerInfo() models.ServerConfig { - var cfg models.ServerConfig - cfg.Server = GetServer() - if GetBrokerType() == EmqxBrokerType { - cfg.MQUserName = "HOST_ID" - cfg.MQPassword = "HOST_PASS" - } else { - cfg.MQUserName = GetMqUserName() - cfg.MQPassword = GetMqPassword() - } - cfg.APIHost = GetAPIHost() - cfg.API = GetAPIConnString() - cfg.CoreDNSAddr = GetCoreDNSAddr() - cfg.APIPort = GetAPIPort() - cfg.DNSMode = "off" - cfg.Broker = GetPublicBrokerEndpoint() - cfg.BrokerType = GetBrokerType() - if IsDNSMode() { - cfg.DNSMode = "on" - } - cfg.Version = GetVersion() - cfg.IsPro = IsPro - cfg.MetricInterval = GetMetricInterval() - cfg.MetricsPort = GetMetricsPort() - cfg.ManageDNS = GetManageDNS() - cfg.Stun = IsStunEnabled() - cfg.StunServers = GetStunServers() - cfg.DefaultDomain = GetDefaultDomain() - cfg.EndpointDetection = IsEndpointDetectionEnabled() - return cfg -} - // GetFrontendURL - gets the frontend url func GetFrontendURL() string { var frontend = "" @@ -167,6 +144,9 @@ func GetFrontendURL() string { } else if config.Config.Server.FrontendURL != "" { frontend = config.Config.Server.FrontendURL } + if frontend == "" { + return fmt.Sprintf("https://dashboard.%s", GetNmBaseDomain()) + } return frontend }