refactoring for ee

This commit is contained in:
afeiszli
2022-09-14 13:26:31 -04:00
parent 8a1ba674a7
commit b670755cce
35 changed files with 473 additions and 504 deletions

View File

@@ -25,7 +25,6 @@ var HttpHandlers = []interface{}{
serverHandlers, serverHandlers,
extClientHandlers, extClientHandlers,
ipHandlers, ipHandlers,
metricHandlers,
loggerHandlers, loggerHandlers,
userGroupsHandlers, userGroupsHandlers,
networkUsersHandlers, networkUsersHandlers,

View File

@@ -16,13 +16,13 @@ import (
func dnsHandlers(r *mux.Router) { func dnsHandlers(r *mux.Router) {
r.HandleFunc("/api/dns", securityCheck(true, http.HandlerFunc(getAllDNS))).Methods("GET") r.HandleFunc("/api/dns", logic.SecurityCheck(true, http.HandlerFunc(getAllDNS))).Methods("GET")
r.HandleFunc("/api/dns/adm/{network}/nodes", securityCheck(false, http.HandlerFunc(getNodeDNS))).Methods("GET") r.HandleFunc("/api/dns/adm/{network}/nodes", logic.SecurityCheck(false, http.HandlerFunc(getNodeDNS))).Methods("GET")
r.HandleFunc("/api/dns/adm/{network}/custom", securityCheck(false, http.HandlerFunc(getCustomDNS))).Methods("GET") r.HandleFunc("/api/dns/adm/{network}/custom", logic.SecurityCheck(false, http.HandlerFunc(getCustomDNS))).Methods("GET")
r.HandleFunc("/api/dns/adm/{network}", securityCheck(false, http.HandlerFunc(getDNS))).Methods("GET") r.HandleFunc("/api/dns/adm/{network}", logic.SecurityCheck(false, http.HandlerFunc(getDNS))).Methods("GET")
r.HandleFunc("/api/dns/{network}", securityCheck(false, http.HandlerFunc(createDNS))).Methods("POST") r.HandleFunc("/api/dns/{network}", logic.SecurityCheck(false, http.HandlerFunc(createDNS))).Methods("POST")
r.HandleFunc("/api/dns/adm/pushdns", securityCheck(false, http.HandlerFunc(pushDNS))).Methods("POST") r.HandleFunc("/api/dns/adm/pushdns", logic.SecurityCheck(false, http.HandlerFunc(pushDNS))).Methods("POST")
r.HandleFunc("/api/dns/{network}/{domain}", securityCheck(false, http.HandlerFunc(deleteDNS))).Methods("DELETE") r.HandleFunc("/api/dns/{network}/{domain}", logic.SecurityCheck(false, http.HandlerFunc(deleteDNS))).Methods("DELETE")
} }
// swagger:route GET /api/dns/adm/{network}/nodes dns getNodeDNS // swagger:route GET /api/dns/adm/{network}/nodes dns getNodeDNS
@@ -44,7 +44,7 @@ func getNodeDNS(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get node DNS entries for network [%s]: %v", network, err)) fmt.Sprintf("failed to get node DNS entries for network [%s]: %v", network, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -68,7 +68,7 @@ func getAllDNS(w http.ResponseWriter, r *http.Request) {
dns, err := logic.GetAllDNS() dns, err := logic.GetAllDNS()
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to get all DNS entries: ", err.Error()) logger.Log(0, r.Header.Get("user"), "failed to get all DNS entries: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -98,7 +98,7 @@ func getCustomDNS(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get custom DNS entries for network [%s]: %v", network, err.Error())) fmt.Sprintf("failed to get custom DNS entries for network [%s]: %v", network, err.Error()))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -128,7 +128,7 @@ func getDNS(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get all DNS entries for network [%s]: %v", network, err.Error())) fmt.Sprintf("failed to get all DNS entries for network [%s]: %v", network, err.Error()))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -160,7 +160,7 @@ func createDNS(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("invalid DNS entry %+v: %v", entry, err)) fmt.Sprintf("invalid DNS entry %+v: %v", entry, err))
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
@@ -168,14 +168,14 @@ func createDNS(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("Failed to create DNS entry %+v: %v", entry, err)) fmt.Sprintf("Failed to create DNS entry %+v: %v", entry, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
err = logic.SetDNS() err = logic.SetDNS()
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("Failed to set DNS entries on file: %v", err)) fmt.Sprintf("Failed to set DNS entries on file: %v", err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
logger.Log(1, "new DNS record added:", entry.Name) logger.Log(1, "new DNS record added:", entry.Name)
@@ -221,7 +221,7 @@ func deleteDNS(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, "failed to delete dns entry: ", entrytext) logger.Log(0, "failed to delete dns entry: ", entrytext)
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
logger.Log(1, "deleted dns entry: ", entrytext) logger.Log(1, "deleted dns entry: ", entrytext)
@@ -229,7 +229,7 @@ func deleteDNS(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("Failed to set DNS entries on file: %v", err)) fmt.Sprintf("Failed to set DNS entries on file: %v", err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
json.NewEncoder(w).Encode(entrytext + " deleted.") json.NewEncoder(w).Encode(entrytext + " deleted.")
@@ -287,7 +287,7 @@ func pushDNS(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("Failed to set DNS entries on file: %v", err)) fmt.Sprintf("Failed to set DNS entries on file: %v", err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
logger.Log(1, r.Header.Get("user"), "pushed DNS updates to nameserver") logger.Log(1, r.Header.Get("user"), "pushed DNS updates to nameserver")

View File

@@ -21,13 +21,13 @@ import (
func extClientHandlers(r *mux.Router) { func extClientHandlers(r *mux.Router) {
r.HandleFunc("/api/extclients", securityCheck(false, http.HandlerFunc(getAllExtClients))).Methods("GET") r.HandleFunc("/api/extclients", logic.SecurityCheck(false, http.HandlerFunc(getAllExtClients))).Methods("GET")
r.HandleFunc("/api/extclients/{network}", securityCheck(false, http.HandlerFunc(getNetworkExtClients))).Methods("GET") r.HandleFunc("/api/extclients/{network}", logic.SecurityCheck(false, http.HandlerFunc(getNetworkExtClients))).Methods("GET")
r.HandleFunc("/api/extclients/{network}/{clientid}", securityCheck(false, http.HandlerFunc(getExtClient))).Methods("GET") r.HandleFunc("/api/extclients/{network}/{clientid}", logic.SecurityCheck(false, http.HandlerFunc(getExtClient))).Methods("GET")
r.HandleFunc("/api/extclients/{network}/{clientid}/{type}", netUserSecurityCheck(false, true, http.HandlerFunc(getExtClientConf))).Methods("GET") r.HandleFunc("/api/extclients/{network}/{clientid}/{type}", logic.NetUserSecurityCheck(false, true, http.HandlerFunc(getExtClientConf))).Methods("GET")
r.HandleFunc("/api/extclients/{network}/{clientid}", netUserSecurityCheck(false, true, http.HandlerFunc(updateExtClient))).Methods("PUT") r.HandleFunc("/api/extclients/{network}/{clientid}", logic.NetUserSecurityCheck(false, true, http.HandlerFunc(updateExtClient))).Methods("PUT")
r.HandleFunc("/api/extclients/{network}/{clientid}", netUserSecurityCheck(false, true, http.HandlerFunc(deleteExtClient))).Methods("DELETE") r.HandleFunc("/api/extclients/{network}/{clientid}", logic.NetUserSecurityCheck(false, true, http.HandlerFunc(deleteExtClient))).Methods("DELETE")
r.HandleFunc("/api/extclients/{network}/{nodeid}", netUserSecurityCheck(false, true, checkFreeTierLimits(clients_l, http.HandlerFunc(createExtClient)))).Methods("POST") r.HandleFunc("/api/extclients/{network}/{nodeid}", logic.NetUserSecurityCheck(false, true, checkFreeTierLimits(clients_l, http.HandlerFunc(createExtClient)))).Methods("POST")
} }
func checkIngressExists(nodeID string) bool { func checkIngressExists(nodeID string) bool {
@@ -62,7 +62,7 @@ func getNetworkExtClients(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get ext clients for network [%s]: %v", network, err)) fmt.Sprintf("failed to get ext clients for network [%s]: %v", network, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -96,16 +96,16 @@ func getAllExtClients(w http.ResponseWriter, r *http.Request) {
if marshalErr != nil { if marshalErr != nil {
logger.Log(0, "error unmarshalling networks: ", logger.Log(0, "error unmarshalling networks: ",
marshalErr.Error()) marshalErr.Error())
returnErrorResponse(w, r, formatError(marshalErr, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(marshalErr, "internal"))
return return
} }
clients := []models.ExtClient{} clients := []models.ExtClient{}
var err error var err error
if networksSlice[0] == ALL_NETWORK_ACCESS { if networksSlice[0] == logic.ALL_NETWORK_ACCESS {
clients, err = functions.GetAllExtClients() clients, err = functions.GetAllExtClients()
if err != nil && !database.IsEmptyRecord(err) { if err != nil && !database.IsEmptyRecord(err) {
logger.Log(0, "failed to get all extclients: ", err.Error()) logger.Log(0, "failed to get all extclients: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
} else { } else {
@@ -146,7 +146,7 @@ func getExtClient(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get extclient for [%s] on network [%s]: %v", logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get extclient for [%s] on network [%s]: %v",
clientid, network, err)) clientid, network, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -177,7 +177,7 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get extclient for [%s] on network [%s]: %v", logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get extclient for [%s] on network [%s]: %v",
clientid, networkid, err)) clientid, networkid, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -185,14 +185,14 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get ingress gateway node [%s] info: %v", client.IngressGatewayID, err)) fmt.Sprintf("failed to get ingress gateway node [%s] info: %v", client.IngressGatewayID, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
network, err := logic.GetParentNetwork(client.Network) network, err := logic.GetParentNetwork(client.Network)
if err != nil { if err != nil {
logger.Log(1, r.Header.Get("user"), "Could not retrieve Ingress Gateway Network", client.Network) logger.Log(1, r.Header.Get("user"), "Could not retrieve Ingress Gateway Network", client.Network)
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -258,7 +258,7 @@ Endpoint = %s
bytes, err := qrcode.Encode(config, qrcode.Medium, 220) bytes, err := qrcode.Encode(config, qrcode.Medium, 220)
if err != nil { if err != nil {
logger.Log(1, r.Header.Get("user"), "failed to encode qr code: ", err.Error()) logger.Log(1, r.Header.Get("user"), "failed to encode qr code: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
w.Header().Set("Content-Type", "image/png") w.Header().Set("Content-Type", "image/png")
@@ -266,7 +266,7 @@ Endpoint = %s
_, err = w.Write(bytes) _, err = w.Write(bytes)
if err != nil { if err != nil {
logger.Log(1, r.Header.Get("user"), "response writer error (qr) ", err.Error()) logger.Log(1, r.Header.Get("user"), "response writer error (qr) ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
return return
@@ -280,7 +280,7 @@ Endpoint = %s
_, err := fmt.Fprint(w, config) _, err := fmt.Fprint(w, config)
if err != nil { if err != nil {
logger.Log(1, r.Header.Get("user"), "response writer error (file) ", err.Error()) logger.Log(1, r.Header.Get("user"), "response writer error (file) ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
} }
return return
} }
@@ -310,7 +310,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
err := errors.New("ingress does not exist") err := errors.New("ingress does not exist")
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to create extclient on network [%s]: %v", networkName, err)) fmt.Sprintf("failed to create extclient on network [%s]: %v", networkName, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -329,7 +329,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get ingress gateway node [%s] info: %v", nodeid, err)) fmt.Sprintf("failed to get ingress gateway node [%s] info: %v", nodeid, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
extclient.IngressGatewayEndpoint = node.Endpoint + ":" + strconv.FormatInt(int64(node.ListenPort), 10) extclient.IngressGatewayEndpoint = node.Endpoint + ":" + strconv.FormatInt(int64(node.ListenPort), 10)
@@ -345,7 +345,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to create new ext client on network [%s]: %v", networkName, err)) fmt.Sprintf("failed to create new ext client on network [%s]: %v", networkName, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -355,7 +355,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
if isAdmin, err = checkProClientAccess(userID, extclient.ClientID, &parentNetwork); err != nil { if isAdmin, err = checkProClientAccess(userID, extclient.ClientID, &parentNetwork); err != nil {
logger.Log(0, userID, "attempted to create a client on network", networkName, "but they lack access") logger.Log(0, userID, "attempted to create a client on network", networkName, "but they lack access")
logic.DeleteExtClient(networkName, extclient.ClientID) logic.DeleteExtClient(networkName, extclient.ClientID)
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
if !isAdmin { if !isAdmin {
@@ -400,7 +400,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ", logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
clientid := params["clientid"] clientid := params["clientid"]
@@ -410,7 +410,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get record key for client [%s], network [%s]: %v", fmt.Sprintf("failed to get record key for client [%s], network [%s]: %v",
clientid, network, err)) clientid, network, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
data, err := database.FetchRecord(database.EXT_CLIENT_TABLE_NAME, key) data, err := database.FetchRecord(database.EXT_CLIENT_TABLE_NAME, key)
@@ -418,13 +418,13 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to fetch ext client record key [%s] from db for client [%s], network [%s]: %v", fmt.Sprintf("failed to fetch ext client record key [%s] from db for client [%s], network [%s]: %v",
key, clientid, network, err)) key, clientid, network, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
if err = json.Unmarshal([]byte(data), &oldExtClient); err != nil { if err = json.Unmarshal([]byte(data), &oldExtClient); err != nil {
logger.Log(0, "error unmarshalling extclient: ", logger.Log(0, "error unmarshalling extclient: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -435,7 +435,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
userID := r.Header.Get("user") userID := r.Header.Get("user")
_, doesOwn := doesUserOwnClient(userID, params["clientid"], networkName) _, doesOwn := doesUserOwnClient(userID, params["clientid"], networkName)
if !doesOwn { if !doesOwn {
returnErrorResponse(w, r, formatError(fmt.Errorf("user not permitted"), "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("user not permitted"), "internal"))
return return
} }
} }
@@ -457,7 +457,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to update ext client [%s], network [%s]: %v", fmt.Sprintf("failed to update ext client [%s], network [%s]: %v",
clientid, network, err)) clientid, network, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
logger.Log(0, r.Header.Get("user"), "updated ext client", newExtClient.ClientID) logger.Log(0, r.Header.Get("user"), "updated ext client", newExtClient.ClientID)
@@ -497,14 +497,14 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) {
err = errors.New("Could not delete extclient " + params["clientid"]) err = errors.New("Could not delete extclient " + params["clientid"])
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to delete extclient [%s],network [%s]: %v", clientid, network, err)) fmt.Sprintf("failed to delete extclient [%s],network [%s]: %v", clientid, network, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
ingressnode, err := logic.GetNodeByID(extclient.IngressGatewayID) ingressnode, err := logic.GetNodeByID(extclient.IngressGatewayID)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get ingress gateway node [%s] info: %v", extclient.IngressGatewayID, err)) fmt.Sprintf("failed to get ingress gateway node [%s] info: %v", extclient.IngressGatewayID, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -513,7 +513,7 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) {
userID, clientID, networkName := r.Header.Get("user"), params["clientid"], params["network"] userID, clientID, networkName := r.Header.Get("user"), params["clientid"], params["network"]
_, doesOwn := doesUserOwnClient(userID, clientID, networkName) _, doesOwn := doesUserOwnClient(userID, clientID, networkName)
if !doesOwn { if !doesOwn {
returnErrorResponse(w, r, formatError(fmt.Errorf("user not permitted"), "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("user not permitted"), "internal"))
return return
} }
} }
@@ -531,7 +531,7 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to delete extclient [%s],network [%s]: %v", clientid, network, err)) fmt.Sprintf("failed to delete extclient [%s],network [%s]: %v", clientid, network, err))
err = errors.New("Could not delete extclient " + params["clientid"]) err = errors.New("Could not delete extclient " + params["clientid"])
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -542,7 +542,7 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
"Deleted extclient client", params["clientid"], "from network", params["network"]) "Deleted extclient client", params["clientid"], "from network", params["network"])
returnSuccessResponse(w, r, params["clientid"]+" deleted.") logic.ReturnSuccessResponse(w, r, params["clientid"]+" deleted.")
} }
func checkProClientAccess(username, clientID string, network *models.Network) (bool, error) { func checkProClientAccess(username, clientID string, network *models.Network) (bool, error) {

View File

@@ -4,7 +4,6 @@ import (
"net/http" "net/http"
"github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/ee"
"github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
) )
@@ -23,32 +22,32 @@ func checkFreeTierLimits(limit_choice int, next http.Handler) http.HandlerFunc {
Code: http.StatusUnauthorized, Message: "free tier limits exceeded on networks", Code: http.StatusUnauthorized, Message: "free tier limits exceeded on networks",
} }
if ee.Limits.FreeTier { // check that free tier limits not exceeded if logic.Free_Tier && logic.Is_EE { // check that free tier limits not exceeded
if limit_choice == networks_l { if limit_choice == networks_l {
currentNetworks, err := logic.GetNetworks() currentNetworks, err := logic.GetNetworks()
if (err != nil && !database.IsEmptyRecord(err)) || len(currentNetworks) >= ee.Limits.Networks { if (err != nil && !database.IsEmptyRecord(err)) || len(currentNetworks) >= logic.Networks_Limit {
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} }
} else if limit_choice == node_l { } else if limit_choice == node_l {
nodes, err := logic.GetAllNodes() nodes, err := logic.GetAllNodes()
if (err != nil && !database.IsEmptyRecord(err)) || len(nodes) >= ee.Limits.Nodes { if (err != nil && !database.IsEmptyRecord(err)) || len(nodes) >= logic.Node_Limit {
errorResponse.Message = "free tier limits exceeded on nodes" errorResponse.Message = "free tier limits exceeded on nodes"
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} }
} else if limit_choice == users_l { } else if limit_choice == users_l {
users, err := logic.GetUsers() users, err := logic.GetUsers()
if (err != nil && !database.IsEmptyRecord(err)) || len(users) >= ee.Limits.Users { if (err != nil && !database.IsEmptyRecord(err)) || len(users) >= logic.Users_Limit {
errorResponse.Message = "free tier limits exceeded on users" errorResponse.Message = "free tier limits exceeded on users"
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} }
} else if limit_choice == clients_l { } else if limit_choice == clients_l {
clients, err := logic.GetAllExtClients() clients, err := logic.GetAllExtClients()
if (err != nil && !database.IsEmptyRecord(err)) || len(clients) >= ee.Limits.Clients { if (err != nil && !database.IsEmptyRecord(err)) || len(clients) >= logic.Clients_Limit {
errorResponse.Message = "free tier limits exceeded on external clients" errorResponse.Message = "free tier limits exceeded on external clients"
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} }
} }

View File

@@ -7,10 +7,11 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
) )
func loggerHandlers(r *mux.Router) { func loggerHandlers(r *mux.Router) {
r.HandleFunc("/api/logs", securityCheck(true, http.HandlerFunc(getLogs))).Methods("GET") r.HandleFunc("/api/logs", logic.SecurityCheck(true, http.HandlerFunc(getLogs))).Methods("GET")
} }
func getLogs(w http.ResponseWriter, r *http.Request) { func getLogs(w http.ResponseWriter, r *http.Request) {

View File

@@ -17,26 +17,20 @@ import (
"github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/servercfg"
) )
// ALL_NETWORK_ACCESS - represents all networks
const ALL_NETWORK_ACCESS = "THIS_USER_HAS_ALL"
// NO_NETWORKS_PRESENT - represents no networks
const NO_NETWORKS_PRESENT = "THIS_USER_HAS_NONE"
func networkHandlers(r *mux.Router) { func networkHandlers(r *mux.Router) {
r.HandleFunc("/api/networks", securityCheck(false, http.HandlerFunc(getNetworks))).Methods("GET") r.HandleFunc("/api/networks", logic.SecurityCheck(false, http.HandlerFunc(getNetworks))).Methods("GET")
r.HandleFunc("/api/networks", securityCheck(true, checkFreeTierLimits(networks_l, http.HandlerFunc(createNetwork)))).Methods("POST") r.HandleFunc("/api/networks", logic.SecurityCheck(true, checkFreeTierLimits(networks_l, http.HandlerFunc(createNetwork)))).Methods("POST")
r.HandleFunc("/api/networks/{networkname}", securityCheck(false, http.HandlerFunc(getNetwork))).Methods("GET") r.HandleFunc("/api/networks/{networkname}", logic.SecurityCheck(false, http.HandlerFunc(getNetwork))).Methods("GET")
r.HandleFunc("/api/networks/{networkname}", securityCheck(false, http.HandlerFunc(updateNetwork))).Methods("PUT") r.HandleFunc("/api/networks/{networkname}", logic.SecurityCheck(false, http.HandlerFunc(updateNetwork))).Methods("PUT")
r.HandleFunc("/api/networks/{networkname}/nodelimit", securityCheck(true, http.HandlerFunc(updateNetworkNodeLimit))).Methods("PUT") r.HandleFunc("/api/networks/{networkname}/nodelimit", logic.SecurityCheck(true, http.HandlerFunc(updateNetworkNodeLimit))).Methods("PUT")
r.HandleFunc("/api/networks/{networkname}", securityCheck(true, http.HandlerFunc(deleteNetwork))).Methods("DELETE") r.HandleFunc("/api/networks/{networkname}", logic.SecurityCheck(true, http.HandlerFunc(deleteNetwork))).Methods("DELETE")
r.HandleFunc("/api/networks/{networkname}/keyupdate", securityCheck(true, http.HandlerFunc(keyUpdate))).Methods("POST") r.HandleFunc("/api/networks/{networkname}/keyupdate", logic.SecurityCheck(true, http.HandlerFunc(keyUpdate))).Methods("POST")
r.HandleFunc("/api/networks/{networkname}/keys", securityCheck(false, http.HandlerFunc(createAccessKey))).Methods("POST") r.HandleFunc("/api/networks/{networkname}/keys", logic.SecurityCheck(false, http.HandlerFunc(createAccessKey))).Methods("POST")
r.HandleFunc("/api/networks/{networkname}/keys", securityCheck(false, http.HandlerFunc(getAccessKeys))).Methods("GET") r.HandleFunc("/api/networks/{networkname}/keys", logic.SecurityCheck(false, http.HandlerFunc(getAccessKeys))).Methods("GET")
r.HandleFunc("/api/networks/{networkname}/keys/{name}", securityCheck(false, http.HandlerFunc(deleteAccessKey))).Methods("DELETE") r.HandleFunc("/api/networks/{networkname}/keys/{name}", logic.SecurityCheck(false, http.HandlerFunc(deleteAccessKey))).Methods("DELETE")
// ACLs // ACLs
r.HandleFunc("/api/networks/{networkname}/acls", securityCheck(true, http.HandlerFunc(updateNetworkACL))).Methods("PUT") r.HandleFunc("/api/networks/{networkname}/acls", logic.SecurityCheck(true, http.HandlerFunc(updateNetworkACL))).Methods("PUT")
r.HandleFunc("/api/networks/{networkname}/acls", securityCheck(true, http.HandlerFunc(getNetworkACL))).Methods("GET") r.HandleFunc("/api/networks/{networkname}/acls", logic.SecurityCheck(true, http.HandlerFunc(getNetworkACL))).Methods("GET")
} }
// swagger:route GET /api/networks networks getNetworks // swagger:route GET /api/networks networks getNetworks
@@ -58,16 +52,16 @@ func getNetworks(w http.ResponseWriter, r *http.Request) {
if marshalErr != nil { if marshalErr != nil {
logger.Log(0, r.Header.Get("user"), "error unmarshalling networks: ", logger.Log(0, r.Header.Get("user"), "error unmarshalling networks: ",
marshalErr.Error()) marshalErr.Error())
returnErrorResponse(w, r, formatError(marshalErr, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(marshalErr, "badrequest"))
return return
} }
allnetworks := []models.Network{} allnetworks := []models.Network{}
var err error var err error
if networksSlice[0] == ALL_NETWORK_ACCESS { if networksSlice[0] == logic.ALL_NETWORK_ACCESS {
allnetworks, err = logic.GetNetworks() allnetworks, err = logic.GetNetworks()
if err != nil && !database.IsEmptyRecord(err) { if err != nil && !database.IsEmptyRecord(err) {
logger.Log(0, r.Header.Get("user"), "failed to fetch networks: ", err.Error()) logger.Log(0, r.Header.Get("user"), "failed to fetch networks: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
} else { } else {
@@ -110,7 +104,7 @@ func getNetwork(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to fetch network [%s] info: %v", logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to fetch network [%s] info: %v",
netname, err)) netname, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
if !servercfg.IsDisplayKeys() { if !servercfg.IsDisplayKeys() {
@@ -140,7 +134,7 @@ func keyUpdate(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to update keys for network [%s]: %v", logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to update keys for network [%s]: %v",
netname, err)) netname, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
logger.Log(2, r.Header.Get("user"), "updated key on network", netname) logger.Log(2, r.Header.Get("user"), "updated key on network", netname)
@@ -182,7 +176,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to get network info: ", logger.Log(0, r.Header.Get("user"), "failed to get network info: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
var newNetwork models.Network var newNetwork models.Network
@@ -190,7 +184,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ", logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
@@ -203,7 +197,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to update network: ", logger.Log(0, r.Header.Get("user"), "failed to update network: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
@@ -231,7 +225,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to update network [%s] ipv4 addresses: %v", fmt.Sprintf("failed to update network [%s] ipv4 addresses: %v",
network.NetID, err.Error())) network.NetID, err.Error()))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
} }
@@ -241,7 +235,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to update network [%s] ipv6 addresses: %v", fmt.Sprintf("failed to update network [%s] ipv6 addresses: %v",
network.NetID, err.Error())) network.NetID, err.Error()))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
} }
@@ -251,7 +245,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to update network [%s] local addresses: %v", fmt.Sprintf("failed to update network [%s] local addresses: %v",
network.NetID, err.Error())) network.NetID, err.Error()))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
} }
@@ -261,7 +255,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to update network [%s] hole punching: %v", fmt.Sprintf("failed to update network [%s] hole punching: %v",
network.NetID, err.Error())) network.NetID, err.Error()))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
} }
@@ -271,7 +265,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get network [%s] nodes: %v", fmt.Sprintf("failed to get network [%s] nodes: %v",
network.NetID, err.Error())) network.NetID, err.Error()))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
for _, node := range nodes { for _, node := range nodes {
@@ -305,7 +299,7 @@ func updateNetworkNodeLimit(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get network [%s] nodes: %v", fmt.Sprintf("failed to get network [%s] nodes: %v",
network.NetID, err.Error())) network.NetID, err.Error()))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -315,7 +309,7 @@ func updateNetworkNodeLimit(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ", logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
if networkChange.NodeLimit != 0 { if networkChange.NodeLimit != 0 {
@@ -324,7 +318,7 @@ func updateNetworkNodeLimit(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
"error marshalling resp: ", err.Error()) "error marshalling resp: ", err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME) database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME)
@@ -354,21 +348,21 @@ func updateNetworkACL(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to fetch ACLs for network [%s]: %v", netname, err)) fmt.Sprintf("failed to fetch ACLs for network [%s]: %v", netname, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
err = json.NewDecoder(r.Body).Decode(&networkACLChange) err = json.NewDecoder(r.Body).Decode(&networkACLChange)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ", logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
newNetACL, err := networkACLChange.Save(acls.ContainerID(netname)) newNetACL, err := networkACLChange.Save(acls.ContainerID(netname))
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to update ACLs for network [%s]: %v", netname, err)) fmt.Sprintf("failed to update ACLs for network [%s]: %v", netname, err))
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
logger.Log(1, r.Header.Get("user"), "updated ACLs for network", netname) logger.Log(1, r.Header.Get("user"), "updated ACLs for network", netname)
@@ -412,7 +406,7 @@ func getNetworkACL(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to fetch ACLs for network [%s]: %v", netname, err)) fmt.Sprintf("failed to fetch ACLs for network [%s]: %v", netname, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
logger.Log(2, r.Header.Get("user"), "fetched acl for network", netname) logger.Log(2, r.Header.Get("user"), "fetched acl for network", netname)
@@ -445,7 +439,7 @@ func deleteNetwork(w http.ResponseWriter, r *http.Request) {
} }
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to delete network [%s]: %v", network, err)) fmt.Sprintf("failed to delete network [%s]: %v", network, err))
returnErrorResponse(w, r, formatError(err, errtype)) logic.ReturnErrorResponse(w, r, logic.FormatError(err, errtype))
return return
} }
logger.Log(1, r.Header.Get("user"), "deleted network", network) logger.Log(1, r.Header.Get("user"), "deleted network", network)
@@ -475,7 +469,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ", logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
@@ -483,7 +477,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
err := errors.New("IPv4 or IPv6 CIDR required") err := errors.New("IPv4 or IPv6 CIDR required")
logger.Log(0, r.Header.Get("user"), "failed to create network: ", logger.Log(0, r.Header.Get("user"), "failed to create network: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
@@ -491,7 +485,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to create network: ", logger.Log(0, r.Header.Get("user"), "failed to create network: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
@@ -504,7 +498,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
} }
logger.Log(0, r.Header.Get("user"), "failed to create network: ", logger.Log(0, r.Header.Get("user"), "failed to create network: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
} }
@@ -537,28 +531,28 @@ func createAccessKey(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to get network info: ", logger.Log(0, r.Header.Get("user"), "failed to get network info: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
err = json.NewDecoder(r.Body).Decode(&accesskey) err = json.NewDecoder(r.Body).Decode(&accesskey)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ", logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
key, err := logic.CreateAccessKey(accesskey, network) key, err := logic.CreateAccessKey(accesskey, network)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to create access key: ", logger.Log(0, r.Header.Get("user"), "failed to create access key: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
// do not allow access key creations view API with user names // do not allow access key creations view API with user names
if _, err = logic.GetUser(key.Name); err == nil { if _, err = logic.GetUser(key.Name); err == nil {
logger.Log(0, "access key creation with invalid name attempted by", r.Header.Get("user")) logger.Log(0, "access key creation with invalid name attempted by", r.Header.Get("user"))
returnErrorResponse(w, r, formatError(fmt.Errorf("cannot create access key with user name"), "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("cannot create access key with user name"), "badrequest"))
logic.DeleteKey(key.Name, network.NetID) logic.DeleteKey(key.Name, network.NetID)
return return
} }
@@ -587,7 +581,7 @@ func getAccessKeys(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get keys for network [%s]: %v", logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get keys for network [%s]: %v",
network, err)) network, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
if !servercfg.IsDisplayKeys() { if !servercfg.IsDisplayKeys() {
@@ -621,7 +615,7 @@ func deleteAccessKey(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to delete key [%s] for network [%s]: %v", logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to delete key [%s] for network [%s]: %v",
keyname, netname, err)) keyname, netname, err))
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
logger.Log(1, r.Header.Get("user"), "deleted access key", keyname, "on network,", netname) logger.Log(1, r.Header.Get("user"), "deleted access key", keyname, "on network,", netname)

View File

@@ -182,24 +182,24 @@ func TestSecurityCheck(t *testing.T) {
initialize() initialize()
os.Setenv("MASTER_KEY", "secretkey") os.Setenv("MASTER_KEY", "secretkey")
t.Run("NoNetwork", func(t *testing.T) { t.Run("NoNetwork", func(t *testing.T) {
networks, username, err := SecurityCheck(false, "", "Bearer secretkey") networks, username, err := logic.UserPermissions(false, "", "Bearer secretkey")
assert.Nil(t, err) assert.Nil(t, err)
t.Log(networks, username) t.Log(networks, username)
}) })
t.Run("WithNetwork", func(t *testing.T) { t.Run("WithNetwork", func(t *testing.T) {
networks, username, err := SecurityCheck(false, "skynet", "Bearer secretkey") networks, username, err := logic.UserPermissions(false, "skynet", "Bearer secretkey")
assert.Nil(t, err) assert.Nil(t, err)
t.Log(networks, username) t.Log(networks, username)
}) })
t.Run("BadNet", func(t *testing.T) { t.Run("BadNet", func(t *testing.T) {
t.Skip() t.Skip()
networks, username, err := SecurityCheck(false, "badnet", "Bearer secretkey") networks, username, err := logic.UserPermissions(false, "badnet", "Bearer secretkey")
assert.NotNil(t, err) assert.NotNil(t, err)
t.Log(err) t.Log(err)
t.Log(networks, username) t.Log(networks, username)
}) })
t.Run("BadToken", func(t *testing.T) { t.Run("BadToken", func(t *testing.T) {
networks, username, err := SecurityCheck(false, "skynet", "Bearer badkey") networks, username, err := logic.UserPermissions(false, "skynet", "Bearer badkey")
assert.NotNil(t, err) assert.NotNil(t, err)
t.Log(err) t.Log(err)
t.Log(networks, username) t.Log(networks, username)

View File

@@ -14,13 +14,13 @@ import (
) )
func networkUsersHandlers(r *mux.Router) { func networkUsersHandlers(r *mux.Router) {
r.HandleFunc("/api/networkusers", securityCheck(true, http.HandlerFunc(getAllNetworkUsers))).Methods("GET") r.HandleFunc("/api/networkusers", logic.SecurityCheck(true, http.HandlerFunc(getAllNetworkUsers))).Methods("GET")
r.HandleFunc("/api/networkusers/{network}", securityCheck(true, http.HandlerFunc(getNetworkUsers))).Methods("GET") r.HandleFunc("/api/networkusers/{network}", logic.SecurityCheck(true, http.HandlerFunc(getNetworkUsers))).Methods("GET")
r.HandleFunc("/api/networkusers/{network}/{networkuser}", securityCheck(true, http.HandlerFunc(getNetworkUser))).Methods("GET") r.HandleFunc("/api/networkusers/{network}/{networkuser}", logic.SecurityCheck(true, http.HandlerFunc(getNetworkUser))).Methods("GET")
r.HandleFunc("/api/networkusers/{network}", securityCheck(true, http.HandlerFunc(createNetworkUser))).Methods("POST") r.HandleFunc("/api/networkusers/{network}", logic.SecurityCheck(true, http.HandlerFunc(createNetworkUser))).Methods("POST")
r.HandleFunc("/api/networkusers/{network}", securityCheck(true, http.HandlerFunc(updateNetworkUser))).Methods("PUT") r.HandleFunc("/api/networkusers/{network}", logic.SecurityCheck(true, http.HandlerFunc(updateNetworkUser))).Methods("PUT")
r.HandleFunc("/api/networkusers/data/{networkuser}/me", netUserSecurityCheck(false, false, http.HandlerFunc(getNetworkUserData))).Methods("GET") r.HandleFunc("/api/networkusers/data/{networkuser}/me", logic.NetUserSecurityCheck(false, false, http.HandlerFunc(getNetworkUserData))).Methods("GET")
r.HandleFunc("/api/networkusers/{network}/{networkuser}", securityCheck(true, http.HandlerFunc(deleteNetworkUser))).Methods("DELETE") r.HandleFunc("/api/networkusers/{network}/{networkuser}", logic.SecurityCheck(true, http.HandlerFunc(deleteNetworkUser))).Methods("DELETE")
} }
// == RETURN TYPES == // == RETURN TYPES ==
@@ -52,18 +52,18 @@ func getNetworkUserData(w http.ResponseWriter, r *http.Request) {
networks, err := logic.GetNetworks() networks, err := logic.GetNetworks()
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
if networkUserName == "" { if networkUserName == "" {
returnErrorResponse(w, r, formatError(errors.New("netuserToGet"), "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("netuserToGet"), "badrequest"))
return return
} }
u, err := logic.GetUser(networkUserName) u, err := logic.GetUser(networkUserName)
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(errors.New("could not find user"), "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("could not find user"), "badrequest"))
return return
} }
@@ -151,7 +151,7 @@ func getAllNetworkUsers(w http.ResponseWriter, r *http.Request) {
networks, err := logic.GetNetworks() networks, err := logic.GetNetworks()
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -160,7 +160,7 @@ func getAllNetworkUsers(w http.ResponseWriter, r *http.Request) {
for i := range networks { for i := range networks {
netusers, err := pro.GetNetworkUsers(networks[i].NetID) netusers, err := pro.GetNetworkUsers(networks[i].NetID)
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
for _, v := range netusers { for _, v := range netusers {
@@ -181,13 +181,13 @@ func getNetworkUsers(w http.ResponseWriter, r *http.Request) {
_, err := logic.GetNetwork(netname) _, err := logic.GetNetwork(netname)
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
netusers, err := pro.GetNetworkUsers(netname) netusers, err := pro.GetNetworkUsers(netname)
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -203,19 +203,19 @@ func getNetworkUser(w http.ResponseWriter, r *http.Request) {
_, err := logic.GetNetwork(netname) _, err := logic.GetNetwork(netname)
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
netuserToGet := params["networkuser"] netuserToGet := params["networkuser"]
if netuserToGet == "" { if netuserToGet == "" {
returnErrorResponse(w, r, formatError(errors.New("netuserToGet"), "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("netuserToGet"), "badrequest"))
return return
} }
netuser, err := pro.GetNetworkUser(netname, promodels.NetworkUserID(netuserToGet)) netuser, err := pro.GetNetworkUser(netname, promodels.NetworkUserID(netuserToGet))
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -230,7 +230,7 @@ func createNetworkUser(w http.ResponseWriter, r *http.Request) {
network, err := logic.GetNetwork(netname) network, err := logic.GetNetwork(netname)
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
var networkuser promodels.NetworkUser var networkuser promodels.NetworkUser
@@ -238,13 +238,13 @@ func createNetworkUser(w http.ResponseWriter, r *http.Request) {
// we decode our body request params // we decode our body request params
err = json.NewDecoder(r.Body).Decode(&networkuser) err = json.NewDecoder(r.Body).Decode(&networkuser)
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
err = pro.CreateNetworkUser(&network, &networkuser) err = pro.CreateNetworkUser(&network, &networkuser)
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
@@ -260,7 +260,7 @@ func updateNetworkUser(w http.ResponseWriter, r *http.Request) {
network, err := logic.GetNetwork(netname) network, err := logic.GetNetwork(netname)
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
var networkuser promodels.NetworkUser var networkuser promodels.NetworkUser
@@ -268,38 +268,38 @@ func updateNetworkUser(w http.ResponseWriter, r *http.Request) {
// we decode our body request params // we decode our body request params
err = json.NewDecoder(r.Body).Decode(&networkuser) err = json.NewDecoder(r.Body).Decode(&networkuser)
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
if networkuser.ID == "" || !pro.DoesNetworkUserExist(netname, networkuser.ID) { if networkuser.ID == "" || !pro.DoesNetworkUserExist(netname, networkuser.ID) {
returnErrorResponse(w, r, formatError(errors.New("invalid user "+string(networkuser.ID)), "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid user "+string(networkuser.ID)), "badrequest"))
return return
} }
if networkuser.AccessLevel < pro.NET_ADMIN || networkuser.AccessLevel > pro.NO_ACCESS { if networkuser.AccessLevel < pro.NET_ADMIN || networkuser.AccessLevel > pro.NO_ACCESS {
returnErrorResponse(w, r, formatError(errors.New("invalid user access level provided"), "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid user access level provided"), "badrequest"))
return return
} }
if networkuser.ClientLimit < 0 || networkuser.NodeLimit < 0 { if networkuser.ClientLimit < 0 || networkuser.NodeLimit < 0 {
returnErrorResponse(w, r, formatError(errors.New("negative user limit provided"), "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("negative user limit provided"), "badrequest"))
return return
} }
u, err := logic.GetUser(string(networkuser.ID)) u, err := logic.GetUser(string(networkuser.ID))
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(errors.New("invalid user "+string(networkuser.ID)), "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid user "+string(networkuser.ID)), "badrequest"))
return return
} }
if !pro.IsUserAllowed(&network, u.UserName, u.Groups) { if !pro.IsUserAllowed(&network, u.UserName, u.Groups) {
returnErrorResponse(w, r, formatError(errors.New("user must be in allowed groups or users"), "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("user must be in allowed groups or users"), "badrequest"))
return return
} }
if networkuser.AccessLevel == pro.NET_ADMIN { if networkuser.AccessLevel == pro.NET_ADMIN {
currentUser, err := logic.GetUser(string(networkuser.ID)) currentUser, err := logic.GetUser(string(networkuser.ID))
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(errors.New("user model not found for "+string(networkuser.ID)), "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("user model not found for "+string(networkuser.ID)), "badrequest"))
return return
} }
@@ -316,7 +316,7 @@ func updateNetworkUser(w http.ResponseWriter, r *http.Request) {
UserName: currentUser.UserName, UserName: currentUser.UserName,
}, },
); err != nil { ); err != nil {
returnErrorResponse(w, r, formatError(errors.New("user model failed net admin update "+string(networkuser.ID)+" (are they an admin?"), "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("user model failed net admin update "+string(networkuser.ID)+" (are they an admin?"), "badrequest"))
return return
} }
} }
@@ -324,7 +324,7 @@ func updateNetworkUser(w http.ResponseWriter, r *http.Request) {
err = pro.UpdateNetworkUser(netname, &networkuser) err = pro.UpdateNetworkUser(netname, &networkuser)
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
@@ -340,18 +340,18 @@ func deleteNetworkUser(w http.ResponseWriter, r *http.Request) {
_, err := logic.GetNetwork(netname) _, err := logic.GetNetwork(netname)
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
netuserToDelete := params["networkuser"] netuserToDelete := params["networkuser"]
if netuserToDelete == "" { if netuserToDelete == "" {
returnErrorResponse(w, r, formatError(errors.New("no group name provided"), "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("no group name provided"), "badrequest"))
return return
} }
if err := pro.DeleteNetworkUser(netname, netuserToDelete); err != nil { if err := pro.DeleteNetworkUser(netname, netuserToDelete); err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }

View File

@@ -8,7 +8,6 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/functions"
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/pro" "github.com/gravitl/netmaker/logic/pro"
@@ -30,8 +29,8 @@ func nodeHandlers(r *mux.Router) {
r.HandleFunc("/api/nodes/{network}/{nodeid}/deleterelay", authorize(false, true, "user", http.HandlerFunc(deleteRelay))).Methods("DELETE") r.HandleFunc("/api/nodes/{network}/{nodeid}/deleterelay", authorize(false, true, "user", http.HandlerFunc(deleteRelay))).Methods("DELETE")
r.HandleFunc("/api/nodes/{network}/{nodeid}/creategateway", authorize(false, true, "user", http.HandlerFunc(createEgressGateway))).Methods("POST") r.HandleFunc("/api/nodes/{network}/{nodeid}/creategateway", authorize(false, true, "user", http.HandlerFunc(createEgressGateway))).Methods("POST")
r.HandleFunc("/api/nodes/{network}/{nodeid}/deletegateway", authorize(false, true, "user", http.HandlerFunc(deleteEgressGateway))).Methods("DELETE") r.HandleFunc("/api/nodes/{network}/{nodeid}/deletegateway", authorize(false, true, "user", http.HandlerFunc(deleteEgressGateway))).Methods("DELETE")
r.HandleFunc("/api/nodes/{network}/{nodeid}/createingress", securityCheck(false, http.HandlerFunc(createIngressGateway))).Methods("POST") r.HandleFunc("/api/nodes/{network}/{nodeid}/createingress", logic.SecurityCheck(false, http.HandlerFunc(createIngressGateway))).Methods("POST")
r.HandleFunc("/api/nodes/{network}/{nodeid}/deleteingress", securityCheck(false, http.HandlerFunc(deleteIngressGateway))).Methods("DELETE") r.HandleFunc("/api/nodes/{network}/{nodeid}/deleteingress", logic.SecurityCheck(false, http.HandlerFunc(deleteIngressGateway))).Methods("DELETE")
r.HandleFunc("/api/nodes/{network}/{nodeid}/approve", authorize(false, true, "user", http.HandlerFunc(uncordonNode))).Methods("POST") r.HandleFunc("/api/nodes/{network}/{nodeid}/approve", authorize(false, true, "user", http.HandlerFunc(uncordonNode))).Methods("POST")
r.HandleFunc("/api/nodes/{network}", nodeauth(checkFreeTierLimits(node_l, http.HandlerFunc(createNode)))).Methods("POST") r.HandleFunc("/api/nodes/{network}", nodeauth(checkFreeTierLimits(node_l, http.HandlerFunc(createNode)))).Methods("POST")
r.HandleFunc("/api/nodes/adm/{network}/lastmodified", authorize(false, true, "network", http.HandlerFunc(getLastModified))).Methods("GET") r.HandleFunc("/api/nodes/adm/{network}/lastmodified", authorize(false, true, "network", http.HandlerFunc(getLastModified))).Methods("GET")
@@ -66,19 +65,19 @@ func authenticate(response http.ResponseWriter, request *http.Request) {
errorResponse.Message = decoderErr.Error() errorResponse.Message = decoderErr.Error()
logger.Log(0, request.Header.Get("user"), "error decoding request body: ", logger.Log(0, request.Header.Get("user"), "error decoding request body: ",
decoderErr.Error()) decoderErr.Error())
returnErrorResponse(response, request, errorResponse) logic.ReturnErrorResponse(response, request, errorResponse)
return return
} else { } else {
errorResponse.Code = http.StatusBadRequest errorResponse.Code = http.StatusBadRequest
if authRequest.ID == "" { if authRequest.ID == "" {
errorResponse.Message = "W1R3: ID can't be empty" errorResponse.Message = "W1R3: ID can't be empty"
logger.Log(0, request.Header.Get("user"), errorResponse.Message) logger.Log(0, request.Header.Get("user"), errorResponse.Message)
returnErrorResponse(response, request, errorResponse) logic.ReturnErrorResponse(response, request, errorResponse)
return return
} else if authRequest.Password == "" { } else if authRequest.Password == "" {
errorResponse.Message = "W1R3: Password can't be empty" errorResponse.Message = "W1R3: Password can't be empty"
logger.Log(0, request.Header.Get("user"), errorResponse.Message) logger.Log(0, request.Header.Get("user"), errorResponse.Message)
returnErrorResponse(response, request, errorResponse) logic.ReturnErrorResponse(response, request, errorResponse)
return return
} else { } else {
var err error var err error
@@ -89,7 +88,7 @@ func authenticate(response http.ResponseWriter, request *http.Request) {
errorResponse.Message = err.Error() errorResponse.Message = err.Error()
logger.Log(0, request.Header.Get("user"), logger.Log(0, request.Header.Get("user"),
fmt.Sprintf("failed to get node info [%s]: %v", authRequest.ID, err)) fmt.Sprintf("failed to get node info [%s]: %v", authRequest.ID, err))
returnErrorResponse(response, request, errorResponse) logic.ReturnErrorResponse(response, request, errorResponse)
return return
} }
@@ -99,7 +98,7 @@ func authenticate(response http.ResponseWriter, request *http.Request) {
errorResponse.Message = err.Error() errorResponse.Message = err.Error()
logger.Log(0, request.Header.Get("user"), logger.Log(0, request.Header.Get("user"),
"error validating user password: ", err.Error()) "error validating user password: ", err.Error())
returnErrorResponse(response, request, errorResponse) logic.ReturnErrorResponse(response, request, errorResponse)
return return
} else { } else {
tokenString, err := logic.CreateJWT(authRequest.ID, authRequest.MacAddress, result.Network) tokenString, err := logic.CreateJWT(authRequest.ID, authRequest.MacAddress, result.Network)
@@ -109,7 +108,7 @@ func authenticate(response http.ResponseWriter, request *http.Request) {
errorResponse.Message = "Could not create Token" errorResponse.Message = "Could not create Token"
logger.Log(0, request.Header.Get("user"), logger.Log(0, request.Header.Get("user"),
fmt.Sprintf("%s: %v", errorResponse.Message, err)) fmt.Sprintf("%s: %v", errorResponse.Message, err))
returnErrorResponse(response, request, errorResponse) logic.ReturnErrorResponse(response, request, errorResponse)
return return
} }
@@ -128,7 +127,7 @@ func authenticate(response http.ResponseWriter, request *http.Request) {
errorResponse.Message = err.Error() errorResponse.Message = err.Error()
logger.Log(0, request.Header.Get("user"), logger.Log(0, request.Header.Get("user"),
"error marshalling resp: ", err.Error()) "error marshalling resp: ", err.Error())
returnErrorResponse(response, request, errorResponse) logic.ReturnErrorResponse(response, request, errorResponse)
return return
} }
response.WriteHeader(http.StatusOK) response.WriteHeader(http.StatusOK)
@@ -149,7 +148,7 @@ func nodeauth(next http.Handler) http.HandlerFunc {
errorResponse := models.ErrorResponse{ errorResponse := models.ErrorResponse{
Code: http.StatusUnauthorized, Message: "W1R3: You are unauthorized to access this endpoint.", Code: http.StatusUnauthorized, Message: "W1R3: You are unauthorized to access this endpoint.",
} }
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} else { } else {
token = tokenSplit[1] token = tokenSplit[1]
@@ -161,7 +160,7 @@ func nodeauth(next http.Handler) http.HandlerFunc {
errorResponse := models.ErrorResponse{ errorResponse := models.ErrorResponse{
Code: http.StatusNotFound, Message: "no networks", Code: http.StatusNotFound, Message: "no networks",
} }
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} }
for _, network := range networks { for _, network := range networks {
@@ -177,7 +176,7 @@ func nodeauth(next http.Handler) http.HandlerFunc {
errorResponse := models.ErrorResponse{ errorResponse := models.ErrorResponse{
Code: http.StatusUnauthorized, Message: "You are unauthorized to access this endpoint.", Code: http.StatusUnauthorized, Message: "You are unauthorized to access this endpoint.",
} }
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} }
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
@@ -194,16 +193,16 @@ func nodeauth(next http.Handler) http.HandlerFunc {
func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Handler) http.HandlerFunc { func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var errorResponse = models.ErrorResponse{ var errorResponse = models.ErrorResponse{
Code: http.StatusUnauthorized, Message: unauthorized_msg, Code: http.StatusUnauthorized, Message: logic.Unauthorized_Msg,
} }
var params = mux.Vars(r) var params = mux.Vars(r)
networkexists, _ := functions.NetworkExists(params["network"]) networkexists, _ := logic.NetworkExists(params["network"])
//check that the request is for a valid network //check that the request is for a valid network
//if (networkCheck && !networkexists) || err != nil { //if (networkCheck && !networkexists) || err != nil {
if networkCheck && !networkexists { if networkCheck && !networkexists {
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} else { } else {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
@@ -220,7 +219,7 @@ func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Ha
if len(tokenSplit) > 1 { if len(tokenSplit) > 1 {
authToken = tokenSplit[1] authToken = tokenSplit[1]
} else { } else {
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} }
//check if node instead of user //check if node instead of user
@@ -236,7 +235,7 @@ func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Ha
var nodeID = "" var nodeID = ""
username, networks, isadmin, errN := logic.VerifyUserToken(authToken) username, networks, isadmin, errN := logic.VerifyUserToken(authToken)
if errN != nil { if errN != nil {
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} }
@@ -269,7 +268,7 @@ func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Ha
} else { } else {
node, err := logic.GetNodeByID(nodeID) node, err := logic.GetNodeByID(nodeID)
if err != nil { if err != nil {
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} }
isAuthorized = (node.Network == params["network"]) isAuthorized = (node.Network == params["network"])
@@ -287,7 +286,7 @@ func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Ha
} }
} }
if !isAuthorized { if !isAuthorized {
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} else { } else {
//If authorized, this function passes along it's request and output to the appropriate route function. //If authorized, this function passes along it's request and output to the appropriate route function.
@@ -324,7 +323,7 @@ func getNetworkNodes(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("error fetching nodes on network %s: %v", networkName, err)) fmt.Sprintf("error fetching nodes on network %s: %v", networkName, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -358,7 +357,7 @@ func getAllNodes(w http.ResponseWriter, r *http.Request) {
if err != nil && r.Header.Get("ismasterkey") != "yes" { if err != nil && r.Header.Get("ismasterkey") != "yes" {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
"error fetching user info: ", err.Error()) "error fetching user info: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
var nodes []models.Node var nodes []models.Node
@@ -366,7 +365,7 @@ func getAllNodes(w http.ResponseWriter, r *http.Request) {
nodes, err = logic.GetAllNodes() nodes, err = logic.GetAllNodes()
if err != nil { if err != nil {
logger.Log(0, "error fetching all nodes info: ", err.Error()) logger.Log(0, "error fetching all nodes info: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
} else { } else {
@@ -374,7 +373,7 @@ func getAllNodes(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
"error fetching nodes: ", err.Error()) "error fetching nodes: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
} }
@@ -418,7 +417,7 @@ func getNode(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("error fetching node [ %s ] info: %v", nodeid, err)) fmt.Sprintf("error fetching node [ %s ] info: %v", nodeid, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -426,7 +425,7 @@ func getNode(w http.ResponseWriter, r *http.Request) {
if err != nil && !database.IsEmptyRecord(err) { if err != nil && !database.IsEmptyRecord(err) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("error fetching wg peers config for node [ %s ]: %v", nodeid, err)) fmt.Sprintf("error fetching wg peers config for node [ %s ]: %v", nodeid, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -470,7 +469,7 @@ func getLastModified(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("error fetching network [%s] info: %v", networkName, err)) fmt.Sprintf("error fetching network [%s] info: %v", networkName, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
logger.Log(2, r.Header.Get("user"), "called last modified") logger.Log(2, r.Header.Get("user"), "called last modified")
@@ -498,12 +497,12 @@ func createNode(w http.ResponseWriter, r *http.Request) {
Code: http.StatusInternalServerError, Message: "W1R3: It's not you it's me.", Code: http.StatusInternalServerError, Message: "W1R3: It's not you it's me.",
} }
networkName := params["network"] networkName := params["network"]
networkexists, err := functions.NetworkExists(networkName) networkexists, err := logic.NetworkExists(networkName)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to fetch network [%s] info: %v", networkName, err)) fmt.Sprintf("failed to fetch network [%s] info: %v", networkName, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} else if !networkexists { } else if !networkexists {
errorResponse = models.ErrorResponse{ errorResponse = models.ErrorResponse{
@@ -511,7 +510,7 @@ func createNode(w http.ResponseWriter, r *http.Request) {
} }
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("network [%s] does not exist", networkName)) fmt.Sprintf("network [%s] does not exist", networkName))
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} }
@@ -521,7 +520,7 @@ func createNode(w http.ResponseWriter, r *http.Request) {
err = json.NewDecoder(r.Body).Decode(&node) err = json.NewDecoder(r.Body).Decode(&node)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
@@ -531,14 +530,14 @@ func createNode(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get network [%s] info: %v", node.Network, err)) fmt.Sprintf("failed to get network [%s] info: %v", node.Network, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
node.NetworkSettings, err = logic.GetNetworkSettings(node.Network) node.NetworkSettings, err = logic.GetNetworkSettings(node.Network)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get network [%s] settings: %v", node.Network, err)) fmt.Sprintf("failed to get network [%s] settings: %v", node.Network, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
keyName, validKey := logic.IsKeyValid(networkName, node.AccessKey) keyName, validKey := logic.IsKeyValid(networkName, node.AccessKey)
@@ -554,7 +553,7 @@ func createNode(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to create node on network [%s]: %s", fmt.Sprintf("failed to create node on network [%s]: %s",
node.Network, errorResponse.Message)) node.Network, errorResponse.Message))
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} }
} }
@@ -569,17 +568,17 @@ func createNode(w http.ResponseWriter, r *http.Request) {
key, keyErr := logic.RetrievePublicTrafficKey() key, keyErr := logic.RetrievePublicTrafficKey()
if keyErr != nil { if keyErr != nil {
logger.Log(0, "error retrieving key: ", keyErr.Error()) logger.Log(0, "error retrieving key: ", keyErr.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
if key == nil { if key == nil {
logger.Log(0, "error: server traffic key is nil") logger.Log(0, "error: server traffic key is nil")
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
if node.TrafficKeys.Mine == nil { if node.TrafficKeys.Mine == nil {
logger.Log(0, "error: node traffic key is nil") logger.Log(0, "error: node traffic key is nil")
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
node.TrafficKeys = models.TrafficKeys{ node.TrafficKeys = models.TrafficKeys{
@@ -592,7 +591,7 @@ func createNode(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to create node on network [%s]: %s", fmt.Sprintf("failed to create node on network [%s]: %s",
node.Network, err)) node.Network, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -609,7 +608,7 @@ func createNode(w http.ResponseWriter, r *http.Request) {
if !updatedUserNode { // user was found but not updated, so delete node if !updatedUserNode { // user was found but not updated, so delete node
logger.Log(0, "failed to add node to user", keyName) logger.Log(0, "failed to add node to user", keyName)
logic.DeleteNodeByID(&node, true) logic.DeleteNodeByID(&node, true)
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
} }
@@ -618,7 +617,7 @@ func createNode(w http.ResponseWriter, r *http.Request) {
if err != nil && !database.IsEmptyRecord(err) { if err != nil && !database.IsEmptyRecord(err) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("error fetching wg peers config for node [ %s ]: %v", node.ID, err)) fmt.Sprintf("error fetching wg peers config for node [ %s ]: %v", node.ID, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -656,7 +655,7 @@ func uncordonNode(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to uncordon node [%s]: %v", node.Name, err)) fmt.Sprintf("failed to uncordon node [%s]: %v", node.Name, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
logger.Log(1, r.Header.Get("user"), "uncordoned node", node.Name) logger.Log(1, r.Header.Get("user"), "uncordoned node", node.Name)
@@ -686,7 +685,7 @@ func createEgressGateway(w http.ResponseWriter, r *http.Request) {
err := json.NewDecoder(r.Body).Decode(&gateway) err := json.NewDecoder(r.Body).Decode(&gateway)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
gateway.NetID = params["network"] gateway.NetID = params["network"]
@@ -696,7 +695,7 @@ func createEgressGateway(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to create egress gateway on node [%s] on network [%s]: %v", fmt.Sprintf("failed to create egress gateway on node [%s] on network [%s]: %v",
gateway.NodeID, gateway.NetID, err)) gateway.NodeID, gateway.NetID, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -728,7 +727,7 @@ func deleteEgressGateway(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to delete egress gateway on node [%s] on network [%s]: %v", fmt.Sprintf("failed to delete egress gateway on node [%s] on network [%s]: %v",
nodeid, netid, err)) nodeid, netid, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -762,7 +761,7 @@ func createIngressGateway(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to create ingress gateway on node [%s] on network [%s]: %v", fmt.Sprintf("failed to create ingress gateway on node [%s] on network [%s]: %v",
nodeid, netid, err)) nodeid, netid, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -794,7 +793,7 @@ func deleteIngressGateway(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to delete ingress gateway on node [%s] on network [%s]: %v", fmt.Sprintf("failed to delete ingress gateway on node [%s] on network [%s]: %v",
nodeid, netid, err)) nodeid, netid, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -828,7 +827,7 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("error fetching node [ %s ] info: %v", nodeid, err)) fmt.Sprintf("error fetching node [ %s ] info: %v", nodeid, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -837,7 +836,7 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
err = json.NewDecoder(r.Body).Decode(&newNode) err = json.NewDecoder(r.Body).Decode(&newNode)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
relayupdate := false relayupdate := false
@@ -885,7 +884,7 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to update node info [ %s ] info: %v", nodeid, err)) fmt.Sprintf("failed to update node info [ %s ] info: %v", nodeid, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
if relayupdate { if relayupdate {
@@ -932,20 +931,20 @@ func deleteNode(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("error fetching node [ %s ] info: %v", nodeid, err)) fmt.Sprintf("error fetching node [ %s ] info: %v", nodeid, err))
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
if isServer(&node) { if isServer(&node) {
err := fmt.Errorf("cannot delete server node") err := fmt.Errorf("cannot delete server node")
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to delete node [ %s ]: %v", nodeid, err)) fmt.Sprintf("failed to delete node [ %s ]: %v", nodeid, err))
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
if r.Header.Get("ismaster") != "yes" { if r.Header.Get("ismaster") != "yes" {
username := r.Header.Get("user") username := r.Header.Get("user")
if username != "" && !doesUserOwnNode(username, params["network"], nodeid) { if username != "" && !doesUserOwnNode(username, params["network"], nodeid) {
returnErrorResponse(w, r, formatError(fmt.Errorf("user not permitted"), "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("user not permitted"), "badrequest"))
return return
} }
} }
@@ -954,11 +953,11 @@ func deleteNode(w http.ResponseWriter, r *http.Request) {
err = logic.DeleteNodeByID(&node, false) err = logic.DeleteNodeByID(&node, false)
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
returnSuccessResponse(w, r, nodeid+" deleted.") logic.ReturnSuccessResponse(w, r, nodeid+" deleted.")
logger.Log(1, r.Header.Get("user"), "Deleted node", nodeid, "from network", params["network"]) logger.Log(1, r.Header.Get("user"), "Deleted node", nodeid, "from network", params["network"])
runUpdates(&node, false) runUpdates(&node, false)

View File

@@ -30,7 +30,7 @@ func createRelay(w http.ResponseWriter, r *http.Request) {
err := json.NewDecoder(r.Body).Decode(&relay) err := json.NewDecoder(r.Body).Decode(&relay)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
relay.NetID = params["network"] relay.NetID = params["network"]
@@ -39,7 +39,7 @@ func createRelay(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to create relay on node [%s] on network [%s]: %v", relay.NodeID, relay.NetID, err)) fmt.Sprintf("failed to create relay on node [%s] on network [%s]: %v", relay.NodeID, relay.NetID, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
logger.Log(1, r.Header.Get("user"), "created relay on node", relay.NodeID, "on network", relay.NetID) logger.Log(1, r.Header.Get("user"), "created relay on node", relay.NodeID, "on network", relay.NetID)
@@ -73,7 +73,7 @@ func deleteRelay(w http.ResponseWriter, r *http.Request) {
updatenodes, node, err := logic.DeleteRelay(netid, nodeid) updatenodes, node, err := logic.DeleteRelay(netid, nodeid)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
logger.Log(1, r.Header.Get("user"), "deleted relay server", nodeid, "on network", netid) logger.Log(1, r.Header.Get("user"), "deleted relay server", nodeid, "on network", netid)

View File

@@ -7,12 +7,13 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestFormatError(t *testing.T) { func TestFormatError(t *testing.T) {
response := formatError(errors.New("this is a sample error"), "badrequest") response := logic.FormatError(errors.New("this is a sample error"), "badrequest")
assert.Equal(t, http.StatusBadRequest, response.Code) assert.Equal(t, http.StatusBadRequest, response.Code)
assert.Equal(t, "this is a sample error", response.Message) assert.Equal(t, "this is a sample error", response.Message)
} }
@@ -20,7 +21,7 @@ func TestFormatError(t *testing.T) {
func TestReturnSuccessResponse(t *testing.T) { func TestReturnSuccessResponse(t *testing.T) {
var response models.SuccessResponse var response models.SuccessResponse
handler := func(rw http.ResponseWriter, r *http.Request) { handler := func(rw http.ResponseWriter, r *http.Request) {
returnSuccessResponse(rw, r, "This is a test message") logic.ReturnSuccessResponse(rw, r, "This is a test message")
} }
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -42,7 +43,7 @@ func TestReturnErrorResponse(t *testing.T) {
errMessage.Code = http.StatusUnauthorized errMessage.Code = http.StatusUnauthorized
errMessage.Message = "You are not authorized to access this endpoint" errMessage.Message = "You are not authorized to access this endpoint"
handler := func(rw http.ResponseWriter, r *http.Request) { handler := func(rw http.ResponseWriter, r *http.Request) {
returnErrorResponse(rw, r, errMessage) logic.ReturnErrorResponse(rw, r, errMessage)
} }
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()

View File

@@ -10,7 +10,6 @@ import (
"strings" "strings"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/gravitl/netmaker/ee"
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
@@ -22,82 +21,35 @@ import (
func serverHandlers(r *mux.Router) { func serverHandlers(r *mux.Router) {
// r.HandleFunc("/api/server/addnetwork/{network}", securityCheckServer(true, http.HandlerFunc(addNetwork))).Methods("POST") // r.HandleFunc("/api/server/addnetwork/{network}", securityCheckServer(true, http.HandlerFunc(addNetwork))).Methods("POST")
r.HandleFunc("/api/server/getconfig", securityCheckServer(false, http.HandlerFunc(getConfig))).Methods("GET") r.HandleFunc("/api/server/getconfig", allowUsers(http.HandlerFunc(getConfig))).Methods("GET")
r.HandleFunc("/api/server/removenetwork/{network}", securityCheckServer(true, http.HandlerFunc(removeNetwork))).Methods("DELETE")
r.HandleFunc("/api/server/register", authorize(true, false, "node", http.HandlerFunc(register))).Methods("POST") r.HandleFunc("/api/server/register", authorize(true, false, "node", http.HandlerFunc(register))).Methods("POST")
r.HandleFunc("/api/server/getserverinfo", authorize(true, false, "node", http.HandlerFunc(getServerInfo))).Methods("GET") r.HandleFunc("/api/server/getserverinfo", authorize(true, false, "node", http.HandlerFunc(getServerInfo))).Methods("GET")
} }
//Security check is middleware for every function and just checks to make sure that its the master calling // allowUsers - allow all authenticated (valid) users - only used by getConfig, may be able to remove during refactor
//Only admin should have access to all these network-level actions func allowUsers(next http.Handler) http.HandlerFunc {
//or maybe some Users once implemented
func securityCheckServer(adminonly bool, next http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var errorResponse = models.ErrorResponse{ var errorResponse = models.ErrorResponse{
Code: http.StatusInternalServerError, Message: "W1R3: It's not you it's me.", Code: http.StatusInternalServerError, Message: logic.Unauthorized_Msg,
} }
bearerToken := r.Header.Get("Authorization") bearerToken := r.Header.Get("Authorization")
var tokenSplit = strings.Split(bearerToken, " ") var tokenSplit = strings.Split(bearerToken, " ")
var authToken = "" var authToken = ""
if len(tokenSplit) < 2 { if len(tokenSplit) < 2 {
errorResponse = models.ErrorResponse{ logic.ReturnErrorResponse(w, r, errorResponse)
Code: http.StatusUnauthorized, Message: "W1R3: You are unauthorized to access this endpoint.",
}
returnErrorResponse(w, r, errorResponse)
return return
} else { } else {
authToken = tokenSplit[1] authToken = tokenSplit[1]
} }
//all endpoints here require master so not as complicated user, _, _, err := logic.VerifyUserToken(authToken)
//still might not be a good way of doing this if err != nil || user == "" {
user, _, isadmin, err := logic.VerifyUserToken(authToken) logic.ReturnErrorResponse(w, r, errorResponse)
errorResponse = models.ErrorResponse{
Code: http.StatusUnauthorized, Message: "W1R3: You are unauthorized to access this endpoint.",
}
if !adminonly && (err != nil || user == "") {
returnErrorResponse(w, r, errorResponse)
return
}
if adminonly && !isadmin && !authenticateMaster(authToken) {
returnErrorResponse(w, r, errorResponse)
return return
} }
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
} }
} }
// swagger:route DELETE /api/server/removenetwork/{network} nodes removeNetwork
//
// Remove a network from the server.
//
// Schemes: https
//
// Security:
// oauth
//
// Responses:
// 200: stringJSONResponse
func removeNetwork(w http.ResponseWriter, r *http.Request) {
// Set header
w.Header().Set("Content-Type", "application/json")
// get params
var params = mux.Vars(r)
network := params["network"]
err := logic.DeleteNetwork(network)
if err != nil {
logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to delete network [%s]: %v", network, err))
json.NewEncoder(w).Encode(fmt.Sprintf("could not remove network %s from server", network))
return
}
logger.Log(1, r.Header.Get("user"),
fmt.Sprintf("deleted network [%s]: %v", network, err))
json.NewEncoder(w).Encode(fmt.Sprintf("network %s removed from server", network))
}
// swagger:route GET /api/server/getserverinfo nodes getServerInfo // swagger:route GET /api/server/getserverinfo nodes getServerInfo
// //
// Get the server configuration. // Get the server configuration.
@@ -138,7 +90,7 @@ func getConfig(w http.ResponseWriter, r *http.Request) {
scfg := servercfg.GetServerConfig() scfg := servercfg.GetServerConfig()
scfg.IsEE = "no" scfg.IsEE = "no"
if ee.IsEnterprise() { if logic.Is_EE {
scfg.IsEE = "yes" scfg.IsEE = "yes"
} }
json.NewEncoder(w).Encode(scfg) json.NewEncoder(w).Encode(scfg)
@@ -166,7 +118,7 @@ func register(w http.ResponseWriter, r *http.Request) {
errorResponse := models.ErrorResponse{ errorResponse := models.ErrorResponse{
Code: http.StatusBadRequest, Message: err.Error(), Code: http.StatusBadRequest, Message: err.Error(),
} }
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} }
cert, ca, err := genCerts(&request.Key, &request.CommonName) cert, ca, err := genCerts(&request.Key, &request.CommonName)
@@ -175,7 +127,7 @@ func register(w http.ResponseWriter, r *http.Request) {
errorResponse := models.ErrorResponse{ errorResponse := models.ErrorResponse{
Code: http.StatusNotFound, Message: err.Error(), Code: http.StatusNotFound, Message: err.Error(),
} }
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} }
//x509.Certificate.PublicKey is an interface therefore json encoding/decoding result in a string value rather than a []byte //x509.Certificate.PublicKey is an interface therefore json encoding/decoding result in a string value rather than a []byte

View File

@@ -25,13 +25,13 @@ func userHandlers(r *mux.Router) {
r.HandleFunc("/api/users/adm/hasadmin", hasAdmin).Methods("GET") r.HandleFunc("/api/users/adm/hasadmin", hasAdmin).Methods("GET")
r.HandleFunc("/api/users/adm/createadmin", createAdmin).Methods("POST") r.HandleFunc("/api/users/adm/createadmin", createAdmin).Methods("POST")
r.HandleFunc("/api/users/adm/authenticate", authenticateUser).Methods("POST") r.HandleFunc("/api/users/adm/authenticate", authenticateUser).Methods("POST")
r.HandleFunc("/api/users/{username}", securityCheck(false, continueIfUserMatch(http.HandlerFunc(updateUser)))).Methods("PUT") r.HandleFunc("/api/users/{username}", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(updateUser)))).Methods("PUT")
r.HandleFunc("/api/users/networks/{username}", securityCheck(true, http.HandlerFunc(updateUserNetworks))).Methods("PUT") r.HandleFunc("/api/users/networks/{username}", logic.SecurityCheck(true, http.HandlerFunc(updateUserNetworks))).Methods("PUT")
r.HandleFunc("/api/users/{username}/adm", securityCheck(true, http.HandlerFunc(updateUserAdm))).Methods("PUT") r.HandleFunc("/api/users/{username}/adm", logic.SecurityCheck(true, http.HandlerFunc(updateUserAdm))).Methods("PUT")
r.HandleFunc("/api/users/{username}", securityCheck(true, checkFreeTierLimits(users_l, http.HandlerFunc(createUser)))).Methods("POST") r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, checkFreeTierLimits(users_l, http.HandlerFunc(createUser)))).Methods("POST")
r.HandleFunc("/api/users/{username}", securityCheck(true, http.HandlerFunc(deleteUser))).Methods("DELETE") r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, http.HandlerFunc(deleteUser))).Methods("DELETE")
r.HandleFunc("/api/users/{username}", securityCheck(false, continueIfUserMatch(http.HandlerFunc(getUser)))).Methods("GET") r.HandleFunc("/api/users/{username}", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUser)))).Methods("GET")
r.HandleFunc("/api/users", securityCheck(true, http.HandlerFunc(getUsers))).Methods("GET") r.HandleFunc("/api/users", logic.SecurityCheck(true, http.HandlerFunc(getUsers))).Methods("GET")
r.HandleFunc("/api/oauth/login", auth.HandleAuthLogin).Methods("GET") r.HandleFunc("/api/oauth/login", auth.HandleAuthLogin).Methods("GET")
r.HandleFunc("/api/oauth/callback", auth.HandleAuthCallback).Methods("GET") r.HandleFunc("/api/oauth/callback", auth.HandleAuthCallback).Methods("GET")
r.HandleFunc("/api/oauth/node-handler", socketHandler) r.HandleFunc("/api/oauth/node-handler", socketHandler)
@@ -59,7 +59,7 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
} }
if !servercfg.IsBasicAuthEnabled() { if !servercfg.IsBasicAuthEnabled() {
returnErrorResponse(response, request, formatError(fmt.Errorf("basic auth is disabled"), "badrequest")) logic.ReturnErrorResponse(response, request, logic.FormatError(fmt.Errorf("basic auth is disabled"), "badrequest"))
return return
} }
@@ -69,7 +69,7 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
if decoderErr != nil { if decoderErr != nil {
logger.Log(0, "error decoding request body: ", logger.Log(0, "error decoding request body: ",
decoderErr.Error()) decoderErr.Error())
returnErrorResponse(response, request, errorResponse) logic.ReturnErrorResponse(response, request, errorResponse)
return return
} }
username := authRequest.UserName username := authRequest.UserName
@@ -77,14 +77,14 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
if err != nil { if err != nil {
logger.Log(0, username, "user validation failed: ", logger.Log(0, username, "user validation failed: ",
err.Error()) err.Error())
returnErrorResponse(response, request, formatError(err, "badrequest")) logic.ReturnErrorResponse(response, request, logic.FormatError(err, "badrequest"))
return return
} }
if jwt == "" { if jwt == "" {
// very unlikely that err is !nil and no jwt returned, but handle it anyways. // very unlikely that err is !nil and no jwt returned, but handle it anyways.
logger.Log(0, username, "jwt token is empty") logger.Log(0, username, "jwt token is empty")
returnErrorResponse(response, request, formatError(errors.New("no token returned"), "internal")) logic.ReturnErrorResponse(response, request, logic.FormatError(errors.New("no token returned"), "internal"))
return return
} }
@@ -102,7 +102,7 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
if jsonError != nil { if jsonError != nil {
logger.Log(0, username, logger.Log(0, username,
"error marshalling resp: ", err.Error()) "error marshalling resp: ", err.Error())
returnErrorResponse(response, request, errorResponse) logic.ReturnErrorResponse(response, request, errorResponse)
return return
} }
logger.Log(2, username, "was authenticated") logger.Log(2, username, "was authenticated")
@@ -128,7 +128,7 @@ func hasAdmin(w http.ResponseWriter, r *http.Request) {
hasadmin, err := logic.HasAdmin() hasadmin, err := logic.HasAdmin()
if err != nil { if err != nil {
logger.Log(0, "failed to check for admin: ", err.Error()) logger.Log(0, "failed to check for admin: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -171,7 +171,7 @@ func getUser(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, usernameFetched, "failed to fetch user: ", err.Error()) logger.Log(0, usernameFetched, "failed to fetch user: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
logger.Log(2, r.Header.Get("user"), "fetched user", usernameFetched) logger.Log(2, r.Header.Get("user"), "fetched user", usernameFetched)
@@ -197,7 +197,7 @@ func getUsers(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, "failed to fetch users: ", err.Error()) logger.Log(0, "failed to fetch users: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -226,12 +226,12 @@ func createAdmin(w http.ResponseWriter, r *http.Request) {
logger.Log(0, admin.UserName, "error decoding request body: ", logger.Log(0, admin.UserName, "error decoding request body: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
if !servercfg.IsBasicAuthEnabled() { if !servercfg.IsBasicAuthEnabled() {
returnErrorResponse(w, r, formatError(fmt.Errorf("basic auth is disabled"), "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("basic auth is disabled"), "badrequest"))
return return
} }
@@ -239,7 +239,7 @@ func createAdmin(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, admin.UserName, "failed to create admin: ", logger.Log(0, admin.UserName, "failed to create admin: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
@@ -266,7 +266,7 @@ func createUser(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, user.UserName, "error decoding request body: ", logger.Log(0, user.UserName, "error decoding request body: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
@@ -274,7 +274,7 @@ func createUser(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, user.UserName, "error creating new user: ", logger.Log(0, user.UserName, "error creating new user: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
logger.Log(1, user.UserName, "was created") logger.Log(1, user.UserName, "was created")
@@ -302,7 +302,7 @@ func updateUserNetworks(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, username, logger.Log(0, username,
"failed to update user networks: ", err.Error()) "failed to update user networks: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
var userchange models.User var userchange models.User
@@ -311,7 +311,7 @@ func updateUserNetworks(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, username, "error decoding request body: ", logger.Log(0, username, "error decoding request body: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
err = logic.UpdateUserNetworks(userchange.Networks, userchange.Groups, userchange.IsAdmin, &models.ReturnUser{ err = logic.UpdateUserNetworks(userchange.Networks, userchange.Groups, userchange.IsAdmin, &models.ReturnUser{
@@ -324,7 +324,7 @@ func updateUserNetworks(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, username, logger.Log(0, username,
"failed to update user networks: ", err.Error()) "failed to update user networks: ", err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
logger.Log(1, username, "status was updated") logger.Log(1, username, "status was updated")
@@ -352,13 +352,13 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, username, logger.Log(0, username,
"failed to update user info: ", err.Error()) "failed to update user info: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
if auth.IsOauthUser(&user) == nil { if auth.IsOauthUser(&user) == nil {
err := fmt.Errorf("cannot update user info for oauth user %s", username) err := fmt.Errorf("cannot update user info for oauth user %s", username)
logger.Log(0, err.Error()) logger.Log(0, err.Error())
returnErrorResponse(w, r, formatError(err, "forbidden")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "forbidden"))
return return
} }
var userchange models.User var userchange models.User
@@ -367,7 +367,7 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, username, "error decoding request body: ", logger.Log(0, username, "error decoding request body: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
userchange.Networks = nil userchange.Networks = nil
@@ -375,7 +375,7 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, username, logger.Log(0, username,
"failed to update user info: ", err.Error()) "failed to update user info: ", err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
logger.Log(1, username, "was updated") logger.Log(1, username, "was updated")
@@ -401,13 +401,13 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) {
username := params["username"] username := params["username"]
user, err := GetUserInternal(username) user, err := GetUserInternal(username)
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
if auth.IsOauthUser(&user) != nil { if auth.IsOauthUser(&user) != nil {
err := fmt.Errorf("cannot update user info for oauth user %s", username) err := fmt.Errorf("cannot update user info for oauth user %s", username)
logger.Log(0, err.Error()) logger.Log(0, err.Error())
returnErrorResponse(w, r, formatError(err, "forbidden")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "forbidden"))
return return
} }
var userchange models.User var userchange models.User
@@ -416,18 +416,18 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, username, "error decoding request body: ", logger.Log(0, username, "error decoding request body: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
if !user.IsAdmin { if !user.IsAdmin {
logger.Log(0, username, "not an admin user") logger.Log(0, username, "not an admin user")
returnErrorResponse(w, r, formatError(errors.New("not a admin user"), "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("not a admin user"), "badrequest"))
} }
user, err = logic.UpdateUser(userchange, user) user, err = logic.UpdateUser(userchange, user)
if err != nil { if err != nil {
logger.Log(0, username, logger.Log(0, username,
"failed to update user (admin) info: ", err.Error()) "failed to update user (admin) info: ", err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
logger.Log(1, username, "was updated (admin)") logger.Log(1, username, "was updated (admin)")
@@ -458,12 +458,12 @@ func deleteUser(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, username, logger.Log(0, username,
"failed to delete user: ", err.Error()) "failed to delete user: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} else if !success { } else if !success {
err := errors.New("delete unsuccessful") err := errors.New("delete unsuccessful")
logger.Log(0, username, err.Error()) logger.Log(0, username, err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }

View File

@@ -3,18 +3,20 @@ package controller
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"github.com/gravitl/netmaker/logger"
"net/http" "net/http"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/gravitl/netmaker/logic/pro" "github.com/gravitl/netmaker/logic/pro"
"github.com/gravitl/netmaker/models/promodels" "github.com/gravitl/netmaker/models/promodels"
) )
func userGroupsHandlers(r *mux.Router) { func userGroupsHandlers(r *mux.Router) {
r.HandleFunc("/api/usergroups", securityCheck(true, http.HandlerFunc(getUserGroups))).Methods("GET") r.HandleFunc("/api/usergroups", logic.SecurityCheck(true, http.HandlerFunc(getUserGroups))).Methods("GET")
r.HandleFunc("/api/usergroups/{usergroup}", securityCheck(true, http.HandlerFunc(createUserGroup))).Methods("POST") r.HandleFunc("/api/usergroups/{usergroup}", logic.SecurityCheck(true, http.HandlerFunc(createUserGroup))).Methods("POST")
r.HandleFunc("/api/usergroups/{usergroup}", securityCheck(true, http.HandlerFunc(deleteUserGroup))).Methods("DELETE") r.HandleFunc("/api/usergroups/{usergroup}", logic.SecurityCheck(true, http.HandlerFunc(deleteUserGroup))).Methods("DELETE")
} }
func getUserGroups(w http.ResponseWriter, r *http.Request) { func getUserGroups(w http.ResponseWriter, r *http.Request) {
@@ -23,7 +25,7 @@ func getUserGroups(w http.ResponseWriter, r *http.Request) {
userGroups, err := pro.GetUserGroups() userGroups, err := pro.GetUserGroups()
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
// Returns all the groups in JSON format // Returns all the groups in JSON format
@@ -39,13 +41,13 @@ func createUserGroup(w http.ResponseWriter, r *http.Request) {
logger.Log(1, r.Header.Get("user"), "requested creating user group", newGroup) logger.Log(1, r.Header.Get("user"), "requested creating user group", newGroup)
if newGroup == "" { if newGroup == "" {
returnErrorResponse(w, r, formatError(errors.New("no group name provided"), "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("no group name provided"), "badrequest"))
return return
} }
err := pro.InsertUserGroup(promodels.UserGroupName(newGroup)) err := pro.InsertUserGroup(promodels.UserGroupName(newGroup))
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
@@ -58,12 +60,12 @@ func deleteUserGroup(w http.ResponseWriter, r *http.Request) {
logger.Log(1, r.Header.Get("user"), "requested deleting user group", groupToDelete) logger.Log(1, r.Header.Get("user"), "requested deleting user group", groupToDelete)
if groupToDelete == "" { if groupToDelete == "" {
returnErrorResponse(w, r, formatError(errors.New("no group name provided"), "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("no group name provided"), "badrequest"))
return return
} }
if err := pro.DeleteUserGroup(promodels.UserGroupName(groupToDelete)); err != nil { if err := pro.DeleteUserGroup(promodels.UserGroupName(groupToDelete)); err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }

View File

@@ -1,4 +1,4 @@
package controller package ee_controllers
import ( import (
"encoding/json" "encoding/json"
@@ -10,10 +10,11 @@ import (
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
) )
func metricHandlers(r *mux.Router) { // MetricHandlers - How we handle EE Metrics
r.HandleFunc("/api/metrics/{network}/{nodeid}", securityCheck(true, http.HandlerFunc(getNodeMetrics))).Methods("GET") func MetricHandlers(r *mux.Router) {
r.HandleFunc("/api/metrics/{network}", securityCheck(true, http.HandlerFunc(getNetworkNodesMetrics))).Methods("GET") r.HandleFunc("/api/metrics/{network}/{nodeid}", logic.SecurityCheck(true, http.HandlerFunc(getNodeMetrics))).Methods("GET")
r.HandleFunc("/api/metrics", securityCheck(true, http.HandlerFunc(getAllMetrics))).Methods("GET") r.HandleFunc("/api/metrics/{network}", logic.SecurityCheck(true, http.HandlerFunc(getNetworkNodesMetrics))).Methods("GET")
r.HandleFunc("/api/metrics", logic.SecurityCheck(true, http.HandlerFunc(getAllMetrics))).Methods("GET")
} }
// get the metrics of a given node // get the metrics of a given node
@@ -28,7 +29,7 @@ func getNodeMetrics(w http.ResponseWriter, r *http.Request) {
metrics, err := logic.GetMetrics(nodeID) metrics, err := logic.GetMetrics(nodeID)
if err != nil { if err != nil {
logger.Log(1, r.Header.Get("user"), "failed to fetch metrics of node", nodeID, err.Error()) logger.Log(1, r.Header.Get("user"), "failed to fetch metrics of node", nodeID, err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -49,7 +50,7 @@ func getNetworkNodesMetrics(w http.ResponseWriter, r *http.Request) {
networkNodes, err := logic.GetNetworkNodes(network) networkNodes, err := logic.GetNetworkNodes(network)
if err != nil { if err != nil {
logger.Log(1, r.Header.Get("user"), "failed to fetch metrics of all nodes in network", network, err.Error()) logger.Log(1, r.Header.Get("user"), "failed to fetch metrics of all nodes in network", network, err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -79,7 +80,7 @@ func getAllMetrics(w http.ResponseWriter, r *http.Request) {
allNodes, err := logic.GetAllNodes() allNodes, err := logic.GetAllNodes()
if err != nil { if err != nil {
logger.Log(1, r.Header.Get("user"), "failed to fetch metrics of all nodes on server", err.Error()) logger.Log(1, r.Header.Get("user"), "failed to fetch metrics of all nodes on server", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }

54
ee/initialize.go Normal file
View File

@@ -0,0 +1,54 @@
//go:build ee
// +build ee
package ee
import (
controller "github.com/gravitl/netmaker/controllers"
"github.com/gravitl/netmaker/ee/ee_controllers"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
)
// InitEE - Initialize EE Logic
func InitEE() {
SetIsEnterprise()
models.SetLogo(retrieveEELogo())
controller.HttpHandlers = append(controller.HttpHandlers, ee_controllers.MetricHandlers)
logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() {
// == License Handling ==
ValidateLicense()
if Limits.FreeTier {
logger.Log(0, "proceeding with Free Tier license")
} else {
logger.Log(0, "proceeding with Paid Tier license")
}
// == End License Handling ==
AddLicenseHooks()
})
}
func setControllerLimits() {
logic.Node_Limit = Limits.Nodes
logic.Users_Limit = Limits.Users
logic.Clients_Limit = Limits.Clients
logic.Free_Tier = Limits.FreeTier
logic.Is_EE = true
}
func retrieveEELogo() string {
return `
__ __ ______ ______ __ __ ______ __ __ ______ ______
/\ "-.\ \ /\ ___\ /\__ _\ /\ "-./ \ /\ __ \ /\ \/ / /\ ___\ /\ == \
\ \ \-. \ \ \ __\ \/_/\ \/ \ \ \-./\ \ \ \ __ \ \ \ _"-. \ \ __\ \ \ __<
\ \_\\"\_\ \ \_____\ \ \_\ \ \_\ \ \_\ \ \_\ \_\ \ \_\ \_\ \ \_____\ \ \_\ \_\
\/_/ \/_/ \/_____/ \/_/ \/_/ \/_/ \/_/\/_/ \/_/\/_/ \/_____/ \/_/ /_/
___ ___ ____
____ ____ ____ / _ \ / _ \ / __ \ ____ ____ ____
/___/ /___/ /___/ / ___/ / , _// /_/ / /___/ /___/ /___/
/___/ /___/ /___/ /_/ /_/|_| \____/ /___/ /___/ /___/
`
}

View File

@@ -1,7 +1,11 @@
//go:build ee
// +build ee
package ee package ee
import ( import (
"bytes" "bytes"
"crypto/rand"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@@ -11,11 +15,20 @@ import (
"github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/pro"
"github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/netclient/ncutils"
"github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/servercfg"
"golang.org/x/crypto/nacl/box"
) )
const (
db_license_key = "netmaker-id-key-pair"
)
type apiServerConf struct {
PrivateKey []byte `json:"private_key" binding:"required"`
PublicKey []byte `json:"public_key" binding:"required"`
}
// AddLicenseHooks - adds the validation and cache clear hooks // AddLicenseHooks - adds the validation and cache clear hooks
func AddLicenseHooks() { func AddLicenseHooks() {
logic.AddHook(ValidateLicense) logic.AddHook(ValidateLicense)
@@ -39,7 +52,7 @@ func ValidateLicense() error {
logger.FatalLog(errValidation.Error()) logger.FatalLog(errValidation.Error())
} }
tempPubKey, tempPrivKey, err := pro.FetchApiServerKeys() tempPubKey, tempPrivKey, err := FetchApiServerKeys()
if err != nil { if err != nil {
logger.FatalLog(errValidation.Error()) logger.FatalLog(errValidation.Error())
} }
@@ -88,11 +101,59 @@ func ValidateLicense() error {
if Limits.FreeTier { if Limits.FreeTier {
Limits.Networks = 3 Limits.Networks = 3
} }
setControllerLimits()
logger.Log(0, "License validation succeeded!") logger.Log(0, "License validation succeeded!")
return nil return nil
} }
// FetchApiServerKeys - fetches netmaker license keys for identification
// as well as secure communication with API
// if none present, it generates a new pair
func FetchApiServerKeys() (pub *[32]byte, priv *[32]byte, err error) {
var returnData = apiServerConf{}
currentData, err := database.FetchRecord(database.SERVERCONF_TABLE_NAME, db_license_key)
if err != nil && !database.IsEmptyRecord(err) {
return nil, nil, err
} else if database.IsEmptyRecord(err) { // need to generate a new identifier pair
pub, priv, err = box.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, err
}
pubBytes, err := ncutils.ConvertKeyToBytes(pub)
if err != nil {
return nil, nil, err
}
privBytes, err := ncutils.ConvertKeyToBytes(priv)
if err != nil {
return nil, nil, err
}
returnData.PrivateKey = privBytes
returnData.PublicKey = pubBytes
record, err := json.Marshal(&returnData)
if err != nil {
return nil, nil, err
}
if err = database.Insert(db_license_key, string(record), database.SERVERCONF_TABLE_NAME); err != nil {
return nil, nil, err
}
} else {
if err = json.Unmarshal([]byte(currentData), &returnData); err != nil {
return nil, nil, err
}
priv, err = ncutils.ConvertBytesToKey(returnData.PrivateKey)
if err != nil {
return nil, nil, err
}
pub, err = ncutils.ConvertBytesToKey(returnData.PublicKey)
if err != nil {
return nil, nil, err
}
}
return pub, priv, nil
}
func getLicensePublicKey(licensePubKeyEncoded string) (*[32]byte, error) { func getLicensePublicKey(licensePubKeyEncoded string) (*[32]byte, error) {
decodedPubKey := base64decode(licensePubKeyEncoded) decodedPubKey := base64decode(licensePubKeyEncoded)
return ncutils.ConvertBytesToKey(decodedPubKey) return ncutils.ConvertBytesToKey(decodedPubKey)
@@ -179,32 +240,6 @@ func ClearLicenseCache() error {
return database.DeleteRecord(database.CACHE_TABLE_NAME, license_cache_key) return database.DeleteRecord(database.CACHE_TABLE_NAME, license_cache_key)
} }
// AddServerIDIfNotPresent - add's current server ID to DB if not present
func AddServerIDIfNotPresent() error {
currentNodeID := servercfg.GetNodeID()
currentServerIDs := serverIDs{}
record, err := database.FetchRecord(database.SERVERCONF_TABLE_NAME, server_id_key)
if err != nil && !database.IsEmptyRecord(err) {
return err
} else if err == nil {
if err = json.Unmarshal([]byte(record), &currentServerIDs); err != nil {
return err
}
}
if !logic.StringSliceContains(currentServerIDs.ServerIDs, currentNodeID) {
currentServerIDs.ServerIDs = append(currentServerIDs.ServerIDs, currentNodeID)
data, err := json.Marshal(&currentServerIDs)
if err != nil {
return err
}
return database.Insert(server_id_key, string(data), database.SERVERCONF_TABLE_NAME)
}
return nil
}
func getServerCount() int { func getServerCount() int {
if record, err := database.FetchRecord(database.SERVERCONF_TABLE_NAME, server_id_key); err == nil { if record, err := database.FetchRecord(database.SERVERCONF_TABLE_NAME, server_id_key); err == nil {
currentServerIDs := serverIDs{} currentServerIDs := serverIDs{}

View File

@@ -49,6 +49,6 @@ func getCurrentServerLimit() (limits LicenseLimits) {
if err == nil { if err == nil {
limits.Users = len(users) limits.Users = len(users)
} }
limits.Servers = getServerCount() limits.Servers = logic.GetServerCount()
return return
} }

View File

@@ -8,17 +8,6 @@ import (
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
) )
// NetworkExists - check if network exists
func NetworkExists(name string) (bool, error) {
var network string
var err error
if network, err = database.FetchRecord(database.NETWORKS_TABLE_NAME, name); err != nil {
return false, err
}
return len(network) > 0, nil
}
// NameInDNSCharSet - name in dns char set // NameInDNSCharSet - name in dns char set
func NameInDNSCharSet(name string) bool { func NameInDNSCharSet(name string) bool {

View File

@@ -26,7 +26,7 @@ func TestNetworkExists(t *testing.T) {
} }
database.DeleteRecord(database.NETWORKS_TABLE_NAME, testNetwork.NetID) database.DeleteRecord(database.NETWORKS_TABLE_NAME, testNetwork.NetID)
defer database.CloseDB() defer database.CloseDB()
exists, err := NetworkExists(testNetwork.NetID) exists, err := logic.NetworkExists(testNetwork.NetID)
if err == nil { if err == nil {
t.Fatalf("expected error, received nil") t.Fatalf("expected error, received nil")
} }
@@ -38,7 +38,7 @@ func TestNetworkExists(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to save test network in databse: %s", err) t.Fatalf("failed to save test network in databse: %s", err)
} }
exists, err = NetworkExists(testNetwork.NetID) exists, err = logic.NetworkExists(testNetwork.NetID)
if err != nil { if err != nil {
t.Fatalf("expected nil, received err: %s", err) t.Fatalf("expected nil, received err: %s", err)
} }

View File

@@ -99,7 +99,7 @@ func CreateUser(user models.User) (models.User, error) {
tokenString, _ := CreateProUserJWT(user.UserName, user.Networks, user.Groups, user.IsAdmin) tokenString, _ := CreateProUserJWT(user.UserName, user.Networks, user.Groups, user.IsAdmin)
if tokenString == "" { if tokenString == "" {
// returnErrorResponse(w, r, errorResponse) // logic.ReturnErrorResponse(w, r, errorResponse)
return user, err return user, err
} }

View File

@@ -1,4 +1,4 @@
package controller package logic
import ( import (
"encoding/json" "encoding/json"
@@ -8,7 +8,8 @@ import (
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
) )
func formatError(err error, errType string) models.ErrorResponse { // FormatError - takes ErrorResponse and uses correct code
func FormatError(err error, errType string) models.ErrorResponse {
var status = http.StatusInternalServerError var status = http.StatusInternalServerError
switch errType { switch errType {
@@ -33,7 +34,8 @@ func formatError(err error, errType string) models.ErrorResponse {
return response return response
} }
func returnSuccessResponse(response http.ResponseWriter, request *http.Request, message string) { // ReturnSuccessResponse - processes message and adds header
func ReturnSuccessResponse(response http.ResponseWriter, request *http.Request, message string) {
var httpResponse models.SuccessResponse var httpResponse models.SuccessResponse
httpResponse.Code = http.StatusOK httpResponse.Code = http.StatusOK
httpResponse.Message = message httpResponse.Message = message
@@ -42,7 +44,8 @@ func returnSuccessResponse(response http.ResponseWriter, request *http.Request,
json.NewEncoder(response).Encode(httpResponse) json.NewEncoder(response).Encode(httpResponse)
} }
func returnErrorResponse(response http.ResponseWriter, request *http.Request, errorMessage models.ErrorResponse) { // ReturnErrorResponse - processes error and adds header
func ReturnErrorResponse(response http.ResponseWriter, request *http.Request, errorMessage models.ErrorResponse) {
httpResponse := &models.ErrorResponse{Code: errorMessage.Code, Message: errorMessage.Message} httpResponse := &models.ErrorResponse{Code: errorMessage.Code, Message: errorMessage.Message}
jsonResponse, err := json.Marshal(httpResponse) jsonResponse, err := json.Marshal(httpResponse)
if err != nil { if err != nil {

View File

@@ -96,7 +96,7 @@ func CreateNetwork(network models.Network) (models.Network, error) {
err := ValidateNetwork(&network, false) err := ValidateNetwork(&network, false)
if err != nil { if err != nil {
//returnErrorResponse(w, r, formatError(err, "badrequest")) //logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return models.Network{}, err return models.Network{}, err
} }
@@ -656,6 +656,17 @@ func SaveNetwork(network *models.Network) error {
return nil return nil
} }
// NetworkExists - check if network exists
func NetworkExists(name string) (bool, error) {
var network string
var err error
if network, err = database.FetchRecord(database.NETWORKS_TABLE_NAME, name); err != nil {
return false, err
}
return len(network) > 0, nil
}
// == Private == // == Private ==
func networkNodesUpdateAction(networkName string, action string) error { func networkNodesUpdateAction(networkName string, action string) error {

View File

@@ -311,7 +311,7 @@ func CreateNode(node *models.Node) error {
//Create a JWT for the node //Create a JWT for the node
tokenString, _ := CreateJWT(node.ID, node.MacAddress, node.Network) tokenString, _ := CreateJWT(node.ID, node.MacAddress, node.Network)
if tokenString == "" { if tokenString == "" {
//returnErrorResponse(w, r, errorResponse) //logic.ReturnErrorResponse(w, r, errorResponse)
return err return err
} }
err = ValidateNode(node, false) err = ValidateNode(node, false)

View File

@@ -1,66 +0,0 @@
package pro
import (
"crypto/rand"
"encoding/json"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/netclient/ncutils"
"golang.org/x/crypto/nacl/box"
)
const (
db_license_key = "netmaker-id-key-pair"
)
type apiServerConf struct {
PrivateKey []byte `json:"private_key" binding:"required"`
PublicKey []byte `json:"public_key" binding:"required"`
}
// FetchApiServerKeys - fetches netmaker license keys for identification
// as well as secure communication with API
// if none present, it generates a new pair
func FetchApiServerKeys() (pub *[32]byte, priv *[32]byte, err error) {
var returnData = apiServerConf{}
currentData, err := database.FetchRecord(database.SERVERCONF_TABLE_NAME, db_license_key)
if err != nil && !database.IsEmptyRecord(err) {
return nil, nil, err
} else if database.IsEmptyRecord(err) { // need to generate a new identifier pair
pub, priv, err = box.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, err
}
pubBytes, err := ncutils.ConvertKeyToBytes(pub)
if err != nil {
return nil, nil, err
}
privBytes, err := ncutils.ConvertKeyToBytes(priv)
if err != nil {
return nil, nil, err
}
returnData.PrivateKey = privBytes
returnData.PublicKey = pubBytes
record, err := json.Marshal(&returnData)
if err != nil {
return nil, nil, err
}
if err = database.Insert(db_license_key, string(record), database.SERVERCONF_TABLE_NAME); err != nil {
return nil, nil, err
}
} else {
if err = json.Unmarshal([]byte(currentData), &returnData); err != nil {
return nil, nil, err
}
priv, err = ncutils.ConvertBytesToKey(returnData.PrivateKey)
if err != nil {
return nil, nil, err
}
pub, err = ncutils.ConvertBytesToKey(returnData.PublicKey)
if err != nil {
return nil, nil, err
}
}
return pub, priv, nil
}

View File

@@ -58,7 +58,7 @@ func TestNetworkProSettings(t *testing.T) {
} }
AddProNetDefaults(&network) AddProNetDefaults(&network)
assert.NotNil(t, network.ProSettings) assert.NotNil(t, network.ProSettings)
assert.Nil(t, network.ProSettings.AllowedGroups) assert.Equal(t, len(network.ProSettings.AllowedGroups), 1)
assert.Nil(t, network.ProSettings.AllowedUsers) assert.Equal(t, len(network.ProSettings.AllowedUsers), 0)
}) })
} }

View File

@@ -1,4 +1,4 @@
package controller package logic
import ( import (
"encoding/json" "encoding/json"
@@ -7,8 +7,6 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/functions"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/pro" "github.com/gravitl/netmaker/logic/pro"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/models/promodels" "github.com/gravitl/netmaker/models/promodels"
@@ -16,16 +14,20 @@ import (
) )
const ( const (
// ALL_NETWORK_ACCESS - represents all networks
ALL_NETWORK_ACCESS = "THIS_USER_HAS_ALL"
master_uname = "masteradministrator" master_uname = "masteradministrator"
unauthorized_msg = "unauthorized" Unauthorized_Msg = "unauthorized"
unauthorized_err = models.Error(unauthorized_msg) Unauthorized_Err = models.Error(Unauthorized_Msg)
) )
func securityCheck(reqAdmin bool, next http.Handler) http.HandlerFunc { // SecurityCheck - Check if user has appropriate permissions
func SecurityCheck(reqAdmin bool, next http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var errorResponse = models.ErrorResponse{ var errorResponse = models.ErrorResponse{
Code: http.StatusUnauthorized, Message: unauthorized_msg, Code: http.StatusUnauthorized, Message: Unauthorized_Msg,
} }
var params = mux.Vars(r) var params = mux.Vars(r)
@@ -44,14 +46,14 @@ func securityCheck(reqAdmin bool, next http.Handler) http.HandlerFunc {
if len(networkName) == 0 { if len(networkName) == 0 {
networkName = params["network"] networkName = params["network"]
} }
networks, username, err := SecurityCheck(reqAdmin, networkName, bearerToken) networks, username, err := UserPermissions(reqAdmin, networkName, bearerToken)
if err != nil { if err != nil {
returnErrorResponse(w, r, errorResponse) ReturnErrorResponse(w, r, errorResponse)
return return
} }
networksJson, err := json.Marshal(&networks) networksJson, err := json.Marshal(&networks)
if err != nil { if err != nil {
returnErrorResponse(w, r, errorResponse) ReturnErrorResponse(w, r, errorResponse)
return return
} }
r.Header.Set("user", username) r.Header.Set("user", username)
@@ -60,7 +62,8 @@ func securityCheck(reqAdmin bool, next http.Handler) http.HandlerFunc {
} }
} }
func netUserSecurityCheck(isNodes, isClients bool, next http.Handler) http.HandlerFunc { // NetUserSecurityCheck - Check if network user has appropriate permissions
func NetUserSecurityCheck(isNodes, isClients bool, next http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var errorResponse = models.ErrorResponse{ var errorResponse = models.ErrorResponse{
Code: http.StatusUnauthorized, Message: "unauthorized", Code: http.StatusUnauthorized, Message: "unauthorized",
@@ -77,7 +80,7 @@ func netUserSecurityCheck(isNodes, isClients bool, next http.Handler) http.Handl
var authToken = "" var authToken = ""
if len(tokenSplit) < 2 { if len(tokenSplit) < 2 {
returnErrorResponse(w, r, errorResponse) ReturnErrorResponse(w, r, errorResponse)
return return
} else { } else {
authToken = tokenSplit[1] authToken = tokenSplit[1]
@@ -91,9 +94,9 @@ func netUserSecurityCheck(isNodes, isClients bool, next http.Handler) http.Handl
return return
} }
userName, _, isadmin, err := logic.VerifyUserToken(authToken) userName, _, isadmin, err := VerifyUserToken(authToken)
if err != nil { if err != nil {
returnErrorResponse(w, r, errorResponse) ReturnErrorResponse(w, r, errorResponse)
return return
} }
r.Header.Set("user", userName) r.Header.Set("user", userName)
@@ -113,15 +116,15 @@ func netUserSecurityCheck(isNodes, isClients bool, next http.Handler) http.Handl
} }
u, err := pro.GetNetworkUser(network, promodels.NetworkUserID(userName)) u, err := pro.GetNetworkUser(network, promodels.NetworkUserID(userName))
if err != nil { if err != nil {
returnErrorResponse(w, r, errorResponse) ReturnErrorResponse(w, r, errorResponse)
return return
} }
if u.AccessLevel > necessaryAccess { if u.AccessLevel > necessaryAccess {
returnErrorResponse(w, r, errorResponse) ReturnErrorResponse(w, r, errorResponse)
return return
} }
} else if netUserName != userName { } else if netUserName != userName {
returnErrorResponse(w, r, errorResponse) ReturnErrorResponse(w, r, errorResponse)
return return
} }
@@ -129,14 +132,14 @@ func netUserSecurityCheck(isNodes, isClients bool, next http.Handler) http.Handl
} }
} }
// SecurityCheck - checks token stuff // UserPermissions - checks token stuff
func SecurityCheck(reqAdmin bool, netname string, token string) ([]string, string, error) { func UserPermissions(reqAdmin bool, netname string, token string) ([]string, string, error) {
var tokenSplit = strings.Split(token, " ") var tokenSplit = strings.Split(token, " ")
var authToken = "" var authToken = ""
userNetworks := []string{} userNetworks := []string{}
if len(tokenSplit) < 2 { if len(tokenSplit) < 2 {
return userNetworks, "", unauthorized_err return userNetworks, "", Unauthorized_Err
} else { } else {
authToken = tokenSplit[1] authToken = tokenSplit[1]
} }
@@ -144,12 +147,12 @@ func SecurityCheck(reqAdmin bool, netname string, token string) ([]string, strin
if authenticateMaster(authToken) { if authenticateMaster(authToken) {
return []string{ALL_NETWORK_ACCESS}, master_uname, nil return []string{ALL_NETWORK_ACCESS}, master_uname, nil
} }
username, networks, isadmin, err := logic.VerifyUserToken(authToken) username, networks, isadmin, err := VerifyUserToken(authToken)
if err != nil { if err != nil {
return nil, username, unauthorized_err return nil, username, Unauthorized_Err
} }
if !isadmin && reqAdmin { if !isadmin && reqAdmin {
return nil, username, unauthorized_err return nil, username, Unauthorized_Err
} }
userNetworks = networks userNetworks = networks
if isadmin { if isadmin {
@@ -157,10 +160,10 @@ func SecurityCheck(reqAdmin bool, netname string, token string) ([]string, strin
} }
// check network admin access // check network admin access
if len(netname) > 0 && (!authenticateNetworkUser(netname, userNetworks) || len(userNetworks) == 0) { if len(netname) > 0 && (!authenticateNetworkUser(netname, userNetworks) || len(userNetworks) == 0) {
return nil, username, unauthorized_err return nil, username, Unauthorized_Err
} }
if !pro.IsUserNetAdmin(netname, username) { if !pro.IsUserNetAdmin(netname, username) {
return nil, "", unauthorized_err return nil, "", Unauthorized_Err
} }
return userNetworks, username, nil return userNetworks, username, nil
} }
@@ -171,11 +174,11 @@ func authenticateMaster(tokenString string) bool {
} }
func authenticateNetworkUser(network string, userNetworks []string) bool { func authenticateNetworkUser(network string, userNetworks []string) bool {
networkexists, err := functions.NetworkExists(network) networkexists, err := NetworkExists(network)
if (err != nil && !database.IsEmptyRecord(err)) || !networkexists { if (err != nil && !database.IsEmptyRecord(err)) || !networkexists {
return false return false
} }
return logic.StringSliceContains(userNetworks, network) return StringSliceContains(userNetworks, network)
} }
//Consider a more secure way of setting master key //Consider a more secure way of setting master key
@@ -187,15 +190,15 @@ func authenticateDNSToken(tokenString string) bool {
return tokens[1] == servercfg.GetDNSKey() return tokens[1] == servercfg.GetDNSKey()
} }
func continueIfUserMatch(next http.Handler) http.HandlerFunc { func ContinueIfUserMatch(next http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var errorResponse = models.ErrorResponse{ var errorResponse = models.ErrorResponse{
Code: http.StatusUnauthorized, Message: unauthorized_msg, Code: http.StatusUnauthorized, Message: Unauthorized_Msg,
} }
var params = mux.Vars(r) var params = mux.Vars(r)
var requestedUser = params["username"] var requestedUser = params["username"]
if requestedUser != r.Header.Get("user") { if requestedUser != r.Header.Get("user") {
returnErrorResponse(w, r, errorResponse) ReturnErrorResponse(w, r, errorResponse)
return return
} }
next.ServeHTTP(w, r) next.ServeHTTP(w, r)

View File

@@ -18,6 +18,8 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
var EnterpriseCheckFuncs []interface{}
// == Join, Checkin, and Leave for Server == // == Join, Checkin, and Leave for Server ==
// KUBERNETES_LISTEN_PORT - starting port for Kubernetes in order to use NodePort range // KUBERNETES_LISTEN_PORT - starting port for Kubernetes in order to use NodePort range
@@ -164,6 +166,13 @@ func ServerJoin(networkSettings *models.Network) (models.Node, error) {
return *node, nil return *node, nil
} }
// EnterpriseCheck - Runs enterprise functions if presented
func EnterpriseCheck() {
for _, check := range EnterpriseCheckFuncs {
check.(func())()
}
}
// ServerUpdate - updates the server // ServerUpdate - updates the server
// replaces legacy Checkin code // replaces legacy Checkin code
func ServerUpdate(serverNode *models.Node, ifaceDelta bool) error { func ServerUpdate(serverNode *models.Node, ifaceDelta bool) error {

View File

@@ -6,6 +6,21 @@ import (
"github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/database"
) )
var (
// Node_Limit - dummy var for community
Node_Limit = 1000000000
// Networks_Limit - dummy var for community
Networks_Limit = 1000000000
// Users_Limit - dummy var for community
Users_Limit = 1000000000
// Clients_Limit - dummy var for community
Clients_Limit = 1000000000
// Free_Tier - specifies if free tier
Free_Tier = false
// Is_EE - specifies if enterprise
Is_EE = false
)
// constant for database key for storing server ids // constant for database key for storing server ids
const server_id_key = "nm-server-id" const server_id_key = "nm-server-id"

16
main.go
View File

@@ -20,7 +20,6 @@ import (
"github.com/gravitl/netmaker/config" "github.com/gravitl/netmaker/config"
controller "github.com/gravitl/netmaker/controllers" controller "github.com/gravitl/netmaker/controllers"
"github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/ee"
"github.com/gravitl/netmaker/functions" "github.com/gravitl/netmaker/functions"
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic"
@@ -76,7 +75,7 @@ func initialize() { // Client Mode Prereq Check
logger.FatalLog("Error connecting to database") logger.FatalLog("Error connecting to database")
} }
logger.Log(0, "database successfully connected") logger.Log(0, "database successfully connected")
if err = ee.AddServerIDIfNotPresent(); err != nil { if err = logic.AddServerIDIfNotPresent(); err != nil {
logger.Log(1, "failed to save server ID") logger.Log(1, "failed to save server ID")
} }
@@ -91,18 +90,7 @@ func initialize() { // Client Mode Prereq Check
logger.Log(1, "Timer error occurred: ", err.Error()) logger.Log(1, "Timer error occurred: ", err.Error())
} }
if ee.IsEnterprise() { logic.EnterpriseCheck()
// == License Handling ==
ee.ValidateLicense()
if ee.Limits.FreeTier {
logger.Log(0, "proceeding with Free Tier license")
} else {
logger.Log(0, "proceeding with Paid Tier license")
}
// == End License Handling ==
ee.AddLicenseHooks()
}
var authProvider = auth.InitializeAuthProvider() var authProvider = auth.InitializeAuthProvider()
if authProvider != "" { if authProvider != "" {

View File

@@ -5,26 +5,8 @@ package main
import ( import (
"github.com/gravitl/netmaker/ee" "github.com/gravitl/netmaker/ee"
"github.com/gravitl/netmaker/models"
) )
func init() { func init() {
ee.SetIsEnterprise() ee.InitEE()
models.SetLogo(retrieveEELogo())
}
func retrieveEELogo() string {
return `
__ __ ______ ______ __ __ ______ __ __ ______ ______
/\ "-.\ \ /\ ___\ /\__ _\ /\ "-./ \ /\ __ \ /\ \/ / /\ ___\ /\ == \
\ \ \-. \ \ \ __\ \/_/\ \/ \ \ \-./\ \ \ \ __ \ \ \ _"-. \ \ __\ \ \ __<
\ \_\\"\_\ \ \_____\ \ \_\ \ \_\ \ \_\ \ \_\ \_\ \ \_\ \_\ \ \_____\ \ \_\ \_\
\/_/ \/_/ \/_____/ \/_/ \/_/ \/_/ \/_/\/_/ \/_/\/_/ \/_____/ \/_/ /_/
___ ___ ____
____ ____ ____ / _ \ / _ \ / __ \ ____ ____ ____
/___/ /___/ /___/ / ___/ / , _// /_/ / /___/ /___/ /___/
/___/ /___/ /___/ /_/ /_/|_| \____/ /___/ /___/ /___/
`
} }

View File

@@ -7,7 +7,6 @@ import (
mqtt "github.com/eclipse/paho.mqtt.golang" mqtt "github.com/eclipse/paho.mqtt.golang"
"github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/ee"
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
@@ -99,7 +98,7 @@ func UpdateNode(client mqtt.Client, msg mqtt.Message) {
// UpdateMetrics message Handler -- handles updates from client nodes for metrics // UpdateMetrics message Handler -- handles updates from client nodes for metrics
func UpdateMetrics(client mqtt.Client, msg mqtt.Message) { func UpdateMetrics(client mqtt.Client, msg mqtt.Message) {
if ee.IsEnterprise() { if logic.Is_EE {
go func() { go func() {
id, err := getID(msg.Topic()) id, err := getID(msg.Topic())
if err != nil { if err != nil {

View File

@@ -6,10 +6,9 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/gravitl/netmaker/ee"
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/pro/metrics" "github.com/gravitl/netmaker/logic/metrics"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/servercfg"
"github.com/gravitl/netmaker/serverctl" "github.com/gravitl/netmaker/serverctl"
@@ -185,7 +184,7 @@ func ServerStartNotify() error {
// function to collect and store metrics for server nodes // function to collect and store metrics for server nodes
func collectServerMetrics(networks []models.Network) { func collectServerMetrics(networks []models.Network) {
if !ee.IsEnterprise() { if !logic.Is_EE {
return return
} }
if len(networks) > 0 { if len(networks) > 0 {

View File

@@ -15,7 +15,7 @@ import (
"github.com/cloverstd/tcping/ping" "github.com/cloverstd/tcping/ping"
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic/pro/metrics" "github.com/gravitl/netmaker/logic/metrics"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/netclient/auth" "github.com/gravitl/netmaker/netclient/auth"
"github.com/gravitl/netmaker/netclient/config" "github.com/gravitl/netmaker/netclient/config"