NET-1440: scale test changes (#3014)

* NET-1440 scale test changes

* fix UT error and add error info

* load metric data into cacha in startup

* remove debug info for metric

* add server telemetry and hasSuperAdmin to cache

* fix user UT case

* update sqlite connection string for performance

* update check-in TS in cache only if cache enabled

* update metric data in cache only if cache enabled and write to DB once in stop

* update server status in mq topic

* add failover existed to server status update

* only send mq messsage when there is server status change

* batch peerUpdate

* code changes for scale for review

* update UT case

* update mq client check

* mq connection code change

* revert server status update changes

* revert batch peerUpdate

* remove server status update info

* code changes based on review and setupmqtt in keepalive

* set the mq message order to false for PIN

* remove setupmqtt in keepalive

* recycle ip in node deletion

* update ip allocation logic

* remove ip addr cap

* remove ippool file

* update get extClient func

* remove ip from cache map when extClient is removed
This commit is contained in:
Max Ma
2024-08-15 08:29:00 +02:00
committed by GitHub
parent c551c487ca
commit 46b8fd21c8
23 changed files with 388 additions and 53 deletions

View File

@@ -49,12 +49,12 @@ func getEnrollmentKeys(w http.ResponseWriter, r *http.Request) {
ret := []*models.EnrollmentKey{} ret := []*models.EnrollmentKey{}
for _, key := range keys { for _, key := range keys {
key := key key := key
if err = logic.Tokenize(key, servercfg.GetAPIHost()); err != nil { if err = logic.Tokenize(&key, servercfg.GetAPIHost()); err != nil {
logger.Log(0, r.Header.Get("user"), "failed to get token values for keys:", err.Error()) logger.Log(0, r.Header.Get("user"), "failed to get token values for keys:", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
ret = append(ret, key) ret = append(ret, &key)
} }
// return JSON/API formatted keys // return JSON/API formatted keys
logger.Log(2, r.Header.Get("user"), "fetched enrollment keys") logger.Log(2, r.Header.Get("user"), "fetched enrollment keys")

View File

@@ -604,7 +604,7 @@ func authenticateHost(response http.ResponseWriter, request *http.Request) {
errorResponse.Code = http.StatusBadRequest errorResponse.Code = http.StatusBadRequest
errorResponse.Message = err.Error() errorResponse.Message = err.Error()
logger.Log(0, request.Header.Get("user"), logger.Log(0, request.Header.Get("user"),
"error retrieving host: ", err.Error()) "error retrieving host: ", authRequest.ID, err.Error())
logic.ReturnErrorResponse(response, request, errorResponse) logic.ReturnErrorResponse(response, request, errorResponse)
return return
} }

View File

@@ -408,6 +408,8 @@ func deleteNetwork(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, errtype)) logic.ReturnErrorResponse(w, r, logic.FormatError(err, errtype))
return return
} }
//delete network from allocated ip map
go logic.RemoveNetworkFromAllocatedIpMap(network)
logger.Log(1, r.Header.Get("user"), "deleted network", network) logger.Log(1, r.Header.Get("user"), "deleted network", network)
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -480,6 +482,10 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
//add new network to allocated ip map
go logic.AddNetworkToAllocatedIpMap(network.NetID)
go func() { go func() {
defaultHosts := logic.GetDefaultHosts() defaultHosts := logic.GetDefaultHosts()
for i := range defaultHosts { for i := range defaultHosts {

View File

@@ -46,8 +46,8 @@ func TestCreateNetwork(t *testing.T) {
deleteAllNetworks() deleteAllNetworks()
var network models.Network var network models.Network
network.NetID = "skynet" network.NetID = "skynet1"
network.AddressRange = "10.0.0.1/24" network.AddressRange = "10.10.0.1/24"
// if tests break - check here (removed displayname) // if tests break - check here (removed displayname)
//network.DisplayName = "mynetwork" //network.DisplayName = "mynetwork"

View File

@@ -93,7 +93,7 @@ func authenticate(response http.ResponseWriter, request *http.Request) {
errorResponse.Code = http.StatusBadRequest errorResponse.Code = http.StatusBadRequest
errorResponse.Message = err.Error() errorResponse.Message = err.Error()
logger.Log(0, request.Header.Get("user"), logger.Log(0, request.Header.Get("user"),
"error retrieving host: ", err.Error()) "error retrieving host: ", result.HostID.String(), err.Error())
logic.ReturnErrorResponse(response, request, errorResponse) logic.ReturnErrorResponse(response, request, errorResponse)
return return
} }

View File

@@ -137,6 +137,7 @@ func TestCreateUser(t *testing.T) {
func TestCreateSuperAdmin(t *testing.T) { func TestCreateSuperAdmin(t *testing.T) {
deleteAllUsers(t) deleteAllUsers(t)
logic.ClearSuperUserCache()
var user models.User var user models.User
t.Run("NoSuperAdmin", func(t *testing.T) { t.Run("NoSuperAdmin", func(t *testing.T) {
user.UserName = "admin" user.UserName = "admin"

View File

@@ -20,9 +20,21 @@ const (
auth_key = "netmaker_auth" auth_key = "netmaker_auth"
) )
var (
superUser = models.User{}
)
func ClearSuperUserCache() {
superUser = models.User{}
}
// HasSuperAdmin - checks if server has an superadmin/owner // HasSuperAdmin - checks if server has an superadmin/owner
func HasSuperAdmin() (bool, error) { func HasSuperAdmin() (bool, error) {
if superUser.IsSuperAdmin {
return true, nil
}
collection, err := database.FetchRecords(database.USERS_TABLE_NAME) collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
if err != nil { if err != nil {
if database.IsEmptyRecord(err) { if database.IsEmptyRecord(err) {
@@ -38,6 +50,7 @@ func HasSuperAdmin() (bool, error) {
continue continue
} }
if user.IsSuperAdmin { if user.IsSuperAdmin {
superUser = user
return true, nil return true, nil
} }
} }
@@ -116,7 +129,7 @@ func CreateUser(user *models.User) error {
tokenString, _ := CreateUserJWT(user.UserName, user.IsSuperAdmin, user.IsAdmin) tokenString, _ := CreateUserJWT(user.UserName, user.IsSuperAdmin, user.IsAdmin)
if tokenString == "" { if tokenString == "" {
logger.Log(0, "failed to generate token", err.Error()) logger.Log(0, "failed to generate token")
return err return err
} }
@@ -204,6 +217,9 @@ func UpsertUser(user models.User) error {
slog.Error("error inserting user", "user", user.UserName, "error", err.Error()) slog.Error("error inserting user", "user", user.UserName, "error", err.Error())
return err return err
} }
if user.IsSuperAdmin {
superUser = user
}
return nil return nil
} }

View File

@@ -5,11 +5,13 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"sync"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
) )
@@ -29,6 +31,10 @@ var EnrollmentErrors = struct {
FailedToTokenize: fmt.Errorf("failed to tokenize"), FailedToTokenize: fmt.Errorf("failed to tokenize"),
FailedToDeTokenize: fmt.Errorf("failed to detokenize"), FailedToDeTokenize: fmt.Errorf("failed to detokenize"),
} }
var (
enrollmentkeyCacheMutex = &sync.RWMutex{}
enrollmentkeyCacheMap = make(map[string]models.EnrollmentKey)
)
// CreateEnrollmentKey - creates a new enrollment key in db // CreateEnrollmentKey - creates a new enrollment key in db
func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool, relay uuid.UUID) (*models.EnrollmentKey, error) { func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool, relay uuid.UUID) (*models.EnrollmentKey, error) {
@@ -104,21 +110,21 @@ func UpdateEnrollmentKey(keyId string, relayId uuid.UUID) (*models.EnrollmentKey
key.Relay = relayId key.Relay = relayId
if err = upsertEnrollmentKey(key); err != nil { if err = upsertEnrollmentKey(&key); err != nil {
return nil, err return nil, err
} }
return key, nil return &key, nil
} }
// GetAllEnrollmentKeys - fetches all enrollment keys from DB // GetAllEnrollmentKeys - fetches all enrollment keys from DB
// TODO drop double pointer // TODO drop double pointer
func GetAllEnrollmentKeys() ([]*models.EnrollmentKey, error) { func GetAllEnrollmentKeys() ([]models.EnrollmentKey, error) {
currentKeys, err := getEnrollmentKeysMap() currentKeys, err := getEnrollmentKeysMap()
if err != nil { if err != nil {
return nil, err return nil, err
} }
var currentKeysList = []*models.EnrollmentKey{} var currentKeysList = []models.EnrollmentKey{}
for k := range currentKeys { for k := range currentKeys {
currentKeysList = append(currentKeysList, currentKeys[k]) currentKeysList = append(currentKeysList, currentKeys[k])
} }
@@ -127,15 +133,21 @@ func GetAllEnrollmentKeys() ([]*models.EnrollmentKey, error) {
// GetEnrollmentKey - fetches a single enrollment key // GetEnrollmentKey - fetches a single enrollment key
// returns nil and error if not found // returns nil and error if not found
func GetEnrollmentKey(value string) (*models.EnrollmentKey, error) { func GetEnrollmentKey(value string) (key models.EnrollmentKey, err error) {
currentKeys, err := getEnrollmentKeysMap() currentKeys, err := getEnrollmentKeysMap()
if err != nil { if err != nil {
return nil, err return key, err
} }
if key, ok := currentKeys[value]; ok { if key, ok := currentKeys[value]; ok {
return key, nil return key, nil
} }
return nil, EnrollmentErrors.NoKeyFound return key, EnrollmentErrors.NoKeyFound
}
func deleteEnrollmentkeyFromCache(key string) {
enrollmentkeyCacheMutex.Lock()
delete(enrollmentkeyCacheMap, key)
enrollmentkeyCacheMutex.Unlock()
} }
// DeleteEnrollmentKey - delete's a given enrollment key by value // DeleteEnrollmentKey - delete's a given enrollment key by value
@@ -144,7 +156,13 @@ func DeleteEnrollmentKey(value string) error {
if err != nil { if err != nil {
return err return err
} }
return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value) err = database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value)
if err == nil {
if servercfg.CacheEnabled() {
deleteEnrollmentkeyFromCache(value)
}
}
return err
} }
// TryToUseEnrollmentKey - checks first if key can be decremented // TryToUseEnrollmentKey - checks first if key can be decremented
@@ -200,7 +218,7 @@ func DeTokenize(b64Token string) (*models.EnrollmentKey, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return k, nil return &k, nil
} }
// == private == // == private ==
@@ -215,11 +233,11 @@ func decrementEnrollmentKey(value string) (*models.EnrollmentKey, error) {
return nil, EnrollmentErrors.NoUsesRemaining return nil, EnrollmentErrors.NoUsesRemaining
} }
k.UsesRemaining = k.UsesRemaining - 1 k.UsesRemaining = k.UsesRemaining - 1
if err = upsertEnrollmentKey(k); err != nil { if err = upsertEnrollmentKey(&k); err != nil {
return nil, err return nil, err
} }
return k, nil return &k, nil
} }
func upsertEnrollmentKey(k *models.EnrollmentKey) error { func upsertEnrollmentKey(k *models.EnrollmentKey) error {
@@ -230,7 +248,13 @@ func upsertEnrollmentKey(k *models.EnrollmentKey) error {
if err != nil { if err != nil {
return err return err
} }
return database.Insert(k.Value, string(data), database.ENROLLMENT_KEYS_TABLE_NAME) err = database.Insert(k.Value, string(data), database.ENROLLMENT_KEYS_TABLE_NAME)
if err == nil {
if servercfg.CacheEnabled() {
storeEnrollmentkeyInCache(k.Value, *k)
}
}
return nil
} }
func getUniqueEnrollmentID() (string, error) { func getUniqueEnrollmentID() (string, error) {
@@ -245,7 +269,23 @@ func getUniqueEnrollmentID() (string, error) {
return newID, nil return newID, nil
} }
func getEnrollmentKeysMap() (map[string]*models.EnrollmentKey, error) { func getEnrollmentkeysFromCache() map[string]models.EnrollmentKey {
return enrollmentkeyCacheMap
}
func storeEnrollmentkeyInCache(key string, enrollmentkey models.EnrollmentKey) {
enrollmentkeyCacheMutex.Lock()
enrollmentkeyCacheMap[key] = enrollmentkey
enrollmentkeyCacheMutex.Unlock()
}
func getEnrollmentKeysMap() (map[string]models.EnrollmentKey, error) {
if servercfg.CacheEnabled() {
keys := getEnrollmentkeysFromCache()
if len(keys) != 0 {
return keys, nil
}
}
records, err := database.FetchRecords(database.ENROLLMENT_KEYS_TABLE_NAME) records, err := database.FetchRecords(database.ENROLLMENT_KEYS_TABLE_NAME)
if err != nil { if err != nil {
if !database.IsEmptyRecord(err) { if !database.IsEmptyRecord(err) {
@@ -255,14 +295,17 @@ func getEnrollmentKeysMap() (map[string]*models.EnrollmentKey, error) {
if records == nil { if records == nil {
records = make(map[string]string) records = make(map[string]string)
} }
currentKeys := make(map[string]*models.EnrollmentKey, 0) currentKeys := make(map[string]models.EnrollmentKey, 0)
if len(records) > 0 { if len(records) > 0 {
for k := range records { for k := range records {
var currentKey models.EnrollmentKey var currentKey models.EnrollmentKey
if err = json.Unmarshal([]byte(records[k]), &currentKey); err != nil { if err = json.Unmarshal([]byte(records[k]), &currentKey); err != nil {
continue continue
} }
currentKeys[k] = &currentKey currentKeys[k] = currentKey
if servercfg.CacheEnabled() {
storeEnrollmentkeyInCache(currentKey.Value, currentKey)
}
} }
} }
return currentKeys, nil return currentKeys, nil

View File

@@ -68,7 +68,7 @@ func TestDelete_EnrollmentKey(t *testing.T) {
err := DeleteEnrollmentKey(newKey.Value) err := DeleteEnrollmentKey(newKey.Value)
assert.Nil(t, err) assert.Nil(t, err)
oldKey, err := GetEnrollmentKey(newKey.Value) oldKey, err := GetEnrollmentKey(newKey.Value)
assert.Nil(t, oldKey) assert.Equal(t, oldKey, models.EnrollmentKey{})
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, err, EnrollmentErrors.NoKeyFound) assert.Equal(t, err, EnrollmentErrors.NoKeyFound)
}) })

View File

@@ -86,10 +86,21 @@ func DeleteExtClient(network string, clientid string) error {
if err != nil { if err != nil {
return err return err
} }
extClient, err := GetExtClient(clientid, network)
if err != nil {
return err
}
err = database.DeleteRecord(database.EXT_CLIENT_TABLE_NAME, key) err = database.DeleteRecord(database.EXT_CLIENT_TABLE_NAME, key)
if err != nil { if err != nil {
return err return err
} }
//recycle ip address
if extClient.Address != "" {
RemoveIpFromAllocatedIpMap(network, extClient.Address)
}
if extClient.Address6 != "" {
RemoveIpFromAllocatedIpMap(network, extClient.Address6)
}
if servercfg.CacheEnabled() { if servercfg.CacheEnabled() {
deleteExtClientFromCache(key) deleteExtClientFromCache(key)
} }
@@ -287,6 +298,14 @@ func SaveExtClient(extclient *models.ExtClient) error {
if servercfg.CacheEnabled() { if servercfg.CacheEnabled() {
storeExtClientInCache(key, *extclient) storeExtClientInCache(key, *extclient)
} }
if _, ok := allocatedIpMap[extclient.Network]; ok {
if extclient.Address != "" {
AddIpToAllocatedIpMap(extclient.Network, net.ParseIP(extclient.Address))
}
if extclient.Address6 != "" {
AddIpToAllocatedIpMap(extclient.Network, net.ParseIP(extclient.Address6))
}
}
return SetNetworkNodesLastModified(extclient.Network) return SetNetworkNodesLastModified(extclient.Network)
} }

View File

@@ -31,7 +31,7 @@ func SetJWTSecret() {
// CreateJWT func will used to create the JWT while signing in and signing out // CreateJWT func will used to create the JWT while signing in and signing out
func CreateJWT(uuid string, macAddress string, network string) (response string, err error) { func CreateJWT(uuid string, macAddress string, network string) (response string, err error) {
expirationTime := time.Now().Add(5 * time.Minute) expirationTime := time.Now().Add(15 * time.Minute)
claims := &models.Claims{ claims := &models.Claims{
ID: uuid, ID: uuid,
Network: network, Network: network,

View File

@@ -15,13 +15,123 @@ import (
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic/acls/nodeacls" "github.com/gravitl/netmaker/logic/acls/nodeacls"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/servercfg"
"github.com/gravitl/netmaker/validation" "github.com/gravitl/netmaker/validation"
"golang.org/x/exp/slog"
) )
var (
networkCacheMutex = &sync.RWMutex{}
networkCacheMap = make(map[string]models.Network)
allocatedIpMap = make(map[string]map[string]net.IP)
)
// SetAllocatedIpMap - set allocated ip map for networks
func SetAllocatedIpMap() error {
logger.Log(0, "start setting up allocated ip map")
if allocatedIpMap == nil {
allocatedIpMap = map[string]map[string]net.IP{}
}
currentNetworks, err := GetNetworks()
if err != nil {
return err
}
for _, v := range currentNetworks {
pMap := map[string]net.IP{}
netName := v.NetID
nodes, err := GetNetworkNodes(netName)
if err != nil {
slog.Error("could not load node for network", netName, "error", err.Error())
continue
}
for _, n := range nodes {
if n.Address.IP != nil {
pMap[n.Address.IP.String()] = n.Address.IP
}
if n.Address6.IP != nil {
pMap[n.Address6.IP.String()] = n.Address6.IP
}
}
allocatedIpMap[netName] = pMap
}
logger.Log(0, "setting up allocated ip map done")
return nil
}
// ClearAllocatedIpMap - set allocatedIpMap to nil
func ClearAllocatedIpMap() {
allocatedIpMap = nil
}
func AddIpToAllocatedIpMap(networkName string, ip net.IP) {
networkCacheMutex.Lock()
allocatedIpMap[networkName][ip.String()] = ip
networkCacheMutex.Unlock()
}
func RemoveIpFromAllocatedIpMap(networkName string, ip string) {
networkCacheMutex.Lock()
delete(allocatedIpMap[networkName], ip)
networkCacheMutex.Unlock()
}
// AddNetworkToAllocatedIpMap - add network to allocated ip map when network is added
func AddNetworkToAllocatedIpMap(networkName string) {
networkCacheMutex.Lock()
allocatedIpMap[networkName] = map[string]net.IP{}
networkCacheMutex.Unlock()
}
// RemoveNetworkFromAllocatedIpMap - remove network from allocated ip map when network is deleted
func RemoveNetworkFromAllocatedIpMap(networkName string) {
networkCacheMutex.Lock()
delete(allocatedIpMap, networkName)
networkCacheMutex.Unlock()
}
func getNetworksFromCache() (networks []models.Network) {
networkCacheMutex.RLock()
for _, network := range networkCacheMap {
networks = append(networks, network)
}
networkCacheMutex.RUnlock()
return
}
func deleteNetworkFromCache(key string) {
networkCacheMutex.Lock()
delete(networkCacheMap, key)
networkCacheMutex.Unlock()
}
func getNetworkFromCache(key string) (network models.Network, ok bool) {
networkCacheMutex.RLock()
network, ok = networkCacheMap[key]
networkCacheMutex.RUnlock()
return
}
func storeNetworkInCache(key string, network models.Network) {
networkCacheMutex.Lock()
networkCacheMap[key] = network
networkCacheMutex.Unlock()
}
// GetNetworks - returns all networks from database // GetNetworks - returns all networks from database
func GetNetworks() ([]models.Network, error) { func GetNetworks() ([]models.Network, error) {
var networks []models.Network var networks []models.Network
if servercfg.CacheEnabled() {
networks := getNetworksFromCache()
if len(networks) != 0 {
return networks, nil
}
}
collection, err := database.FetchRecords(database.NETWORKS_TABLE_NAME) collection, err := database.FetchRecords(database.NETWORKS_TABLE_NAME)
if err != nil { if err != nil {
return networks, err return networks, err
@@ -34,6 +144,9 @@ func GetNetworks() ([]models.Network, error) {
} }
// add network our array // add network our array
networks = append(networks, network) networks = append(networks, network)
if servercfg.CacheEnabled() {
storeNetworkInCache(network.NetID, network)
}
} }
return networks, err return networks, err
@@ -49,7 +162,14 @@ func DeleteNetwork(network string) error {
nodeCount, err := GetNetworkNonServerNodeCount(network) nodeCount, err := GetNetworkNonServerNodeCount(network)
if nodeCount == 0 || database.IsEmptyRecord(err) { if nodeCount == 0 || database.IsEmptyRecord(err) {
// delete server nodes first then db records // delete server nodes first then db records
return database.DeleteRecord(database.NETWORKS_TABLE_NAME, network) err = database.DeleteRecord(database.NETWORKS_TABLE_NAME, network)
if err != nil {
return err
}
if servercfg.CacheEnabled() {
deleteNetworkFromCache(network)
}
return nil
} }
return errors.New("node check failed. All nodes must be deleted before deleting network") return errors.New("node check failed. All nodes must be deleted before deleting network")
} }
@@ -93,6 +213,9 @@ func CreateNetwork(network models.Network) (models.Network, error) {
if err = database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil { if err = database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
return models.Network{}, err return models.Network{}, err
} }
if servercfg.CacheEnabled() {
storeNetworkInCache(network.NetID, network)
}
return network, nil return network, nil
} }
@@ -128,6 +251,11 @@ func intersect(n1, n2 *net.IPNet) bool {
func GetParentNetwork(networkname string) (models.Network, error) { func GetParentNetwork(networkname string) (models.Network, error) {
var network models.Network var network models.Network
if servercfg.CacheEnabled() {
if network, ok := getNetworkFromCache(networkname); ok {
return network, nil
}
}
networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, networkname) networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, networkname)
if err != nil { if err != nil {
return network, err return network, err
@@ -142,6 +270,11 @@ func GetParentNetwork(networkname string) (models.Network, error) {
func GetNetworkSettings(networkname string) (models.Network, error) { func GetNetworkSettings(networkname string) (models.Network, error) {
var network models.Network var network models.Network
if servercfg.CacheEnabled() {
if network, ok := getNetworkFromCache(networkname); ok {
return network, nil
}
}
networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, networkname) networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, networkname)
if err != nil { if err != nil {
return network, err return network, err
@@ -177,9 +310,9 @@ func UniqueAddress(networkName string, reverse bool) (net.IP, error) {
newAddrs = net4.LastAddress() newAddrs = net4.LastAddress()
} }
ipAllocated := allocatedIpMap[networkName]
for { for {
if IsIPUnique(networkName, newAddrs.String(), database.NODES_TABLE_NAME, false) && if _, ok := ipAllocated[newAddrs.String()]; !ok {
IsIPUnique(networkName, newAddrs.String(), database.EXT_CLIENT_TABLE_NAME, false) {
return newAddrs, nil return newAddrs, nil
} }
if reverse { if reverse {
@@ -266,9 +399,9 @@ func UniqueAddress6(networkName string, reverse bool) (net.IP, error) {
return add, err return add, err
} }
ipAllocated := allocatedIpMap[networkName]
for { for {
if IsIPUnique(networkName, newAddrs.String(), database.NODES_TABLE_NAME, true) && if _, ok := ipAllocated[newAddrs.String()]; !ok {
IsIPUnique(networkName, newAddrs.String(), database.EXT_CLIENT_TABLE_NAME, true) {
return newAddrs, nil return newAddrs, nil
} }
if reverse { if reverse {
@@ -320,6 +453,11 @@ func UpdateNetwork(currentNetwork *models.Network, newNetwork *models.Network) (
} }
newNetwork.SetNetworkLastModified() newNetwork.SetNetworkLastModified()
err = database.Insert(newNetwork.NetID, string(data), database.NETWORKS_TABLE_NAME) err = database.Insert(newNetwork.NetID, string(data), database.NETWORKS_TABLE_NAME)
if err == nil {
if servercfg.CacheEnabled() {
storeNetworkInCache(newNetwork.NetID, *newNetwork)
}
}
return hasrangeupdate4, hasrangeupdate6, hasholepunchupdate, err return hasrangeupdate4, hasrangeupdate6, hasholepunchupdate, err
} }
// copy values // copy values
@@ -330,6 +468,11 @@ func UpdateNetwork(currentNetwork *models.Network, newNetwork *models.Network) (
func GetNetwork(networkname string) (models.Network, error) { func GetNetwork(networkname string) (models.Network, error) {
var network models.Network var network models.Network
if servercfg.CacheEnabled() {
if network, ok := getNetworkFromCache(networkname); ok {
return network, nil
}
}
networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, networkname) networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, networkname)
if err != nil { if err != nil {
return network, err return network, err
@@ -394,6 +537,9 @@ func SaveNetwork(network *models.Network) error {
if err := database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil { if err := database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
return err return err
} }
if servercfg.CacheEnabled() {
storeNetworkInCache(network.NetID, *network)
}
return nil return nil
} }
@@ -402,6 +548,11 @@ func NetworkExists(name string) (bool, error) {
var network string var network string
var err error var err error
if servercfg.CacheEnabled() {
if _, ok := getNetworkFromCache(name); ok {
return ok, nil
}
}
if network, err = database.FetchRecord(database.NETWORKS_TABLE_NAME, name); err != nil { if network, err = database.FetchRecord(database.NETWORKS_TABLE_NAME, name); err != nil {
return false, err return false, err
} }

View File

@@ -116,6 +116,7 @@ func UpdateNodeCheckin(node *models.Node) error {
if err != nil { if err != nil {
return err return err
} }
err = database.Insert(node.ID.String(), string(data), database.NODES_TABLE_NAME) err = database.Insert(node.ID.String(), string(data), database.NODES_TABLE_NAME)
if err != nil { if err != nil {
return err return err
@@ -300,6 +301,13 @@ func DeleteNodeByID(node *models.Node) error {
if err = DeleteMetrics(node.ID.String()); err != nil { if err = DeleteMetrics(node.ID.String()); err != nil {
logger.Log(1, "unable to remove metrics from DB for node", node.ID.String(), err.Error()) logger.Log(1, "unable to remove metrics from DB for node", node.ID.String(), err.Error())
} }
//recycle ip address
if node.Address.IP != nil {
RemoveIpFromAllocatedIpMap(node.Network, node.Address.IP.String())
}
if node.Address6.IP != nil {
RemoveIpFromAllocatedIpMap(node.Network, node.Address6.IP.String())
}
return nil return nil
} }
@@ -585,6 +593,14 @@ func createNode(node *models.Node) error {
if servercfg.CacheEnabled() { if servercfg.CacheEnabled() {
storeNodeInCache(*node) storeNodeInCache(*node)
} }
if _, ok := allocatedIpMap[node.Network]; ok {
if node.Address.IP != nil {
AddIpToAllocatedIpMap(node.Network, node.Address.IP)
}
if node.Address6.IP != nil {
AddIpToAllocatedIpMap(node.Network, node.Address6.IP)
}
}
_, err = nodeacls.CreateNodeACL(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), defaultACLVal) _, err = nodeacls.CreateNodeACL(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), defaultACLVal)
if err != nil { if err != nil {
logger.Log(1, "failed to create node ACL for node,", node.ID.String(), "err:", err.Error()) logger.Log(1, "failed to create node ACL for node,", node.ID.String(), "err:", err.Error())

View File

@@ -436,7 +436,7 @@ func GetEgressIPs(peer *models.Node) []net.IPNet {
peerHost, err := GetHost(peer.HostID.String()) peerHost, err := GetHost(peer.HostID.String())
if err != nil { if err != nil {
logger.Log(0, "error retrieving host for peer", peer.ID.String(), err.Error()) logger.Log(0, "error retrieving host for peer", peer.ID.String(), "host id", peer.HostID.String(), err.Error())
} }
// check for internet gateway // check for internet gateway

View File

@@ -13,8 +13,11 @@ import (
"golang.org/x/exp/slog" "golang.org/x/exp/slog"
) )
// flags to keep for telemetry var (
var isFreeTier bool // flags to keep for telemetry
isFreeTier bool
telServerRecord = models.Telemetry{}
)
// posthog_pub_key - Key for sending data to PostHog // posthog_pub_key - Key for sending data to PostHog
const posthog_pub_key = "phc_1vEXhPOA1P7HP5jP2dVU9xDTUqXHAelmtravyZ1vvES" const posthog_pub_key = "phc_1vEXhPOA1P7HP5jP2dVU9xDTUqXHAelmtravyZ1vvES"
@@ -125,6 +128,9 @@ func setTelemetryTimestamp(telRecord *models.Telemetry) error {
return err return err
} }
err = database.Insert(database.SERVER_UUID_RECORD_KEY, string(jsonObj), database.SERVER_UUID_TABLE_NAME) err = database.Insert(database.SERVER_UUID_RECORD_KEY, string(jsonObj), database.SERVER_UUID_TABLE_NAME)
if err == nil {
telServerRecord = serverTelData
}
return err return err
} }
@@ -152,6 +158,9 @@ func getClientCount(nodes []models.Node) clientCount {
// FetchTelemetryRecord - get the existing UUID and Timestamp from the DB // FetchTelemetryRecord - get the existing UUID and Timestamp from the DB
func FetchTelemetryRecord() (models.Telemetry, error) { func FetchTelemetryRecord() (models.Telemetry, error) {
if telServerRecord.TrafficKeyPub != nil {
return telServerRecord, nil
}
var rawData string var rawData string
var telObj models.Telemetry var telObj models.Telemetry
var err error var err error
@@ -160,6 +169,9 @@ func FetchTelemetryRecord() (models.Telemetry, error) {
return telObj, err return telObj, err
} }
err = json.Unmarshal([]byte(rawData), &telObj) err = json.Unmarshal([]byte(rawData), &telObj)
if err == nil {
telServerRecord = telObj
}
return telObj, err return telObj, err
} }

View File

@@ -48,6 +48,8 @@ func main() {
servercfg.SetVersion(version) servercfg.SetVersion(version)
fmt.Println(models.RetrieveLogo()) // print the logo fmt.Println(models.RetrieveLogo()) // print the logo
initialize() // initial db and acls initialize() // initial db and acls
logic.SetAllocatedIpMap()
defer logic.ClearAllocatedIpMap()
setGarbageCollection() setGarbageCollection()
setVerbosity() setVerbosity()
if servercfg.DeployedByOperator() && !servercfg.IsPro { if servercfg.DeployedByOperator() && !servercfg.IsPro {

View File

@@ -63,6 +63,7 @@ func getEmqxAuthToken() (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
defer resp.Body.Close()
msg, err := io.ReadAll(resp.Body) msg, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return "", err return "", err
@@ -206,7 +207,9 @@ func (e *EmqxOnPrem) CreateEmqxDefaultAuthenticator() error {
if err != nil { if err != nil {
return err return err
} }
return fmt.Errorf("error creating default EMQX authenticator %v", string(msg)) if !strings.ContainsAny(string(msg), "ALREADY_EXISTS") {
return fmt.Errorf("error creating default EMQX authenticator %v", string(msg))
}
} }
return nil return nil
} }
@@ -240,7 +243,9 @@ func (e *EmqxOnPrem) CreateEmqxDefaultAuthorizer() error {
if err != nil { if err != nil {
return err return err
} }
return fmt.Errorf("error creating default EMQX ACL authorization mechanism %v", string(msg)) if !strings.ContainsAny(string(msg), "duplicated_authz_source_type") {
return fmt.Errorf("error creating default EMQX ACL authorization mechanism %v", string(msg))
}
} }
return nil return nil
} }

View File

@@ -34,9 +34,9 @@ func setMqOptions(user, password string, opts *mqtt.ClientOptions) {
opts.SetAutoReconnect(true) opts.SetAutoReconnect(true)
opts.SetConnectRetry(true) opts.SetConnectRetry(true)
opts.SetCleanSession(true) opts.SetCleanSession(true)
opts.SetConnectRetryInterval(time.Second * 4) opts.SetConnectRetryInterval(time.Second * 1)
opts.SetKeepAlive(time.Minute) opts.SetKeepAlive(time.Second * 10)
opts.SetCleanSession(true) opts.SetOrderMatters(false)
opts.SetWriteTimeout(time.Minute) opts.SetWriteTimeout(time.Minute)
} }
@@ -75,19 +75,15 @@ func SetupMQTT(fatal bool) {
opts.SetOnConnectHandler(func(client mqtt.Client) { opts.SetOnConnectHandler(func(client mqtt.Client) {
serverName := servercfg.GetServer() serverName := servercfg.GetServer()
if token := client.Subscribe(fmt.Sprintf("update/%s/#", serverName), 0, mqtt.MessageHandler(UpdateNode)); token.WaitTimeout(MQ_TIMEOUT*time.Second) && token.Error() != nil { if token := client.Subscribe(fmt.Sprintf("update/%s/#", serverName), 0, mqtt.MessageHandler(UpdateNode)); token.WaitTimeout(MQ_TIMEOUT*time.Second) && token.Error() != nil {
client.Disconnect(240)
logger.Log(0, "node update subscription failed") logger.Log(0, "node update subscription failed")
} }
if token := client.Subscribe(fmt.Sprintf("host/serverupdate/%s/#", serverName), 0, mqtt.MessageHandler(UpdateHost)); token.WaitTimeout(MQ_TIMEOUT*time.Second) && token.Error() != nil { if token := client.Subscribe(fmt.Sprintf("host/serverupdate/%s/#", serverName), 0, mqtt.MessageHandler(UpdateHost)); token.WaitTimeout(MQ_TIMEOUT*time.Second) && token.Error() != nil {
client.Disconnect(240)
logger.Log(0, "host update subscription failed") logger.Log(0, "host update subscription failed")
} }
if token := client.Subscribe(fmt.Sprintf("signal/%s/#", serverName), 0, mqtt.MessageHandler(ClientPeerUpdate)); token.WaitTimeout(MQ_TIMEOUT*time.Second) && token.Error() != nil { if token := client.Subscribe(fmt.Sprintf("signal/%s/#", serverName), 0, mqtt.MessageHandler(ClientPeerUpdate)); token.WaitTimeout(MQ_TIMEOUT*time.Second) && token.Error() != nil {
client.Disconnect(240)
logger.Log(0, "node client subscription failed") logger.Log(0, "node client subscription failed")
} }
if token := client.Subscribe(fmt.Sprintf("metrics/%s/#", serverName), 0, mqtt.MessageHandler(UpdateMetrics)); token.WaitTimeout(MQ_TIMEOUT*time.Second) && token.Error() != nil { if token := client.Subscribe(fmt.Sprintf("metrics/%s/#", serverName), 0, mqtt.MessageHandler(UpdateMetrics)); token.WaitTimeout(MQ_TIMEOUT*time.Second) && token.Error() != nil {
client.Disconnect(240)
logger.Log(0, "node metrics subscription failed") logger.Log(0, "node metrics subscription failed")
} }

View File

@@ -93,7 +93,6 @@ func PublishDeletedClientPeerUpdate(delClient *models.ExtClient) error {
// PublishSingleHostPeerUpdate --- determines and publishes a peer update to one host // PublishSingleHostPeerUpdate --- determines and publishes a peer update to one host
func PublishSingleHostPeerUpdate(host *models.Host, allNodes []models.Node, deletedNode *models.Node, deletedClients []models.ExtClient, replacePeers bool) error { func PublishSingleHostPeerUpdate(host *models.Host, allNodes []models.Node, deletedNode *models.Node, deletedClients []models.ExtClient, replacePeers bool) error {
peerUpdate, err := logic.GetPeerUpdateForHost("", host, allNodes, deletedNode, deletedClients) peerUpdate, err := logic.GetPeerUpdateForHost("", host, allNodes, deletedNode, deletedClients)
if err != nil { if err != nil {
return err return err
@@ -211,12 +210,6 @@ func PushMetricsToExporter(metrics models.Metrics) error {
// sendPeers - retrieve networks, send peer ports to all peers // sendPeers - retrieve networks, send peer ports to all peers
func sendPeers() { func sendPeers() {
hosts, err := logic.GetAllHosts()
if err != nil && len(hosts) > 0 {
logger.Log(1, "error retrieving networks for keepalive", err.Error())
}
peer_force_send++ peer_force_send++
if peer_force_send == 5 { if peer_force_send == 5 {
servercfg.SetHost() servercfg.SetHost()

View File

@@ -9,6 +9,7 @@ import (
"github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/netclient/ncutils"
"golang.org/x/exp/slog"
) )
func decryptMsgWithHost(host *models.Host, msg []byte) ([]byte, error) { func decryptMsgWithHost(host *models.Host, msg []byte) ([]byte, error) {
@@ -87,6 +88,7 @@ func publish(host *models.Host, dest string, msg []byte) error {
if token.Error() == nil { if token.Error() == nil {
err = errors.New("connection timeout") err = errors.New("connection timeout")
} else { } else {
slog.Error("publish to mq error", "error", token.Error().Error())
err = token.Error() err = token.Error()
} }
return err return err

View File

@@ -44,7 +44,7 @@ func getfailOver(w http.ResponseWriter, r *http.Request) {
// confirm host exists // confirm host exists
node, err := logic.GetNodeByID(nodeid) node, err := logic.GetNodeByID(nodeid)
if err != nil { if err != nil {
slog.Error("failed to get node:", "error", err.Error()) slog.Error("failed to get node:", "node", nodeid, "error", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }

View File

@@ -88,7 +88,7 @@ func InitPro() {
} else { } else {
slog.Error("no OAuth provider found or not configured, continuing without OAuth") slog.Error("no OAuth provider found or not configured, continuing without OAuth")
} }
proLogic.LoadNodeMetricsToCache()
}) })
logic.ResetFailOver = proLogic.ResetFailOver logic.ResetFailOver = proLogic.ResetFailOver
logic.ResetFailedOverPeer = proLogic.ResetFailedOverPeer logic.ResetFailedOverPeer = proLogic.ResetFailedOverPeer

View File

@@ -3,6 +3,7 @@ package logic
import ( import (
"encoding/json" "encoding/json"
"math" "math"
"sync"
"time" "time"
mqtt "github.com/eclipse/paho.mqtt.golang" mqtt "github.com/eclipse/paho.mqtt.golang"
@@ -15,9 +16,64 @@ import (
"golang.org/x/exp/slog" "golang.org/x/exp/slog"
) )
var (
metricsCacheMutex = &sync.RWMutex{}
metricsCacheMap map[string]models.Metrics
)
func getMetricsFromCache(key string) (metrics models.Metrics, ok bool) {
metricsCacheMutex.RLock()
metrics, ok = metricsCacheMap[key]
metricsCacheMutex.RUnlock()
return
}
func storeMetricsInCache(key string, metrics models.Metrics) {
metricsCacheMutex.Lock()
metricsCacheMap[key] = metrics
metricsCacheMutex.Unlock()
}
func deleteNetworkFromCache(key string) {
metricsCacheMutex.Lock()
delete(metricsCacheMap, key)
metricsCacheMutex.Unlock()
}
func LoadNodeMetricsToCache() error {
slog.Info("loading metrics to cache")
if metricsCacheMap == nil {
metricsCacheMap = map[string]models.Metrics{}
}
collection, err := database.FetchRecords(database.METRICS_TABLE_NAME)
if err != nil {
return err
}
for key, value := range collection {
var metrics models.Metrics
if err := json.Unmarshal([]byte(value), &metrics); err != nil {
slog.Error("parse metric record error", "error", err.Error())
continue
}
if servercfg.CacheEnabled() {
storeMetricsInCache(key, metrics)
}
}
slog.Info("metrics loading done")
return nil
}
// GetMetrics - gets the metrics // GetMetrics - gets the metrics
func GetMetrics(nodeid string) (*models.Metrics, error) { func GetMetrics(nodeid string) (*models.Metrics, error) {
var metrics models.Metrics var metrics models.Metrics
if servercfg.CacheEnabled() {
if metrics, ok := getMetricsFromCache(nodeid); ok {
return &metrics, nil
}
}
record, err := database.FetchRecord(database.METRICS_TABLE_NAME, nodeid) record, err := database.FetchRecord(database.METRICS_TABLE_NAME, nodeid)
if err != nil { if err != nil {
if database.IsEmptyRecord(err) { if database.IsEmptyRecord(err) {
@@ -29,6 +85,9 @@ func GetMetrics(nodeid string) (*models.Metrics, error) {
if err != nil { if err != nil {
return &metrics, err return &metrics, err
} }
if servercfg.CacheEnabled() {
storeMetricsInCache(nodeid, metrics)
}
return &metrics, nil return &metrics, nil
} }
@@ -38,12 +97,26 @@ func UpdateMetrics(nodeid string, metrics *models.Metrics) error {
if err != nil { if err != nil {
return err return err
} }
return database.Insert(nodeid, string(data), database.METRICS_TABLE_NAME) err = database.Insert(nodeid, string(data), database.METRICS_TABLE_NAME)
if err != nil {
return err
}
if servercfg.CacheEnabled() {
storeMetricsInCache(nodeid, *metrics)
}
return nil
} }
// DeleteMetrics - deletes metrics of a given node // DeleteMetrics - deletes metrics of a given node
func DeleteMetrics(nodeid string) error { func DeleteMetrics(nodeid string) error {
return database.DeleteRecord(database.METRICS_TABLE_NAME, nodeid) err := database.DeleteRecord(database.METRICS_TABLE_NAME, nodeid)
if err != nil {
return err
}
if servercfg.CacheEnabled() {
deleteNetworkFromCache(nodeid)
}
return nil
} }
// MQUpdateMetricsFallBack - called when mq fallback thread is triggered on client // MQUpdateMetricsFallBack - called when mq fallback thread is triggered on client