diff --git a/config/config.go b/config/config.go index 95756c9c..044d1a88 100644 --- a/config/config.go +++ b/config/config.go @@ -12,7 +12,7 @@ import ( "gopkg.in/yaml.v3" ) -//setting dev by default +// setting dev by default func getEnv() string { env := os.Getenv("NETMAKER_ENV") @@ -27,13 +27,13 @@ func getEnv() string { // Config : application config stored as global variable var Config *EnvironmentConfig -// EnvironmentConfig : +// EnvironmentConfig - environment conf struct type EnvironmentConfig struct { Server ServerConfig `yaml:"server"` SQL SQLConfig `yaml:"sql"` } -// ServerConfig : +// ServerConfig - server conf struct type ServerConfig struct { CoreDNSAddr string `yaml:"corednsaddr"` APIConnString string `yaml:"apiconn"` @@ -58,8 +58,8 @@ type ServerConfig struct { Version string `yaml:"version"` SQLConn string `yaml:"sqlconn"` Platform string `yaml:"platform"` - Database string `yaml:database` - CheckinInterval string `yaml:checkininterval` + Database string `yaml:"database"` + CheckinInterval string `yaml:"checkininterval"` DefaultNodeLimit int32 `yaml:"defaultnodelimit"` Verbosity int32 `yaml:"verbosity"` ServerCheckinInterval int64 `yaml:"servercheckininterval"` @@ -71,7 +71,7 @@ type ServerConfig struct { AzureTenant string `yaml:"azuretenant"` } -// Generic SQL Config +// SQLConfig - Generic SQL Config type SQLConfig struct { Host string `yaml:"host"` Port int32 `yaml:"port"` @@ -81,7 +81,7 @@ type SQLConfig struct { SSLMode string `yaml:"sslmode"` } -//reading in the env file +// reading in the env file func readConfig() *EnvironmentConfig { file := fmt.Sprintf("config/environments/%s.yaml", getEnv()) f, err := os.Open(file) diff --git a/logic/util.go b/logic/util.go index d7ce5534..d89b0daa 100644 --- a/logic/util.go +++ b/logic/util.go @@ -184,26 +184,18 @@ func GetNode(macaddress string, network string) (models.Node, error) { // GetNodePeers - fetches peers for a given node func GetNodePeers(networkName string, excludeRelayed bool) ([]models.Node, error) { var peers []models.Node - collection, err := database.FetchRecords(database.NODES_TABLE_NAME) + var networkNodes, egressNetworkNodes, err = getNetworkEgressAndNodes(networkName) if err != nil { - if database.IsEmptyRecord(err) { - return peers, nil - } - logger.Log(2, err.Error()) - return nil, err + return peers, nil } + udppeers, errN := database.GetPeers(networkName) if errN != nil { logger.Log(2, errN.Error()) } - for _, value := range collection { - var node = &models.Node{} + + for _, node := range networkNodes { var peer = models.Node{} - err := json.Unmarshal([]byte(value), node) - if err != nil { - logger.Log(2, err.Error()) - continue - } if node.IsEgressGateway == "yes" { // handle egress stuff peer.EgressGatewayRanges = node.EgressGatewayRanges peer.IsEgressGateway = node.IsEgressGateway @@ -211,7 +203,7 @@ func GetNodePeers(networkName string, excludeRelayed bool) ([]models.Node, error allow := node.IsRelayed != "yes" || !excludeRelayed if node.Network == networkName && node.IsPending != "yes" && allow { - peer = setPeerInfo(node) + peer = setPeerInfo(&node) if node.UDPHolePunch == "yes" && errN == nil && CheckEndpoint(udppeers[node.PublicKey]) { endpointstring := udppeers[node.PublicKey] endpointarr := strings.Split(endpointstring, ":") @@ -230,6 +222,11 @@ func GetNodePeers(networkName string, excludeRelayed bool) ([]models.Node, error } else { peer.AllowedIPs = append(peer.AllowedIPs, node.RelayAddrs...) } + for _, egressNode := range egressNetworkNodes { + if egressNode.IsRelayed == "yes" && StringSliceContains(node.RelayAddrs, egressNode.Address) { + peer.AllowedIPs = append(peer.AllowedIPs, egressNode.EgressGatewayRanges...) + } + } } peers = append(peers, peer) } @@ -286,6 +283,34 @@ func RandomString(length int) string { // == Private Methods == +func getNetworkEgressAndNodes(networkName string) ([]models.Node, []models.Node, error) { + var networkNodes, egressNetworkNodes []models.Node + collection, err := database.FetchRecords(database.NODES_TABLE_NAME) + if err != nil { + if database.IsEmptyRecord(err) { + return networkNodes, egressNetworkNodes, nil + } + logger.Log(2, err.Error()) + return nil, nil, err + } + + for _, value := range collection { + var node = models.Node{} + err := json.Unmarshal([]byte(value), &node) + if err != nil { + logger.Log(2, err.Error()) + continue + } + if node.Network == networkName { + networkNodes = append(networkNodes, node) + if node.IsEgressGateway == "yes" { + egressNetworkNodes = append(egressNetworkNodes, node) + } + } + } + return networkNodes, egressNetworkNodes, nil +} + func setPeerInfo(node *models.Node) models.Node { var peer models.Node peer.RelayAddrs = node.RelayAddrs @@ -326,3 +351,13 @@ func setIPForwardingLinux() error { } return nil } + +// StringSliceContains - sees if a string slice contains a string element +func StringSliceContains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +}