refactor validation for node creation

This commit is contained in:
Matthew R Kasun
2021-05-06 11:57:32 -04:00
parent 8ec8731eb5
commit 646f613b93
5 changed files with 375 additions and 255 deletions

View File

@@ -59,61 +59,15 @@ func GetPeersList(networkName string) ([]models.PeersResponse, error) {
} }
func ValidateNodeCreate(networkName string, node models.Node) error { func ValidateNodeCreate(networkName string, node models.Node) error {
v := validator.New() v := validator.New()
_ = v.RegisterValidation("address_check", func(fl validator.FieldLevel) bool {
isIpv4 := functions.IsIpNet(node.Address)
empty := node.Address == ""
return (empty || isIpv4)
})
_ = v.RegisterValidation("address6_check", func(fl validator.FieldLevel) bool {
isIpv6 := functions.IsIpNet(node.Address6)
empty := node.Address6 == ""
return (empty || isIpv6)
})
_ = v.RegisterValidation("endpoint_check", func(fl validator.FieldLevel) bool {
//var isFieldUnique bool = functions.IsFieldUnique(networkName, "endpoint", node.Endpoint)
isIp := functions.IsIpNet(node.Endpoint)
notEmptyCheck := node.Endpoint != ""
return (notEmptyCheck && isIp)
})
_ = v.RegisterValidation("localaddress_check", func(fl validator.FieldLevel) bool {
//var isFieldUnique bool = functions.IsFieldUnique(networkName, "endpoint", node.Endpoint)
isIp := functions.IsIpNet(node.LocalAddress)
empty := node.LocalAddress == ""
return (empty || isIp)
})
_ = v.RegisterValidation("macaddress_unique", func(fl validator.FieldLevel) bool { _ = v.RegisterValidation("macaddress_unique", func(fl validator.FieldLevel) bool {
var isFieldUnique bool = functions.IsFieldUnique(networkName, "macaddress", node.MacAddress) var isFieldUnique bool = functions.IsFieldUnique(networkName, "macaddress", node.MacAddress)
return isFieldUnique return isFieldUnique
}) })
_ = v.RegisterValidation("macaddress_valid", func(fl validator.FieldLevel) bool {
_, err := net.ParseMAC(node.MacAddress)
return err == nil
})
_ = v.RegisterValidation("name_valid", func(fl validator.FieldLevel) bool {
isvalid := functions.NameInNodeCharSet(node.Name)
return isvalid
})
_ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool { _ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
_, err := node.GetNetwork() _, err := node.GetNetwork()
return err == nil return err == nil
}) })
_ = v.RegisterValidation("pubkey_check", func(fl validator.FieldLevel) bool {
notEmptyCheck := node.PublicKey != ""
isBase64 := functions.IsBase64(node.PublicKey)
return (notEmptyCheck && isBase64)
})
_ = v.RegisterValidation("password_check", func(fl validator.FieldLevel) bool {
notEmptyCheck := node.Password != ""
goodLength := len(node.Password) > 5
return (notEmptyCheck && goodLength)
})
err := v.Struct(node) err := v.Struct(node)
if err != nil { if err != nil {
@@ -124,7 +78,7 @@ func ValidateNodeCreate(networkName string, node models.Node) error {
return err return err
} }
func ValidateNodeUpdate(networkName string, node models.Node) error { func ValidateNodeUpdate(networkName string, node models.NodeUpdate) error {
v := validator.New() v := validator.New()
_ = v.RegisterValidation("address_check", func(fl validator.FieldLevel) bool { _ = v.RegisterValidation("address_check", func(fl validator.FieldLevel) bool {
@@ -188,7 +142,7 @@ func ValidateNodeUpdate(networkName string, node models.Node) error {
return err return err
} }
func UpdateNode(nodechange models.Node, node models.Node) (models.Node, error) { func UpdateNode(nodechange models.NodeUpdate, node models.Node) (models.Node, error) {
//Question: Is there a better way of doing this than a bunch of "if" statements? probably... //Question: Is there a better way of doing this than a bunch of "if" statements? probably...
//Eventually, lets have a better way to check if any of the fields are filled out... //Eventually, lets have a better way to check if any of the fields are filled out...
queryMac := node.MacAddress queryMac := node.MacAddress

View File

@@ -13,6 +13,12 @@ type NodeValidationTC struct {
errorMessage string errorMessage string
} }
type NodeValidationUpdateTC struct {
testname string
node models.NodeUpdate
errorMessage string
}
func TestCreateNode(t *testing.T) { func TestCreateNode(t *testing.T) {
} }
func TestDeleteNode(t *testing.T) { func TestDeleteNode(t *testing.T) {
@@ -43,28 +49,28 @@ func TestValidateNodeCreate(t *testing.T) {
node: models.Node{ node: models.Node{
Address: "256.0.0.1", Address: "256.0.0.1",
}, },
errorMessage: "Field validation for 'Address' failed on the 'address_check' tag", errorMessage: "Field validation for 'Address' failed on the 'ipv4' tag",
}, },
NodeValidationTC{ NodeValidationTC{
testname: "BadAddress6", testname: "BadAddress6",
node: models.Node{ node: models.Node{
Address6: "2607::abcd:efgh::1", Address6: "2607::abcd:efgh::1",
}, },
errorMessage: "Field validation for 'Address6' failed on the 'address6_check' tag", errorMessage: "Field validation for 'Address6' failed on the 'ipv6' tag",
}, },
NodeValidationTC{ NodeValidationTC{
testname: "BadLocalAddress", testname: "BadLocalAddress",
node: models.Node{ node: models.Node{
LocalAddress: "10.0.200.300", LocalAddress: "10.0.200.300",
}, },
errorMessage: "Field validation for 'LocalAddress' failed on the 'localaddress_check' tag", errorMessage: "Field validation for 'LocalAddress' failed on the 'ip' tag",
}, },
NodeValidationTC{ NodeValidationTC{
testname: "InvalidName", testname: "InvalidName",
node: models.Node{ node: models.Node{
Name: "mynode*", Name: "mynode*",
}, },
errorMessage: "Field validation for 'Name' failed on the 'name_valid' tag", errorMessage: "Field validation for 'Name' failed on the 'alphanum' tag",
}, },
NodeValidationTC{ NodeValidationTC{
testname: "NameTooLong", testname: "NameTooLong",
@@ -88,18 +94,32 @@ func TestValidateNodeCreate(t *testing.T) {
errorMessage: "Field validation for 'ListenPort' failed on the 'max' tag", errorMessage: "Field validation for 'ListenPort' failed on the 'max' tag",
}, },
NodeValidationTC{ NodeValidationTC{
testname: "PublicKeyInvalid", testname: "PublicKeyEmpty",
node: models.Node{ node: models.Node{
PublicKey: "", PublicKey: "",
}, },
errorMessage: "Field validation for 'PublicKey' failed on the 'pubkey_check' tag", errorMessage: "Field validation for 'PublicKey' failed on the 'required' tag",
},
NodeValidationTC{
testname: "PublicKeyInvalid",
node: models.Node{
PublicKey: "junk%key",
},
errorMessage: "Field validation for 'PublicKey' failed on the 'base64' tag",
}, },
NodeValidationTC{ NodeValidationTC{
testname: "EndpointInvalid", testname: "EndpointInvalid",
node: models.Node{ node: models.Node{
Endpoint: "10.2.0.300", Endpoint: "10.2.0.300",
}, },
errorMessage: "Field validation for 'Endpoint' failed on the 'endpoint_check' tag", errorMessage: "Field validation for 'Endpoint' failed on the 'ip' tag",
},
NodeValidationTC{
testname: "EndpointEmpty",
node: models.Node{
Endpoint: "",
},
errorMessage: "Field validation for 'Endpoint' failed on the 'required' tag",
}, },
NodeValidationTC{ NodeValidationTC{
testname: "PersistentKeepaliveMax", testname: "PersistentKeepaliveMax",
@@ -113,7 +133,7 @@ func TestValidateNodeCreate(t *testing.T) {
node: models.Node{ node: models.Node{
MacAddress: "01:02:03:04:05", MacAddress: "01:02:03:04:05",
}, },
errorMessage: "Field validation for 'MacAddress' failed on the 'macaddress_valid' tag", errorMessage: "Field validation for 'MacAddress' failed on the 'mac' tag",
}, },
NodeValidationTC{ NodeValidationTC{
testname: "MacAddressMissing", testname: "MacAddressMissing",
@@ -127,14 +147,14 @@ func TestValidateNodeCreate(t *testing.T) {
node: models.Node{ node: models.Node{
Password: "", Password: "",
}, },
errorMessage: "Field validation for 'Password' failed on the 'password_check' tag", errorMessage: "Field validation for 'Password' failed on the 'required' tag",
}, },
NodeValidationTC{ NodeValidationTC{
testname: "ShortPassword", testname: "ShortPassword",
node: models.Node{ node: models.Node{
Password: "1234", Password: "1234",
}, },
errorMessage: "Field validation for 'Password' failed on the 'password_check' tag", errorMessage: "Field validation for 'Password' failed on the 'min' tag",
}, },
NodeValidationTC{ NodeValidationTC{
testname: "NoNetwork", testname: "NoNetwork",
@@ -170,18 +190,119 @@ func TestValidateNodeCreate(t *testing.T) {
} }
func TestValidateNodeUpdate(t *testing.T) { func TestValidateNodeUpdate(t *testing.T) {
//cases //cases
t.Run("BlankAddress", func(t *testing.T) { cases := []NodeValidationUpdateTC{
}) NodeValidationUpdateTC{
t.Run("BlankAddress6", func(t *testing.T) { testname: "BadAddress",
}) node: models.NodeUpdate{
t.Run("Blank", func(t *testing.T) { Address: "256.0.0.1",
}) },
errorMessage: "Field validation for 'Address' failed on the 'address_check' tag",
},
NodeValidationUpdateTC{
testname: "BadAddress6",
node: models.NodeUpdate{
Address6: "2607::abcd:efgh::1",
},
errorMessage: "Field validation for 'Address6' failed on the 'address6_check' tag",
},
NodeValidationUpdateTC{
testname: "BadLocalAddress",
node: models.NodeUpdate{
LocalAddress: "10.0.200.300",
},
errorMessage: "Field validation for 'LocalAddress' failed on the 'localaddress_check' tag",
},
NodeValidationUpdateTC{
testname: "InvalidName",
node: models.NodeUpdate{
Name: "mynode*",
},
errorMessage: "Field validation for 'Name' failed on the 'name_valid' tag",
},
NodeValidationUpdateTC{
testname: "NameTooLong",
node: models.NodeUpdate{
Name: "mynodexmynode",
},
errorMessage: "Field validation for 'Name' failed on the 'max' tag",
},
NodeValidationUpdateTC{
testname: "ListenPortMin",
node: models.NodeUpdate{
ListenPort: 1023,
},
errorMessage: "Field validation for 'ListenPort' failed on the 'min' tag",
},
NodeValidationUpdateTC{
testname: "ListenPortMax",
node: models.NodeUpdate{
ListenPort: 65536,
},
errorMessage: "Field validation for 'ListenPort' failed on the 'max' tag",
},
NodeValidationUpdateTC{
testname: "PublicKeyInvalid",
node: models.NodeUpdate{
PublicKey: "",
},
errorMessage: "Field validation for 'PublicKey' failed on the 'pubkey_check' tag",
},
NodeValidationUpdateTC{
testname: "EndpointInvalid",
node: models.NodeUpdate{
Endpoint: "10.2.0.300",
},
errorMessage: "Field validation for 'Endpoint' failed on the 'endpoint_check' tag",
},
NodeValidationUpdateTC{
testname: "PersistentKeepaliveMax",
node: models.NodeUpdate{
PersistentKeepalive: 1001,
},
errorMessage: "Field validation for 'PersistentKeepalive' failed on the 'max' tag",
},
NodeValidationUpdateTC{
testname: "MacAddressInvalid",
node: models.NodeUpdate{
MacAddress: "01:02:03:04:05",
},
errorMessage: "Field validation for 'MacAddress' failed on the 'macaddress_valid' tag",
},
NodeValidationUpdateTC{
testname: "MacAddressMissing",
node: models.NodeUpdate{
MacAddress: "",
},
errorMessage: "Field validation for 'MacAddress' failed on the 'required' tag",
},
NodeValidationUpdateTC{
testname: "EmptyPassword",
node: models.NodeUpdate{
Password: "",
},
errorMessage: "Field validation for 'Password' failed on the 'password_check' tag",
},
NodeValidationUpdateTC{
testname: "ShortPassword",
node: models.NodeUpdate{
Password: "1234",
},
errorMessage: "Field validation for 'Password' failed on the 'password_check' tag",
},
NodeValidationUpdateTC{
testname: "NoNetwork",
node: models.NodeUpdate{
Network: "badnet",
},
errorMessage: "Field validation for 'Network' failed on the 'network_exists' tag",
},
}
for _, tc := range cases {
t.Run(tc.testname, func(t *testing.T) {
err := ValidateNodeUpdate("skynet", tc.node)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), tc.errorMessage)
})
}
// for _, tc := range cases {
// t.Run(tc.testname, func(t *testing.T) {
// err := ValidateNodeUpdate(tc.node)
// assert.NotNil(t, err)
// assert.Contains(t, err.Error(), tc.errorMessage)
// })
// }
} }

View File

@@ -1,12 +1,13 @@
package controller package controller
import ( import (
"context" "context"
"fmt" "fmt"
"strconv" "strconv"
"github.com/gravitl/netmaker/functions"
nodepb "github.com/gravitl/netmaker/grpc" nodepb "github.com/gravitl/netmaker/grpc"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/functions"
"go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
@@ -15,12 +16,12 @@ import (
type NodeServiceServer struct { type NodeServiceServer struct {
NodeDB *mongo.Collection NodeDB *mongo.Collection
nodepb.UnimplementedNodeServiceServer nodepb.UnimplementedNodeServiceServer
} }
func (s *NodeServiceServer) ReadNode(ctx context.Context, req *nodepb.ReadNodeReq) (*nodepb.ReadNodeRes, error) { func (s *NodeServiceServer) ReadNode(ctx context.Context, req *nodepb.ReadNodeReq) (*nodepb.ReadNodeRes, error) {
// convert string id (from proto) to mongoDB ObjectId // convert string id (from proto) to mongoDB ObjectId
macaddress := req.GetMacaddress() macaddress := req.GetMacaddress()
networkName := req.GetNetwork() networkName := req.GetNetwork()
network, _ := functions.GetParentNetwork(networkName) network, _ := functions.GetParentNetwork(networkName)
node, err := GetNode(macaddress, networkName) node, err := GetNode(macaddress, networkName)
@@ -30,31 +31,30 @@ func (s *NodeServiceServer) ReadNode(ctx context.Context, req *nodepb.ReadNodeRe
} }
/* /*
if node == nil { if node == nil {
return nil, status.Errorf(codes.NotFound, fmt.Sprintf("Could not find node with Mac Address %s: %v", req.GetMacaddress(), err)) return nil, status.Errorf(codes.NotFound, fmt.Sprintf("Could not find node with Mac Address %s: %v", req.GetMacaddress(), err))
} }
*/ */
// Cast to ReadNodeRes type // Cast to ReadNodeRes type
response := &nodepb.ReadNodeRes{ response := &nodepb.ReadNodeRes{
Node: &nodepb.Node{ Node: &nodepb.Node{
Macaddress: node.MacAddress, Macaddress: node.MacAddress,
Name: node.Name, Name: node.Name,
Address: node.Address, Address: node.Address,
Endpoint: node.Endpoint, Endpoint: node.Endpoint,
Password: node.Password, Password: node.Password,
Nodenetwork: node.Network, Nodenetwork: node.Network,
Interface: node.Interface, Interface: node.Interface,
Localaddress: node.LocalAddress, Localaddress: node.LocalAddress,
Postdown: node.PostDown, Postdown: node.PostDown,
Postup: node.PostUp, Postup: node.PostUp,
Checkininterval: node.CheckInInterval, Checkininterval: node.CheckInInterval,
Ispending: node.IsPending, Ispending: node.IsPending,
Publickey: node.PublicKey, Publickey: node.PublicKey,
Listenport: node.ListenPort, Listenport: node.ListenPort,
Keepalive: node.PersistentKeepalive, Keepalive: node.PersistentKeepalive,
Islocal: *network.IsLocal, Islocal: *network.IsLocal,
Localrange: network.LocalRange, Localrange: network.LocalRange,
}, },
} }
return response, nil return response, nil
@@ -67,54 +67,52 @@ func (s *NodeServiceServer) CreateNode(ctx context.Context, req *nodepb.CreateNo
// Now we have to convert this into a NodeItem type to convert into BSON // Now we have to convert this into a NodeItem type to convert into BSON
node := models.Node{ node := models.Node{
// ID: primitive.NilObjectID, // ID: primitive.NilObjectID,
MacAddress: data.GetMacaddress(), MacAddress: data.GetMacaddress(),
LocalAddress: data.GetLocaladdress(), LocalAddress: data.GetLocaladdress(),
Name: data.GetName(), Name: data.GetName(),
Address: data.GetAddress(), Address: data.GetAddress(),
AccessKey: data.GetAccesskey(), AccessKey: data.GetAccesskey(),
Endpoint: data.GetEndpoint(), Endpoint: data.GetEndpoint(),
PersistentKeepalive: data.GetKeepalive(), PersistentKeepalive: data.GetKeepalive(),
Password: data.GetPassword(), Password: data.GetPassword(),
Interface: data.GetInterface(), Interface: data.GetInterface(),
Network: data.GetNodenetwork(), Network: data.GetNodenetwork(),
IsPending: data.GetIspending(), IsPending: data.GetIspending(),
PublicKey: data.GetPublickey(), PublicKey: data.GetPublickey(),
ListenPort: data.GetListenport(), ListenPort: data.GetListenport(),
} }
err := ValidateNodeCreate(node.Network, node) err := ValidateNodeCreate(node.Network, node)
if err != nil { if err != nil {
// return internal gRPC error to be handled later // return internal gRPC error to be handled later
return nil, err return nil, err
} }
//Check to see if key is valid //Check to see if key is valid
//TODO: Triple inefficient!!! This is the third call to the DB we make for networks //TODO: Triple inefficient!!! This is the third call to the DB we make for networks
validKey := functions.IsKeyValid(node.Network, node.AccessKey) validKey := functions.IsKeyValid(node.Network, node.AccessKey)
network, err := functions.GetParentNetwork(node.Network) network, err := functions.GetParentNetwork(node.Network)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, fmt.Sprintf("Could not find network: %v", err)) return nil, status.Errorf(codes.NotFound, fmt.Sprintf("Could not find network: %v", err))
} else { } else {
fmt.Println("Creating node in network " + network.NetID) fmt.Println("Creating node in network " + network.NetID)
fmt.Println("Network is local? " + strconv.FormatBool(*network.IsLocal)) fmt.Println("Network is local? " + strconv.FormatBool(*network.IsLocal))
fmt.Println("Range if local: " + network.LocalRange) fmt.Println("Range if local: " + network.LocalRange)
} }
if !validKey {
//Check to see if network will allow manual sign up
if !validKey { //may want to switch this up with the valid key check and avoid a DB call that way.
//Check to see if network will allow manual sign up if *network.AllowManualSignUp {
//may want to switch this up with the valid key check and avoid a DB call that way. node.IsPending = true
if *network.AllowManualSignUp { } else {
node.IsPending = true return nil, status.Errorf(
} else { codes.Internal,
return nil, status.Errorf(
codes.Internal,
fmt.Sprintf("Invalid key, and network does not allow no-key signups"), fmt.Sprintf("Invalid key, and network does not allow no-key signups"),
) )
} }
} }
node, err = CreateNode(node, node.Network) node, err = CreateNode(node, node.Network)
@@ -128,118 +126,114 @@ func (s *NodeServiceServer) CreateNode(ctx context.Context, req *nodepb.CreateNo
// return the node in a CreateNodeRes type // return the node in a CreateNodeRes type
response := &nodepb.CreateNodeRes{ response := &nodepb.CreateNodeRes{
Node: &nodepb.Node{ Node: &nodepb.Node{
Macaddress: node.MacAddress, Macaddress: node.MacAddress,
Localaddress: node.LocalAddress, Localaddress: node.LocalAddress,
Name: node.Name, Name: node.Name,
Address: node.Address, Address: node.Address,
Endpoint: node.Endpoint, Endpoint: node.Endpoint,
Password: node.Password, Password: node.Password,
Interface: node.Interface, Interface: node.Interface,
Nodenetwork: node.Network, Nodenetwork: node.Network,
Ispending: node.IsPending, Ispending: node.IsPending,
Publickey: node.PublicKey, Publickey: node.PublicKey,
Listenport: node.ListenPort, Listenport: node.ListenPort,
Keepalive: node.PersistentKeepalive, Keepalive: node.PersistentKeepalive,
Islocal: *network.IsLocal, Islocal: *network.IsLocal,
Localrange: network.LocalRange, Localrange: network.LocalRange,
}, },
} }
err = SetNetworkNodesLastModified(node.Network) err = SetNetworkNodesLastModified(node.Network)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, fmt.Sprintf("Could not update network last modified date: %v", err)) return nil, status.Errorf(codes.NotFound, fmt.Sprintf("Could not update network last modified date: %v", err))
} }
return response, nil return response, nil
} }
func (s *NodeServiceServer) CheckIn(ctx context.Context, req *nodepb.CheckInReq) (*nodepb.CheckInRes, error) { func (s *NodeServiceServer) CheckIn(ctx context.Context, req *nodepb.CheckInReq) (*nodepb.CheckInRes, error) {
// Get the protobuf node type from the protobuf request type // Get the protobuf node type from the protobuf request type
// Essentially doing req.Node to access the struct with a nil check // Essentially doing req.Node to access the struct with a nil check
data := req.GetNode() data := req.GetNode()
//postchanges := req.GetPostchanges() //postchanges := req.GetPostchanges()
// Now we have to convert this into a NodeItem type to convert into BSON // Now we have to convert this into a NodeItem type to convert into BSON
node := models.Node{ node := models.Node{
// ID: primitive.NilObjectID, // ID: primitive.NilObjectID,
MacAddress: data.GetMacaddress(), MacAddress: data.GetMacaddress(),
Address: data.GetAddress(), Address: data.GetAddress(),
Endpoint: data.GetEndpoint(), Endpoint: data.GetEndpoint(),
Network: data.GetNodenetwork(), Network: data.GetNodenetwork(),
Password: data.GetPassword(), Password: data.GetPassword(),
LocalAddress: data.GetLocaladdress(), LocalAddress: data.GetLocaladdress(),
ListenPort: data.GetListenport(), ListenPort: data.GetListenport(),
PersistentKeepalive: data.GetKeepalive(), PersistentKeepalive: data.GetKeepalive(),
PublicKey: data.GetPublickey(), PublicKey: data.GetPublickey(),
} }
checkinresponse, err := NodeCheckIn(node, node.Network) checkinresponse, err := NodeCheckIn(node, node.Network)
if err != nil { if err != nil {
// return internal gRPC error to be handled later // return internal gRPC error to be handled later
if checkinresponse == (models.CheckInResponse{}) || !checkinresponse.IsPending { if checkinresponse == (models.CheckInResponse{}) || !checkinresponse.IsPending {
return nil, status.Errorf( return nil, status.Errorf(
codes.Internal, codes.Internal,
fmt.Sprintf("Internal error: %v", err), fmt.Sprintf("Internal error: %v", err),
) )
} }
} }
// return the node in a CreateNodeRes type // return the node in a CreateNodeRes type
response := &nodepb.CheckInRes{ response := &nodepb.CheckInRes{
Checkinresponse: &nodepb.CheckInResponse{ Checkinresponse: &nodepb.CheckInResponse{
Success: checkinresponse.Success, Success: checkinresponse.Success,
Needpeerupdate: checkinresponse.NeedPeerUpdate, Needpeerupdate: checkinresponse.NeedPeerUpdate,
Needdelete: checkinresponse.NeedDelete, Needdelete: checkinresponse.NeedDelete,
Needconfigupdate: checkinresponse.NeedConfigUpdate, Needconfigupdate: checkinresponse.NeedConfigUpdate,
Needkeyupdate: checkinresponse.NeedKeyUpdate, Needkeyupdate: checkinresponse.NeedKeyUpdate,
Nodemessage: checkinresponse.NodeMessage, Nodemessage: checkinresponse.NodeMessage,
Ispending: checkinresponse.IsPending, Ispending: checkinresponse.IsPending,
}, },
} }
return response, nil return response, nil
} }
func (s *NodeServiceServer) UpdateNode(ctx context.Context, req *nodepb.UpdateNodeReq) (*nodepb.UpdateNodeRes, error) { func (s *NodeServiceServer) UpdateNode(ctx context.Context, req *nodepb.UpdateNodeReq) (*nodepb.UpdateNodeRes, error) {
// Get the node data from the request // Get the node data from the request
data := req.GetNode() data := req.GetNode()
// Now we have to convert this into a NodeItem type to convert into BSON // Now we have to convert this into a NodeItem type to convert into BSON
nodechange := models.Node{ nodechange := models.NodeUpdate{
// ID: primitive.NilObjectID, // ID: primitive.NilObjectID,
MacAddress: data.GetMacaddress(), MacAddress: data.GetMacaddress(),
Name: data.GetName(), Name: data.GetName(),
Address: data.GetAddress(), Address: data.GetAddress(),
LocalAddress: data.GetLocaladdress(), LocalAddress: data.GetLocaladdress(),
Endpoint: data.GetEndpoint(), Endpoint: data.GetEndpoint(),
Password: data.GetPassword(), Password: data.GetPassword(),
PersistentKeepalive: data.GetKeepalive(), PersistentKeepalive: data.GetKeepalive(),
Network: data.GetNodenetwork(), Network: data.GetNodenetwork(),
Interface: data.GetInterface(), Interface: data.GetInterface(),
PostDown: data.GetPostdown(), PostDown: data.GetPostdown(),
PostUp: data.GetPostup(), PostUp: data.GetPostup(),
IsPending: data.GetIspending(), IsPending: data.GetIspending(),
PublicKey: data.GetPublickey(), PublicKey: data.GetPublickey(),
ListenPort: data.GetListenport(), ListenPort: data.GetListenport(),
} }
// Convert the Id string to a MongoDB ObjectId // Convert the Id string to a MongoDB ObjectId
macaddress := nodechange.MacAddress macaddress := nodechange.MacAddress
networkName := nodechange.Network networkName := nodechange.Network
network, _ := functions.GetParentNetwork(networkName) network, _ := functions.GetParentNetwork(networkName)
err := ValidateNodeUpdate(networkName, nodechange) err := ValidateNodeUpdate(networkName, nodechange)
if err != nil { if err != nil {
return nil, err return nil, err
}
node, err := functions.GetNodeByMacAddress(networkName, macaddress)
if err != nil {
return nil, status.Errorf(
codes.NotFound,
fmt.Sprintf("Could not find node with supplied Mac Address: %v", err),
)
} }
node, err := functions.GetNodeByMacAddress(networkName, macaddress)
if err != nil {
return nil, status.Errorf(
codes.NotFound,
fmt.Sprintf("Could not find node with supplied Mac Address: %v", err),
)
}
newnode, err := UpdateNode(nodechange, node) newnode, err := UpdateNode(nodechange, node)
@@ -251,23 +245,22 @@ func (s *NodeServiceServer) UpdateNode(ctx context.Context, req *nodepb.UpdateNo
} }
return &nodepb.UpdateNodeRes{ return &nodepb.UpdateNodeRes{
Node: &nodepb.Node{ Node: &nodepb.Node{
Macaddress: newnode.MacAddress, Macaddress: newnode.MacAddress,
Localaddress: newnode.LocalAddress, Localaddress: newnode.LocalAddress,
Name: newnode.Name, Name: newnode.Name,
Address: newnode.Address, Address: newnode.Address,
Endpoint: newnode.Endpoint, Endpoint: newnode.Endpoint,
Password: newnode.Password, Password: newnode.Password,
Interface: newnode.Interface, Interface: newnode.Interface,
Postdown: newnode.PostDown, Postdown: newnode.PostDown,
Postup: newnode.PostUp, Postup: newnode.PostUp,
Nodenetwork: newnode.Network, Nodenetwork: newnode.Network,
Ispending: newnode.IsPending, Ispending: newnode.IsPending,
Publickey: newnode.PublicKey, Publickey: newnode.PublicKey,
Listenport: newnode.ListenPort, Listenport: newnode.ListenPort,
Keepalive: newnode.PersistentKeepalive, Keepalive: newnode.PersistentKeepalive,
Islocal: *network.IsLocal, Islocal: *network.IsLocal,
Localrange: network.LocalRange, Localrange: network.LocalRange,
}, },
}, nil }, nil
} }
@@ -287,12 +280,11 @@ func (s *NodeServiceServer) DeleteNode(ctx context.Context, req *nodepb.DeleteNo
fmt.Println("updating network last modified of " + req.GetNetworkName()) fmt.Println("updating network last modified of " + req.GetNetworkName())
err = SetNetworkNodesLastModified(req.GetNetworkName()) err = SetNetworkNodesLastModified(req.GetNetworkName())
if err != nil { if err != nil {
fmt.Println("Error updating Network") fmt.Println("Error updating Network")
fmt.Println(err) fmt.Println(err)
return nil, status.Errorf(codes.NotFound, fmt.Sprintf("Could not update network last modified date: %v", err)) return nil, status.Errorf(codes.NotFound, fmt.Sprintf("Could not update network last modified date: %v", err))
} }
return &nodepb.DeleteNodeRes{ return &nodepb.DeleteNodeRes{
Success: true, Success: true,
@@ -310,34 +302,32 @@ func (s *NodeServiceServer) GetPeers(req *nodepb.GetPeersReq, stream nodepb.Node
return status.Errorf(codes.Internal, fmt.Sprintf("Unknown internal error: %v", err)) return status.Errorf(codes.Internal, fmt.Sprintf("Unknown internal error: %v", err))
} }
// cursor.Next() returns a boolean, if false there are no more items and loop will break // cursor.Next() returns a boolean, if false there are no more items and loop will break
for i := 0; i < len(peers); i++ { for i := 0; i < len(peers); i++ {
// If no error is found send node over stream // If no error is found send node over stream
stream.Send(&nodepb.GetPeersRes{ stream.Send(&nodepb.GetPeersRes{
Peers: &nodepb.PeersResponse{ Peers: &nodepb.PeersResponse{
Address: peers[i].Address, Address: peers[i].Address,
Endpoint: peers[i].Endpoint, Endpoint: peers[i].Endpoint,
Gatewayrange: peers[i].GatewayRange, Gatewayrange: peers[i].GatewayRange,
Isgateway: peers[i].IsGateway, Isgateway: peers[i].IsGateway,
Publickey: peers[i].PublicKey, Publickey: peers[i].PublicKey,
Keepalive: peers[i].KeepAlive, Keepalive: peers[i].KeepAlive,
Listenport: peers[i].ListenPort, Listenport: peers[i].ListenPort,
Localaddress: peers[i].LocalAddress, Localaddress: peers[i].LocalAddress,
}, },
}) })
} }
node, err := functions.GetNodeByMacAddress(req.GetNetwork(), req.GetMacaddress()) node, err := functions.GetNodeByMacAddress(req.GetNetwork(), req.GetMacaddress())
if err != nil { if err != nil {
return status.Errorf(codes.Internal, fmt.Sprintf("Could not get node: %v", err)) return status.Errorf(codes.Internal, fmt.Sprintf("Could not get node: %v", err))
} }
err = TimestampNode(node, false, true, false) err = TimestampNode(node, false, true, false)
if err != nil { if err != nil {
return status.Errorf(codes.Internal, fmt.Sprintf("Internal error occurred: %v", err)) return status.Errorf(codes.Internal, fmt.Sprintf("Internal error occurred: %v", err))
} }
return nil return nil
} }

View File

@@ -689,7 +689,7 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
return return
} }
var nodechange models.Node var nodechange models.NodeUpdate
// we decode our body request params // we decode our body request params
_ = json.NewDecoder(r.Body).Decode(&nodechange) _ = json.NewDecoder(r.Body).Decode(&nodechange)

View File

@@ -18,6 +18,38 @@ var seededRand *rand.Rand = rand.New(
//node struct //node struct
type Node struct { type Node struct {
ID primitive.ObjectID `json:"_id,omitempty" bson:"_id,omitempty"`
Address string `json:"address" bson:"address" validate:"omitempty,ipv4"`
Address6 string `json:"address6" bson:"address6" validate:"omitempty,ipv6"`
LocalAddress string `json:"localaddress" bson:"localaddress" validate:"omitempty,ip"`
Name string `json:"name" bson:"name" validate:"omitempty,alphanum,max=12"`
ListenPort int32 `json:"listenport" bson:"listenport" validate:"omitempty,numeric,min=1024,max=65535"`
PublicKey string `json:"publickey" bson:"publickey" validate:"required,base64"`
Endpoint string `json:"endpoint" bson:"endpoint" validate:"required,ip"`
PostUp string `json:"postup" bson:"postup"`
PostDown string `json:"postdown" bson:"postdown"`
AllowedIPs string `json:"allowedips" bson:"allowedips"`
PersistentKeepalive int32 `json:"persistentkeepalive" bson:"persistentkeepalive" validate:"omitempty,numeric,max=1000"`
SaveConfig *bool `json:"saveconfig" bson:"saveconfig"`
AccessKey string `json:"accesskey" bson:"accesskey"`
Interface string `json:"interface" bson:"interface"`
LastModified int64 `json:"lastmodified" bson:"lastmodified"`
KeyUpdateTimeStamp int64 `json:"keyupdatetimestamp" bson:"keyupdatetimestamp"`
ExpirationDateTime int64 `json:"expdatetime" bson:"expdatetime"`
LastPeerUpdate int64 `json:"lastpeerupdate" bson:"lastpeerupdate"`
LastCheckIn int64 `json:"lastcheckin" bson:"lastcheckin"`
MacAddress string `json:"macaddress" bson:"macaddress" validate:"required,mac,macaddress_unique"`
CheckInInterval int32 `json:"checkininterval" bson:"checkininterval"`
Password string `json:"password" bson:"password" validate:"required,min=6"`
Network string `json:"network" bson:"network" validate:"network_exists"`
IsPending bool `json:"ispending" bson:"ispending"`
IsGateway bool `json:"isgateway" bson:"isgateway"`
GatewayRange string `json:"gatewayrange" bson:"gatewayrange"`
PostChanges string `json:"postchanges" bson:"postchanges"`
}
//node update struct --- only validations are different
type NodeUpdate struct {
ID primitive.ObjectID `json:"_id,omitempty" bson:"_id,omitempty"` ID primitive.ObjectID `json:"_id,omitempty" bson:"_id,omitempty"`
Address string `json:"address" bson:"address" validate:"address_check"` Address string `json:"address" bson:"address" validate:"address_check"`
Address6 string `json:"address6" bson:"address6" validate:"address6_check"` Address6 string `json:"address6" bson:"address6" validate:"address6_check"`
@@ -48,6 +80,29 @@ type Node struct {
PostChanges string `json:"postchanges" bson:"postchanges"` PostChanges string `json:"postchanges" bson:"postchanges"`
} }
//Duplicated function for NodeUpdates
func (node *NodeUpdate) GetNetwork() (Network, error) {
var network Network
collection := mongoconn.NetworkDB
//collection := mongoconn.Client.Database("netmaker").Collection("networks")
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
filter := bson.M{"netid": node.Network}
err := collection.FindOne(ctx, filter).Decode(&network)
defer cancel()
if err != nil {
//log.Fatal(err)
return network, err
}
return network, err
}
//TODO: Contains a fatal error return. Need to change //TODO: Contains a fatal error return. Need to change
//Used in contexts where it's not the Parent network. //Used in contexts where it's not the Parent network.
func (node *Node) GetNetwork() (Network, error) { func (node *Node) GetNetwork() (Network, error) {