diff --git a/controllers/node_test.go b/controllers/node_test.go index d20a3971..dee08e73 100644 --- a/controllers/node_test.go +++ b/controllers/node_test.go @@ -5,6 +5,7 @@ import ( "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/logic/acls" nodeacls "github.com/gravitl/netmaker/logic/acls/node-acls" "github.com/gravitl/netmaker/models" "github.com/stretchr/testify/assert" @@ -153,49 +154,49 @@ func TestNodeACLs(t *testing.T) { t.Run("acls not present", func(t *testing.T) { currentACL, err := nodeacls.CreateNetworkACL(nodeacls.NetworkID(node1.Network)) assert.Nil(t, err) - assert.Nil(t, currentACL[nodeacls.NodeID(node1.ID)]) - assert.Nil(t, currentACL[nodeacls.NodeID(node2.ID)]) - node1ACL, err := nodeacls.FetchNodeACL(nodeacls.NetworkID(node1.Network), nodeacls.NodeID(node1.ID)) + assert.Nil(t, currentACL[acls.AclID(node1.ID)]) + assert.Nil(t, currentACL[acls.AclID(node2.ID)]) + node1ACL, err := nodeacls.FetchNodeACL(nodeacls.NetworkID(node1.Network), acls.AclID(node1.ID)) assert.NotNil(t, err) assert.Nil(t, node1ACL) assert.EqualError(t, err, "no node ACL present for node "+node1.ID) }) t.Run("node acls exists after creates", func(t *testing.T) { - node1ACL, err := nodeacls.CreateNodeACL(nodeacls.NetworkID(node1.Network), nodeacls.NodeID(node1.ID), nodeacls.Allowed) + node1ACL, err := nodeacls.CreateNodeACL(nodeacls.NetworkID(node1.Network), acls.AclID(node1.ID), acls.Allowed) assert.Nil(t, err) assert.NotNil(t, node1ACL) - assert.Equal(t, node1ACL[nodeacls.NodeID(node2.ID)], nodeacls.NotPresent) - node2ACL, err := nodeacls.CreateNodeACL(nodeacls.NetworkID(node1.Network), nodeacls.NodeID(node2.ID), nodeacls.Allowed) + assert.Equal(t, node1ACL[acls.AclID(node2.ID)], acls.NotPresent) + node2ACL, err := nodeacls.CreateNodeACL(nodeacls.NetworkID(node1.Network), acls.AclID(node2.ID), acls.Allowed) assert.Nil(t, err) assert.NotNil(t, node2ACL) - assert.Equal(t, nodeacls.Allowed, node2ACL[nodeacls.NodeID(node1.ID)]) + assert.Equal(t, acls.Allowed, node2ACL[acls.AclID(node1.ID)]) }) t.Run("node acls correct after fetch", func(t *testing.T) { - node1ACL, err := nodeacls.FetchNodeACL(nodeacls.NetworkID(node1.Network), nodeacls.NodeID(node1.ID)) + node1ACL, err := nodeacls.FetchNodeACL(nodeacls.NetworkID(node1.Network), acls.AclID(node1.ID)) assert.Nil(t, err) - assert.Equal(t, nodeacls.Allowed, node1ACL[nodeacls.NodeID(node2.ID)]) + assert.Equal(t, acls.Allowed, node1ACL[acls.AclID(node2.ID)]) }) t.Run("node acls correct after modify", func(t *testing.T) { currentACL, err := nodeacls.CreateNetworkACL(nodeacls.NetworkID(node1.Network)) assert.Nil(t, err) assert.NotNil(t, currentACL) - node1ACL, err := nodeacls.CreateNodeACL(nodeacls.NetworkID(node1.Network), nodeacls.NodeID(node1.ID), nodeacls.Allowed) + node1ACL, err := nodeacls.CreateNodeACL(nodeacls.NetworkID(node1.Network), acls.AclID(node1.ID), acls.Allowed) assert.Nil(t, err) - node2ACL, err := nodeacls.CreateNodeACL(nodeacls.NetworkID(node1.Network), nodeacls.NodeID(node2.ID), nodeacls.Allowed) + node2ACL, err := nodeacls.CreateNodeACL(nodeacls.NetworkID(node1.Network), acls.AclID(node2.ID), acls.Allowed) assert.Nil(t, err) assert.NotNil(t, node1ACL) assert.NotNil(t, node2ACL) currentACL, err = nodeacls.FetchCurrentACL(nodeacls.NetworkID(node1.Network)) assert.Nil(t, err) - currentACL.ChangeNodesAccess(nodeacls.NodeID(node1.ID), nodeacls.NodeID(node2.ID), nodeacls.NotAllowed) - assert.Equal(t, nodeacls.NotAllowed, currentACL[nodeacls.NodeID(node1.ID)][nodeacls.NodeID(node2.ID)]) - assert.Equal(t, nodeacls.NotAllowed, currentACL[nodeacls.NodeID(node2.ID)][nodeacls.NodeID(node1.ID)]) + currentACL.ChangeNodesAccess(acls.AclID(node1.ID), acls.AclID(node2.ID), acls.NotAllowed) + assert.Equal(t, acls.NotAllowed, currentACL[acls.AclID(node1.ID)][acls.AclID(node2.ID)]) + assert.Equal(t, acls.NotAllowed, currentACL[acls.AclID(node2.ID)][acls.AclID(node1.ID)]) }) t.Run("node acls removed", func(t *testing.T) { - retNetworkACL, err := nodeacls.RemoveNodeACL(nodeacls.NetworkID(node1.Network), nodeacls.NodeID(node1.ID)) + retNetworkACL, err := nodeacls.RemoveNodeACL(nodeacls.NetworkID(node1.Network), acls.AclID(node1.ID)) assert.Nil(t, err) assert.NotNil(t, retNetworkACL) - assert.Equal(t, nodeacls.NotPresent, retNetworkACL[nodeacls.NodeID(node2.ID)][nodeacls.NodeID(node1.ID)]) + assert.Equal(t, acls.NotPresent, retNetworkACL[acls.AclID(node2.ID)][acls.AclID(node1.ID)]) }) deleteAllNodes() diff --git a/logic/acls/common.go b/logic/acls/common.go new file mode 100644 index 00000000..63b2c69a --- /dev/null +++ b/logic/acls/common.go @@ -0,0 +1,115 @@ +package acls + +import ( + "encoding/json" + + "github.com/gravitl/netmaker/database" +) + +// CreateACLContainer - creates an empty ACL list in a given network +func CreateACLContainer(networkID ContainerID) (ACLContainer, error) { + var aclContainer = make(ACLContainer) + return aclContainer, database.Insert(string(networkID), string(convertNetworkACLtoACLJson(aclContainer)), database.NODE_ACLS_TABLE_NAME) +} + +// FetchACLContainer - fetches all current node rules in given network ACL +func FetchACLContainer(networkID ContainerID) (ACLContainer, error) { + aclJson, err := FetchACLContainerJson(ContainerID(networkID)) + if err != nil { + return nil, err + } + var currentNetworkACL ACLContainer + if err := json.Unmarshal([]byte(aclJson), ¤tNetworkACL); err != nil { + return nil, err + } + return currentNetworkACL, nil +} + +// FetchACLContainerJson - fetch the current ACL of given network except in json string +func FetchACLContainerJson(networkID ContainerID) (ACLJson, error) { + currentACLs, err := database.FetchRecord(database.NODE_ACLS_TABLE_NAME, string(networkID)) + if err != nil { + return ACLJson(""), err + } + return ACLJson(currentACLs), nil +} + +// == type functions == + +// ACL.AllowNode - allows a node by ID in memory +func (acl ACL) Allow(ID AclID) { + acl[ID] = Allowed +} + +// ACL.DisallowNode - disallows a node access by ID in memory +func (acl ACL) Disallow(ID AclID) { + acl[ID] = NotAllowed +} + +// ACL.Remove - removes a node from a ACL +func (acl ACL) Remove(ID AclID) { + delete(acl, ID) +} + +// ACL.Update - updates a ACL in DB +func (acl ACL) Save(networkID ContainerID, ID AclID) (ACL, error) { + return upsertACL(networkID, ID, acl) +} + +// ACL.IsNodeAllowed - sees if ID is allowed in referring ACL +func (acl ACL) IsNodeAllowed(ID AclID) bool { + return acl[ID] == Allowed +} + +// ACLContainer.UpdateNodeACL - saves the state of a ACL in the ACLContainer in memory +func (aclContainer ACLContainer) UpdateNodeACL(ID AclID, acl ACL) ACLContainer { + aclContainer[ID] = acl + return aclContainer +} + +// ACLContainer.RemoveNodeACL - removes the state of a ACL in the ACLContainer in memory +func (aclContainer ACLContainer) RemoveNodeACL(ID AclID) ACLContainer { + delete(aclContainer, ID) + return aclContainer +} + +// ACLContainer.ChangeNodesAccess - changes the relationship between two nodes in memory +func (networkACL ACLContainer) ChangeNodesAccess(ID1, ID2 AclID, value byte) { + networkACL[ID1][ID2] = value + networkACL[ID2][ID1] = value +} + +// ACLContainer.Save - saves the state of a ACLContainer to the db +func (aclContainer ACLContainer) Save(networkID ContainerID) (ACLContainer, error) { + return upsertACLContainer(networkID, aclContainer) +} + +// == private == + +// upsertACL - applies a ACL to the db, overwrites or creates +func upsertACL(networkID ContainerID, ID AclID, acl ACL) (ACL, error) { + currentNetACL, err := FetchACLContainer(networkID) + if err != nil { + return acl, err + } + currentNetACL[ID] = acl + _, err = upsertACLContainer(networkID, currentNetACL) + return acl, err +} + +// upsertACLContainer - Inserts or updates a network ACL given the json string of the ACL and the network name +// if nil, create it +func upsertACLContainer(networkID ContainerID, aclContainer ACLContainer) (ACLContainer, error) { + if aclContainer == nil { + aclContainer = make(ACLContainer) + } + return aclContainer, database.Insert(string(networkID), string(convertNetworkACLtoACLJson(aclContainer)), database.NODE_ACLS_TABLE_NAME) +} + +func convertNetworkACLtoACLJson(networkACL ACLContainer) ACLJson { + data, err := json.Marshal(networkACL) + if err != nil { + return "" + } + return ACLJson(data) +} diff --git a/logic/acls/node-acls/modify.go b/logic/acls/node-acls/modify.go index 345df599..e9a7cdea 100644 --- a/logic/acls/node-acls/modify.go +++ b/logic/acls/node-acls/modify.go @@ -1,133 +1,48 @@ package nodeacls import ( - "encoding/json" - "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/logic/acls" ) // CreateNodeACL - inserts or updates a node ACL on given network -func CreateNodeACL(networkID NetworkID, nodeID NodeID, defaultVal byte) (NodeACL, error) { - if defaultVal != NotAllowed && defaultVal != Allowed { - defaultVal = NotAllowed +func CreateNodeACL(networkID NetworkID, nodeID NodeID, defaultVal byte) (acls.ACL, error) { + if defaultVal != acls.NotAllowed && defaultVal != acls.Allowed { + defaultVal = acls.NotAllowed } - var currentNetworkACL, err = FetchCurrentACL(networkID) + var currentNetworkACL, err = acls.FetchACLContainer(acls.ContainerID(networkID)) if err != nil { return nil, err } - var newNodeACL = make(NodeACL) + var newNodeACL = make(acls.ACL) for existingNodeID := range currentNetworkACL { - currentNetworkACL[existingNodeID][nodeID] = defaultVal // set the old nodes to default value for new node - newNodeACL[existingNodeID] = defaultVal // set the old nodes in new node ACL to default value + currentNetworkACL[existingNodeID][acls.AclID(nodeID)] = defaultVal // set the old nodes to default value for new node + newNodeACL[existingNodeID] = defaultVal // set the old nodes in new node ACL to default value } - currentNetworkACL[nodeID] = newNodeACL // append the new node's ACL - retNetworkACL, err := upsertNetworkACL(networkID, currentNetworkACL) // insert into db, return result + currentNetworkACL[acls.AclID(nodeID)] = newNodeACL // append the new node's ACL + retNetworkACL, err := currentNetworkACL.Save(acls.ContainerID(networkID)) // insert into db if err != nil { return nil, err } - return retNetworkACL[nodeID], nil -} - -// CreateNetworkACL - creates an empty ACL list in a given network -func CreateNetworkACL(networkID NetworkID) (NetworkACL, error) { - var networkACL = make(NetworkACL) - return networkACL, database.Insert(string(networkID), string(convertNetworkACLtoACLJson(&networkACL)), database.NODE_ACLS_TABLE_NAME) + return retNetworkACL[acls.AclID(nodeID)], nil } // RemoveNodeACL - removes a specific Node's ACL, returns the NetworkACL and error -func RemoveNodeACL(networkID NetworkID, nodeID NodeID) (NetworkACL, error) { - var currentNeworkACL, err = FetchCurrentACL(networkID) +func RemoveNodeACL(networkID NetworkID, nodeID NodeID) (acls.ACLContainer, error) { + var currentNeworkACL, err = acls.FetchACLContainer(acls.ContainerID(networkID)) if err != nil { return nil, err } for currentNodeID := range currentNeworkACL { - if currentNodeID != nodeID { - currentNeworkACL[currentNodeID].RemoveNode(nodeID) + if NodeID(currentNodeID) != nodeID { + currentNeworkACL[currentNodeID].Remove(acls.AclID(nodeID)) } } - delete(currentNeworkACL, nodeID) - return currentNeworkACL.Save(networkID) + delete(currentNeworkACL, acls.AclID(nodeID)) + return currentNeworkACL.Save(acls.ContainerID(networkID)) } // RemoveNetworkACL - just delete the network ACL func RemoveNetworkACL(networkID NetworkID) error { return database.DeleteRecord(database.NODE_ACLS_TABLE_NAME, string(networkID)) } - -// NodeACL.AllowNode - allows a node by ID in memory -func (nodeACL NodeACL) AllowNode(nodeID NodeID) { - nodeACL[nodeID] = Allowed -} - -// NodeACL.DisallowNode - disallows a node access by ID in memory -func (nodeACL NodeACL) DisallowNode(nodeID NodeID) { - nodeACL[nodeID] = NotAllowed -} - -// NodeACL.RemoveNode - removes a node from a NodeACL -func (nodeACL NodeACL) RemoveNode(nodeID NodeID) { - delete(nodeACL, nodeID) -} - -// NodeACL.Update - updates a nodeACL in DB -func (nodeACL NodeACL) Save(networkID NetworkID, nodeID NodeID) (NodeACL, error) { - return upsertNodeACL(networkID, nodeID, nodeACL) -} - -// NodeACL.IsNodeAllowed - sees if nodeID is allowed in referring NodeACL -func (nodeACL NodeACL) IsNodeAllowed(nodeID NodeID) bool { - return nodeACL[nodeID] == Allowed -} - -// NetworkACL.UpdateNodeACL - saves the state of a NodeACL in the NetworkACL in memory -func (networkACL NetworkACL) UpdateNodeACL(nodeID NodeID, nodeACL NodeACL) NetworkACL { - networkACL[nodeID] = nodeACL - return networkACL -} - -// NetworkACL.RemoveNodeACL - removes the state of a NodeACL in the NetworkACL in memory -func (networkACL NetworkACL) RemoveNodeACL(nodeID NodeID) NetworkACL { - delete(networkACL, nodeID) - return networkACL -} - -// NetworkACL.ChangeNodesAccess - changes the relationship between two nodes in memory -func (networkACL NetworkACL) ChangeNodesAccess(nodeID1, nodeID2 NodeID, value byte) { - networkACL[nodeID1][nodeID2] = value - networkACL[nodeID2][nodeID1] = value -} - -// NetworkACL.Save - saves the state of a NetworkACL to the db -func (networkACL NetworkACL) Save(networkID NetworkID) (NetworkACL, error) { - return upsertNetworkACL(networkID, networkACL) -} - -// == private == - -// upsertNodeACL - applies a NodeACL to the db, overwrites or creates -func upsertNodeACL(networkID NetworkID, nodeID NodeID, nodeACL NodeACL) (NodeACL, error) { - currentNetACL, err := FetchCurrentACL(networkID) - if err != nil { - return nodeACL, err - } - currentNetACL[nodeID] = nodeACL - _, err = upsertNetworkACL(networkID, currentNetACL) - return nodeACL, err -} - -// upsertNetworkACL - Inserts or updates a network ACL given the json string of the ACL and the network name -// if nil, create it -func upsertNetworkACL(networkID NetworkID, networkACL NetworkACL) (NetworkACL, error) { - if networkACL == nil { - networkACL = make(NetworkACL) - } - return networkACL, database.Insert(string(networkID), string(convertNetworkACLtoACLJson(&networkACL)), database.NODE_ACLS_TABLE_NAME) -} - -func convertNetworkACLtoACLJson(networkACL *NetworkACL) ACLJson { - data, err := json.Marshal(networkACL) - if err != nil { - return "" - } - return ACLJson(data) -} diff --git a/logic/acls/node-acls/retrieve.go b/logic/acls/node-acls/retrieve.go index bb53ead9..7568516e 100644 --- a/logic/acls/node-acls/retrieve.go +++ b/logic/acls/node-acls/retrieve.go @@ -4,32 +4,32 @@ import ( "encoding/json" "fmt" - "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/logic/acls" ) // AreNodesAllowed - checks if nodes are allowed to communicate in their network ACL func AreNodesAllowed(networkID NetworkID, node1, node2 NodeID) bool { - var currentNetworkACL, err = FetchCurrentACL(networkID) + var currentNetworkACL, err = acls.FetchACLContainer(acls.ContainerID(networkID)) if err != nil { return false } - return currentNetworkACL[node1].IsNodeAllowed(node2) && currentNetworkACL[node2].IsNodeAllowed(node1) + return currentNetworkACL[acls.AclID(node1)].IsNodeAllowed(acls.AclID(node2)) && currentNetworkACL[acls.AclID(node2)].IsNodeAllowed(acls.AclID(node1)) } // FetchNodeACL - fetches a specific node's ACL in a given network -func FetchNodeACL(networkID NetworkID, nodeID NodeID) (NodeACL, error) { - currentNetACL, err := FetchCurrentACL(networkID) +func FetchNodeACL(networkID NetworkID, nodeID NodeID) (acls.ACL, error) { + currentNetACL, err := acls.FetchACLContainer(acls.ContainerID(networkID)) if err != nil { return nil, err } - if currentNetACL[nodeID] == nil { + if currentNetACL[acls.AclID(nodeID)] == nil { return nil, fmt.Errorf("no node ACL present for node %s", nodeID) } - return currentNetACL[nodeID], nil + return currentNetACL[acls.AclID(nodeID)], nil } // FetchNodeACLJson - fetches a node's acl in given network except returns the json string -func FetchNodeACLJson(networkID NetworkID, nodeID NodeID) (ACLJson, error) { +func FetchNodeACLJson(networkID NetworkID, nodeID NodeID) (acls.ACLJson, error) { currentNodeACL, err := FetchNodeACL(networkID, nodeID) if err != nil { return "", err @@ -38,27 +38,5 @@ func FetchNodeACLJson(networkID NetworkID, nodeID NodeID) (ACLJson, error) { if err != nil { return "", err } - return ACLJson(jsonData), nil -} - -// FetchCurrentACL - fetches all current node rules in given network ACL -func FetchCurrentACL(networkID NetworkID) (NetworkACL, error) { - aclJson, err := FetchCurrentACLJson(NetworkID(networkID)) - if err != nil { - return nil, err - } - var currentNetworkACL NetworkACL - if err := json.Unmarshal([]byte(aclJson), ¤tNetworkACL); err != nil { - return nil, err - } - return currentNetworkACL, nil -} - -// FetchCurrentACLJson - fetch the current ACL of given network except in json string -func FetchCurrentACLJson(networkID NetworkID) (ACLJson, error) { - currentACLs, err := database.FetchRecord(database.NODE_ACLS_TABLE_NAME, string(networkID)) - if err != nil { - return ACLJson(""), err - } - return ACLJson(currentACLs), nil + return acls.ACLJson(jsonData), nil } diff --git a/logic/acls/node-acls/types.go b/logic/acls/node-acls/types.go index 289cb614..3cfc71e9 100644 --- a/logic/acls/node-acls/types.go +++ b/logic/acls/node-acls/types.go @@ -1,27 +1,12 @@ package nodeacls -var ( - // NotPresent - 0 - not present (default) - NotPresent = byte(0) - // NotAllowed - 1 - not allowed access - NotAllowed = byte(1) // 1 - not allowed - // Allowed - 2 - allowed access - Allowed = byte(2) +import ( + "github.com/gravitl/netmaker/logic/acls" ) type ( - // NodeID - the node id of a given node - NodeID string - - // NetworkID - the networkID of a given network - NetworkID string - - // NodeACL - the ACL of other nodes in a NetworkACL for a single unique node - NodeACL map[NodeID]byte - - // NetworkACL - the total list of all node's ACL in a given network - NetworkACL map[NodeID]NodeACL - - // ACLJson - the string representation in JSON of an ACL Node or Network - ACLJson string + // NodeID - node ID for ACLs + NodeID acls.AclID + // NetworkID - ACL container based on network ID for nodes + NetworkID acls.ContainerID ) diff --git a/logic/acls/types.go b/logic/acls/types.go new file mode 100644 index 00000000..57364508 --- /dev/null +++ b/logic/acls/types.go @@ -0,0 +1,27 @@ +package acls + +var ( + // NotPresent - 0 - not present (default) + NotPresent = byte(0) + // NotAllowed - 1 - not allowed access + NotAllowed = byte(1) // 1 - not allowed + // Allowed - 2 - allowed access + Allowed = byte(2) +) + +type ( + // AclID - the node id of a given node + AclID string + + // ACL - the ACL of other nodes in a NetworkACL for a single unique node + ACL map[AclID]byte + + // ACLJson - the string representation in JSON of an ACL Node or Network + ACLJson string + + // ContainerID - the networkID of a given network + ContainerID string + + // ACLContainer - the total list of all node's ACL in a given network + ACLContainer map[AclID]ACL +)