diff --git a/controllers/config/dnsconfig/netmaker.hosts b/controllers/config/dnsconfig/netmaker.hosts index 655eaef6..45eab425 100644 --- a/controllers/config/dnsconfig/netmaker.hosts +++ b/controllers/config/dnsconfig/netmaker.hosts @@ -1,2 +1 @@ -10.0.0.1 testnode.skynet -10.0.0.2 myhost.skynet +10.0.0.2 testnode.skynet myhost.skynet diff --git a/controllers/node_test.go b/controllers/node_test.go index d4482c5a..28eb0300 100644 --- a/controllers/node_test.go +++ b/controllers/node_test.go @@ -5,6 +5,8 @@ import ( "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/logic/acls" + nodeacls "github.com/gravitl/netmaker/logic/acls/node-acls" "github.com/gravitl/netmaker/models" "github.com/stretchr/testify/assert" ) @@ -143,6 +145,61 @@ func TestValidateEgressGateway(t *testing.T) { }) } +func TestNodeACLs(t *testing.T) { + deleteAllNodes() + node1 := models.Node{PublicKey: "DM5qhLAE20PG9BbfBCger+Ac9D2NDOwCtY1rbYDLf34=", Name: "testnode", Endpoint: "10.0.0.50", MacAddress: "01:02:03:04:05:06", Password: "password", Network: "skynet", OS: "linux"} + node2 := models.Node{PublicKey: "DM5qhLAE20FG7BbfBCger+Ac9D2NDOwCtY1rbYDXf14=", Name: "testnode", Endpoint: "10.0.0.100", MacAddress: "01:02:03:04:05:07", Password: "password", Network: "skynet", OS: "linux"} + logic.CreateNode(&node1) + logic.CreateNode(&node2) + currentACL, err := nodeacls.CreateNetworkACL(acls.NetworkID(node1.Network)) + t.Run("acls not present", func(t *testing.T) { + assert.Nil(t, err) + assert.Nil(t, currentACL[acls.NodeID(node1.ID)]) + assert.Nil(t, currentACL[acls.NodeID(node2.ID)]) + node1ACL, err := nodeacls.FetchNodeACL(acls.NetworkID(node1.Network), acls.NodeID(node1.ID)) + assert.NotNil(t, err) + assert.Nil(t, node1ACL) + assert.EqualError(t, err, "no node ACL present for node "+node1.ID) + }) + t.Run("node acls exists after creates", func(t *testing.T) { + node1ACL, err := nodeacls.CreateNodeACL(acls.NetworkID(node1.Network), acls.NodeID(node1.ID), acls.Allowed) + assert.Nil(t, err) + assert.NotNil(t, node1ACL) + assert.Equal(t, node1ACL[acls.NodeID(node2.ID)], acls.NotPresent) + node2ACL, err := nodeacls.CreateNodeACL(acls.NetworkID(node1.Network), acls.NodeID(node2.ID), acls.Allowed) + assert.Nil(t, err) + assert.NotNil(t, node2ACL) + assert.Equal(t, acls.Allowed, node2ACL[acls.NodeID(node1.ID)]) + }) + t.Run("node acls correct after fetch", func(t *testing.T) { + node1ACL, err := nodeacls.FetchNodeACL(acls.NetworkID(node1.Network), acls.NodeID(node1.ID)) + assert.Nil(t, err) + assert.Equal(t, acls.Allowed, node1ACL[acls.NodeID(node2.ID)]) + }) + t.Run("node acls correct after modify", func(t *testing.T) { + retNetworkACL, err := nodeacls.ChangeNodeACL(acls.NetworkID(node1.Network), acls.NodeID(node1.ID), acls.NodeID(node2.ID), acls.NotAllowed) + assert.Nil(t, err) + assert.NotNil(t, retNetworkACL) + assert.Equal(t, acls.NotAllowed, retNetworkACL[acls.NodeID(node1.ID)][acls.NodeID(node2.ID)]) + assert.Equal(t, acls.NotAllowed, retNetworkACL[acls.NodeID(node2.ID)][acls.NodeID(node1.ID)]) + }) + t.Run("node acls correct after erroneous modify", func(t *testing.T) { + retNetworkACL, err := nodeacls.ChangeNodeACL(acls.NetworkID(node1.Network), acls.NodeID(node1.ID), acls.NodeID(node2.ID), acls.NotPresent) + assert.Nil(t, err) + assert.NotNil(t, retNetworkACL) + assert.Equal(t, acls.NotAllowed, retNetworkACL[acls.NodeID(node1.ID)][acls.NodeID(node2.ID)]) + assert.Equal(t, acls.NotAllowed, retNetworkACL[acls.NodeID(node2.ID)][acls.NodeID(node1.ID)]) + }) + t.Run("node acls removed", func(t *testing.T) { + retNetworkACL, err := nodeacls.RemoveNodeACL(acls.NetworkID(node1.Network), acls.NodeID(node1.ID)) + assert.Nil(t, err) + assert.NotNil(t, retNetworkACL) + assert.Equal(t, acls.NotPresent, retNetworkACL[acls.NodeID(node2.ID)][acls.NodeID(node1.ID)]) + }) + + deleteAllNodes() +} + func deleteAllNodes() { database.DeleteAllRecords(database.NODES_TABLE_NAME) } diff --git a/logic/acls/node-acls/modify.go b/logic/acls/node-acls/modify.go index 090b7e2a..ff0f0d58 100644 --- a/logic/acls/node-acls/modify.go +++ b/logic/acls/node-acls/modify.go @@ -4,57 +4,97 @@ import ( "encoding/json" "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/logic/acls" ) -// UpsertNodeACL - inserts or updates a node ACL on given network -func UpsertNodeACL(networkID NetworkID, nodeID NodeID, defaultVal byte) (NodeACL, error) { - if defaultVal != NotAllowed && defaultVal != Allowed { - defaultVal = NotAllowed +// ChangeNodeACL - takes in two node IDs of a given network and changes them to specified allowed or not value +// returns the total network's ACL and error +func ChangeNodeACL(networkID acls.NetworkID, node1, node2 acls.NodeID, value byte) (acls.NetworkACL, error) { + if value != acls.NotAllowed && value != acls.Allowed { // if invalid option make not allowed + value = acls.NotAllowed + } + currentACL, err := FetchCurrentACL(networkID) + if err != nil { + return nil, err + } + // == make the access control change == + currentACL[node1][node2] = value + currentACL[node2][node1] = value + return UpsertNetworkACL(networkID, currentACL) +} + +// CreateNodeACL - inserts or updates a node ACL on given network +func CreateNodeACL(networkID acls.NetworkID, nodeID acls.NodeID, defaultVal byte) (acls.NodeACL, error) { + if defaultVal != acls.NotAllowed && defaultVal != acls.Allowed { + defaultVal = acls.NotAllowed } var currentNetworkACL, err = FetchCurrentACL(networkID) if err != nil { return nil, err } - var newNodeACL = make(NodeACL) - for existingNode := range currentNetworkACL { - currentNetworkACL[existingNode][nodeID] = defaultVal - newNodeACL[existingNode] = defaultVal + var newNodeACL = make(acls.NodeACL) + for existingNodeID := range currentNetworkACL { + currentNetworkACL[existingNodeID][nodeID] = defaultVal // set the old nodes to default value for new node + newNodeACL[existingNodeID] = defaultVal // set the old nodes in new node ACL to default value } - currentNetworkACL[nodeID] = newNodeACL - return newNodeACL, nil + currentNetworkACL[nodeID] = newNodeACL // append the new node's ACL + retNetworkACL, err := UpsertNetworkACL(networkID, currentNetworkACL) // insert into db, return result + if err != nil { + return nil, err + } + return retNetworkACL[nodeID], nil +} + +// CreateNetworkACL - creates an empty ACL list in a given network +func CreateNetworkACL(networkID acls.NetworkID) (acls.NetworkACL, error) { + var networkACL = make(acls.NetworkACL) + return networkACL, database.Insert(string(networkID), string(convertNetworkACLtoACLJson(&networkACL)), database.NODE_ACLS_TABLE_NAME) +} + +// UpsertNodeACL - applies a NodeACL to the db, overwrites or creates +func UpsertNodeACL(networkID acls.NetworkID, nodeID acls.NodeID, nodeACL acls.NodeACL) (acls.NodeACL, error) { + currentNetACL, err := FetchCurrentACL(networkID) + if err != nil { + return nodeACL, err + } + currentNetACL[nodeID] = nodeACL + _, err = UpsertNetworkACL(networkID, currentNetACL) + return nodeACL, err } // UpsertNetworkACL - Inserts or updates a network ACL given the json string of the ACL and the network name // if nil, create it -func UpsertNetworkACL(networkID NetworkID, networkACL NetworkACL) (NetworkACL, error) { +func UpsertNetworkACL(networkID acls.NetworkID, networkACL acls.NetworkACL) (acls.NetworkACL, error) { if networkACL == nil { - networkACL = make(NetworkACL) + networkACL = make(acls.NetworkACL) } return networkACL, database.Insert(string(networkID), string(convertNetworkACLtoACLJson(&networkACL)), database.NODE_ACLS_TABLE_NAME) } // RemoveNodeACL - removes a specific Node's ACL, returns the NetworkACL and error -func RemoveNodeACL(networkID NetworkID, nodeID NodeID) (NetworkACL, error) { +func RemoveNodeACL(networkID acls.NetworkID, nodeID acls.NodeID) (acls.NetworkACL, error) { var currentNeworkACL, err = FetchCurrentACL(networkID) if err != nil { return nil, err } for currentNodeID := range currentNeworkACL { - delete(currentNeworkACL[nodeID], currentNodeID) + if currentNodeID != nodeID { + delete(currentNeworkACL[currentNodeID], nodeID) + } } delete(currentNeworkACL, nodeID) return UpsertNetworkACL(networkID, currentNeworkACL) } // RemoveNetworkACL - just delete the network ACL -func RemoveNetworkACL(networkID NetworkID) error { +func RemoveNetworkACL(networkID acls.NetworkID) error { return database.DeleteRecord(database.NODE_ACLS_TABLE_NAME, string(networkID)) } -func convertNetworkACLtoACLJson(networkACL *NetworkACL) ACLJson { +func convertNetworkACLtoACLJson(networkACL *acls.NetworkACL) acls.ACLJson { data, err := json.Marshal(networkACL) if err != nil { return "" } - return ACLJson(data) + return acls.ACLJson(data) } diff --git a/logic/acls/node-acls/retrieve.go b/logic/acls/node-acls/retrieve.go index 15cc0360..e29b8892 100644 --- a/logic/acls/node-acls/retrieve.go +++ b/logic/acls/node-acls/retrieve.go @@ -2,30 +2,35 @@ package nodeacls import ( "encoding/json" + "fmt" "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/logic/acls" ) // AreNodesAllowed - checks if nodes are allowed to communicate in their network ACL -func AreNodesAllowed(networkID NetworkID, node1, node2 NodeID) bool { +func AreNodesAllowed(networkID acls.NetworkID, node1, node2 acls.NodeID) bool { var currentNetworkACL, err = FetchCurrentACL(networkID) if err != nil { return false } - return currentNetworkACL[node1][node2] == Allowed && currentNetworkACL[node2][node1] == Allowed + return currentNetworkACL[node1][node2] == acls.Allowed && currentNetworkACL[node2][node1] == acls.Allowed } // FetchNodeACL - fetches a specific node's ACL in a given network -func FetchNodeACL(networkID NetworkID, nodeID NodeID) (NodeACL, error) { +func FetchNodeACL(networkID acls.NetworkID, nodeID acls.NodeID) (acls.NodeACL, error) { currentNetACL, err := FetchCurrentACL(networkID) if err != nil { return nil, err } + if currentNetACL[nodeID] == nil { + return nil, fmt.Errorf("no node ACL present for node %s", nodeID) + } return currentNetACL[nodeID], nil } // FetchNodeACLJson - fetches a node's acl in given network except returns the json string -func FetchNodeACLJson(networkID NetworkID, nodeID NodeID) (ACLJson, error) { +func FetchNodeACLJson(networkID acls.NetworkID, nodeID acls.NodeID) (acls.ACLJson, error) { currentNodeACL, err := FetchNodeACL(networkID, nodeID) if err != nil { return "", err @@ -34,16 +39,16 @@ func FetchNodeACLJson(networkID NetworkID, nodeID NodeID) (ACLJson, error) { if err != nil { return "", err } - return ACLJson(jsonData), nil + return acls.ACLJson(jsonData), nil } // FetchCurrentACL - fetches all current node rules in given network ACL -func FetchCurrentACL(networkID NetworkID) (NetworkACL, error) { - aclJson, err := FetchCurrentACLJson(NetworkID(networkID)) +func FetchCurrentACL(networkID acls.NetworkID) (acls.NetworkACL, error) { + aclJson, err := FetchCurrentACLJson(acls.NetworkID(networkID)) if err != nil { return nil, err } - var currentNetworkACL NetworkACL + var currentNetworkACL acls.NetworkACL if err := json.Unmarshal([]byte(aclJson), ¤tNetworkACL); err != nil { return nil, err } @@ -51,10 +56,10 @@ func FetchCurrentACL(networkID NetworkID) (NetworkACL, error) { } // FetchCurrentACLJson - fetch the current ACL of given network except in json string -func FetchCurrentACLJson(networkID NetworkID) (ACLJson, error) { +func FetchCurrentACLJson(networkID acls.NetworkID) (acls.ACLJson, error) { currentACLs, err := database.FetchRecord(database.NODE_ACLS_TABLE_NAME, string(networkID)) if err != nil { - return ACLJson(""), err + return acls.ACLJson(""), err } - return ACLJson(currentACLs), nil + return acls.ACLJson(currentACLs), nil } diff --git a/logic/acls/node-acls/types.go b/logic/acls/types.go similarity index 97% rename from logic/acls/node-acls/types.go rename to logic/acls/types.go index 289cb614..8f4836e7 100644 --- a/logic/acls/node-acls/types.go +++ b/logic/acls/types.go @@ -1,4 +1,4 @@ -package nodeacls +package acls var ( // NotPresent - 0 - not present (default)