added error on attempt to update NetID

This commit is contained in:
Matthew R Kasun
2021-04-15 14:48:10 -04:00
parent b4afb373bf
commit 66fb590be2
2 changed files with 388 additions and 387 deletions

View File

@@ -1,22 +1,23 @@
package controller package controller
import ( import (
"gopkg.in/go-playground/validator.v9"
"github.com/gravitl/netmaker/models"
"errors"
"encoding/base64"
"github.com/gravitl/netmaker/functions"
"github.com/gravitl/netmaker/mongoconn"
"time"
"strings"
"fmt"
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"errors"
"fmt"
"net/http" "net/http"
"strings"
"time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/gravitl/netmaker/config"
"github.com/gravitl/netmaker/functions"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/mongoconn"
"go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/options"
"github.com/gravitl/netmaker/config" "gopkg.in/go-playground/validator.v9"
) )
func networkHandlers(r *mux.Router) { func networkHandlers(r *mux.Router) {
@@ -79,6 +80,7 @@ func securityCheck(next http.Handler) http.HandlerFunc {
} }
} }
} }
//Consider a more secure way of setting master key //Consider a more secure way of setting master key
func authenticateMaster(tokenString string) bool { func authenticateMaster(tokenString string) bool {
if tokenString == config.Config.Server.MasterKey { if tokenString == config.Config.Server.MasterKey {
@@ -119,7 +121,9 @@ func validateNetwork(operation string, network models.Network) error {
_ = v.RegisterValidation("netid_valid", func(fl validator.FieldLevel) bool { _ = v.RegisterValidation("netid_valid", func(fl validator.FieldLevel) bool {
isFieldUnique := false isFieldUnique := false
inCharSet := false inCharSet := false
if operation == "update" { isFieldUnique = true } else{ if operation == "update" {
isFieldUnique = true
} else {
isFieldUnique, _ = functions.IsNetworkNameUnique(fl.Field().String()) isFieldUnique, _ = functions.IsNetworkNameUnique(fl.Field().String())
inCharSet = functions.NameInNetworkCharSet(fl.Field().String()) inCharSet = functions.NameInNetworkCharSet(fl.Field().String())
} }
@@ -161,7 +165,7 @@ func getNetwork(w http.ResponseWriter, r *http.Request) {
defer cancel() defer cancel()
if err != nil { if err != nil {
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -178,11 +182,10 @@ func keyUpdate(w http.ResponseWriter, r *http.Request) {
network, err := functions.GetParentNetwork(params["networkname"]) network, err := functions.GetParentNetwork(params["networkname"])
if err != nil { if err != nil {
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
network.KeyUpdateTimeStamp = time.Now().Unix() network.KeyUpdateTimeStamp = time.Now().Unix()
collection := mongoconn.Client.Database("netmaker").Collection("networks") collection := mongoconn.Client.Database("netmaker").Collection("networks")
@@ -214,7 +217,7 @@ func keyUpdate(w http.ResponseWriter, r *http.Request) {
defer cancel() defer cancel()
if err != nil { if err != nil {
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
@@ -223,13 +226,12 @@ func keyUpdate(w http.ResponseWriter, r *http.Request) {
} }
//Update a network //Update a network
func AlertNetwork(netid string) error{ func AlertNetwork(netid string) error {
collection := mongoconn.Client.Database("netmaker").Collection("networks") collection := mongoconn.Client.Database("netmaker").Collection("networks")
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
filter := bson.M{"netid": netid} filter := bson.M{"netid": netid}
var network models.Network var network models.Network
network, err := functions.GetParentNetwork(netid) network, err := functions.GetParentNetwork(netid)
@@ -261,7 +263,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
network, err := functions.GetParentNetwork(params["networkname"]) network, err := functions.GetParentNetwork(params["networkname"])
if err != nil { if err != nil {
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
@@ -280,15 +282,18 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
networkChange.NetID = network.NetID networkChange.NetID = network.NetID
} }
//err = validateNetwork("update", networkChange) //err = validateNetwork("update", networkChange)
if err != nil { if err != nil {
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
//NOTE: Network.NetID is intentionally NOT editable. It acts as a static ID for the network. //NOTE: Network.NetID is intentionally NOT editable. It acts as a static ID for the network.
//DisplayName can be changed instead, which is what shows on the front end //DisplayName can be changed instead, which is what shows on the front end
if networkChange.NetID != network.NetID {
returnErrorResponse(w, r, formatError(errors.New("NetID is not editable"), "badrequest"))
return
}
if networkChange.AddressRange != "" { if networkChange.AddressRange != "" {
@@ -297,7 +302,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
var isAddressOK bool = functions.IsIpv4CIDR(networkChange.AddressRange) var isAddressOK bool = functions.IsIpv4CIDR(networkChange.AddressRange)
if !isAddressOK { if !isAddressOK {
err := errors.New("Invalid Range of " + networkChange.AddressRange + " for addresses.") err := errors.New("Invalid Range of " + networkChange.AddressRange + " for addresses.")
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
haschange = true haschange = true
@@ -310,7 +315,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
var isAddressOK bool = functions.IsIpv4CIDR(networkChange.LocalRange) var isAddressOK bool = functions.IsIpv4CIDR(networkChange.LocalRange)
if !isAddressOK { if !isAddressOK {
err := errors.New("Invalid Range of " + networkChange.LocalRange + " for internal addresses.") err := errors.New("Invalid Range of " + networkChange.LocalRange + " for internal addresses.")
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
haschange = true haschange = true
@@ -384,7 +389,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
defer cancel() defer cancel()
if err != nil { if err != nil {
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
@@ -393,20 +398,20 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
if hasrangeupdate { if hasrangeupdate {
err = functions.UpdateNetworkNodeAddresses(params["networkname"]) err = functions.UpdateNetworkNodeAddresses(params["networkname"])
if err != nil { if err != nil {
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
} }
if haslocalrangeupdate { if haslocalrangeupdate {
err = functions.UpdateNetworkPrivateAddresses(params["networkname"]) err = functions.UpdateNetworkPrivateAddresses(params["networkname"])
if err != nil { if err != nil {
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
} }
returnnetwork, err := functions.GetParentNetwork(network.NetID) returnnetwork, err := functions.GetParentNetwork(network.NetID)
if err != nil { if err != nil {
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
@@ -445,7 +450,7 @@ func deleteNetwork(w http.ResponseWriter, r *http.Request) {
defer cancel() defer cancel()
if err != nil { if err != nil {
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
@@ -464,7 +469,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
// we decode our body request params // we decode our body request params
err := json.NewDecoder(r.Body).Decode(&network) err := json.NewDecoder(r.Body).Decode(&network)
if err != nil { if err != nil {
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
@@ -477,7 +482,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
//err = validateNetwork("create", network) //err = validateNetwork("create", network)
if err != nil { if err != nil {
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
network.SetDefaults() network.SetDefaults()
@@ -488,14 +493,13 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
collection := mongoconn.Client.Database("netmaker").Collection("networks") collection := mongoconn.Client.Database("netmaker").Collection("networks")
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
// insert our network into the network table // insert our network into the network table
result, err := collection.InsertOne(ctx, network) result, err := collection.InsertOne(ctx, network)
defer cancel() defer cancel()
if err != nil { if err != nil {
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -504,7 +508,6 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
// BEGIN KEY MANAGEMENT SECTION // BEGIN KEY MANAGEMENT SECTION
//TODO: Very little error handling //TODO: Very little error handling
//accesskey is created as a json string inside the Network collection item in mongo //accesskey is created as a json string inside the Network collection item in mongo
func createAccessKey(w http.ResponseWriter, r *http.Request) { func createAccessKey(w http.ResponseWriter, r *http.Request) {
@@ -519,13 +522,13 @@ func createAccessKey(w http.ResponseWriter, r *http.Request) {
//start here //start here
network, err := functions.GetParentNetwork(params["networkname"]) network, err := functions.GetParentNetwork(params["networkname"])
if err != nil { if err != nil {
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
err = json.NewDecoder(r.Body).Decode(&accesskey) err = json.NewDecoder(r.Body).Decode(&accesskey)
if err != nil { if err != nil {
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
@@ -540,7 +543,7 @@ func createAccessKey(w http.ResponseWriter, r *http.Request) {
} }
gconf, err := functions.GetGlobalConfig() gconf, err := functions.GetGlobalConfig()
if err != nil { if err != nil {
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
@@ -549,7 +552,6 @@ func createAccessKey(w http.ResponseWriter, r *http.Request) {
privAddr = network.LocalRange privAddr = network.LocalRange
} }
netID := params["networkname"] netID := params["networkname"]
address := gconf.ServerGRPC + gconf.PortGRPC address := gconf.ServerGRPC + gconf.PortGRPC
@@ -580,7 +582,7 @@ func createAccessKey(w http.ResponseWriter, r *http.Request) {
defer cancel() defer cancel()
if err != nil { if err != nil {
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -609,13 +611,13 @@ func getAccessKeys(w http.ResponseWriter, r *http.Request) {
defer cancel() defer cancel()
if err != nil { if err != nil {
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
keydata, err := json.Marshal(network.AccessKeys) keydata, err := json.Marshal(network.AccessKeys)
if err != nil { if err != nil {
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
@@ -625,7 +627,6 @@ func getAccessKeys(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(keys) json.NewEncoder(w).Encode(keys)
} }
//delete key. Has to do a little funky logic since it's not a collection item //delete key. Has to do a little funky logic since it's not a collection item
func deleteAccessKey(w http.ResponseWriter, r *http.Request) { func deleteAccessKey(w http.ResponseWriter, r *http.Request) {
@@ -639,14 +640,14 @@ func deleteAccessKey(w http.ResponseWriter, r *http.Request) {
//start here //start here
network, err := functions.GetParentNetwork(params["networkname"]) network, err := functions.GetParentNetwork(params["networkname"])
if err != nil { if err != nil {
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
//basically, turn the list of access keys into the list of access keys before and after the item //basically, turn the list of access keys into the list of access keys before and after the item
//have not done any error handling for if there's like...1 item. I think it works? need to test. //have not done any error handling for if there's like...1 item. I think it works? need to test.
for i := len(network.AccessKeys) - 1; i >= 0; i-- { for i := len(network.AccessKeys) - 1; i >= 0; i-- {
currentkey:= network.AccessKeys[i] currentkey := network.AccessKeys[i]
if currentkey.Name == keyname { if currentkey.Name == keyname {
network.AccessKeys = append(network.AccessKeys[:i], network.AccessKeys = append(network.AccessKeys[:i],
network.AccessKeys[i+1:]...) network.AccessKeys[i+1:]...)
@@ -672,13 +673,13 @@ func deleteAccessKey(w http.ResponseWriter, r *http.Request) {
defer cancel() defer cancel()
if err != nil { if err != nil {
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
var keys []models.AccessKey var keys []models.AccessKey
keydata, err := json.Marshal(network.AccessKeys) keydata, err := json.Marshal(network.AccessKeys)
if err != nil { if err != nil {
returnErrorResponse(w,r,formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }

View File

@@ -341,13 +341,13 @@ func TestUpdatenetwork(t *testing.T) {
network.NetID = "wirecat" network.NetID = "wirecat"
response, err := api(t, network, http.MethodPut, baseURL+"/api/networks/skynet", "secretkey") response, err := api(t, network, http.MethodPut, baseURL+"/api/networks/skynet", "secretkey")
assert.Nil(t, err, err) assert.Nil(t, err, err)
assert.Equal(t, http.StatusOK, response.StatusCode) assert.Equal(t, http.StatusBadRequest, response.StatusCode)
defer response.Body.Close() defer response.Body.Close()
err = json.NewDecoder(response.Body).Decode(&returnedNetwork) var message models.ErrorResponse
err = json.NewDecoder(response.Body).Decode(&message)
assert.Nil(t, err, err) assert.Nil(t, err, err)
//returns previous value not the updated value assert.Equal(t, http.StatusBadRequest, message.Code)
// ----- needs fixing ----- assert.Equal(t, "NetID is not editable", message.message)
//assert.Equal(t, network.NetID, returnedNetwork.NetID)
}) })
t.Run("NetIDInvalidCredentials", func(t *testing.T) { t.Run("NetIDInvalidCredentials", func(t *testing.T) {
type Network struct { type Network struct {