refactored more

This commit is contained in:
0xdcarns
2022-01-10 18:32:49 -05:00
parent a4a8c368d4
commit 545f45d86d
10 changed files with 78 additions and 53 deletions

View File

@@ -566,7 +566,7 @@ func deleteNode(w http.ResponseWriter, r *http.Request) {
returnErrorResponse(w, r, formatError(err, "badrequest")) returnErrorResponse(w, r, formatError(err, "badrequest"))
return return
} }
err = logic.DeleteNodeByMacAddress(&node, false) err = logic.DeleteNodeByID(&node, false)
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return

View File

@@ -139,12 +139,16 @@ func (s *NodeServiceServer) UpdateNode(ctx context.Context, req *nodepb.Object)
// NodeServiceServer.DeleteNode - deletes a node and responds over gRPC // NodeServiceServer.DeleteNode - deletes a node and responds over gRPC
func (s *NodeServiceServer) DeleteNode(ctx context.Context, req *nodepb.Object) (*nodepb.Object, error) { func (s *NodeServiceServer) DeleteNode(ctx context.Context, req *nodepb.Object) (*nodepb.Object, error) {
nodeID := req.GetData() nodeID := req.GetData()
var nodeInfo = strings.Split(nodeID, "###") var nodeInfo = make([]string, 2)
if strings.Contains(nodeID, "###") {
nodeInfo = strings.Split(nodeID, "###")
if len(nodeInfo) != 2 { if len(nodeInfo) != 2 {
return nil, errors.New("node not found") return nil, errors.New("node not found")
} }
var node, err = logic.GetNode(nodeInfo[0], nodeInfo[1]) }
err = logic.DeleteNodeByMacAddress(&node, true)
var node, err = logic.GetNodeByIDorMacAddress(nodeID, nodeInfo[0], nodeInfo[1])
err = logic.DeleteNodeByID(&node, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -157,10 +161,16 @@ func (s *NodeServiceServer) DeleteNode(ctx context.Context, req *nodepb.Object)
// NodeServiceServer.GetPeers - fetches peers over gRPC // NodeServiceServer.GetPeers - fetches peers over gRPC
func (s *NodeServiceServer) GetPeers(ctx context.Context, req *nodepb.Object) (*nodepb.Object, error) { func (s *NodeServiceServer) GetPeers(ctx context.Context, req *nodepb.Object) (*nodepb.Object, error) {
macAndNetwork := strings.Split(req.Data, "###") nodeID := req.GetData()
if len(macAndNetwork) == 2 { var nodeInfo = make([]string, 2)
// TODO: Make constant and new variable for isServer if strings.Contains(nodeID, "###") {
node, err := logic.GetNode(macAndNetwork[0], macAndNetwork[1]) nodeInfo = strings.Split(nodeID, "###")
if len(nodeInfo) != 2 {
return nil, errors.New("could not fetch peers, invalid node id")
}
}
node, err := logic.GetNodeByIDorMacAddress(nodeID, nodeInfo[0], nodeInfo[1])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -172,7 +182,7 @@ func (s *NodeServiceServer) GetPeers(ctx context.Context, req *nodepb.Object) (*
if node.IsRelayed == "yes" { if node.IsRelayed == "yes" {
relayedNode = node.Address relayedNode = node.Address
} }
peers, err := logic.GetPeersList(macAndNetwork[1], excludeIsRelayed, relayedNode) peers, err := logic.GetPeersList(node.Network, excludeIsRelayed, relayedNode)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -183,11 +193,6 @@ func (s *NodeServiceServer) GetPeers(ctx context.Context, req *nodepb.Object) (*
Data: string(peersData), Data: string(peersData),
Type: nodepb.NODE_TYPE, Type: nodepb.NODE_TYPE,
}, err }, err
}
return &nodepb.Object{
Data: "",
Type: nodepb.NODE_TYPE,
}, errors.New("could not fetch peers, invalid node id")
} }
// NodeServiceServer.GetExtPeers - returns ext peers for a gateway node // NodeServiceServer.GetExtPeers - returns ext peers for a gateway node
@@ -199,7 +204,21 @@ func (s *NodeServiceServer) GetExtPeers(ctx context.Context, req *nodepb.Object)
if len(macAndNetwork) != 2 { if len(macAndNetwork) != 2 {
return nil, errors.New("did not receive valid node id when fetching ext peers") return nil, errors.New("did not receive valid node id when fetching ext peers")
} }
peers, err := logic.GetExtPeersList(macAndNetwork[0], macAndNetwork[1]) nodeID := req.GetData()
var nodeInfo = make([]string, 2)
if strings.Contains(nodeID, "###") {
nodeInfo = strings.Split(nodeID, "###")
if len(nodeInfo) != 2 {
return nil, errors.New("could not fetch peers, invalid node id")
}
}
node, err := logic.GetNodeByIDorMacAddress(nodeID, nodeInfo[0], nodeInfo[1])
if err != nil {
return nil, err
}
peers, err := logic.GetExtPeersList(&node)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -142,7 +142,7 @@ func TestValidateEgressGateway(t *testing.T) {
func deleteAllNodes() { func deleteAllNodes() {
nodes, _ := logic.GetAllNodes() nodes, _ := logic.GetAllNodes()
for _, node := range nodes { for _, node := range nodes {
logic.DeleteNodeByMacAddress(&node, true) logic.DeleteNodeByID(&node, true)
} }
} }

View File

@@ -11,7 +11,7 @@ import (
) )
// GetExtPeersList - gets the ext peers lists // GetExtPeersList - gets the ext peers lists
func GetExtPeersList(macaddress string, networkName string) ([]models.ExtPeersResponse, error) { func GetExtPeersList(node *models.Node) ([]models.ExtPeersResponse, error) {
var peers []models.ExtPeersResponse var peers []models.ExtPeersResponse
records, err := database.FetchRecords(database.EXT_CLIENT_TABLE_NAME) records, err := database.FetchRecords(database.EXT_CLIENT_TABLE_NAME)
@@ -33,7 +33,7 @@ func GetExtPeersList(macaddress string, networkName string) ([]models.ExtPeersRe
logger.Log(2, "failed to unmarshal ext client") logger.Log(2, "failed to unmarshal ext client")
continue continue
} }
if extClient.Network == networkName && extClient.IngressGatewayID == macaddress { if extClient.Network == node.Network && extClient.IngressGatewayID == node.ID {
peers = append(peers, peer) peers = append(peers, peer)
} }
} }

View File

@@ -47,7 +47,7 @@ func DeleteNetwork(network string) error {
servers, err := GetSortedNetworkServerNodes(network) servers, err := GetSortedNetworkServerNodes(network)
if err == nil { if err == nil {
for _, s := range servers { for _, s := range servers {
if err = DeleteNodeByMacAddress(&s, true); err != nil { if err = DeleteNodeByID(&s, true); err != nil {
logger.Log(2, "could not removed server", s.Name, "before deleting network", network) logger.Log(2, "could not removed server", s.Name, "before deleting network", network)
} else { } else {
logger.Log(2, "removed server", s.Name, "before deleting network", network) logger.Log(2, "removed server", s.Name, "before deleting network", network)

View File

@@ -395,6 +395,7 @@ func GetNodeByIDorMacAddress(uuid string, macaddress string, network string) (mo
return models.Node{}, err return models.Node{}, err
} }
err = CreateNode(&node) err = CreateNode(&node)
logger.Log(2, "rewriting legacy node data; node now has id,", node.ID)
if err != nil { if err != nil {
return models.Node{}, err return models.Node{}, err
} }

View File

@@ -39,6 +39,7 @@ func ServerJoin(networkSettings *models.Network, serverID string) error {
IsStatic: "yes", IsStatic: "yes",
Name: models.NODE_SERVER_NAME, Name: models.NODE_SERVER_NAME,
MacAddress: serverID, MacAddress: serverID,
ID: serverID,
UDPHolePunch: "no", UDPHolePunch: "no",
IsLocal: networkSettings.IsLocal, IsLocal: networkSettings.IsLocal,
LocalRange: networkSettings.LocalRange, LocalRange: networkSettings.LocalRange,
@@ -135,9 +136,9 @@ func ServerJoin(networkSettings *models.Network, serverID string) error {
} }
// ServerCheckin - runs pulls and pushes for server // ServerCheckin - runs pulls and pushes for server
func ServerCheckin(mac string, network string) error { func ServerCheckin(serverID string, mac string, network string) error {
var serverNode = &models.Node{} var serverNode = &models.Node{}
var currentNode, err = GetNode(mac, network) var currentNode, err = GetNodeByIDorMacAddress(serverID, mac, network)
if err != nil { if err != nil {
return err return err
} }
@@ -145,7 +146,7 @@ func ServerCheckin(mac string, network string) error {
err = ServerPull(serverNode, false) err = ServerPull(serverNode, false)
if isDeleteError(err) { if isDeleteError(err) {
return ServerLeave(mac, network) return ServerLeave(currentNode.ID)
} else if err != nil { } else if err != nil {
return err return err
} }
@@ -208,13 +209,13 @@ func ServerPush(serverNode *models.Node) error {
} }
// ServerLeave - removes a server node // ServerLeave - removes a server node
func ServerLeave(mac string, network string) error { func ServerLeave(serverID string) error {
var serverNode, err = GetNode(mac, network) var serverNode, err = GetNodeByID(serverID)
if err != nil { if err != nil {
return err return err
} }
return DeleteNodeByMacAddress(&serverNode, true) return DeleteNodeByID(&serverNode, true)
} }
/** /**
@@ -229,7 +230,7 @@ func GetServerPeers(serverNode *models.Node) ([]wgtypes.PeerConfig, bool, []stri
var peers []wgtypes.PeerConfig var peers []wgtypes.PeerConfig
var nodes []models.Node // fill above fields from server or client var nodes []models.Node // fill above fields from server or client
var nodecfg, err = GetNode(serverNode.MacAddress, serverNode.Network) var nodecfg, err = GetNodeByIDorMacAddress(serverNode.ID, serverNode.MacAddress, serverNode.Network)
if err != nil { if err != nil {
return nil, hasGateway, gateways, err return nil, hasGateway, gateways, err
} }
@@ -348,7 +349,7 @@ func GetServerExtPeers(serverNode *models.Node) ([]wgtypes.PeerConfig, error) {
var err error var err error
var tempPeers []models.ExtPeersResponse var tempPeers []models.ExtPeersResponse
tempPeers, err = GetExtPeersList(serverNode.MacAddress, serverNode.Network) tempPeers, err = GetExtPeersList(serverNode)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -419,7 +420,7 @@ func checkNodeActions(node *models.Node) string {
} }
} }
if node.Action == models.NODE_DELETE { if node.Action == models.NODE_DELETE {
err := ServerLeave(node.MacAddress, node.Network) err := ServerLeave(node.ID)
if err != nil { if err != nil {
logger.Log(1, "error deleting locally:", err.Error()) logger.Log(1, "error deleting locally:", err.Error())
} }

View File

@@ -104,7 +104,10 @@ func CreateNode(node *models.Node) error {
if err != nil { if err != nil {
return err return err
} }
if node.IsServer != "yes" || (node.IsServer == "yes" && servercfg.GetNodeID() == "") {
node.ID = uuid.NewString() node.ID = uuid.NewString()
}
//Create a JWT for the node //Create a JWT for the node
tokenString, _ := CreateJWT(node.ID, node.MacAddress, node.Network) tokenString, _ := CreateJWT(node.ID, node.MacAddress, node.Network)

View File

@@ -23,9 +23,9 @@ const NODE_NOOP = "noop"
var seededRand *rand.Rand = rand.New( var seededRand *rand.Rand = rand.New(
rand.NewSource(time.Now().UnixNano())) rand.NewSource(time.Now().UnixNano()))
// node struct // Node - struct for node model
type Node struct { type Node struct {
ID string `json:"id,omitempty" bson:"id,omitempty"` ID string `json:"id,omitempty" bson:"id,omitempty" yaml:"id,omitempty" validate:"required,min=5"`
Address string `json:"address" bson:"address" yaml:"address" validate:"omitempty,ipv4"` Address string `json:"address" bson:"address" yaml:"address" validate:"omitempty,ipv4"`
Address6 string `json:"address6" bson:"address6" yaml:"address6" validate:"omitempty,ipv6"` Address6 string `json:"address6" bson:"address6" yaml:"address6" validate:"omitempty,ipv6"`
LocalAddress string `json:"localaddress" bson:"localaddress" yaml:"localaddress" validate:"omitempty,ip"` LocalAddress string `json:"localaddress" bson:"localaddress" yaml:"localaddress" validate:"omitempty,ip"`
@@ -46,7 +46,7 @@ type Node struct {
ExpirationDateTime int64 `json:"expdatetime" bson:"expdatetime" yaml:"expdatetime"` ExpirationDateTime int64 `json:"expdatetime" bson:"expdatetime" yaml:"expdatetime"`
LastPeerUpdate int64 `json:"lastpeerupdate" bson:"lastpeerupdate" yaml:"lastpeerupdate"` LastPeerUpdate int64 `json:"lastpeerupdate" bson:"lastpeerupdate" yaml:"lastpeerupdate"`
LastCheckIn int64 `json:"lastcheckin" bson:"lastcheckin" yaml:"lastcheckin"` LastCheckIn int64 `json:"lastcheckin" bson:"lastcheckin" yaml:"lastcheckin"`
MacAddress string `json:"macaddress" bson:"macaddress" yaml:"macaddress" validate:"required,min=5,macaddress_unique"` MacAddress string `json:"macaddress" bson:"macaddress" yaml:"macaddress"`
// checkin interval is depreciated at the network level. Set on server with CHECKIN_INTERVAL // checkin interval is depreciated at the network level. Set on server with CHECKIN_INTERVAL
CheckInInterval int32 `json:"checkininterval" bson:"checkininterval" yaml:"checkininterval"` CheckInInterval int32 `json:"checkininterval" bson:"checkininterval" yaml:"checkininterval"`
Password string `json:"password" bson:"password" yaml:"password" validate:"required,min=6"` Password string `json:"password" bson:"password" yaml:"password" validate:"required,min=6"`
@@ -72,6 +72,7 @@ type Node struct {
IPForwarding string `json:"ipforwarding" bson:"ipforwarding" yaml:"ipforwarding" validate:"checkyesorno"` IPForwarding string `json:"ipforwarding" bson:"ipforwarding" yaml:"ipforwarding" validate:"checkyesorno"`
OS string `json:"os" bson:"os" yaml:"os"` OS string `json:"os" bson:"os" yaml:"os"`
MTU int32 `json:"mtu" bson:"mtu" yaml:"mtu"` MTU int32 `json:"mtu" bson:"mtu" yaml:"mtu"`
Version string `json:"version" bson:"version" yaml:"version"`
} }
// NodesArray - used for node sorting // NodesArray - used for node sorting

View File

@@ -42,7 +42,7 @@ func FileExists(f string) bool {
// RemoveNetwork - removes a network locally on server // RemoveNetwork - removes a network locally on server
func RemoveNetwork(network string) (bool, error) { func RemoveNetwork(network string) (bool, error) {
err := logic.ServerLeave(servercfg.GetNodeID(), network) err := logic.ServerLeave(servercfg.GetNodeID())
return true, err return true, err
} }
@@ -70,7 +70,7 @@ func HandleContainedClient() error {
return err return err
} }
for _, serverNet := range servernets { for _, serverNet := range servernets {
err = logic.ServerCheckin(servercfg.GetNodeID(), serverNet.NetID) err = logic.ServerCheckin(servercfg.GetNodeID(), servercfg.GetNodeID(), serverNet.NetID)
if err != nil { if err != nil {
logger.Log(1, "error occurred during server checkin:", err.Error()) logger.Log(1, "error occurred during server checkin:", err.Error())
} else { } else {