diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index 71e9313d..dc6669bd 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -49,12 +49,12 @@ func getEnrollmentKeys(w http.ResponseWriter, r *http.Request) { ret := []*models.EnrollmentKey{} for _, key := range keys { key := key - if err = logic.Tokenize(key, servercfg.GetAPIHost()); err != nil { + if err = logic.Tokenize(&key, servercfg.GetAPIHost()); err != nil { logger.Log(0, r.Header.Get("user"), "failed to get token values for keys:", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - ret = append(ret, key) + ret = append(ret, &key) } // return JSON/API formatted keys logger.Log(2, r.Header.Get("user"), "fetched enrollment keys") diff --git a/controllers/hosts.go b/controllers/hosts.go index ebcd8059..7a5d3912 100644 --- a/controllers/hosts.go +++ b/controllers/hosts.go @@ -604,7 +604,7 @@ func authenticateHost(response http.ResponseWriter, request *http.Request) { errorResponse.Code = http.StatusBadRequest errorResponse.Message = err.Error() logger.Log(0, request.Header.Get("user"), - "error retrieving host: ", err.Error()) + "error retrieving host: ", authRequest.ID, err.Error()) logic.ReturnErrorResponse(response, request, errorResponse) return } diff --git a/controllers/network.go b/controllers/network.go index a7050771..2c2f2b43 100644 --- a/controllers/network.go +++ b/controllers/network.go @@ -408,6 +408,8 @@ func deleteNetwork(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, errtype)) return } + //delete network from allocated ip map + go logic.RemoveNetworkFromAllocatedIpMap(network) logger.Log(1, r.Header.Get("user"), "deleted network", network) w.WriteHeader(http.StatusOK) @@ -480,6 +482,10 @@ func createNetwork(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } + + //add new network to allocated ip map + go logic.AddNetworkToAllocatedIpMap(network.NetID) + go func() { defaultHosts := logic.GetDefaultHosts() for i := range defaultHosts { diff --git a/controllers/network_test.go b/controllers/network_test.go index 8678f408..aed9e288 100644 --- a/controllers/network_test.go +++ b/controllers/network_test.go @@ -46,8 +46,8 @@ func TestCreateNetwork(t *testing.T) { deleteAllNetworks() var network models.Network - network.NetID = "skynet" - network.AddressRange = "10.0.0.1/24" + network.NetID = "skynet1" + network.AddressRange = "10.10.0.1/24" // if tests break - check here (removed displayname) //network.DisplayName = "mynetwork" diff --git a/controllers/node.go b/controllers/node.go index 739dee46..09cce98f 100644 --- a/controllers/node.go +++ b/controllers/node.go @@ -93,7 +93,7 @@ func authenticate(response http.ResponseWriter, request *http.Request) { errorResponse.Code = http.StatusBadRequest errorResponse.Message = err.Error() logger.Log(0, request.Header.Get("user"), - "error retrieving host: ", err.Error()) + "error retrieving host: ", result.HostID.String(), err.Error()) logic.ReturnErrorResponse(response, request, errorResponse) return } diff --git a/controllers/user_test.go b/controllers/user_test.go index c99b8c47..b1517ac4 100644 --- a/controllers/user_test.go +++ b/controllers/user_test.go @@ -137,6 +137,7 @@ func TestCreateUser(t *testing.T) { func TestCreateSuperAdmin(t *testing.T) { deleteAllUsers(t) + logic.ClearSuperUserCache() var user models.User t.Run("NoSuperAdmin", func(t *testing.T) { user.UserName = "admin" diff --git a/logic/auth.go b/logic/auth.go index e2a660db..87fc3095 100644 --- a/logic/auth.go +++ b/logic/auth.go @@ -20,9 +20,21 @@ const ( auth_key = "netmaker_auth" ) +var ( + superUser = models.User{} +) + +func ClearSuperUserCache() { + superUser = models.User{} +} + // HasSuperAdmin - checks if server has an superadmin/owner func HasSuperAdmin() (bool, error) { + if superUser.IsSuperAdmin { + return true, nil + } + collection, err := database.FetchRecords(database.USERS_TABLE_NAME) if err != nil { if database.IsEmptyRecord(err) { @@ -38,6 +50,7 @@ func HasSuperAdmin() (bool, error) { continue } if user.IsSuperAdmin { + superUser = user return true, nil } } @@ -116,7 +129,7 @@ func CreateUser(user *models.User) error { tokenString, _ := CreateUserJWT(user.UserName, user.IsSuperAdmin, user.IsAdmin) if tokenString == "" { - logger.Log(0, "failed to generate token", err.Error()) + logger.Log(0, "failed to generate token") return err } @@ -204,6 +217,9 @@ func UpsertUser(user models.User) error { slog.Error("error inserting user", "user", user.UserName, "error", err.Error()) return err } + if user.IsSuperAdmin { + superUser = user + } return nil } diff --git a/logic/enrollmentkey.go b/logic/enrollmentkey.go index ae5d01d5..d3c48a01 100644 --- a/logic/enrollmentkey.go +++ b/logic/enrollmentkey.go @@ -5,11 +5,13 @@ import ( "encoding/json" "errors" "fmt" + "sync" "time" "github.com/google/uuid" "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/servercfg" "golang.org/x/exp/slices" ) @@ -29,6 +31,10 @@ var EnrollmentErrors = struct { FailedToTokenize: fmt.Errorf("failed to tokenize"), FailedToDeTokenize: fmt.Errorf("failed to detokenize"), } +var ( + enrollmentkeyCacheMutex = &sync.RWMutex{} + enrollmentkeyCacheMap = make(map[string]models.EnrollmentKey) +) // CreateEnrollmentKey - creates a new enrollment key in db func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool, relay uuid.UUID) (*models.EnrollmentKey, error) { @@ -104,21 +110,21 @@ func UpdateEnrollmentKey(keyId string, relayId uuid.UUID) (*models.EnrollmentKey key.Relay = relayId - if err = upsertEnrollmentKey(key); err != nil { + if err = upsertEnrollmentKey(&key); err != nil { return nil, err } - return key, nil + return &key, nil } // GetAllEnrollmentKeys - fetches all enrollment keys from DB // TODO drop double pointer -func GetAllEnrollmentKeys() ([]*models.EnrollmentKey, error) { +func GetAllEnrollmentKeys() ([]models.EnrollmentKey, error) { currentKeys, err := getEnrollmentKeysMap() if err != nil { return nil, err } - var currentKeysList = []*models.EnrollmentKey{} + var currentKeysList = []models.EnrollmentKey{} for k := range currentKeys { currentKeysList = append(currentKeysList, currentKeys[k]) } @@ -127,15 +133,21 @@ func GetAllEnrollmentKeys() ([]*models.EnrollmentKey, error) { // GetEnrollmentKey - fetches a single enrollment key // returns nil and error if not found -func GetEnrollmentKey(value string) (*models.EnrollmentKey, error) { +func GetEnrollmentKey(value string) (key models.EnrollmentKey, err error) { currentKeys, err := getEnrollmentKeysMap() if err != nil { - return nil, err + return key, err } if key, ok := currentKeys[value]; ok { return key, nil } - return nil, EnrollmentErrors.NoKeyFound + return key, EnrollmentErrors.NoKeyFound +} + +func deleteEnrollmentkeyFromCache(key string) { + enrollmentkeyCacheMutex.Lock() + delete(enrollmentkeyCacheMap, key) + enrollmentkeyCacheMutex.Unlock() } // DeleteEnrollmentKey - delete's a given enrollment key by value @@ -144,7 +156,13 @@ func DeleteEnrollmentKey(value string) error { if err != nil { return err } - return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value) + err = database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value) + if err == nil { + if servercfg.CacheEnabled() { + deleteEnrollmentkeyFromCache(value) + } + } + return err } // TryToUseEnrollmentKey - checks first if key can be decremented @@ -200,7 +218,7 @@ func DeTokenize(b64Token string) (*models.EnrollmentKey, error) { if err != nil { return nil, err } - return k, nil + return &k, nil } // == private == @@ -215,11 +233,11 @@ func decrementEnrollmentKey(value string) (*models.EnrollmentKey, error) { return nil, EnrollmentErrors.NoUsesRemaining } k.UsesRemaining = k.UsesRemaining - 1 - if err = upsertEnrollmentKey(k); err != nil { + if err = upsertEnrollmentKey(&k); err != nil { return nil, err } - return k, nil + return &k, nil } func upsertEnrollmentKey(k *models.EnrollmentKey) error { @@ -230,7 +248,13 @@ func upsertEnrollmentKey(k *models.EnrollmentKey) error { if err != nil { return err } - return database.Insert(k.Value, string(data), database.ENROLLMENT_KEYS_TABLE_NAME) + err = database.Insert(k.Value, string(data), database.ENROLLMENT_KEYS_TABLE_NAME) + if err == nil { + if servercfg.CacheEnabled() { + storeEnrollmentkeyInCache(k.Value, *k) + } + } + return nil } func getUniqueEnrollmentID() (string, error) { @@ -245,7 +269,23 @@ func getUniqueEnrollmentID() (string, error) { return newID, nil } -func getEnrollmentKeysMap() (map[string]*models.EnrollmentKey, error) { +func getEnrollmentkeysFromCache() map[string]models.EnrollmentKey { + return enrollmentkeyCacheMap +} + +func storeEnrollmentkeyInCache(key string, enrollmentkey models.EnrollmentKey) { + enrollmentkeyCacheMutex.Lock() + enrollmentkeyCacheMap[key] = enrollmentkey + enrollmentkeyCacheMutex.Unlock() +} + +func getEnrollmentKeysMap() (map[string]models.EnrollmentKey, error) { + if servercfg.CacheEnabled() { + keys := getEnrollmentkeysFromCache() + if len(keys) != 0 { + return keys, nil + } + } records, err := database.FetchRecords(database.ENROLLMENT_KEYS_TABLE_NAME) if err != nil { if !database.IsEmptyRecord(err) { @@ -255,14 +295,17 @@ func getEnrollmentKeysMap() (map[string]*models.EnrollmentKey, error) { if records == nil { records = make(map[string]string) } - currentKeys := make(map[string]*models.EnrollmentKey, 0) + currentKeys := make(map[string]models.EnrollmentKey, 0) if len(records) > 0 { for k := range records { var currentKey models.EnrollmentKey if err = json.Unmarshal([]byte(records[k]), ¤tKey); err != nil { continue } - currentKeys[k] = ¤tKey + currentKeys[k] = currentKey + if servercfg.CacheEnabled() { + storeEnrollmentkeyInCache(currentKey.Value, currentKey) + } } } return currentKeys, nil diff --git a/logic/enrollmentkey_test.go b/logic/enrollmentkey_test.go index f91469ad..677c4714 100644 --- a/logic/enrollmentkey_test.go +++ b/logic/enrollmentkey_test.go @@ -68,7 +68,7 @@ func TestDelete_EnrollmentKey(t *testing.T) { err := DeleteEnrollmentKey(newKey.Value) assert.Nil(t, err) oldKey, err := GetEnrollmentKey(newKey.Value) - assert.Nil(t, oldKey) + assert.Equal(t, oldKey, models.EnrollmentKey{}) assert.NotNil(t, err) assert.Equal(t, err, EnrollmentErrors.NoKeyFound) }) diff --git a/logic/extpeers.go b/logic/extpeers.go index f874f1a8..c619dde9 100644 --- a/logic/extpeers.go +++ b/logic/extpeers.go @@ -86,10 +86,21 @@ func DeleteExtClient(network string, clientid string) error { if err != nil { return err } + extClient, err := GetExtClient(clientid, network) + if err != nil { + return err + } err = database.DeleteRecord(database.EXT_CLIENT_TABLE_NAME, key) if err != nil { return err } + //recycle ip address + if extClient.Address != "" { + RemoveIpFromAllocatedIpMap(network, extClient.Address) + } + if extClient.Address6 != "" { + RemoveIpFromAllocatedIpMap(network, extClient.Address6) + } if servercfg.CacheEnabled() { deleteExtClientFromCache(key) } @@ -287,6 +298,14 @@ func SaveExtClient(extclient *models.ExtClient) error { if servercfg.CacheEnabled() { storeExtClientInCache(key, *extclient) } + if _, ok := allocatedIpMap[extclient.Network]; ok { + if extclient.Address != "" { + AddIpToAllocatedIpMap(extclient.Network, net.ParseIP(extclient.Address)) + } + if extclient.Address6 != "" { + AddIpToAllocatedIpMap(extclient.Network, net.ParseIP(extclient.Address6)) + } + } return SetNetworkNodesLastModified(extclient.Network) } diff --git a/logic/jwts.go b/logic/jwts.go index a2b95049..b435dcaf 100644 --- a/logic/jwts.go +++ b/logic/jwts.go @@ -31,7 +31,7 @@ func SetJWTSecret() { // CreateJWT func will used to create the JWT while signing in and signing out func CreateJWT(uuid string, macAddress string, network string) (response string, err error) { - expirationTime := time.Now().Add(5 * time.Minute) + expirationTime := time.Now().Add(15 * time.Minute) claims := &models.Claims{ ID: uuid, Network: network, diff --git a/logic/networks.go b/logic/networks.go index 22d59543..b87c6db8 100644 --- a/logic/networks.go +++ b/logic/networks.go @@ -15,13 +15,123 @@ import ( "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic/acls/nodeacls" "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/validation" + "golang.org/x/exp/slog" ) +var ( + networkCacheMutex = &sync.RWMutex{} + networkCacheMap = make(map[string]models.Network) + allocatedIpMap = make(map[string]map[string]net.IP) +) + +// SetAllocatedIpMap - set allocated ip map for networks +func SetAllocatedIpMap() error { + logger.Log(0, "start setting up allocated ip map") + if allocatedIpMap == nil { + allocatedIpMap = map[string]map[string]net.IP{} + } + + currentNetworks, err := GetNetworks() + if err != nil { + return err + } + + for _, v := range currentNetworks { + pMap := map[string]net.IP{} + netName := v.NetID + + nodes, err := GetNetworkNodes(netName) + if err != nil { + slog.Error("could not load node for network", netName, "error", err.Error()) + continue + } + + for _, n := range nodes { + + if n.Address.IP != nil { + pMap[n.Address.IP.String()] = n.Address.IP + } + if n.Address6.IP != nil { + pMap[n.Address6.IP.String()] = n.Address6.IP + } + } + + allocatedIpMap[netName] = pMap + } + logger.Log(0, "setting up allocated ip map done") + return nil +} + +// ClearAllocatedIpMap - set allocatedIpMap to nil +func ClearAllocatedIpMap() { + allocatedIpMap = nil +} + +func AddIpToAllocatedIpMap(networkName string, ip net.IP) { + networkCacheMutex.Lock() + allocatedIpMap[networkName][ip.String()] = ip + networkCacheMutex.Unlock() +} + +func RemoveIpFromAllocatedIpMap(networkName string, ip string) { + networkCacheMutex.Lock() + delete(allocatedIpMap[networkName], ip) + networkCacheMutex.Unlock() +} + +// AddNetworkToAllocatedIpMap - add network to allocated ip map when network is added +func AddNetworkToAllocatedIpMap(networkName string) { + networkCacheMutex.Lock() + allocatedIpMap[networkName] = map[string]net.IP{} + networkCacheMutex.Unlock() +} + +// RemoveNetworkFromAllocatedIpMap - remove network from allocated ip map when network is deleted +func RemoveNetworkFromAllocatedIpMap(networkName string) { + networkCacheMutex.Lock() + delete(allocatedIpMap, networkName) + networkCacheMutex.Unlock() +} + +func getNetworksFromCache() (networks []models.Network) { + networkCacheMutex.RLock() + for _, network := range networkCacheMap { + networks = append(networks, network) + } + networkCacheMutex.RUnlock() + return +} + +func deleteNetworkFromCache(key string) { + networkCacheMutex.Lock() + delete(networkCacheMap, key) + networkCacheMutex.Unlock() +} + +func getNetworkFromCache(key string) (network models.Network, ok bool) { + networkCacheMutex.RLock() + network, ok = networkCacheMap[key] + networkCacheMutex.RUnlock() + return +} + +func storeNetworkInCache(key string, network models.Network) { + networkCacheMutex.Lock() + networkCacheMap[key] = network + networkCacheMutex.Unlock() +} + // GetNetworks - returns all networks from database func GetNetworks() ([]models.Network, error) { var networks []models.Network - + if servercfg.CacheEnabled() { + networks := getNetworksFromCache() + if len(networks) != 0 { + return networks, nil + } + } collection, err := database.FetchRecords(database.NETWORKS_TABLE_NAME) if err != nil { return networks, err @@ -34,6 +144,9 @@ func GetNetworks() ([]models.Network, error) { } // add network our array networks = append(networks, network) + if servercfg.CacheEnabled() { + storeNetworkInCache(network.NetID, network) + } } return networks, err @@ -49,7 +162,14 @@ func DeleteNetwork(network string) error { nodeCount, err := GetNetworkNonServerNodeCount(network) if nodeCount == 0 || database.IsEmptyRecord(err) { // delete server nodes first then db records - return database.DeleteRecord(database.NETWORKS_TABLE_NAME, network) + err = database.DeleteRecord(database.NETWORKS_TABLE_NAME, network) + if err != nil { + return err + } + if servercfg.CacheEnabled() { + deleteNetworkFromCache(network) + } + return nil } return errors.New("node check failed. All nodes must be deleted before deleting network") } @@ -93,6 +213,9 @@ func CreateNetwork(network models.Network) (models.Network, error) { if err = database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil { return models.Network{}, err } + if servercfg.CacheEnabled() { + storeNetworkInCache(network.NetID, network) + } return network, nil } @@ -128,6 +251,11 @@ func intersect(n1, n2 *net.IPNet) bool { func GetParentNetwork(networkname string) (models.Network, error) { var network models.Network + if servercfg.CacheEnabled() { + if network, ok := getNetworkFromCache(networkname); ok { + return network, nil + } + } networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, networkname) if err != nil { return network, err @@ -142,6 +270,11 @@ func GetParentNetwork(networkname string) (models.Network, error) { func GetNetworkSettings(networkname string) (models.Network, error) { var network models.Network + if servercfg.CacheEnabled() { + if network, ok := getNetworkFromCache(networkname); ok { + return network, nil + } + } networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, networkname) if err != nil { return network, err @@ -177,9 +310,9 @@ func UniqueAddress(networkName string, reverse bool) (net.IP, error) { newAddrs = net4.LastAddress() } + ipAllocated := allocatedIpMap[networkName] for { - if IsIPUnique(networkName, newAddrs.String(), database.NODES_TABLE_NAME, false) && - IsIPUnique(networkName, newAddrs.String(), database.EXT_CLIENT_TABLE_NAME, false) { + if _, ok := ipAllocated[newAddrs.String()]; !ok { return newAddrs, nil } if reverse { @@ -266,9 +399,9 @@ func UniqueAddress6(networkName string, reverse bool) (net.IP, error) { return add, err } + ipAllocated := allocatedIpMap[networkName] for { - if IsIPUnique(networkName, newAddrs.String(), database.NODES_TABLE_NAME, true) && - IsIPUnique(networkName, newAddrs.String(), database.EXT_CLIENT_TABLE_NAME, true) { + if _, ok := ipAllocated[newAddrs.String()]; !ok { return newAddrs, nil } if reverse { @@ -320,6 +453,11 @@ func UpdateNetwork(currentNetwork *models.Network, newNetwork *models.Network) ( } newNetwork.SetNetworkLastModified() err = database.Insert(newNetwork.NetID, string(data), database.NETWORKS_TABLE_NAME) + if err == nil { + if servercfg.CacheEnabled() { + storeNetworkInCache(newNetwork.NetID, *newNetwork) + } + } return hasrangeupdate4, hasrangeupdate6, hasholepunchupdate, err } // copy values @@ -330,6 +468,11 @@ func UpdateNetwork(currentNetwork *models.Network, newNetwork *models.Network) ( func GetNetwork(networkname string) (models.Network, error) { var network models.Network + if servercfg.CacheEnabled() { + if network, ok := getNetworkFromCache(networkname); ok { + return network, nil + } + } networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, networkname) if err != nil { return network, err @@ -394,6 +537,9 @@ func SaveNetwork(network *models.Network) error { if err := database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil { return err } + if servercfg.CacheEnabled() { + storeNetworkInCache(network.NetID, *network) + } return nil } @@ -402,6 +548,11 @@ func NetworkExists(name string) (bool, error) { var network string var err error + if servercfg.CacheEnabled() { + if _, ok := getNetworkFromCache(name); ok { + return ok, nil + } + } if network, err = database.FetchRecord(database.NETWORKS_TABLE_NAME, name); err != nil { return false, err } diff --git a/logic/nodes.go b/logic/nodes.go index 62f49557..7510cdbd 100644 --- a/logic/nodes.go +++ b/logic/nodes.go @@ -116,6 +116,7 @@ func UpdateNodeCheckin(node *models.Node) error { if err != nil { return err } + err = database.Insert(node.ID.String(), string(data), database.NODES_TABLE_NAME) if err != nil { return err @@ -300,6 +301,13 @@ func DeleteNodeByID(node *models.Node) error { if err = DeleteMetrics(node.ID.String()); err != nil { logger.Log(1, "unable to remove metrics from DB for node", node.ID.String(), err.Error()) } + //recycle ip address + if node.Address.IP != nil { + RemoveIpFromAllocatedIpMap(node.Network, node.Address.IP.String()) + } + if node.Address6.IP != nil { + RemoveIpFromAllocatedIpMap(node.Network, node.Address6.IP.String()) + } return nil } @@ -585,6 +593,14 @@ func createNode(node *models.Node) error { if servercfg.CacheEnabled() { storeNodeInCache(*node) } + if _, ok := allocatedIpMap[node.Network]; ok { + if node.Address.IP != nil { + AddIpToAllocatedIpMap(node.Network, node.Address.IP) + } + if node.Address6.IP != nil { + AddIpToAllocatedIpMap(node.Network, node.Address6.IP) + } + } _, err = nodeacls.CreateNodeACL(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), defaultACLVal) if err != nil { logger.Log(1, "failed to create node ACL for node,", node.ID.String(), "err:", err.Error()) diff --git a/logic/peers.go b/logic/peers.go index 4f449fda..77e76a2a 100644 --- a/logic/peers.go +++ b/logic/peers.go @@ -436,7 +436,7 @@ func GetEgressIPs(peer *models.Node) []net.IPNet { peerHost, err := GetHost(peer.HostID.String()) if err != nil { - logger.Log(0, "error retrieving host for peer", peer.ID.String(), err.Error()) + logger.Log(0, "error retrieving host for peer", peer.ID.String(), "host id", peer.HostID.String(), err.Error()) } // check for internet gateway diff --git a/logic/telemetry.go b/logic/telemetry.go index 5902c048..c0a41bab 100644 --- a/logic/telemetry.go +++ b/logic/telemetry.go @@ -13,8 +13,11 @@ import ( "golang.org/x/exp/slog" ) -// flags to keep for telemetry -var isFreeTier bool +var ( + // flags to keep for telemetry + isFreeTier bool + telServerRecord = models.Telemetry{} +) // posthog_pub_key - Key for sending data to PostHog const posthog_pub_key = "phc_1vEXhPOA1P7HP5jP2dVU9xDTUqXHAelmtravyZ1vvES" @@ -125,6 +128,9 @@ func setTelemetryTimestamp(telRecord *models.Telemetry) error { return err } err = database.Insert(database.SERVER_UUID_RECORD_KEY, string(jsonObj), database.SERVER_UUID_TABLE_NAME) + if err == nil { + telServerRecord = serverTelData + } return err } @@ -152,6 +158,9 @@ func getClientCount(nodes []models.Node) clientCount { // FetchTelemetryRecord - get the existing UUID and Timestamp from the DB func FetchTelemetryRecord() (models.Telemetry, error) { + if telServerRecord.TrafficKeyPub != nil { + return telServerRecord, nil + } var rawData string var telObj models.Telemetry var err error @@ -160,6 +169,9 @@ func FetchTelemetryRecord() (models.Telemetry, error) { return telObj, err } err = json.Unmarshal([]byte(rawData), &telObj) + if err == nil { + telServerRecord = telObj + } return telObj, err } diff --git a/main.go b/main.go index 14d32661..ec5a6c3d 100644 --- a/main.go +++ b/main.go @@ -48,6 +48,8 @@ func main() { servercfg.SetVersion(version) fmt.Println(models.RetrieveLogo()) // print the logo initialize() // initial db and acls + logic.SetAllocatedIpMap() + defer logic.ClearAllocatedIpMap() setGarbageCollection() setVerbosity() if servercfg.DeployedByOperator() && !servercfg.IsPro { diff --git a/mq/emqx_on_prem.go b/mq/emqx_on_prem.go index d69067f3..b9cd690c 100644 --- a/mq/emqx_on_prem.go +++ b/mq/emqx_on_prem.go @@ -63,6 +63,7 @@ func getEmqxAuthToken() (string, error) { if err != nil { return "", err } + defer resp.Body.Close() msg, err := io.ReadAll(resp.Body) if err != nil { return "", err @@ -206,7 +207,9 @@ func (e *EmqxOnPrem) CreateEmqxDefaultAuthenticator() error { if err != nil { return err } - return fmt.Errorf("error creating default EMQX authenticator %v", string(msg)) + if !strings.ContainsAny(string(msg), "ALREADY_EXISTS") { + return fmt.Errorf("error creating default EMQX authenticator %v", string(msg)) + } } return nil } @@ -240,7 +243,9 @@ func (e *EmqxOnPrem) CreateEmqxDefaultAuthorizer() error { if err != nil { return err } - return fmt.Errorf("error creating default EMQX ACL authorization mechanism %v", string(msg)) + if !strings.ContainsAny(string(msg), "duplicated_authz_source_type") { + return fmt.Errorf("error creating default EMQX ACL authorization mechanism %v", string(msg)) + } } return nil } diff --git a/mq/mq.go b/mq/mq.go index a143b1ab..bb3a7d53 100644 --- a/mq/mq.go +++ b/mq/mq.go @@ -34,9 +34,9 @@ func setMqOptions(user, password string, opts *mqtt.ClientOptions) { opts.SetAutoReconnect(true) opts.SetConnectRetry(true) opts.SetCleanSession(true) - opts.SetConnectRetryInterval(time.Second * 4) - opts.SetKeepAlive(time.Minute) - opts.SetCleanSession(true) + opts.SetConnectRetryInterval(time.Second * 1) + opts.SetKeepAlive(time.Second * 10) + opts.SetOrderMatters(false) opts.SetWriteTimeout(time.Minute) } @@ -75,19 +75,15 @@ func SetupMQTT(fatal bool) { opts.SetOnConnectHandler(func(client mqtt.Client) { serverName := servercfg.GetServer() if token := client.Subscribe(fmt.Sprintf("update/%s/#", serverName), 0, mqtt.MessageHandler(UpdateNode)); token.WaitTimeout(MQ_TIMEOUT*time.Second) && token.Error() != nil { - client.Disconnect(240) logger.Log(0, "node update subscription failed") } if token := client.Subscribe(fmt.Sprintf("host/serverupdate/%s/#", serverName), 0, mqtt.MessageHandler(UpdateHost)); token.WaitTimeout(MQ_TIMEOUT*time.Second) && token.Error() != nil { - client.Disconnect(240) logger.Log(0, "host update subscription failed") } if token := client.Subscribe(fmt.Sprintf("signal/%s/#", serverName), 0, mqtt.MessageHandler(ClientPeerUpdate)); token.WaitTimeout(MQ_TIMEOUT*time.Second) && token.Error() != nil { - client.Disconnect(240) logger.Log(0, "node client subscription failed") } if token := client.Subscribe(fmt.Sprintf("metrics/%s/#", serverName), 0, mqtt.MessageHandler(UpdateMetrics)); token.WaitTimeout(MQ_TIMEOUT*time.Second) && token.Error() != nil { - client.Disconnect(240) logger.Log(0, "node metrics subscription failed") } diff --git a/mq/publishers.go b/mq/publishers.go index b3f3efda..71bb96af 100644 --- a/mq/publishers.go +++ b/mq/publishers.go @@ -93,7 +93,6 @@ func PublishDeletedClientPeerUpdate(delClient *models.ExtClient) error { // PublishSingleHostPeerUpdate --- determines and publishes a peer update to one host func PublishSingleHostPeerUpdate(host *models.Host, allNodes []models.Node, deletedNode *models.Node, deletedClients []models.ExtClient, replacePeers bool) error { - peerUpdate, err := logic.GetPeerUpdateForHost("", host, allNodes, deletedNode, deletedClients) if err != nil { return err @@ -211,12 +210,6 @@ func PushMetricsToExporter(metrics models.Metrics) error { // sendPeers - retrieve networks, send peer ports to all peers func sendPeers() { - - hosts, err := logic.GetAllHosts() - if err != nil && len(hosts) > 0 { - logger.Log(1, "error retrieving networks for keepalive", err.Error()) - } - peer_force_send++ if peer_force_send == 5 { servercfg.SetHost() diff --git a/mq/util.go b/mq/util.go index d9989109..72cf7160 100644 --- a/mq/util.go +++ b/mq/util.go @@ -9,6 +9,7 @@ import ( "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/netclient/ncutils" + "golang.org/x/exp/slog" ) func decryptMsgWithHost(host *models.Host, msg []byte) ([]byte, error) { @@ -87,6 +88,7 @@ func publish(host *models.Host, dest string, msg []byte) error { if token.Error() == nil { err = errors.New("connection timeout") } else { + slog.Error("publish to mq error", "error", token.Error().Error()) err = token.Error() } return err diff --git a/pro/controllers/failover.go b/pro/controllers/failover.go index 10ae6a02..946e753e 100644 --- a/pro/controllers/failover.go +++ b/pro/controllers/failover.go @@ -44,7 +44,7 @@ func getfailOver(w http.ResponseWriter, r *http.Request) { // confirm host exists node, err := logic.GetNodeByID(nodeid) if err != nil { - slog.Error("failed to get node:", "error", err.Error()) + slog.Error("failed to get node:", "node", nodeid, "error", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } diff --git a/pro/initialize.go b/pro/initialize.go index ffd28c6c..8da33bfc 100644 --- a/pro/initialize.go +++ b/pro/initialize.go @@ -88,7 +88,7 @@ func InitPro() { } else { slog.Error("no OAuth provider found or not configured, continuing without OAuth") } - + proLogic.LoadNodeMetricsToCache() }) logic.ResetFailOver = proLogic.ResetFailOver logic.ResetFailedOverPeer = proLogic.ResetFailedOverPeer diff --git a/pro/logic/metrics.go b/pro/logic/metrics.go index 995470a2..dfcd4d6a 100644 --- a/pro/logic/metrics.go +++ b/pro/logic/metrics.go @@ -3,6 +3,7 @@ package logic import ( "encoding/json" "math" + "sync" "time" mqtt "github.com/eclipse/paho.mqtt.golang" @@ -15,9 +16,64 @@ import ( "golang.org/x/exp/slog" ) +var ( + metricsCacheMutex = &sync.RWMutex{} + metricsCacheMap map[string]models.Metrics +) + +func getMetricsFromCache(key string) (metrics models.Metrics, ok bool) { + metricsCacheMutex.RLock() + metrics, ok = metricsCacheMap[key] + metricsCacheMutex.RUnlock() + return +} + +func storeMetricsInCache(key string, metrics models.Metrics) { + metricsCacheMutex.Lock() + metricsCacheMap[key] = metrics + metricsCacheMutex.Unlock() +} + +func deleteNetworkFromCache(key string) { + metricsCacheMutex.Lock() + delete(metricsCacheMap, key) + metricsCacheMutex.Unlock() +} + +func LoadNodeMetricsToCache() error { + slog.Info("loading metrics to cache") + if metricsCacheMap == nil { + metricsCacheMap = map[string]models.Metrics{} + } + + collection, err := database.FetchRecords(database.METRICS_TABLE_NAME) + if err != nil { + return err + } + + for key, value := range collection { + var metrics models.Metrics + if err := json.Unmarshal([]byte(value), &metrics); err != nil { + slog.Error("parse metric record error", "error", err.Error()) + continue + } + if servercfg.CacheEnabled() { + storeMetricsInCache(key, metrics) + } + } + + slog.Info("metrics loading done") + return nil +} + // GetMetrics - gets the metrics func GetMetrics(nodeid string) (*models.Metrics, error) { var metrics models.Metrics + if servercfg.CacheEnabled() { + if metrics, ok := getMetricsFromCache(nodeid); ok { + return &metrics, nil + } + } record, err := database.FetchRecord(database.METRICS_TABLE_NAME, nodeid) if err != nil { if database.IsEmptyRecord(err) { @@ -29,6 +85,9 @@ func GetMetrics(nodeid string) (*models.Metrics, error) { if err != nil { return &metrics, err } + if servercfg.CacheEnabled() { + storeMetricsInCache(nodeid, metrics) + } return &metrics, nil } @@ -38,12 +97,26 @@ func UpdateMetrics(nodeid string, metrics *models.Metrics) error { if err != nil { return err } - return database.Insert(nodeid, string(data), database.METRICS_TABLE_NAME) + err = database.Insert(nodeid, string(data), database.METRICS_TABLE_NAME) + if err != nil { + return err + } + if servercfg.CacheEnabled() { + storeMetricsInCache(nodeid, *metrics) + } + return nil } // DeleteMetrics - deletes metrics of a given node func DeleteMetrics(nodeid string) error { - return database.DeleteRecord(database.METRICS_TABLE_NAME, nodeid) + err := database.DeleteRecord(database.METRICS_TABLE_NAME, nodeid) + if err != nil { + return err + } + if servercfg.CacheEnabled() { + deleteNetworkFromCache(nodeid) + } + return nil } // MQUpdateMetricsFallBack - called when mq fallback thread is triggered on client