diff --git a/controllers/dns_test.go b/controllers/dns_test.go index a55571f4..401aeb74 100644 --- a/controllers/dns_test.go +++ b/controllers/dns_test.go @@ -6,7 +6,6 @@ import ( "testing" "github.com/google/uuid" - "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/stretchr/testify/assert" @@ -16,7 +15,6 @@ import ( var dnsHost models.Host func TestGetAllDNS(t *testing.T) { - database.InitializeDatabase() deleteAllDNS(t) deleteAllNetworks() createNet() @@ -47,7 +45,6 @@ func TestGetAllDNS(t *testing.T) { } func TestGetNodeDNS(t *testing.T) { - database.InitializeDatabase() deleteAllDNS(t) deleteAllNetworks() createNet() @@ -94,7 +91,6 @@ func TestGetNodeDNS(t *testing.T) { }) } func TestGetCustomDNS(t *testing.T) { - database.InitializeDatabase() deleteAllDNS(t) deleteAllNetworks() t.Run("NoNetworks", func(t *testing.T) { @@ -133,7 +129,6 @@ func TestGetCustomDNS(t *testing.T) { } func TestGetDNSEntryNum(t *testing.T) { - database.InitializeDatabase() deleteAllDNS(t) deleteAllNetworks() createNet() @@ -152,7 +147,6 @@ func TestGetDNSEntryNum(t *testing.T) { }) } func TestGetDNS(t *testing.T) { - database.InitializeDatabase() deleteAllDNS(t) deleteAllNetworks() createNet() @@ -196,7 +190,6 @@ func TestGetDNS(t *testing.T) { } func TestCreateDNS(t *testing.T) { - database.InitializeDatabase() deleteAllDNS(t) deleteAllNetworks() createNet() @@ -207,7 +200,6 @@ func TestCreateDNS(t *testing.T) { } func TestSetDNS(t *testing.T) { - database.InitializeDatabase() deleteAllDNS(t) deleteAllNetworks() t.Run("NoNetworks", func(t *testing.T) { @@ -255,7 +247,6 @@ func TestSetDNS(t *testing.T) { } func TestGetDNSEntry(t *testing.T) { - database.InitializeDatabase() deleteAllDNS(t) deleteAllNetworks() createNet() @@ -285,7 +276,6 @@ func TestGetDNSEntry(t *testing.T) { } func TestDeleteDNS(t *testing.T) { - database.InitializeDatabase() deleteAllDNS(t) deleteAllNetworks() createNet() @@ -307,7 +297,6 @@ func TestDeleteDNS(t *testing.T) { } func TestValidateDNSUpdate(t *testing.T) { - database.InitializeDatabase() deleteAllDNS(t) deleteAllNetworks() createNet() @@ -369,7 +358,6 @@ func TestValidateDNSUpdate(t *testing.T) { } func TestValidateDNSCreate(t *testing.T) { - database.InitializeDatabase() _ = logic.DeleteDNS("mynode", "skynet") t.Run("NoNetwork", func(t *testing.T) { entry := models.DNSEntry{"10.0.0.2", "", "myhost", "badnet"} diff --git a/controllers/network_test.go b/controllers/network_test.go index 995349be..0d19c7df 100644 --- a/controllers/network_test.go +++ b/controllers/network_test.go @@ -1,11 +1,13 @@ package controller import ( + "context" "os" "testing" "github.com/google/uuid" "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/stretchr/testify/assert" @@ -20,8 +22,27 @@ type NetworkValidationTestCase struct { var netHost models.Host +func TestMain(m *testing.M) { + database.InitializeDatabase() + defer database.CloseDB() + logic.CreateAdmin(&models.User{ + UserName: "admin", + Password: "password", + IsAdmin: true, + Networks: []string{}, + Groups: []string{}, + }) + peerUpdate := make(chan *models.Node) + go logic.ManageZombies(context.Background(), peerUpdate) + go func() { + for update := range peerUpdate { + //do nothing + logger.Log(3, "received node update", update.Action) + } + }() +} + func TestCreateNetwork(t *testing.T) { - initialize() deleteAllNetworks() var network models.Network @@ -34,7 +55,6 @@ func TestCreateNetwork(t *testing.T) { assert.Nil(t, err) } func TestGetNetwork(t *testing.T) { - initialize() createNet() t.Run("GetExistingNetwork", func(t *testing.T) { @@ -50,7 +70,6 @@ func TestGetNetwork(t *testing.T) { } func TestDeleteNetwork(t *testing.T) { - initialize() createNet() //create nodes t.Run("NetworkwithNodes", func(t *testing.T) { @@ -66,7 +85,6 @@ func TestDeleteNetwork(t *testing.T) { } func TestCreateKey(t *testing.T) { - initialize() createNet() keys, _ := logic.GetKeys("skynet") for _, key := range keys { @@ -138,7 +156,6 @@ func TestCreateKey(t *testing.T) { } func TestGetKeys(t *testing.T) { - initialize() deleteAllNetworks() createNet() network, err := logic.GetNetwork("skynet") @@ -161,7 +178,6 @@ func TestGetKeys(t *testing.T) { }) } func TestDeleteKey(t *testing.T) { - initialize() createNet() network, err := logic.GetNetwork("skynet") assert.Nil(t, err) @@ -183,7 +199,6 @@ func TestDeleteKey(t *testing.T) { func TestSecurityCheck(t *testing.T) { //these seem to work but not sure it the tests are really testing the functionality - initialize() os.Setenv("MASTER_KEY", "secretkey") t.Run("NoNetwork", func(t *testing.T) { networks, username, err := logic.UserPermissions(false, "", "Bearer secretkey") @@ -214,7 +229,6 @@ func TestValidateNetwork(t *testing.T) { //t.Skip() //This functions is not called by anyone //it panics as validation function 'display_name_valid' is not defined - initialize() //yes := true //no := false //deleteNet(t) @@ -291,7 +305,6 @@ func TestValidateNetwork(t *testing.T) { func TestIpv6Network(t *testing.T) { //these seem to work but not sure it the tests are really testing the functionality - initialize() os.Setenv("MASTER_KEY", "secretkey") deleteAllNetworks() createNet() @@ -318,21 +331,6 @@ func deleteAllNetworks() { } } -func initialize() { - database.InitializeDatabase() - createAdminUser() -} - -func createAdminUser() { - logic.CreateAdmin(&models.User{ - UserName: "admin", - Password: "password", - IsAdmin: true, - Networks: []string{}, - Groups: []string{}, - }) -} - func createNet() { var network models.Network network.NetID = "skynet" diff --git a/controllers/node_test.go b/controllers/node_test.go index b16c3c9e..bb1e5f00 100644 --- a/controllers/node_test.go +++ b/controllers/node_test.go @@ -21,7 +21,6 @@ func TestCreateEgressGateway(t *testing.T) { var gateway models.EgressGatewayRequest gateway.Ranges = []string{"10.100.100.0/24"} gateway.NetID = "skynet" - database.InitializeDatabase() deleteAllNetworks() createNet() t.Run("NoNodes", func(t *testing.T) { @@ -78,7 +77,6 @@ func TestCreateEgressGateway(t *testing.T) { } func TestDeleteEgressGateway(t *testing.T) { var gateway models.EgressGatewayRequest - database.InitializeDatabase() deleteAllNetworks() createNet() testnode := createTestNode() @@ -110,7 +108,6 @@ func TestDeleteEgressGateway(t *testing.T) { } func TestGetNetworkNodes(t *testing.T) { - database.InitializeDatabase() deleteAllNetworks() createNet() t.Run("BadNet", func(t *testing.T) { diff --git a/controllers/user_test.go b/controllers/user_test.go index 629964e4..4c7eb59b 100644 --- a/controllers/user_test.go +++ b/controllers/user_test.go @@ -3,7 +3,6 @@ package controller import ( "testing" - "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/stretchr/testify/assert" @@ -18,7 +17,6 @@ func deleteAllUsers() { func TestHasAdmin(t *testing.T) { //delete all current users - database.InitializeDatabase() users, _ := logic.GetUsers() for _, user := range users { success, err := logic.DeleteUser(user.UserName) @@ -48,7 +46,7 @@ func TestHasAdmin(t *testing.T) { }) t.Run("multiple admins", func(t *testing.T) { var user = models.User{"admin1", "password", nil, true, nil} - err := logic.CreateUser(&user) + err := logic.CreateUser(&user) assert.Nil(t, err) found, err := logic.HasAdmin() assert.Nil(t, err) @@ -57,7 +55,6 @@ func TestHasAdmin(t *testing.T) { } func TestCreateUser(t *testing.T) { - database.InitializeDatabase() deleteAllUsers() user := models.User{"admin", "password", nil, true, nil} t.Run("NoUser", func(t *testing.T) { @@ -72,7 +69,6 @@ func TestCreateUser(t *testing.T) { } func TestCreateAdmin(t *testing.T) { - database.InitializeDatabase() deleteAllUsers() var user models.User t.Run("NoAdmin", func(t *testing.T) { @@ -90,7 +86,6 @@ func TestCreateAdmin(t *testing.T) { } func TestDeleteUser(t *testing.T) { - database.InitializeDatabase() deleteAllUsers() t.Run("NonExistent User", func(t *testing.T) { deleted, err := logic.DeleteUser("admin") @@ -107,7 +102,6 @@ func TestDeleteUser(t *testing.T) { } func TestValidateUser(t *testing.T) { - database.InitializeDatabase() var user models.User t.Run("Valid Create", func(t *testing.T) { user.UserName = "admin" @@ -155,7 +149,6 @@ func TestValidateUser(t *testing.T) { } func TestGetUser(t *testing.T) { - database.InitializeDatabase() deleteAllUsers() t.Run("NonExistantUser", func(t *testing.T) { admin, err := logic.GetUser("admin") @@ -172,7 +165,6 @@ func TestGetUser(t *testing.T) { } func TestGetUsers(t *testing.T) { - database.InitializeDatabase() deleteAllUsers() t.Run("NonExistantUser", func(t *testing.T) { admin, err := logic.GetUsers() @@ -203,7 +195,6 @@ func TestGetUsers(t *testing.T) { } func TestUpdateUser(t *testing.T) { - database.InitializeDatabase() deleteAllUsers() user := models.User{"admin", "password", nil, true, nil} newuser := models.User{"hello", "world", []string{"wirecat, netmaker"}, true, []string{}} @@ -246,7 +237,6 @@ func TestUpdateUser(t *testing.T) { // } func TestVerifyAuthRequest(t *testing.T) { - database.InitializeDatabase() deleteAllUsers() var authRequest models.UserAuthParams t.Run("EmptyUserName", func(t *testing.T) { diff --git a/functions/helpers_test.go b/functions/helpers_test.go index e2737f48..220ecaad 100644 --- a/functions/helpers_test.go +++ b/functions/helpers_test.go @@ -1,10 +1,12 @@ package functions import ( + "context" "encoding/json" "testing" "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" ) @@ -19,11 +21,27 @@ var ( } ) +func TestMain(m *testing.M) { + database.InitializeDatabase() + defer database.CloseDB() + logic.CreateAdmin(&models.User{ + UserName: "admin", + Password: "password", + IsAdmin: true, + Networks: []string{}, + Groups: []string{}, + }) + peerUpdate := make(chan *models.Node) + go logic.ManageZombies(context.Background(), peerUpdate) + go func() { + for update := range peerUpdate { + //do nothing + logger.Log(3, "received node update", update.Action) + } + }() +} + func TestNetworkExists(t *testing.T) { - err := database.InitializeDatabase() - if err != nil { - t.Fatalf("error initilizing database: %s", err) - } database.DeleteRecord(database.NETWORKS_TABLE_NAME, testNetwork.NetID) defer database.CloseDB() exists, err := logic.NetworkExists(testNetwork.NetID) @@ -53,10 +71,6 @@ func TestNetworkExists(t *testing.T) { } func TestGetAllExtClients(t *testing.T) { - err := database.InitializeDatabase() - if err != nil { - t.Fatalf("error initilizing database: %s", err) - } defer database.CloseDB() database.DeleteRecord(database.EXT_CLIENT_TABLE_NAME, testExternalClient.ClientID) diff --git a/logic/host_test.go b/logic/host_test.go index 75ff7a16..fdde345e 100644 --- a/logic/host_test.go +++ b/logic/host_test.go @@ -1,17 +1,38 @@ package logic import ( + "context" "net" "testing" "github.com/google/uuid" "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/models" "github.com/matryer/is" ) -func TestCheckPorts(t *testing.T) { +func TestMain(m *testing.M) { database.InitializeDatabase() + defer database.CloseDB() + CreateAdmin(&models.User{ + UserName: "admin", + Password: "password", + IsAdmin: true, + Networks: []string{}, + Groups: []string{}, + }) + peerUpdate := make(chan *models.Node) + go ManageZombies(context.Background(), peerUpdate) + go func() { + for update := range peerUpdate { + //do nothing + logger.Log(3, "received node update", update.Action) + } + }() +} + +func TestCheckPorts(t *testing.T) { h := models.Host{ ID: uuid.New(), EndpointIP: net.ParseIP("192.168.1.1"), diff --git a/logic/hosts.go b/logic/hosts.go index f6d14391..de05caa7 100644 --- a/logic/hosts.go +++ b/logic/hosts.go @@ -96,6 +96,7 @@ func CreateHost(h *models.Host) error { return err } h.HostPass = string(hash) + checkForZombieHosts(h) return UpsertHost(h) } diff --git a/logic/nodes.go b/logic/nodes.go index 0e30cea9..1ef27229 100644 --- a/logic/nodes.go +++ b/logic/nodes.go @@ -534,7 +534,7 @@ func createNode(node *models.Node) error { if err != nil { return err } - CheckZombies(node, host.MacAddress) + CheckZombies(node) nodebytes, err := json.Marshal(&node) if err != nil { diff --git a/logic/pro/networkuser_test.go b/logic/pro/networkuser_test.go index ac9994dd..2568ae07 100644 --- a/logic/pro/networkuser_test.go +++ b/logic/pro/networkuser_test.go @@ -10,8 +10,12 @@ import ( "github.com/stretchr/testify/assert" ) -func TestNetworkUserLogic(t *testing.T) { +func TestMain(m *testing.M) { database.InitializeDatabase() + defer database.CloseDB() +} + +func TestNetworkUserLogic(t *testing.T) { networkUser := promodels.NetworkUser{ ID: "helloworld", } diff --git a/logic/pro/usergroups_test.go b/logic/pro/usergroups_test.go index cd472e25..3ca32cef 100644 --- a/logic/pro/usergroups_test.go +++ b/logic/pro/usergroups_test.go @@ -3,13 +3,11 @@ package pro import ( "testing" - "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/models/promodels" "github.com/stretchr/testify/assert" ) func TestUserGroupLogic(t *testing.T) { - database.InitializeDatabase() t.Run("User Groups initialized successfully", func(t *testing.T) { err := InitializeGroups() diff --git a/logic/zombie.go b/logic/zombie.go index 3b0abfb9..488bfaa5 100644 --- a/logic/zombie.go +++ b/logic/zombie.go @@ -2,7 +2,6 @@ package logic import ( "context" - "net" "time" "github.com/google/uuid" @@ -18,15 +17,16 @@ const ( ) var ( - zombies []uuid.UUID - removeZombie chan uuid.UUID = make(chan (uuid.UUID), 10) - newZombie chan uuid.UUID = make(chan (uuid.UUID), 10) + zombies []uuid.UUID + hostZombies []uuid.UUID + newZombie chan uuid.UUID = make(chan (uuid.UUID), 10) + newHostZombie chan uuid.UUID = make(chan (uuid.UUID), 10) ) -// CheckZombies - checks if new node has same macaddress as existing node +// CheckZombies - checks if new node has same hostid as existing node // if so, existing node is added to zombie node quarantine list // also cleans up nodes past their expiration date -func CheckZombies(newnode *models.Node, mac net.HardwareAddr) { +func CheckZombies(newnode *models.Node) { nodes, err := GetNetworkNodes(newnode.Network) if err != nil { logger.Log(1, "Failed to retrieve network nodes", newnode.Network, err.Error()) @@ -44,6 +44,35 @@ func CheckZombies(newnode *models.Node, mac net.HardwareAddr) { } } +// checkForZombieHosts - checks if new host has the same macAddress as an existing host +// if true, existing host is added to host zombie collection +func checkForZombieHosts(h *models.Host) { + hosts, err := GetAllHosts() + if err != nil { + logger.Log(3, "errror retrieving all hosts", err.Error()) + } + for _, existing := range hosts { + if existing.ID == h.ID { + //probably an unnecessary check as new host should not be in database yet, but just in case + //skip self + continue + } + if existing.MacAddress.String() == h.MacAddress.String() { + //add to hostZombies + newHostZombie <- existing.ID + //add all nodes belonging to host to zombile list + for _, node := range existing.Nodes { + id, err := uuid.Parse(node) + if err != nil { + logger.Log(3, "error parsing uuid from host.Nodes", err.Error()) + continue + } + newHostZombie <- id + } + } + } +} + // ManageZombies - goroutine which adds/removes/deletes nodes from the zombie node quarantine list func ManageZombies(ctx context.Context, peerUpdate chan *models.Node) { logger.Log(2, "Zombie management started") @@ -51,24 +80,12 @@ func ManageZombies(ctx context.Context, peerUpdate chan *models.Node) { for { select { case <-ctx.Done(): + close(peerUpdate) return case id := <-newZombie: - logger.Log(1, "adding", id.String(), "to zombie quaratine list") zombies = append(zombies, id) - case id := <-removeZombie: - found := false - if len(zombies) > 0 { - for i := len(zombies) - 1; i >= 0; i-- { - if zombies[i] == id { - logger.Log(1, "removing zombie from quaratine list", zombies[i].String()) - zombies = append(zombies[:i], zombies[i+1:]...) - found = true - } - } - } - if !found { - logger.Log(3, "no zombies found") - } + case id := <-newHostZombie: + hostZombies = append(hostZombies, id) case <-time.After(time.Second * ZOMBIE_TIMEOUT): logger.Log(3, "checking for zombie nodes") if len(zombies) > 0 { @@ -92,6 +109,23 @@ func ManageZombies(ctx context.Context, peerUpdate chan *models.Node) { } } } + if len(hostZombies) > 0 { + logger.Log(3, "checking host zombies") + for i := len(hostZombies) - 1; i >= 0; i-- { + host, err := GetHost(hostZombies[i].String()) + if err != nil { + logger.Log(1, "error retrieving zombie host", err.Error()) + logger.Log(1, "deleting ", host.ID.String(), " from zombie list") + zombies = append(zombies[:i], zombies[i+1:]...) + continue + } + if len(host.Nodes) == 0 { + if err := RemoveHost(host); err != nil { + logger.Log(0, "error deleting zombie host", host.ID.String(), err.Error()) + } + } + } + } } } } @@ -115,10 +149,10 @@ func InitializeZombies() { } if node.HostID == othernode.HostID { if node.LastCheckIn.After(othernode.LastCheckIn) { - zombies = append(zombies, othernode.ID) + newZombie <- othernode.ID logger.Log(1, "adding", othernode.ID.String(), "to zombie list") } else { - zombies = append(zombies, node.ID) + newZombie <- node.ID logger.Log(1, "adding", node.ID.String(), "to zombie list") } } diff --git a/models/network_test.go b/models/network_test.go index bd79398d..3db60cc3 100644 --- a/models/network_test.go +++ b/models/network_test.go @@ -2,7 +2,7 @@ package models // moved from controllers need work //func TestUpdateNetwork(t *testing.T) { -// database.InitializeDatabase() +// initialize() // createNet() // network := getNet() // t.Run("NetID", func(t *testing.T) {