diff --git a/controllers/controller.go b/controllers/controller.go index f43897c3..015ffdb9 100644 --- a/controllers/controller.go +++ b/controllers/controller.go @@ -25,7 +25,6 @@ var HttpHandlers = []interface{}{ serverHandlers, extClientHandlers, ipHandlers, - metricHandlers, loggerHandlers, userGroupsHandlers, networkUsersHandlers, diff --git a/controllers/dns.go b/controllers/dns.go index c5ae5afa..f9e7a3e5 100644 --- a/controllers/dns.go +++ b/controllers/dns.go @@ -16,13 +16,13 @@ import ( func dnsHandlers(r *mux.Router) { - r.HandleFunc("/api/dns", securityCheck(true, http.HandlerFunc(getAllDNS))).Methods("GET") - r.HandleFunc("/api/dns/adm/{network}/nodes", securityCheck(false, http.HandlerFunc(getNodeDNS))).Methods("GET") - r.HandleFunc("/api/dns/adm/{network}/custom", securityCheck(false, http.HandlerFunc(getCustomDNS))).Methods("GET") - r.HandleFunc("/api/dns/adm/{network}", securityCheck(false, http.HandlerFunc(getDNS))).Methods("GET") - r.HandleFunc("/api/dns/{network}", securityCheck(false, http.HandlerFunc(createDNS))).Methods("POST") - r.HandleFunc("/api/dns/adm/pushdns", securityCheck(false, http.HandlerFunc(pushDNS))).Methods("POST") - r.HandleFunc("/api/dns/{network}/{domain}", securityCheck(false, http.HandlerFunc(deleteDNS))).Methods("DELETE") + r.HandleFunc("/api/dns", logic.SecurityCheck(true, http.HandlerFunc(getAllDNS))).Methods("GET") + r.HandleFunc("/api/dns/adm/{network}/nodes", logic.SecurityCheck(false, http.HandlerFunc(getNodeDNS))).Methods("GET") + r.HandleFunc("/api/dns/adm/{network}/custom", logic.SecurityCheck(false, http.HandlerFunc(getCustomDNS))).Methods("GET") + r.HandleFunc("/api/dns/adm/{network}", logic.SecurityCheck(false, http.HandlerFunc(getDNS))).Methods("GET") + r.HandleFunc("/api/dns/{network}", logic.SecurityCheck(false, http.HandlerFunc(createDNS))).Methods("POST") + r.HandleFunc("/api/dns/adm/pushdns", logic.SecurityCheck(false, http.HandlerFunc(pushDNS))).Methods("POST") + r.HandleFunc("/api/dns/{network}/{domain}", logic.SecurityCheck(false, http.HandlerFunc(deleteDNS))).Methods("DELETE") } // swagger:route GET /api/dns/adm/{network}/nodes dns getNodeDNS @@ -44,7 +44,7 @@ func getNodeDNS(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get node DNS entries for network [%s]: %v", network, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } w.WriteHeader(http.StatusOK) @@ -68,7 +68,7 @@ func getAllDNS(w http.ResponseWriter, r *http.Request) { dns, err := logic.GetAllDNS() if err != nil { logger.Log(0, r.Header.Get("user"), "failed to get all DNS entries: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } w.WriteHeader(http.StatusOK) @@ -98,7 +98,7 @@ func getCustomDNS(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get custom DNS entries for network [%s]: %v", network, err.Error())) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } w.WriteHeader(http.StatusOK) @@ -128,7 +128,7 @@ func getDNS(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get all DNS entries for network [%s]: %v", network, err.Error())) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } w.WriteHeader(http.StatusOK) @@ -160,7 +160,7 @@ func createDNS(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("invalid DNS entry %+v: %v", entry, err)) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } @@ -168,14 +168,14 @@ func createDNS(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("Failed to create DNS entry %+v: %v", entry, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } err = logic.SetDNS() if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("Failed to set DNS entries on file: %v", err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } logger.Log(1, "new DNS record added:", entry.Name) @@ -221,7 +221,7 @@ func deleteDNS(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, "failed to delete dns entry: ", entrytext) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } logger.Log(1, "deleted dns entry: ", entrytext) @@ -229,7 +229,7 @@ func deleteDNS(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("Failed to set DNS entries on file: %v", err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } json.NewEncoder(w).Encode(entrytext + " deleted.") @@ -287,7 +287,7 @@ func pushDNS(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("Failed to set DNS entries on file: %v", err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } logger.Log(1, r.Header.Get("user"), "pushed DNS updates to nameserver") diff --git a/controllers/ext_client.go b/controllers/ext_client.go index 05cdad33..3ac9a05f 100644 --- a/controllers/ext_client.go +++ b/controllers/ext_client.go @@ -21,13 +21,13 @@ import ( func extClientHandlers(r *mux.Router) { - r.HandleFunc("/api/extclients", securityCheck(false, http.HandlerFunc(getAllExtClients))).Methods("GET") - r.HandleFunc("/api/extclients/{network}", securityCheck(false, http.HandlerFunc(getNetworkExtClients))).Methods("GET") - r.HandleFunc("/api/extclients/{network}/{clientid}", securityCheck(false, http.HandlerFunc(getExtClient))).Methods("GET") - r.HandleFunc("/api/extclients/{network}/{clientid}/{type}", netUserSecurityCheck(false, true, http.HandlerFunc(getExtClientConf))).Methods("GET") - r.HandleFunc("/api/extclients/{network}/{clientid}", netUserSecurityCheck(false, true, http.HandlerFunc(updateExtClient))).Methods("PUT") - r.HandleFunc("/api/extclients/{network}/{clientid}", netUserSecurityCheck(false, true, http.HandlerFunc(deleteExtClient))).Methods("DELETE") - r.HandleFunc("/api/extclients/{network}/{nodeid}", netUserSecurityCheck(false, true, checkFreeTierLimits(clients_l, http.HandlerFunc(createExtClient)))).Methods("POST") + r.HandleFunc("/api/extclients", logic.SecurityCheck(false, http.HandlerFunc(getAllExtClients))).Methods("GET") + r.HandleFunc("/api/extclients/{network}", logic.SecurityCheck(false, http.HandlerFunc(getNetworkExtClients))).Methods("GET") + r.HandleFunc("/api/extclients/{network}/{clientid}", logic.SecurityCheck(false, http.HandlerFunc(getExtClient))).Methods("GET") + r.HandleFunc("/api/extclients/{network}/{clientid}/{type}", logic.NetUserSecurityCheck(false, true, http.HandlerFunc(getExtClientConf))).Methods("GET") + r.HandleFunc("/api/extclients/{network}/{clientid}", logic.NetUserSecurityCheck(false, true, http.HandlerFunc(updateExtClient))).Methods("PUT") + r.HandleFunc("/api/extclients/{network}/{clientid}", logic.NetUserSecurityCheck(false, true, http.HandlerFunc(deleteExtClient))).Methods("DELETE") + r.HandleFunc("/api/extclients/{network}/{nodeid}", logic.NetUserSecurityCheck(false, true, checkFreeTierLimits(clients_l, http.HandlerFunc(createExtClient)))).Methods("POST") } func checkIngressExists(nodeID string) bool { @@ -62,7 +62,7 @@ func getNetworkExtClients(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get ext clients for network [%s]: %v", network, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -96,16 +96,16 @@ func getAllExtClients(w http.ResponseWriter, r *http.Request) { if marshalErr != nil { logger.Log(0, "error unmarshalling networks: ", marshalErr.Error()) - returnErrorResponse(w, r, formatError(marshalErr, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(marshalErr, "internal")) return } clients := []models.ExtClient{} var err error - if networksSlice[0] == ALL_NETWORK_ACCESS { + if networksSlice[0] == logic.ALL_NETWORK_ACCESS { clients, err = functions.GetAllExtClients() if err != nil && !database.IsEmptyRecord(err) { logger.Log(0, "failed to get all extclients: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } } else { @@ -146,7 +146,7 @@ func getExtClient(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get extclient for [%s] on network [%s]: %v", clientid, network, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -177,7 +177,7 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get extclient for [%s] on network [%s]: %v", clientid, networkid, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -185,14 +185,14 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get ingress gateway node [%s] info: %v", client.IngressGatewayID, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } network, err := logic.GetParentNetwork(client.Network) if err != nil { logger.Log(1, r.Header.Get("user"), "Could not retrieve Ingress Gateway Network", client.Network) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -258,7 +258,7 @@ Endpoint = %s bytes, err := qrcode.Encode(config, qrcode.Medium, 220) if err != nil { logger.Log(1, r.Header.Get("user"), "failed to encode qr code: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } w.Header().Set("Content-Type", "image/png") @@ -266,7 +266,7 @@ Endpoint = %s _, err = w.Write(bytes) if err != nil { logger.Log(1, r.Header.Get("user"), "response writer error (qr) ", err.Error()) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } return @@ -280,7 +280,7 @@ Endpoint = %s _, err := fmt.Fprint(w, config) if err != nil { logger.Log(1, r.Header.Get("user"), "response writer error (file) ", err.Error()) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) } return } @@ -310,7 +310,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) { err := errors.New("ingress does not exist") logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to create extclient on network [%s]: %v", networkName, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -329,7 +329,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get ingress gateway node [%s] info: %v", nodeid, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } extclient.IngressGatewayEndpoint = node.Endpoint + ":" + strconv.FormatInt(int64(node.ListenPort), 10) @@ -345,7 +345,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to create new ext client on network [%s]: %v", networkName, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -355,7 +355,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) { if isAdmin, err = checkProClientAccess(userID, extclient.ClientID, &parentNetwork); err != nil { logger.Log(0, userID, "attempted to create a client on network", networkName, "but they lack access") logic.DeleteExtClient(networkName, extclient.ClientID) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } if !isAdmin { @@ -400,7 +400,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } clientid := params["clientid"] @@ -410,7 +410,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get record key for client [%s], network [%s]: %v", clientid, network, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } data, err := database.FetchRecord(database.EXT_CLIENT_TABLE_NAME, key) @@ -418,13 +418,13 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to fetch ext client record key [%s] from db for client [%s], network [%s]: %v", key, clientid, network, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } if err = json.Unmarshal([]byte(data), &oldExtClient); err != nil { logger.Log(0, "error unmarshalling extclient: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -435,7 +435,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) { userID := r.Header.Get("user") _, doesOwn := doesUserOwnClient(userID, params["clientid"], networkName) if !doesOwn { - returnErrorResponse(w, r, formatError(fmt.Errorf("user not permitted"), "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("user not permitted"), "internal")) return } } @@ -457,7 +457,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to update ext client [%s], network [%s]: %v", clientid, network, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } logger.Log(0, r.Header.Get("user"), "updated ext client", newExtClient.ClientID) @@ -497,14 +497,14 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) { err = errors.New("Could not delete extclient " + params["clientid"]) logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to delete extclient [%s],network [%s]: %v", clientid, network, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } ingressnode, err := logic.GetNodeByID(extclient.IngressGatewayID) if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get ingress gateway node [%s] info: %v", extclient.IngressGatewayID, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -513,7 +513,7 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) { userID, clientID, networkName := r.Header.Get("user"), params["clientid"], params["network"] _, doesOwn := doesUserOwnClient(userID, clientID, networkName) if !doesOwn { - returnErrorResponse(w, r, formatError(fmt.Errorf("user not permitted"), "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("user not permitted"), "internal")) return } } @@ -531,7 +531,7 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to delete extclient [%s],network [%s]: %v", clientid, network, err)) err = errors.New("Could not delete extclient " + params["clientid"]) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -542,7 +542,7 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) { logger.Log(0, r.Header.Get("user"), "Deleted extclient client", params["clientid"], "from network", params["network"]) - returnSuccessResponse(w, r, params["clientid"]+" deleted.") + logic.ReturnSuccessResponse(w, r, params["clientid"]+" deleted.") } func checkProClientAccess(username, clientID string, network *models.Network) (bool, error) { diff --git a/controllers/limits.go b/controllers/limits.go index ddbff298..55ccb313 100644 --- a/controllers/limits.go +++ b/controllers/limits.go @@ -4,7 +4,6 @@ import ( "net/http" "github.com/gravitl/netmaker/database" - "github.com/gravitl/netmaker/ee" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" ) @@ -23,32 +22,32 @@ func checkFreeTierLimits(limit_choice int, next http.Handler) http.HandlerFunc { Code: http.StatusUnauthorized, Message: "free tier limits exceeded on networks", } - if ee.Limits.FreeTier { // check that free tier limits not exceeded + if logic.Free_Tier && logic.Is_EE { // check that free tier limits not exceeded if limit_choice == networks_l { currentNetworks, err := logic.GetNetworks() - if (err != nil && !database.IsEmptyRecord(err)) || len(currentNetworks) >= ee.Limits.Networks { - returnErrorResponse(w, r, errorResponse) + if (err != nil && !database.IsEmptyRecord(err)) || len(currentNetworks) >= logic.Networks_Limit { + logic.ReturnErrorResponse(w, r, errorResponse) return } } else if limit_choice == node_l { nodes, err := logic.GetAllNodes() - if (err != nil && !database.IsEmptyRecord(err)) || len(nodes) >= ee.Limits.Nodes { + if (err != nil && !database.IsEmptyRecord(err)) || len(nodes) >= logic.Node_Limit { errorResponse.Message = "free tier limits exceeded on nodes" - returnErrorResponse(w, r, errorResponse) + logic.ReturnErrorResponse(w, r, errorResponse) return } } else if limit_choice == users_l { users, err := logic.GetUsers() - if (err != nil && !database.IsEmptyRecord(err)) || len(users) >= ee.Limits.Users { + if (err != nil && !database.IsEmptyRecord(err)) || len(users) >= logic.Users_Limit { errorResponse.Message = "free tier limits exceeded on users" - returnErrorResponse(w, r, errorResponse) + logic.ReturnErrorResponse(w, r, errorResponse) return } } else if limit_choice == clients_l { clients, err := logic.GetAllExtClients() - if (err != nil && !database.IsEmptyRecord(err)) || len(clients) >= ee.Limits.Clients { + if (err != nil && !database.IsEmptyRecord(err)) || len(clients) >= logic.Clients_Limit { errorResponse.Message = "free tier limits exceeded on external clients" - returnErrorResponse(w, r, errorResponse) + logic.ReturnErrorResponse(w, r, errorResponse) return } } diff --git a/controllers/logger.go b/controllers/logger.go index 316783fc..09d740d1 100644 --- a/controllers/logger.go +++ b/controllers/logger.go @@ -7,10 +7,11 @@ import ( "github.com/gorilla/mux" "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/logic" ) func loggerHandlers(r *mux.Router) { - r.HandleFunc("/api/logs", securityCheck(true, http.HandlerFunc(getLogs))).Methods("GET") + r.HandleFunc("/api/logs", logic.SecurityCheck(true, http.HandlerFunc(getLogs))).Methods("GET") } func getLogs(w http.ResponseWriter, r *http.Request) { diff --git a/controllers/network.go b/controllers/network.go index 7c9c71f1..6a5f3f99 100644 --- a/controllers/network.go +++ b/controllers/network.go @@ -17,26 +17,20 @@ import ( "github.com/gravitl/netmaker/servercfg" ) -// ALL_NETWORK_ACCESS - represents all networks -const ALL_NETWORK_ACCESS = "THIS_USER_HAS_ALL" - -// NO_NETWORKS_PRESENT - represents no networks -const NO_NETWORKS_PRESENT = "THIS_USER_HAS_NONE" - func networkHandlers(r *mux.Router) { - r.HandleFunc("/api/networks", securityCheck(false, http.HandlerFunc(getNetworks))).Methods("GET") - r.HandleFunc("/api/networks", securityCheck(true, checkFreeTierLimits(networks_l, http.HandlerFunc(createNetwork)))).Methods("POST") - r.HandleFunc("/api/networks/{networkname}", securityCheck(false, http.HandlerFunc(getNetwork))).Methods("GET") - r.HandleFunc("/api/networks/{networkname}", securityCheck(false, http.HandlerFunc(updateNetwork))).Methods("PUT") - r.HandleFunc("/api/networks/{networkname}/nodelimit", securityCheck(true, http.HandlerFunc(updateNetworkNodeLimit))).Methods("PUT") - r.HandleFunc("/api/networks/{networkname}", securityCheck(true, http.HandlerFunc(deleteNetwork))).Methods("DELETE") - r.HandleFunc("/api/networks/{networkname}/keyupdate", securityCheck(true, http.HandlerFunc(keyUpdate))).Methods("POST") - r.HandleFunc("/api/networks/{networkname}/keys", securityCheck(false, http.HandlerFunc(createAccessKey))).Methods("POST") - r.HandleFunc("/api/networks/{networkname}/keys", securityCheck(false, http.HandlerFunc(getAccessKeys))).Methods("GET") - r.HandleFunc("/api/networks/{networkname}/keys/{name}", securityCheck(false, http.HandlerFunc(deleteAccessKey))).Methods("DELETE") + r.HandleFunc("/api/networks", logic.SecurityCheck(false, http.HandlerFunc(getNetworks))).Methods("GET") + r.HandleFunc("/api/networks", logic.SecurityCheck(true, checkFreeTierLimits(networks_l, http.HandlerFunc(createNetwork)))).Methods("POST") + r.HandleFunc("/api/networks/{networkname}", logic.SecurityCheck(false, http.HandlerFunc(getNetwork))).Methods("GET") + r.HandleFunc("/api/networks/{networkname}", logic.SecurityCheck(false, http.HandlerFunc(updateNetwork))).Methods("PUT") + r.HandleFunc("/api/networks/{networkname}/nodelimit", logic.SecurityCheck(true, http.HandlerFunc(updateNetworkNodeLimit))).Methods("PUT") + r.HandleFunc("/api/networks/{networkname}", logic.SecurityCheck(true, http.HandlerFunc(deleteNetwork))).Methods("DELETE") + r.HandleFunc("/api/networks/{networkname}/keyupdate", logic.SecurityCheck(true, http.HandlerFunc(keyUpdate))).Methods("POST") + r.HandleFunc("/api/networks/{networkname}/keys", logic.SecurityCheck(false, http.HandlerFunc(createAccessKey))).Methods("POST") + r.HandleFunc("/api/networks/{networkname}/keys", logic.SecurityCheck(false, http.HandlerFunc(getAccessKeys))).Methods("GET") + r.HandleFunc("/api/networks/{networkname}/keys/{name}", logic.SecurityCheck(false, http.HandlerFunc(deleteAccessKey))).Methods("DELETE") // ACLs - r.HandleFunc("/api/networks/{networkname}/acls", securityCheck(true, http.HandlerFunc(updateNetworkACL))).Methods("PUT") - r.HandleFunc("/api/networks/{networkname}/acls", securityCheck(true, http.HandlerFunc(getNetworkACL))).Methods("GET") + r.HandleFunc("/api/networks/{networkname}/acls", logic.SecurityCheck(true, http.HandlerFunc(updateNetworkACL))).Methods("PUT") + r.HandleFunc("/api/networks/{networkname}/acls", logic.SecurityCheck(true, http.HandlerFunc(getNetworkACL))).Methods("GET") } // swagger:route GET /api/networks networks getNetworks @@ -58,16 +52,16 @@ func getNetworks(w http.ResponseWriter, r *http.Request) { if marshalErr != nil { logger.Log(0, r.Header.Get("user"), "error unmarshalling networks: ", marshalErr.Error()) - returnErrorResponse(w, r, formatError(marshalErr, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(marshalErr, "badrequest")) return } allnetworks := []models.Network{} var err error - if networksSlice[0] == ALL_NETWORK_ACCESS { + if networksSlice[0] == logic.ALL_NETWORK_ACCESS { allnetworks, err = logic.GetNetworks() if err != nil && !database.IsEmptyRecord(err) { logger.Log(0, r.Header.Get("user"), "failed to fetch networks: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } } else { @@ -110,7 +104,7 @@ func getNetwork(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to fetch network [%s] info: %v", netname, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } if !servercfg.IsDisplayKeys() { @@ -140,7 +134,7 @@ func keyUpdate(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to update keys for network [%s]: %v", netname, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } logger.Log(2, r.Header.Get("user"), "updated key on network", netname) @@ -182,7 +176,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), "failed to get network info: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } var newNetwork models.Network @@ -190,7 +184,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } @@ -203,7 +197,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), "failed to update network: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } @@ -231,7 +225,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to update network [%s] ipv4 addresses: %v", network.NetID, err.Error())) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } } @@ -241,7 +235,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to update network [%s] ipv6 addresses: %v", network.NetID, err.Error())) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } } @@ -251,7 +245,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to update network [%s] local addresses: %v", network.NetID, err.Error())) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } } @@ -261,7 +255,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to update network [%s] hole punching: %v", network.NetID, err.Error())) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } } @@ -271,7 +265,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get network [%s] nodes: %v", network.NetID, err.Error())) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } for _, node := range nodes { @@ -305,7 +299,7 @@ func updateNetworkNodeLimit(w http.ResponseWriter, r *http.Request) { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get network [%s] nodes: %v", network.NetID, err.Error())) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -315,7 +309,7 @@ func updateNetworkNodeLimit(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } if networkChange.NodeLimit != 0 { @@ -324,7 +318,7 @@ func updateNetworkNodeLimit(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), "error marshalling resp: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME) @@ -354,21 +348,21 @@ func updateNetworkACL(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to fetch ACLs for network [%s]: %v", netname, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } err = json.NewDecoder(r.Body).Decode(&networkACLChange) if err != nil { logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } newNetACL, err := networkACLChange.Save(acls.ContainerID(netname)) if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to update ACLs for network [%s]: %v", netname, err)) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } logger.Log(1, r.Header.Get("user"), "updated ACLs for network", netname) @@ -412,7 +406,7 @@ func getNetworkACL(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to fetch ACLs for network [%s]: %v", netname, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } logger.Log(2, r.Header.Get("user"), "fetched acl for network", netname) @@ -445,7 +439,7 @@ func deleteNetwork(w http.ResponseWriter, r *http.Request) { } logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to delete network [%s]: %v", network, err)) - returnErrorResponse(w, r, formatError(err, errtype)) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, errtype)) return } logger.Log(1, r.Header.Get("user"), "deleted network", network) @@ -475,7 +469,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } @@ -483,7 +477,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) { err := errors.New("IPv4 or IPv6 CIDR required") logger.Log(0, r.Header.Get("user"), "failed to create network: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } @@ -491,7 +485,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), "failed to create network: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } @@ -504,7 +498,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) { } logger.Log(0, r.Header.Get("user"), "failed to create network: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } } @@ -537,28 +531,28 @@ func createAccessKey(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), "failed to get network info: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } err = json.NewDecoder(r.Body).Decode(&accesskey) if err != nil { logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } key, err := logic.CreateAccessKey(accesskey, network) if err != nil { logger.Log(0, r.Header.Get("user"), "failed to create access key: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } // do not allow access key creations view API with user names if _, err = logic.GetUser(key.Name); err == nil { logger.Log(0, "access key creation with invalid name attempted by", r.Header.Get("user")) - returnErrorResponse(w, r, formatError(fmt.Errorf("cannot create access key with user name"), "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("cannot create access key with user name"), "badrequest")) logic.DeleteKey(key.Name, network.NetID) return } @@ -587,7 +581,7 @@ func getAccessKeys(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get keys for network [%s]: %v", network, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } if !servercfg.IsDisplayKeys() { @@ -621,7 +615,7 @@ func deleteAccessKey(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to delete key [%s] for network [%s]: %v", keyname, netname, err)) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } logger.Log(1, r.Header.Get("user"), "deleted access key", keyname, "on network,", netname) diff --git a/controllers/network_test.go b/controllers/network_test.go index a85b03e3..03b11759 100644 --- a/controllers/network_test.go +++ b/controllers/network_test.go @@ -182,24 +182,24 @@ func TestSecurityCheck(t *testing.T) { initialize() os.Setenv("MASTER_KEY", "secretkey") t.Run("NoNetwork", func(t *testing.T) { - networks, username, err := SecurityCheck(false, "", "Bearer secretkey") + networks, username, err := logic.UserPermissions(false, "", "Bearer secretkey") assert.Nil(t, err) t.Log(networks, username) }) t.Run("WithNetwork", func(t *testing.T) { - networks, username, err := SecurityCheck(false, "skynet", "Bearer secretkey") + networks, username, err := logic.UserPermissions(false, "skynet", "Bearer secretkey") assert.Nil(t, err) t.Log(networks, username) }) t.Run("BadNet", func(t *testing.T) { t.Skip() - networks, username, err := SecurityCheck(false, "badnet", "Bearer secretkey") + networks, username, err := logic.UserPermissions(false, "badnet", "Bearer secretkey") assert.NotNil(t, err) t.Log(err) t.Log(networks, username) }) t.Run("BadToken", func(t *testing.T) { - networks, username, err := SecurityCheck(false, "skynet", "Bearer badkey") + networks, username, err := logic.UserPermissions(false, "skynet", "Bearer badkey") assert.NotNil(t, err) t.Log(err) t.Log(networks, username) diff --git a/controllers/networkusers.go b/controllers/networkusers.go index 861ba0f4..8ac4e24f 100644 --- a/controllers/networkusers.go +++ b/controllers/networkusers.go @@ -14,13 +14,13 @@ import ( ) func networkUsersHandlers(r *mux.Router) { - r.HandleFunc("/api/networkusers", securityCheck(true, http.HandlerFunc(getAllNetworkUsers))).Methods("GET") - r.HandleFunc("/api/networkusers/{network}", securityCheck(true, http.HandlerFunc(getNetworkUsers))).Methods("GET") - r.HandleFunc("/api/networkusers/{network}/{networkuser}", securityCheck(true, http.HandlerFunc(getNetworkUser))).Methods("GET") - r.HandleFunc("/api/networkusers/{network}", securityCheck(true, http.HandlerFunc(createNetworkUser))).Methods("POST") - r.HandleFunc("/api/networkusers/{network}", securityCheck(true, http.HandlerFunc(updateNetworkUser))).Methods("PUT") - r.HandleFunc("/api/networkusers/data/{networkuser}/me", netUserSecurityCheck(false, false, http.HandlerFunc(getNetworkUserData))).Methods("GET") - r.HandleFunc("/api/networkusers/{network}/{networkuser}", securityCheck(true, http.HandlerFunc(deleteNetworkUser))).Methods("DELETE") + r.HandleFunc("/api/networkusers", logic.SecurityCheck(true, http.HandlerFunc(getAllNetworkUsers))).Methods("GET") + r.HandleFunc("/api/networkusers/{network}", logic.SecurityCheck(true, http.HandlerFunc(getNetworkUsers))).Methods("GET") + r.HandleFunc("/api/networkusers/{network}/{networkuser}", logic.SecurityCheck(true, http.HandlerFunc(getNetworkUser))).Methods("GET") + r.HandleFunc("/api/networkusers/{network}", logic.SecurityCheck(true, http.HandlerFunc(createNetworkUser))).Methods("POST") + r.HandleFunc("/api/networkusers/{network}", logic.SecurityCheck(true, http.HandlerFunc(updateNetworkUser))).Methods("PUT") + r.HandleFunc("/api/networkusers/data/{networkuser}/me", logic.NetUserSecurityCheck(false, false, http.HandlerFunc(getNetworkUserData))).Methods("GET") + r.HandleFunc("/api/networkusers/{network}/{networkuser}", logic.SecurityCheck(true, http.HandlerFunc(deleteNetworkUser))).Methods("DELETE") } // == RETURN TYPES == @@ -52,18 +52,18 @@ func getNetworkUserData(w http.ResponseWriter, r *http.Request) { networks, err := logic.GetNetworks() if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } if networkUserName == "" { - returnErrorResponse(w, r, formatError(errors.New("netuserToGet"), "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("netuserToGet"), "badrequest")) return } u, err := logic.GetUser(networkUserName) if err != nil { - returnErrorResponse(w, r, formatError(errors.New("could not find user"), "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("could not find user"), "badrequest")) return } @@ -151,7 +151,7 @@ func getAllNetworkUsers(w http.ResponseWriter, r *http.Request) { networks, err := logic.GetNetworks() if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -160,7 +160,7 @@ func getAllNetworkUsers(w http.ResponseWriter, r *http.Request) { for i := range networks { netusers, err := pro.GetNetworkUsers(networks[i].NetID) if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } for _, v := range netusers { @@ -181,13 +181,13 @@ func getNetworkUsers(w http.ResponseWriter, r *http.Request) { _, err := logic.GetNetwork(netname) if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } netusers, err := pro.GetNetworkUsers(netname) if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } w.WriteHeader(http.StatusOK) @@ -203,19 +203,19 @@ func getNetworkUser(w http.ResponseWriter, r *http.Request) { _, err := logic.GetNetwork(netname) if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } netuserToGet := params["networkuser"] if netuserToGet == "" { - returnErrorResponse(w, r, formatError(errors.New("netuserToGet"), "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("netuserToGet"), "badrequest")) return } netuser, err := pro.GetNetworkUser(netname, promodels.NetworkUserID(netuserToGet)) if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } w.WriteHeader(http.StatusOK) @@ -230,7 +230,7 @@ func createNetworkUser(w http.ResponseWriter, r *http.Request) { network, err := logic.GetNetwork(netname) if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } var networkuser promodels.NetworkUser @@ -238,13 +238,13 @@ func createNetworkUser(w http.ResponseWriter, r *http.Request) { // we decode our body request params err = json.NewDecoder(r.Body).Decode(&networkuser) if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } err = pro.CreateNetworkUser(&network, &networkuser) if err != nil { - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } @@ -260,7 +260,7 @@ func updateNetworkUser(w http.ResponseWriter, r *http.Request) { network, err := logic.GetNetwork(netname) if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } var networkuser promodels.NetworkUser @@ -268,38 +268,38 @@ func updateNetworkUser(w http.ResponseWriter, r *http.Request) { // we decode our body request params err = json.NewDecoder(r.Body).Decode(&networkuser) if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } if networkuser.ID == "" || !pro.DoesNetworkUserExist(netname, networkuser.ID) { - returnErrorResponse(w, r, formatError(errors.New("invalid user "+string(networkuser.ID)), "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid user "+string(networkuser.ID)), "badrequest")) return } if networkuser.AccessLevel < pro.NET_ADMIN || networkuser.AccessLevel > pro.NO_ACCESS { - returnErrorResponse(w, r, formatError(errors.New("invalid user access level provided"), "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid user access level provided"), "badrequest")) return } if networkuser.ClientLimit < 0 || networkuser.NodeLimit < 0 { - returnErrorResponse(w, r, formatError(errors.New("negative user limit provided"), "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("negative user limit provided"), "badrequest")) return } u, err := logic.GetUser(string(networkuser.ID)) if err != nil { - returnErrorResponse(w, r, formatError(errors.New("invalid user "+string(networkuser.ID)), "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid user "+string(networkuser.ID)), "badrequest")) return } if !pro.IsUserAllowed(&network, u.UserName, u.Groups) { - returnErrorResponse(w, r, formatError(errors.New("user must be in allowed groups or users"), "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("user must be in allowed groups or users"), "badrequest")) return } if networkuser.AccessLevel == pro.NET_ADMIN { currentUser, err := logic.GetUser(string(networkuser.ID)) if err != nil { - returnErrorResponse(w, r, formatError(errors.New("user model not found for "+string(networkuser.ID)), "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("user model not found for "+string(networkuser.ID)), "badrequest")) return } @@ -316,7 +316,7 @@ func updateNetworkUser(w http.ResponseWriter, r *http.Request) { UserName: currentUser.UserName, }, ); err != nil { - returnErrorResponse(w, r, formatError(errors.New("user model failed net admin update "+string(networkuser.ID)+" (are they an admin?"), "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("user model failed net admin update "+string(networkuser.ID)+" (are they an admin?"), "badrequest")) return } } @@ -324,7 +324,7 @@ func updateNetworkUser(w http.ResponseWriter, r *http.Request) { err = pro.UpdateNetworkUser(netname, &networkuser) if err != nil { - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } @@ -340,18 +340,18 @@ func deleteNetworkUser(w http.ResponseWriter, r *http.Request) { _, err := logic.GetNetwork(netname) if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } netuserToDelete := params["networkuser"] if netuserToDelete == "" { - returnErrorResponse(w, r, formatError(errors.New("no group name provided"), "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("no group name provided"), "badrequest")) return } if err := pro.DeleteNetworkUser(netname, netuserToDelete); err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } diff --git a/controllers/node.go b/controllers/node.go index 14075dbe..efc75d62 100644 --- a/controllers/node.go +++ b/controllers/node.go @@ -8,7 +8,6 @@ import ( "github.com/gorilla/mux" "github.com/gravitl/netmaker/database" - "github.com/gravitl/netmaker/functions" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic/pro" @@ -30,8 +29,8 @@ func nodeHandlers(r *mux.Router) { r.HandleFunc("/api/nodes/{network}/{nodeid}/deleterelay", authorize(false, true, "user", http.HandlerFunc(deleteRelay))).Methods("DELETE") r.HandleFunc("/api/nodes/{network}/{nodeid}/creategateway", authorize(false, true, "user", http.HandlerFunc(createEgressGateway))).Methods("POST") r.HandleFunc("/api/nodes/{network}/{nodeid}/deletegateway", authorize(false, true, "user", http.HandlerFunc(deleteEgressGateway))).Methods("DELETE") - r.HandleFunc("/api/nodes/{network}/{nodeid}/createingress", securityCheck(false, http.HandlerFunc(createIngressGateway))).Methods("POST") - r.HandleFunc("/api/nodes/{network}/{nodeid}/deleteingress", securityCheck(false, http.HandlerFunc(deleteIngressGateway))).Methods("DELETE") + r.HandleFunc("/api/nodes/{network}/{nodeid}/createingress", logic.SecurityCheck(false, http.HandlerFunc(createIngressGateway))).Methods("POST") + r.HandleFunc("/api/nodes/{network}/{nodeid}/deleteingress", logic.SecurityCheck(false, http.HandlerFunc(deleteIngressGateway))).Methods("DELETE") r.HandleFunc("/api/nodes/{network}/{nodeid}/approve", authorize(false, true, "user", http.HandlerFunc(uncordonNode))).Methods("POST") r.HandleFunc("/api/nodes/{network}", nodeauth(checkFreeTierLimits(node_l, http.HandlerFunc(createNode)))).Methods("POST") r.HandleFunc("/api/nodes/adm/{network}/lastmodified", authorize(false, true, "network", http.HandlerFunc(getLastModified))).Methods("GET") @@ -66,19 +65,19 @@ func authenticate(response http.ResponseWriter, request *http.Request) { errorResponse.Message = decoderErr.Error() logger.Log(0, request.Header.Get("user"), "error decoding request body: ", decoderErr.Error()) - returnErrorResponse(response, request, errorResponse) + logic.ReturnErrorResponse(response, request, errorResponse) return } else { errorResponse.Code = http.StatusBadRequest if authRequest.ID == "" { errorResponse.Message = "W1R3: ID can't be empty" logger.Log(0, request.Header.Get("user"), errorResponse.Message) - returnErrorResponse(response, request, errorResponse) + logic.ReturnErrorResponse(response, request, errorResponse) return } else if authRequest.Password == "" { errorResponse.Message = "W1R3: Password can't be empty" logger.Log(0, request.Header.Get("user"), errorResponse.Message) - returnErrorResponse(response, request, errorResponse) + logic.ReturnErrorResponse(response, request, errorResponse) return } else { var err error @@ -89,7 +88,7 @@ func authenticate(response http.ResponseWriter, request *http.Request) { errorResponse.Message = err.Error() logger.Log(0, request.Header.Get("user"), fmt.Sprintf("failed to get node info [%s]: %v", authRequest.ID, err)) - returnErrorResponse(response, request, errorResponse) + logic.ReturnErrorResponse(response, request, errorResponse) return } @@ -99,7 +98,7 @@ func authenticate(response http.ResponseWriter, request *http.Request) { errorResponse.Message = err.Error() logger.Log(0, request.Header.Get("user"), "error validating user password: ", err.Error()) - returnErrorResponse(response, request, errorResponse) + logic.ReturnErrorResponse(response, request, errorResponse) return } else { tokenString, err := logic.CreateJWT(authRequest.ID, authRequest.MacAddress, result.Network) @@ -109,7 +108,7 @@ func authenticate(response http.ResponseWriter, request *http.Request) { errorResponse.Message = "Could not create Token" logger.Log(0, request.Header.Get("user"), fmt.Sprintf("%s: %v", errorResponse.Message, err)) - returnErrorResponse(response, request, errorResponse) + logic.ReturnErrorResponse(response, request, errorResponse) return } @@ -128,7 +127,7 @@ func authenticate(response http.ResponseWriter, request *http.Request) { errorResponse.Message = err.Error() logger.Log(0, request.Header.Get("user"), "error marshalling resp: ", err.Error()) - returnErrorResponse(response, request, errorResponse) + logic.ReturnErrorResponse(response, request, errorResponse) return } response.WriteHeader(http.StatusOK) @@ -149,7 +148,7 @@ func nodeauth(next http.Handler) http.HandlerFunc { errorResponse := models.ErrorResponse{ Code: http.StatusUnauthorized, Message: "W1R3: You are unauthorized to access this endpoint.", } - returnErrorResponse(w, r, errorResponse) + logic.ReturnErrorResponse(w, r, errorResponse) return } else { token = tokenSplit[1] @@ -161,7 +160,7 @@ func nodeauth(next http.Handler) http.HandlerFunc { errorResponse := models.ErrorResponse{ Code: http.StatusNotFound, Message: "no networks", } - returnErrorResponse(w, r, errorResponse) + logic.ReturnErrorResponse(w, r, errorResponse) return } for _, network := range networks { @@ -177,7 +176,7 @@ func nodeauth(next http.Handler) http.HandlerFunc { errorResponse := models.ErrorResponse{ Code: http.StatusUnauthorized, Message: "You are unauthorized to access this endpoint.", } - returnErrorResponse(w, r, errorResponse) + logic.ReturnErrorResponse(w, r, errorResponse) return } next.ServeHTTP(w, r) @@ -194,16 +193,16 @@ func nodeauth(next http.Handler) http.HandlerFunc { func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var errorResponse = models.ErrorResponse{ - Code: http.StatusUnauthorized, Message: unauthorized_msg, + Code: http.StatusUnauthorized, Message: logic.Unauthorized_Msg, } var params = mux.Vars(r) - networkexists, _ := functions.NetworkExists(params["network"]) + networkexists, _ := logic.NetworkExists(params["network"]) //check that the request is for a valid network //if (networkCheck && !networkexists) || err != nil { if networkCheck && !networkexists { - returnErrorResponse(w, r, errorResponse) + logic.ReturnErrorResponse(w, r, errorResponse) return } else { w.Header().Set("Content-Type", "application/json") @@ -220,7 +219,7 @@ func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Ha if len(tokenSplit) > 1 { authToken = tokenSplit[1] } else { - returnErrorResponse(w, r, errorResponse) + logic.ReturnErrorResponse(w, r, errorResponse) return } //check if node instead of user @@ -236,7 +235,7 @@ func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Ha var nodeID = "" username, networks, isadmin, errN := logic.VerifyUserToken(authToken) if errN != nil { - returnErrorResponse(w, r, errorResponse) + logic.ReturnErrorResponse(w, r, errorResponse) return } @@ -269,7 +268,7 @@ func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Ha } else { node, err := logic.GetNodeByID(nodeID) if err != nil { - returnErrorResponse(w, r, errorResponse) + logic.ReturnErrorResponse(w, r, errorResponse) return } isAuthorized = (node.Network == params["network"]) @@ -287,7 +286,7 @@ func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Ha } } if !isAuthorized { - returnErrorResponse(w, r, errorResponse) + logic.ReturnErrorResponse(w, r, errorResponse) return } else { //If authorized, this function passes along it's request and output to the appropriate route function. @@ -324,7 +323,7 @@ func getNetworkNodes(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("error fetching nodes on network %s: %v", networkName, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -358,7 +357,7 @@ func getAllNodes(w http.ResponseWriter, r *http.Request) { if err != nil && r.Header.Get("ismasterkey") != "yes" { logger.Log(0, r.Header.Get("user"), "error fetching user info: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } var nodes []models.Node @@ -366,7 +365,7 @@ func getAllNodes(w http.ResponseWriter, r *http.Request) { nodes, err = logic.GetAllNodes() if err != nil { logger.Log(0, "error fetching all nodes info: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } } else { @@ -374,7 +373,7 @@ func getAllNodes(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), "error fetching nodes: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } } @@ -418,7 +417,7 @@ func getNode(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("error fetching node [ %s ] info: %v", nodeid, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -426,7 +425,7 @@ func getNode(w http.ResponseWriter, r *http.Request) { if err != nil && !database.IsEmptyRecord(err) { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("error fetching wg peers config for node [ %s ]: %v", nodeid, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -470,7 +469,7 @@ func getLastModified(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("error fetching network [%s] info: %v", networkName, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } logger.Log(2, r.Header.Get("user"), "called last modified") @@ -498,12 +497,12 @@ func createNode(w http.ResponseWriter, r *http.Request) { Code: http.StatusInternalServerError, Message: "W1R3: It's not you it's me.", } networkName := params["network"] - networkexists, err := functions.NetworkExists(networkName) + networkexists, err := logic.NetworkExists(networkName) if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to fetch network [%s] info: %v", networkName, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } else if !networkexists { errorResponse = models.ErrorResponse{ @@ -511,7 +510,7 @@ func createNode(w http.ResponseWriter, r *http.Request) { } logger.Log(0, r.Header.Get("user"), fmt.Sprintf("network [%s] does not exist", networkName)) - returnErrorResponse(w, r, errorResponse) + logic.ReturnErrorResponse(w, r, errorResponse) return } @@ -521,7 +520,7 @@ func createNode(w http.ResponseWriter, r *http.Request) { err = json.NewDecoder(r.Body).Decode(&node) if err != nil { logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } @@ -531,14 +530,14 @@ func createNode(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get network [%s] info: %v", node.Network, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } node.NetworkSettings, err = logic.GetNetworkSettings(node.Network) if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get network [%s] settings: %v", node.Network, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } keyName, validKey := logic.IsKeyValid(networkName, node.AccessKey) @@ -554,7 +553,7 @@ func createNode(w http.ResponseWriter, r *http.Request) { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to create node on network [%s]: %s", node.Network, errorResponse.Message)) - returnErrorResponse(w, r, errorResponse) + logic.ReturnErrorResponse(w, r, errorResponse) return } } @@ -569,17 +568,17 @@ func createNode(w http.ResponseWriter, r *http.Request) { key, keyErr := logic.RetrievePublicTrafficKey() if keyErr != nil { logger.Log(0, "error retrieving key: ", keyErr.Error()) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } if key == nil { logger.Log(0, "error: server traffic key is nil") - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } if node.TrafficKeys.Mine == nil { logger.Log(0, "error: node traffic key is nil") - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } node.TrafficKeys = models.TrafficKeys{ @@ -592,7 +591,7 @@ func createNode(w http.ResponseWriter, r *http.Request) { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to create node on network [%s]: %s", node.Network, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -609,7 +608,7 @@ func createNode(w http.ResponseWriter, r *http.Request) { if !updatedUserNode { // user was found but not updated, so delete node logger.Log(0, "failed to add node to user", keyName) logic.DeleteNodeByID(&node, true) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } } @@ -618,7 +617,7 @@ func createNode(w http.ResponseWriter, r *http.Request) { if err != nil && !database.IsEmptyRecord(err) { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("error fetching wg peers config for node [ %s ]: %v", node.ID, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -656,7 +655,7 @@ func uncordonNode(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to uncordon node [%s]: %v", node.Name, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } logger.Log(1, r.Header.Get("user"), "uncordoned node", node.Name) @@ -686,7 +685,7 @@ func createEgressGateway(w http.ResponseWriter, r *http.Request) { err := json.NewDecoder(r.Body).Decode(&gateway) if err != nil { logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } gateway.NetID = params["network"] @@ -696,7 +695,7 @@ func createEgressGateway(w http.ResponseWriter, r *http.Request) { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to create egress gateway on node [%s] on network [%s]: %v", gateway.NodeID, gateway.NetID, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -728,7 +727,7 @@ func deleteEgressGateway(w http.ResponseWriter, r *http.Request) { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to delete egress gateway on node [%s] on network [%s]: %v", nodeid, netid, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -762,7 +761,7 @@ func createIngressGateway(w http.ResponseWriter, r *http.Request) { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to create ingress gateway on node [%s] on network [%s]: %v", nodeid, netid, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -794,7 +793,7 @@ func deleteIngressGateway(w http.ResponseWriter, r *http.Request) { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to delete ingress gateway on node [%s] on network [%s]: %v", nodeid, netid, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -828,7 +827,7 @@ func updateNode(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("error fetching node [ %s ] info: %v", nodeid, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -837,7 +836,7 @@ func updateNode(w http.ResponseWriter, r *http.Request) { err = json.NewDecoder(r.Body).Decode(&newNode) if err != nil { logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } relayupdate := false @@ -885,7 +884,7 @@ func updateNode(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to update node info [ %s ] info: %v", nodeid, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } if relayupdate { @@ -932,20 +931,20 @@ func deleteNode(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("error fetching node [ %s ] info: %v", nodeid, err)) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } if isServer(&node) { err := fmt.Errorf("cannot delete server node") logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to delete node [ %s ]: %v", nodeid, err)) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } if r.Header.Get("ismaster") != "yes" { username := r.Header.Get("user") if username != "" && !doesUserOwnNode(username, params["network"], nodeid) { - returnErrorResponse(w, r, formatError(fmt.Errorf("user not permitted"), "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("user not permitted"), "badrequest")) return } } @@ -954,11 +953,11 @@ func deleteNode(w http.ResponseWriter, r *http.Request) { err = logic.DeleteNodeByID(&node, false) if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - returnSuccessResponse(w, r, nodeid+" deleted.") + logic.ReturnSuccessResponse(w, r, nodeid+" deleted.") logger.Log(1, r.Header.Get("user"), "Deleted node", nodeid, "from network", params["network"]) runUpdates(&node, false) diff --git a/controllers/relay.go b/controllers/relay.go index 28b95506..9f724e36 100644 --- a/controllers/relay.go +++ b/controllers/relay.go @@ -30,7 +30,7 @@ func createRelay(w http.ResponseWriter, r *http.Request) { err := json.NewDecoder(r.Body).Decode(&relay) if err != nil { logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } relay.NetID = params["network"] @@ -39,7 +39,7 @@ func createRelay(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to create relay on node [%s] on network [%s]: %v", relay.NodeID, relay.NetID, err)) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } logger.Log(1, r.Header.Get("user"), "created relay on node", relay.NodeID, "on network", relay.NetID) @@ -73,7 +73,7 @@ func deleteRelay(w http.ResponseWriter, r *http.Request) { updatenodes, node, err := logic.DeleteRelay(netid, nodeid) if err != nil { logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } logger.Log(1, r.Header.Get("user"), "deleted relay server", nodeid, "on network", netid) diff --git a/controllers/response_test.go b/controllers/response_test.go index b3948b4a..f2d8fd4d 100644 --- a/controllers/response_test.go +++ b/controllers/response_test.go @@ -7,12 +7,13 @@ import ( "net/http/httptest" "testing" + "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/stretchr/testify/assert" ) func TestFormatError(t *testing.T) { - response := formatError(errors.New("this is a sample error"), "badrequest") + response := logic.FormatError(errors.New("this is a sample error"), "badrequest") assert.Equal(t, http.StatusBadRequest, response.Code) assert.Equal(t, "this is a sample error", response.Message) } @@ -20,7 +21,7 @@ func TestFormatError(t *testing.T) { func TestReturnSuccessResponse(t *testing.T) { var response models.SuccessResponse handler := func(rw http.ResponseWriter, r *http.Request) { - returnSuccessResponse(rw, r, "This is a test message") + logic.ReturnSuccessResponse(rw, r, "This is a test message") } req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) w := httptest.NewRecorder() @@ -42,7 +43,7 @@ func TestReturnErrorResponse(t *testing.T) { errMessage.Code = http.StatusUnauthorized errMessage.Message = "You are not authorized to access this endpoint" handler := func(rw http.ResponseWriter, r *http.Request) { - returnErrorResponse(rw, r, errMessage) + logic.ReturnErrorResponse(rw, r, errMessage) } req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) w := httptest.NewRecorder() diff --git a/controllers/server.go b/controllers/server.go index 94e2d366..3c72198c 100644 --- a/controllers/server.go +++ b/controllers/server.go @@ -10,7 +10,6 @@ import ( "strings" "github.com/gorilla/mux" - "github.com/gravitl/netmaker/ee" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" @@ -22,82 +21,35 @@ import ( func serverHandlers(r *mux.Router) { // r.HandleFunc("/api/server/addnetwork/{network}", securityCheckServer(true, http.HandlerFunc(addNetwork))).Methods("POST") - r.HandleFunc("/api/server/getconfig", securityCheckServer(false, http.HandlerFunc(getConfig))).Methods("GET") - r.HandleFunc("/api/server/removenetwork/{network}", securityCheckServer(true, http.HandlerFunc(removeNetwork))).Methods("DELETE") + r.HandleFunc("/api/server/getconfig", allowUsers(http.HandlerFunc(getConfig))).Methods("GET") r.HandleFunc("/api/server/register", authorize(true, false, "node", http.HandlerFunc(register))).Methods("POST") r.HandleFunc("/api/server/getserverinfo", authorize(true, false, "node", http.HandlerFunc(getServerInfo))).Methods("GET") } -//Security check is middleware for every function and just checks to make sure that its the master calling -//Only admin should have access to all these network-level actions -//or maybe some Users once implemented -func securityCheckServer(adminonly bool, next http.Handler) http.HandlerFunc { +// allowUsers - allow all authenticated (valid) users - only used by getConfig, may be able to remove during refactor +func allowUsers(next http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var errorResponse = models.ErrorResponse{ - Code: http.StatusInternalServerError, Message: "W1R3: It's not you it's me.", + Code: http.StatusInternalServerError, Message: logic.Unauthorized_Msg, } - bearerToken := r.Header.Get("Authorization") - var tokenSplit = strings.Split(bearerToken, " ") var authToken = "" if len(tokenSplit) < 2 { - errorResponse = models.ErrorResponse{ - Code: http.StatusUnauthorized, Message: "W1R3: You are unauthorized to access this endpoint.", - } - returnErrorResponse(w, r, errorResponse) + logic.ReturnErrorResponse(w, r, errorResponse) return } else { authToken = tokenSplit[1] } - //all endpoints here require master so not as complicated - //still might not be a good way of doing this - user, _, isadmin, err := logic.VerifyUserToken(authToken) - errorResponse = models.ErrorResponse{ - Code: http.StatusUnauthorized, Message: "W1R3: You are unauthorized to access this endpoint.", - } - if !adminonly && (err != nil || user == "") { - returnErrorResponse(w, r, errorResponse) - return - } - if adminonly && !isadmin && !authenticateMaster(authToken) { - returnErrorResponse(w, r, errorResponse) + user, _, _, err := logic.VerifyUserToken(authToken) + if err != nil || user == "" { + logic.ReturnErrorResponse(w, r, errorResponse) return } next.ServeHTTP(w, r) } } -// swagger:route DELETE /api/server/removenetwork/{network} nodes removeNetwork -// -// Remove a network from the server. -// -// Schemes: https -// -// Security: -// oauth -// -// Responses: -// 200: stringJSONResponse -func removeNetwork(w http.ResponseWriter, r *http.Request) { - // Set header - w.Header().Set("Content-Type", "application/json") - - // get params - var params = mux.Vars(r) - network := params["network"] - err := logic.DeleteNetwork(network) - if err != nil { - logger.Log(0, r.Header.Get("user"), - fmt.Sprintf("failed to delete network [%s]: %v", network, err)) - json.NewEncoder(w).Encode(fmt.Sprintf("could not remove network %s from server", network)) - return - } - logger.Log(1, r.Header.Get("user"), - fmt.Sprintf("deleted network [%s]: %v", network, err)) - json.NewEncoder(w).Encode(fmt.Sprintf("network %s removed from server", network)) -} - // swagger:route GET /api/server/getserverinfo nodes getServerInfo // // Get the server configuration. @@ -138,7 +90,7 @@ func getConfig(w http.ResponseWriter, r *http.Request) { scfg := servercfg.GetServerConfig() scfg.IsEE = "no" - if ee.IsEnterprise() { + if logic.Is_EE { scfg.IsEE = "yes" } json.NewEncoder(w).Encode(scfg) @@ -166,7 +118,7 @@ func register(w http.ResponseWriter, r *http.Request) { errorResponse := models.ErrorResponse{ Code: http.StatusBadRequest, Message: err.Error(), } - returnErrorResponse(w, r, errorResponse) + logic.ReturnErrorResponse(w, r, errorResponse) return } cert, ca, err := genCerts(&request.Key, &request.CommonName) @@ -175,7 +127,7 @@ func register(w http.ResponseWriter, r *http.Request) { errorResponse := models.ErrorResponse{ Code: http.StatusNotFound, Message: err.Error(), } - returnErrorResponse(w, r, errorResponse) + logic.ReturnErrorResponse(w, r, errorResponse) return } //x509.Certificate.PublicKey is an interface therefore json encoding/decoding result in a string value rather than a []byte diff --git a/controllers/user.go b/controllers/user.go index 9c6dbb57..f53ab2f1 100644 --- a/controllers/user.go +++ b/controllers/user.go @@ -25,13 +25,13 @@ func userHandlers(r *mux.Router) { r.HandleFunc("/api/users/adm/hasadmin", hasAdmin).Methods("GET") r.HandleFunc("/api/users/adm/createadmin", createAdmin).Methods("POST") r.HandleFunc("/api/users/adm/authenticate", authenticateUser).Methods("POST") - r.HandleFunc("/api/users/{username}", securityCheck(false, continueIfUserMatch(http.HandlerFunc(updateUser)))).Methods("PUT") - r.HandleFunc("/api/users/networks/{username}", securityCheck(true, http.HandlerFunc(updateUserNetworks))).Methods("PUT") - r.HandleFunc("/api/users/{username}/adm", securityCheck(true, http.HandlerFunc(updateUserAdm))).Methods("PUT") - r.HandleFunc("/api/users/{username}", securityCheck(true, checkFreeTierLimits(users_l, http.HandlerFunc(createUser)))).Methods("POST") - r.HandleFunc("/api/users/{username}", securityCheck(true, http.HandlerFunc(deleteUser))).Methods("DELETE") - r.HandleFunc("/api/users/{username}", securityCheck(false, continueIfUserMatch(http.HandlerFunc(getUser)))).Methods("GET") - r.HandleFunc("/api/users", securityCheck(true, http.HandlerFunc(getUsers))).Methods("GET") + r.HandleFunc("/api/users/{username}", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(updateUser)))).Methods("PUT") + r.HandleFunc("/api/users/networks/{username}", logic.SecurityCheck(true, http.HandlerFunc(updateUserNetworks))).Methods("PUT") + r.HandleFunc("/api/users/{username}/adm", logic.SecurityCheck(true, http.HandlerFunc(updateUserAdm))).Methods("PUT") + r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, checkFreeTierLimits(users_l, http.HandlerFunc(createUser)))).Methods("POST") + r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, http.HandlerFunc(deleteUser))).Methods("DELETE") + r.HandleFunc("/api/users/{username}", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUser)))).Methods("GET") + r.HandleFunc("/api/users", logic.SecurityCheck(true, http.HandlerFunc(getUsers))).Methods("GET") r.HandleFunc("/api/oauth/login", auth.HandleAuthLogin).Methods("GET") r.HandleFunc("/api/oauth/callback", auth.HandleAuthCallback).Methods("GET") r.HandleFunc("/api/oauth/node-handler", socketHandler) @@ -59,7 +59,7 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) { } if !servercfg.IsBasicAuthEnabled() { - returnErrorResponse(response, request, formatError(fmt.Errorf("basic auth is disabled"), "badrequest")) + logic.ReturnErrorResponse(response, request, logic.FormatError(fmt.Errorf("basic auth is disabled"), "badrequest")) return } @@ -69,7 +69,7 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) { if decoderErr != nil { logger.Log(0, "error decoding request body: ", decoderErr.Error()) - returnErrorResponse(response, request, errorResponse) + logic.ReturnErrorResponse(response, request, errorResponse) return } username := authRequest.UserName @@ -77,14 +77,14 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) { if err != nil { logger.Log(0, username, "user validation failed: ", err.Error()) - returnErrorResponse(response, request, formatError(err, "badrequest")) + logic.ReturnErrorResponse(response, request, logic.FormatError(err, "badrequest")) return } if jwt == "" { // very unlikely that err is !nil and no jwt returned, but handle it anyways. logger.Log(0, username, "jwt token is empty") - returnErrorResponse(response, request, formatError(errors.New("no token returned"), "internal")) + logic.ReturnErrorResponse(response, request, logic.FormatError(errors.New("no token returned"), "internal")) return } @@ -102,7 +102,7 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) { if jsonError != nil { logger.Log(0, username, "error marshalling resp: ", err.Error()) - returnErrorResponse(response, request, errorResponse) + logic.ReturnErrorResponse(response, request, errorResponse) return } logger.Log(2, username, "was authenticated") @@ -128,7 +128,7 @@ func hasAdmin(w http.ResponseWriter, r *http.Request) { hasadmin, err := logic.HasAdmin() if err != nil { logger.Log(0, "failed to check for admin: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -171,7 +171,7 @@ func getUser(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, usernameFetched, "failed to fetch user: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } logger.Log(2, r.Header.Get("user"), "fetched user", usernameFetched) @@ -197,7 +197,7 @@ func getUsers(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, "failed to fetch users: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -226,12 +226,12 @@ func createAdmin(w http.ResponseWriter, r *http.Request) { logger.Log(0, admin.UserName, "error decoding request body: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } if !servercfg.IsBasicAuthEnabled() { - returnErrorResponse(w, r, formatError(fmt.Errorf("basic auth is disabled"), "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("basic auth is disabled"), "badrequest")) return } @@ -239,7 +239,7 @@ func createAdmin(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, admin.UserName, "failed to create admin: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } @@ -266,7 +266,7 @@ func createUser(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, user.UserName, "error decoding request body: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } @@ -274,7 +274,7 @@ func createUser(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, user.UserName, "error creating new user: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } logger.Log(1, user.UserName, "was created") @@ -302,7 +302,7 @@ func updateUserNetworks(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, username, "failed to update user networks: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } var userchange models.User @@ -311,7 +311,7 @@ func updateUserNetworks(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, username, "error decoding request body: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } err = logic.UpdateUserNetworks(userchange.Networks, userchange.Groups, userchange.IsAdmin, &models.ReturnUser{ @@ -324,7 +324,7 @@ func updateUserNetworks(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, username, "failed to update user networks: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } logger.Log(1, username, "status was updated") @@ -352,13 +352,13 @@ func updateUser(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, username, "failed to update user info: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } if auth.IsOauthUser(&user) == nil { err := fmt.Errorf("cannot update user info for oauth user %s", username) logger.Log(0, err.Error()) - returnErrorResponse(w, r, formatError(err, "forbidden")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "forbidden")) return } var userchange models.User @@ -367,7 +367,7 @@ func updateUser(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, username, "error decoding request body: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } userchange.Networks = nil @@ -375,7 +375,7 @@ func updateUser(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, username, "failed to update user info: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } logger.Log(1, username, "was updated") @@ -401,13 +401,13 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) { username := params["username"] user, err := GetUserInternal(username) if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } if auth.IsOauthUser(&user) != nil { err := fmt.Errorf("cannot update user info for oauth user %s", username) logger.Log(0, err.Error()) - returnErrorResponse(w, r, formatError(err, "forbidden")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "forbidden")) return } var userchange models.User @@ -416,18 +416,18 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, username, "error decoding request body: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } if !user.IsAdmin { logger.Log(0, username, "not an admin user") - returnErrorResponse(w, r, formatError(errors.New("not a admin user"), "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("not a admin user"), "badrequest")) } user, err = logic.UpdateUser(userchange, user) if err != nil { logger.Log(0, username, "failed to update user (admin) info: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } logger.Log(1, username, "was updated (admin)") @@ -458,12 +458,12 @@ func deleteUser(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(0, username, "failed to delete user: ", err.Error()) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } else if !success { err := errors.New("delete unsuccessful") logger.Log(0, username, err.Error()) - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } diff --git a/controllers/usergroups.go b/controllers/usergroups.go index a73dc1f8..4ade6f29 100644 --- a/controllers/usergroups.go +++ b/controllers/usergroups.go @@ -3,18 +3,20 @@ package controller import ( "encoding/json" "errors" - "github.com/gravitl/netmaker/logger" "net/http" + "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/logic" + "github.com/gorilla/mux" "github.com/gravitl/netmaker/logic/pro" "github.com/gravitl/netmaker/models/promodels" ) func userGroupsHandlers(r *mux.Router) { - r.HandleFunc("/api/usergroups", securityCheck(true, http.HandlerFunc(getUserGroups))).Methods("GET") - r.HandleFunc("/api/usergroups/{usergroup}", securityCheck(true, http.HandlerFunc(createUserGroup))).Methods("POST") - r.HandleFunc("/api/usergroups/{usergroup}", securityCheck(true, http.HandlerFunc(deleteUserGroup))).Methods("DELETE") + r.HandleFunc("/api/usergroups", logic.SecurityCheck(true, http.HandlerFunc(getUserGroups))).Methods("GET") + r.HandleFunc("/api/usergroups/{usergroup}", logic.SecurityCheck(true, http.HandlerFunc(createUserGroup))).Methods("POST") + r.HandleFunc("/api/usergroups/{usergroup}", logic.SecurityCheck(true, http.HandlerFunc(deleteUserGroup))).Methods("DELETE") } func getUserGroups(w http.ResponseWriter, r *http.Request) { @@ -23,7 +25,7 @@ func getUserGroups(w http.ResponseWriter, r *http.Request) { userGroups, err := pro.GetUserGroups() if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } // Returns all the groups in JSON format @@ -39,13 +41,13 @@ func createUserGroup(w http.ResponseWriter, r *http.Request) { logger.Log(1, r.Header.Get("user"), "requested creating user group", newGroup) if newGroup == "" { - returnErrorResponse(w, r, formatError(errors.New("no group name provided"), "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("no group name provided"), "badrequest")) return } err := pro.InsertUserGroup(promodels.UserGroupName(newGroup)) if err != nil { - returnErrorResponse(w, r, formatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } @@ -58,12 +60,12 @@ func deleteUserGroup(w http.ResponseWriter, r *http.Request) { logger.Log(1, r.Header.Get("user"), "requested deleting user group", groupToDelete) if groupToDelete == "" { - returnErrorResponse(w, r, formatError(errors.New("no group name provided"), "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("no group name provided"), "badrequest")) return } if err := pro.DeleteUserGroup(promodels.UserGroupName(groupToDelete)); err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } diff --git a/controllers/metrics.go b/ee/ee_controllers/metrics.go similarity index 81% rename from controllers/metrics.go rename to ee/ee_controllers/metrics.go index 1c08350f..409682a2 100644 --- a/controllers/metrics.go +++ b/ee/ee_controllers/metrics.go @@ -1,4 +1,4 @@ -package controller +package ee_controllers import ( "encoding/json" @@ -10,10 +10,11 @@ import ( "github.com/gravitl/netmaker/models" ) -func metricHandlers(r *mux.Router) { - r.HandleFunc("/api/metrics/{network}/{nodeid}", securityCheck(true, http.HandlerFunc(getNodeMetrics))).Methods("GET") - r.HandleFunc("/api/metrics/{network}", securityCheck(true, http.HandlerFunc(getNetworkNodesMetrics))).Methods("GET") - r.HandleFunc("/api/metrics", securityCheck(true, http.HandlerFunc(getAllMetrics))).Methods("GET") +// MetricHandlers - How we handle EE Metrics +func MetricHandlers(r *mux.Router) { + r.HandleFunc("/api/metrics/{network}/{nodeid}", logic.SecurityCheck(true, http.HandlerFunc(getNodeMetrics))).Methods("GET") + r.HandleFunc("/api/metrics/{network}", logic.SecurityCheck(true, http.HandlerFunc(getNetworkNodesMetrics))).Methods("GET") + r.HandleFunc("/api/metrics", logic.SecurityCheck(true, http.HandlerFunc(getAllMetrics))).Methods("GET") } // get the metrics of a given node @@ -28,7 +29,7 @@ func getNodeMetrics(w http.ResponseWriter, r *http.Request) { metrics, err := logic.GetMetrics(nodeID) if err != nil { logger.Log(1, r.Header.Get("user"), "failed to fetch metrics of node", nodeID, err.Error()) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -49,7 +50,7 @@ func getNetworkNodesMetrics(w http.ResponseWriter, r *http.Request) { networkNodes, err := logic.GetNetworkNodes(network) if err != nil { logger.Log(1, r.Header.Get("user"), "failed to fetch metrics of all nodes in network", network, err.Error()) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -79,7 +80,7 @@ func getAllMetrics(w http.ResponseWriter, r *http.Request) { allNodes, err := logic.GetAllNodes() if err != nil { logger.Log(1, r.Header.Get("user"), "failed to fetch metrics of all nodes on server", err.Error()) - returnErrorResponse(w, r, formatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } diff --git a/ee/initialize.go b/ee/initialize.go new file mode 100644 index 00000000..665e7729 --- /dev/null +++ b/ee/initialize.go @@ -0,0 +1,54 @@ +//go:build ee +// +build ee + +package ee + +import ( + controller "github.com/gravitl/netmaker/controllers" + "github.com/gravitl/netmaker/ee/ee_controllers" + "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/models" +) + +// InitEE - Initialize EE Logic +func InitEE() { + SetIsEnterprise() + models.SetLogo(retrieveEELogo()) + controller.HttpHandlers = append(controller.HttpHandlers, ee_controllers.MetricHandlers) + logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() { + // == License Handling == + ValidateLicense() + if Limits.FreeTier { + logger.Log(0, "proceeding with Free Tier license") + } else { + logger.Log(0, "proceeding with Paid Tier license") + } + // == End License Handling == + AddLicenseHooks() + }) +} + +func setControllerLimits() { + logic.Node_Limit = Limits.Nodes + logic.Users_Limit = Limits.Users + logic.Clients_Limit = Limits.Clients + logic.Free_Tier = Limits.FreeTier + logic.Is_EE = true +} + +func retrieveEELogo() string { + return ` + __ __ ______ ______ __ __ ______ __ __ ______ ______ +/\ "-.\ \ /\ ___\ /\__ _\ /\ "-./ \ /\ __ \ /\ \/ / /\ ___\ /\ == \ +\ \ \-. \ \ \ __\ \/_/\ \/ \ \ \-./\ \ \ \ __ \ \ \ _"-. \ \ __\ \ \ __< + \ \_\\"\_\ \ \_____\ \ \_\ \ \_\ \ \_\ \ \_\ \_\ \ \_\ \_\ \ \_____\ \ \_\ \_\ + \/_/ \/_/ \/_____/ \/_/ \/_/ \/_/ \/_/\/_/ \/_/\/_/ \/_____/ \/_/ /_/ + + ___ ___ ____ + ____ ____ ____ / _ \ / _ \ / __ \ ____ ____ ____ + /___/ /___/ /___/ / ___/ / , _// /_/ / /___/ /___/ /___/ + /___/ /___/ /___/ /_/ /_/|_| \____/ /___/ /___/ /___/ + +` +} diff --git a/ee/license.go b/ee/license.go index f34e50f8..ff4d291c 100644 --- a/ee/license.go +++ b/ee/license.go @@ -1,7 +1,11 @@ +//go:build ee +// +build ee + package ee import ( "bytes" + "crypto/rand" "encoding/json" "fmt" "io/ioutil" @@ -11,11 +15,20 @@ import ( "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" - "github.com/gravitl/netmaker/logic/pro" "github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/servercfg" + "golang.org/x/crypto/nacl/box" ) +const ( + db_license_key = "netmaker-id-key-pair" +) + +type apiServerConf struct { + PrivateKey []byte `json:"private_key" binding:"required"` + PublicKey []byte `json:"public_key" binding:"required"` +} + // AddLicenseHooks - adds the validation and cache clear hooks func AddLicenseHooks() { logic.AddHook(ValidateLicense) @@ -39,7 +52,7 @@ func ValidateLicense() error { logger.FatalLog(errValidation.Error()) } - tempPubKey, tempPrivKey, err := pro.FetchApiServerKeys() + tempPubKey, tempPrivKey, err := FetchApiServerKeys() if err != nil { logger.FatalLog(errValidation.Error()) } @@ -88,11 +101,59 @@ func ValidateLicense() error { if Limits.FreeTier { Limits.Networks = 3 } + setControllerLimits() logger.Log(0, "License validation succeeded!") return nil } +// FetchApiServerKeys - fetches netmaker license keys for identification +// as well as secure communication with API +// if none present, it generates a new pair +func FetchApiServerKeys() (pub *[32]byte, priv *[32]byte, err error) { + var returnData = apiServerConf{} + currentData, err := database.FetchRecord(database.SERVERCONF_TABLE_NAME, db_license_key) + if err != nil && !database.IsEmptyRecord(err) { + return nil, nil, err + } else if database.IsEmptyRecord(err) { // need to generate a new identifier pair + pub, priv, err = box.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, err + } + pubBytes, err := ncutils.ConvertKeyToBytes(pub) + if err != nil { + return nil, nil, err + } + privBytes, err := ncutils.ConvertKeyToBytes(priv) + if err != nil { + return nil, nil, err + } + returnData.PrivateKey = privBytes + returnData.PublicKey = pubBytes + record, err := json.Marshal(&returnData) + if err != nil { + return nil, nil, err + } + if err = database.Insert(db_license_key, string(record), database.SERVERCONF_TABLE_NAME); err != nil { + return nil, nil, err + } + } else { + if err = json.Unmarshal([]byte(currentData), &returnData); err != nil { + return nil, nil, err + } + priv, err = ncutils.ConvertBytesToKey(returnData.PrivateKey) + if err != nil { + return nil, nil, err + } + pub, err = ncutils.ConvertBytesToKey(returnData.PublicKey) + if err != nil { + return nil, nil, err + } + } + + return pub, priv, nil +} + func getLicensePublicKey(licensePubKeyEncoded string) (*[32]byte, error) { decodedPubKey := base64decode(licensePubKeyEncoded) return ncutils.ConvertBytesToKey(decodedPubKey) @@ -179,32 +240,6 @@ func ClearLicenseCache() error { return database.DeleteRecord(database.CACHE_TABLE_NAME, license_cache_key) } -// AddServerIDIfNotPresent - add's current server ID to DB if not present -func AddServerIDIfNotPresent() error { - currentNodeID := servercfg.GetNodeID() - currentServerIDs := serverIDs{} - - record, err := database.FetchRecord(database.SERVERCONF_TABLE_NAME, server_id_key) - if err != nil && !database.IsEmptyRecord(err) { - return err - } else if err == nil { - if err = json.Unmarshal([]byte(record), ¤tServerIDs); err != nil { - return err - } - } - - if !logic.StringSliceContains(currentServerIDs.ServerIDs, currentNodeID) { - currentServerIDs.ServerIDs = append(currentServerIDs.ServerIDs, currentNodeID) - data, err := json.Marshal(¤tServerIDs) - if err != nil { - return err - } - return database.Insert(server_id_key, string(data), database.SERVERCONF_TABLE_NAME) - } - - return nil -} - func getServerCount() int { if record, err := database.FetchRecord(database.SERVERCONF_TABLE_NAME, server_id_key); err == nil { currentServerIDs := serverIDs{} diff --git a/ee/util.go b/ee/util.go index 26e6262f..90705a0b 100644 --- a/ee/util.go +++ b/ee/util.go @@ -49,6 +49,6 @@ func getCurrentServerLimit() (limits LicenseLimits) { if err == nil { limits.Users = len(users) } - limits.Servers = getServerCount() + limits.Servers = logic.GetServerCount() return } diff --git a/functions/helpers.go b/functions/helpers.go index e455d088..d45d3292 100644 --- a/functions/helpers.go +++ b/functions/helpers.go @@ -8,17 +8,6 @@ import ( "github.com/gravitl/netmaker/models" ) -// NetworkExists - check if network exists -func NetworkExists(name string) (bool, error) { - - var network string - var err error - if network, err = database.FetchRecord(database.NETWORKS_TABLE_NAME, name); err != nil { - return false, err - } - return len(network) > 0, nil -} - // NameInDNSCharSet - name in dns char set func NameInDNSCharSet(name string) bool { diff --git a/functions/helpers_test.go b/functions/helpers_test.go index 601747c3..e2737f48 100644 --- a/functions/helpers_test.go +++ b/functions/helpers_test.go @@ -26,7 +26,7 @@ func TestNetworkExists(t *testing.T) { } database.DeleteRecord(database.NETWORKS_TABLE_NAME, testNetwork.NetID) defer database.CloseDB() - exists, err := NetworkExists(testNetwork.NetID) + exists, err := logic.NetworkExists(testNetwork.NetID) if err == nil { t.Fatalf("expected error, received nil") } @@ -38,7 +38,7 @@ func TestNetworkExists(t *testing.T) { if err != nil { t.Fatalf("failed to save test network in databse: %s", err) } - exists, err = NetworkExists(testNetwork.NetID) + exists, err = logic.NetworkExists(testNetwork.NetID) if err != nil { t.Fatalf("expected nil, received err: %s", err) } diff --git a/logic/auth.go b/logic/auth.go index f282d534..fb99ba7e 100644 --- a/logic/auth.go +++ b/logic/auth.go @@ -99,7 +99,7 @@ func CreateUser(user models.User) (models.User, error) { tokenString, _ := CreateProUserJWT(user.UserName, user.Networks, user.Groups, user.IsAdmin) if tokenString == "" { - // returnErrorResponse(w, r, errorResponse) + // logic.ReturnErrorResponse(w, r, errorResponse) return user, err } diff --git a/controllers/response.go b/logic/errors.go similarity index 77% rename from controllers/response.go rename to logic/errors.go index 5783c002..8259d586 100644 --- a/controllers/response.go +++ b/logic/errors.go @@ -1,4 +1,4 @@ -package controller +package logic import ( "encoding/json" @@ -8,7 +8,8 @@ import ( "github.com/gravitl/netmaker/models" ) -func formatError(err error, errType string) models.ErrorResponse { +// FormatError - takes ErrorResponse and uses correct code +func FormatError(err error, errType string) models.ErrorResponse { var status = http.StatusInternalServerError switch errType { @@ -33,7 +34,8 @@ func formatError(err error, errType string) models.ErrorResponse { return response } -func returnSuccessResponse(response http.ResponseWriter, request *http.Request, message string) { +// ReturnSuccessResponse - processes message and adds header +func ReturnSuccessResponse(response http.ResponseWriter, request *http.Request, message string) { var httpResponse models.SuccessResponse httpResponse.Code = http.StatusOK httpResponse.Message = message @@ -42,7 +44,8 @@ func returnSuccessResponse(response http.ResponseWriter, request *http.Request, json.NewEncoder(response).Encode(httpResponse) } -func returnErrorResponse(response http.ResponseWriter, request *http.Request, errorMessage models.ErrorResponse) { +// ReturnErrorResponse - processes error and adds header +func ReturnErrorResponse(response http.ResponseWriter, request *http.Request, errorMessage models.ErrorResponse) { httpResponse := &models.ErrorResponse{Code: errorMessage.Code, Message: errorMessage.Message} jsonResponse, err := json.Marshal(httpResponse) if err != nil { diff --git a/logic/pro/metrics/metrics.go b/logic/metrics/metrics.go similarity index 100% rename from logic/pro/metrics/metrics.go rename to logic/metrics/metrics.go diff --git a/logic/networks.go b/logic/networks.go index f82f22a0..2103a8fa 100644 --- a/logic/networks.go +++ b/logic/networks.go @@ -96,7 +96,7 @@ func CreateNetwork(network models.Network) (models.Network, error) { err := ValidateNetwork(&network, false) if err != nil { - //returnErrorResponse(w, r, formatError(err, "badrequest")) + //logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return models.Network{}, err } @@ -656,6 +656,17 @@ func SaveNetwork(network *models.Network) error { return nil } +// NetworkExists - check if network exists +func NetworkExists(name string) (bool, error) { + + var network string + var err error + if network, err = database.FetchRecord(database.NETWORKS_TABLE_NAME, name); err != nil { + return false, err + } + return len(network) > 0, nil +} + // == Private == func networkNodesUpdateAction(networkName string, action string) error { diff --git a/logic/nodes.go b/logic/nodes.go index 7b0f6fe9..1b3584c4 100644 --- a/logic/nodes.go +++ b/logic/nodes.go @@ -311,7 +311,7 @@ func CreateNode(node *models.Node) error { //Create a JWT for the node tokenString, _ := CreateJWT(node.ID, node.MacAddress, node.Network) if tokenString == "" { - //returnErrorResponse(w, r, errorResponse) + //logic.ReturnErrorResponse(w, r, errorResponse) return err } err = ValidateNode(node, false) diff --git a/logic/pro/license.go b/logic/pro/license.go deleted file mode 100644 index 2ca96d50..00000000 --- a/logic/pro/license.go +++ /dev/null @@ -1,66 +0,0 @@ -package pro - -import ( - "crypto/rand" - "encoding/json" - - "github.com/gravitl/netmaker/database" - "github.com/gravitl/netmaker/netclient/ncutils" - "golang.org/x/crypto/nacl/box" -) - -const ( - db_license_key = "netmaker-id-key-pair" -) - -type apiServerConf struct { - PrivateKey []byte `json:"private_key" binding:"required"` - PublicKey []byte `json:"public_key" binding:"required"` -} - -// FetchApiServerKeys - fetches netmaker license keys for identification -// as well as secure communication with API -// if none present, it generates a new pair -func FetchApiServerKeys() (pub *[32]byte, priv *[32]byte, err error) { - var returnData = apiServerConf{} - currentData, err := database.FetchRecord(database.SERVERCONF_TABLE_NAME, db_license_key) - if err != nil && !database.IsEmptyRecord(err) { - return nil, nil, err - } else if database.IsEmptyRecord(err) { // need to generate a new identifier pair - pub, priv, err = box.GenerateKey(rand.Reader) - if err != nil { - return nil, nil, err - } - pubBytes, err := ncutils.ConvertKeyToBytes(pub) - if err != nil { - return nil, nil, err - } - privBytes, err := ncutils.ConvertKeyToBytes(priv) - if err != nil { - return nil, nil, err - } - returnData.PrivateKey = privBytes - returnData.PublicKey = pubBytes - record, err := json.Marshal(&returnData) - if err != nil { - return nil, nil, err - } - if err = database.Insert(db_license_key, string(record), database.SERVERCONF_TABLE_NAME); err != nil { - return nil, nil, err - } - } else { - if err = json.Unmarshal([]byte(currentData), &returnData); err != nil { - return nil, nil, err - } - priv, err = ncutils.ConvertBytesToKey(returnData.PrivateKey) - if err != nil { - return nil, nil, err - } - pub, err = ncutils.ConvertBytesToKey(returnData.PublicKey) - if err != nil { - return nil, nil, err - } - } - - return pub, priv, nil -} diff --git a/logic/pro/networks_test.go b/logic/pro/networks_test.go index 68915a3b..2f674e03 100644 --- a/logic/pro/networks_test.go +++ b/logic/pro/networks_test.go @@ -58,7 +58,7 @@ func TestNetworkProSettings(t *testing.T) { } AddProNetDefaults(&network) assert.NotNil(t, network.ProSettings) - assert.Nil(t, network.ProSettings.AllowedGroups) - assert.Nil(t, network.ProSettings.AllowedUsers) + assert.Equal(t, len(network.ProSettings.AllowedGroups), 1) + assert.Equal(t, len(network.ProSettings.AllowedUsers), 0) }) } diff --git a/controllers/security.go b/logic/security.go similarity index 71% rename from controllers/security.go rename to logic/security.go index f793da5d..2f013804 100644 --- a/controllers/security.go +++ b/logic/security.go @@ -1,4 +1,4 @@ -package controller +package logic import ( "encoding/json" @@ -7,8 +7,6 @@ import ( "github.com/gorilla/mux" "github.com/gravitl/netmaker/database" - "github.com/gravitl/netmaker/functions" - "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic/pro" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models/promodels" @@ -16,16 +14,20 @@ import ( ) const ( + // ALL_NETWORK_ACCESS - represents all networks + ALL_NETWORK_ACCESS = "THIS_USER_HAS_ALL" + master_uname = "masteradministrator" - unauthorized_msg = "unauthorized" - unauthorized_err = models.Error(unauthorized_msg) + Unauthorized_Msg = "unauthorized" + Unauthorized_Err = models.Error(Unauthorized_Msg) ) -func securityCheck(reqAdmin bool, next http.Handler) http.HandlerFunc { +// SecurityCheck - Check if user has appropriate permissions +func SecurityCheck(reqAdmin bool, next http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var errorResponse = models.ErrorResponse{ - Code: http.StatusUnauthorized, Message: unauthorized_msg, + Code: http.StatusUnauthorized, Message: Unauthorized_Msg, } var params = mux.Vars(r) @@ -44,14 +46,14 @@ func securityCheck(reqAdmin bool, next http.Handler) http.HandlerFunc { if len(networkName) == 0 { networkName = params["network"] } - networks, username, err := SecurityCheck(reqAdmin, networkName, bearerToken) + networks, username, err := UserPermissions(reqAdmin, networkName, bearerToken) if err != nil { - returnErrorResponse(w, r, errorResponse) + ReturnErrorResponse(w, r, errorResponse) return } networksJson, err := json.Marshal(&networks) if err != nil { - returnErrorResponse(w, r, errorResponse) + ReturnErrorResponse(w, r, errorResponse) return } r.Header.Set("user", username) @@ -60,7 +62,8 @@ func securityCheck(reqAdmin bool, next http.Handler) http.HandlerFunc { } } -func netUserSecurityCheck(isNodes, isClients bool, next http.Handler) http.HandlerFunc { +// NetUserSecurityCheck - Check if network user has appropriate permissions +func NetUserSecurityCheck(isNodes, isClients bool, next http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var errorResponse = models.ErrorResponse{ Code: http.StatusUnauthorized, Message: "unauthorized", @@ -77,7 +80,7 @@ func netUserSecurityCheck(isNodes, isClients bool, next http.Handler) http.Handl var authToken = "" if len(tokenSplit) < 2 { - returnErrorResponse(w, r, errorResponse) + ReturnErrorResponse(w, r, errorResponse) return } else { authToken = tokenSplit[1] @@ -91,9 +94,9 @@ func netUserSecurityCheck(isNodes, isClients bool, next http.Handler) http.Handl return } - userName, _, isadmin, err := logic.VerifyUserToken(authToken) + userName, _, isadmin, err := VerifyUserToken(authToken) if err != nil { - returnErrorResponse(w, r, errorResponse) + ReturnErrorResponse(w, r, errorResponse) return } r.Header.Set("user", userName) @@ -113,15 +116,15 @@ func netUserSecurityCheck(isNodes, isClients bool, next http.Handler) http.Handl } u, err := pro.GetNetworkUser(network, promodels.NetworkUserID(userName)) if err != nil { - returnErrorResponse(w, r, errorResponse) + ReturnErrorResponse(w, r, errorResponse) return } if u.AccessLevel > necessaryAccess { - returnErrorResponse(w, r, errorResponse) + ReturnErrorResponse(w, r, errorResponse) return } } else if netUserName != userName { - returnErrorResponse(w, r, errorResponse) + ReturnErrorResponse(w, r, errorResponse) return } @@ -129,14 +132,14 @@ func netUserSecurityCheck(isNodes, isClients bool, next http.Handler) http.Handl } } -// SecurityCheck - checks token stuff -func SecurityCheck(reqAdmin bool, netname string, token string) ([]string, string, error) { +// UserPermissions - checks token stuff +func UserPermissions(reqAdmin bool, netname string, token string) ([]string, string, error) { var tokenSplit = strings.Split(token, " ") var authToken = "" userNetworks := []string{} if len(tokenSplit) < 2 { - return userNetworks, "", unauthorized_err + return userNetworks, "", Unauthorized_Err } else { authToken = tokenSplit[1] } @@ -144,12 +147,12 @@ func SecurityCheck(reqAdmin bool, netname string, token string) ([]string, strin if authenticateMaster(authToken) { return []string{ALL_NETWORK_ACCESS}, master_uname, nil } - username, networks, isadmin, err := logic.VerifyUserToken(authToken) + username, networks, isadmin, err := VerifyUserToken(authToken) if err != nil { - return nil, username, unauthorized_err + return nil, username, Unauthorized_Err } if !isadmin && reqAdmin { - return nil, username, unauthorized_err + return nil, username, Unauthorized_Err } userNetworks = networks if isadmin { @@ -157,10 +160,10 @@ func SecurityCheck(reqAdmin bool, netname string, token string) ([]string, strin } // check network admin access if len(netname) > 0 && (!authenticateNetworkUser(netname, userNetworks) || len(userNetworks) == 0) { - return nil, username, unauthorized_err + return nil, username, Unauthorized_Err } if !pro.IsUserNetAdmin(netname, username) { - return nil, "", unauthorized_err + return nil, "", Unauthorized_Err } return userNetworks, username, nil } @@ -171,11 +174,11 @@ func authenticateMaster(tokenString string) bool { } func authenticateNetworkUser(network string, userNetworks []string) bool { - networkexists, err := functions.NetworkExists(network) + networkexists, err := NetworkExists(network) if (err != nil && !database.IsEmptyRecord(err)) || !networkexists { return false } - return logic.StringSliceContains(userNetworks, network) + return StringSliceContains(userNetworks, network) } //Consider a more secure way of setting master key @@ -187,15 +190,15 @@ func authenticateDNSToken(tokenString string) bool { return tokens[1] == servercfg.GetDNSKey() } -func continueIfUserMatch(next http.Handler) http.HandlerFunc { +func ContinueIfUserMatch(next http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var errorResponse = models.ErrorResponse{ - Code: http.StatusUnauthorized, Message: unauthorized_msg, + Code: http.StatusUnauthorized, Message: Unauthorized_Msg, } var params = mux.Vars(r) var requestedUser = params["username"] if requestedUser != r.Header.Get("user") { - returnErrorResponse(w, r, errorResponse) + ReturnErrorResponse(w, r, errorResponse) return } next.ServeHTTP(w, r) diff --git a/logic/server.go b/logic/server.go index 56897351..6e892bfb 100644 --- a/logic/server.go +++ b/logic/server.go @@ -18,6 +18,8 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +var EnterpriseCheckFuncs []interface{} + // == Join, Checkin, and Leave for Server == // KUBERNETES_LISTEN_PORT - starting port for Kubernetes in order to use NodePort range @@ -164,6 +166,13 @@ func ServerJoin(networkSettings *models.Network) (models.Node, error) { return *node, nil } +// EnterpriseCheck - Runs enterprise functions if presented +func EnterpriseCheck() { + for _, check := range EnterpriseCheckFuncs { + check.(func())() + } +} + // ServerUpdate - updates the server // replaces legacy Checkin code func ServerUpdate(serverNode *models.Node, ifaceDelta bool) error { diff --git a/logic/serverconf.go b/logic/serverconf.go index 68469663..fbd5faf5 100644 --- a/logic/serverconf.go +++ b/logic/serverconf.go @@ -6,6 +6,21 @@ import ( "github.com/gravitl/netmaker/database" ) +var ( + // Node_Limit - dummy var for community + Node_Limit = 1000000000 + // Networks_Limit - dummy var for community + Networks_Limit = 1000000000 + // Users_Limit - dummy var for community + Users_Limit = 1000000000 + // Clients_Limit - dummy var for community + Clients_Limit = 1000000000 + // Free_Tier - specifies if free tier + Free_Tier = false + // Is_EE - specifies if enterprise + Is_EE = false +) + // constant for database key for storing server ids const server_id_key = "nm-server-id" diff --git a/main.go b/main.go index 5c50e0b0..a82a07f8 100644 --- a/main.go +++ b/main.go @@ -20,7 +20,6 @@ import ( "github.com/gravitl/netmaker/config" controller "github.com/gravitl/netmaker/controllers" "github.com/gravitl/netmaker/database" - "github.com/gravitl/netmaker/ee" "github.com/gravitl/netmaker/functions" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" @@ -76,7 +75,7 @@ func initialize() { // Client Mode Prereq Check logger.FatalLog("Error connecting to database") } logger.Log(0, "database successfully connected") - if err = ee.AddServerIDIfNotPresent(); err != nil { + if err = logic.AddServerIDIfNotPresent(); err != nil { logger.Log(1, "failed to save server ID") } @@ -91,18 +90,7 @@ func initialize() { // Client Mode Prereq Check logger.Log(1, "Timer error occurred: ", err.Error()) } - if ee.IsEnterprise() { - // == License Handling == - ee.ValidateLicense() - if ee.Limits.FreeTier { - logger.Log(0, "proceeding with Free Tier license") - } else { - logger.Log(0, "proceeding with Paid Tier license") - } - // == End License Handling == - - ee.AddLicenseHooks() - } + logic.EnterpriseCheck() var authProvider = auth.InitializeAuthProvider() if authProvider != "" { diff --git a/main_ee.go b/main_ee.go index ba40a39a..dd4bb3a4 100644 --- a/main_ee.go +++ b/main_ee.go @@ -5,26 +5,8 @@ package main import ( "github.com/gravitl/netmaker/ee" - "github.com/gravitl/netmaker/models" ) func init() { - ee.SetIsEnterprise() - models.SetLogo(retrieveEELogo()) -} - -func retrieveEELogo() string { - return ` - __ __ ______ ______ __ __ ______ __ __ ______ ______ -/\ "-.\ \ /\ ___\ /\__ _\ /\ "-./ \ /\ __ \ /\ \/ / /\ ___\ /\ == \ -\ \ \-. \ \ \ __\ \/_/\ \/ \ \ \-./\ \ \ \ __ \ \ \ _"-. \ \ __\ \ \ __< - \ \_\\"\_\ \ \_____\ \ \_\ \ \_\ \ \_\ \ \_\ \_\ \ \_\ \_\ \ \_____\ \ \_\ \_\ - \/_/ \/_/ \/_____/ \/_/ \/_/ \/_/ \/_/\/_/ \/_/\/_/ \/_____/ \/_/ /_/ - - ___ ___ ____ - ____ ____ ____ / _ \ / _ \ / __ \ ____ ____ ____ - /___/ /___/ /___/ / ___/ / , _// /_/ / /___/ /___/ /___/ - /___/ /___/ /___/ /_/ /_/|_| \____/ /___/ /___/ /___/ - -` + ee.InitEE() } diff --git a/mq/handlers.go b/mq/handlers.go index 99695eb7..822955a3 100644 --- a/mq/handlers.go +++ b/mq/handlers.go @@ -7,7 +7,6 @@ import ( mqtt "github.com/eclipse/paho.mqtt.golang" "github.com/gravitl/netmaker/database" - "github.com/gravitl/netmaker/ee" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" @@ -99,7 +98,7 @@ func UpdateNode(client mqtt.Client, msg mqtt.Message) { // UpdateMetrics message Handler -- handles updates from client nodes for metrics func UpdateMetrics(client mqtt.Client, msg mqtt.Message) { - if ee.IsEnterprise() { + if logic.Is_EE { go func() { id, err := getID(msg.Topic()) if err != nil { diff --git a/mq/publishers.go b/mq/publishers.go index fd48529f..9e6cc521 100644 --- a/mq/publishers.go +++ b/mq/publishers.go @@ -6,10 +6,9 @@ import ( "fmt" "time" - "github.com/gravitl/netmaker/ee" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" - "github.com/gravitl/netmaker/logic/pro/metrics" + "github.com/gravitl/netmaker/logic/metrics" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/serverctl" @@ -185,7 +184,7 @@ func ServerStartNotify() error { // function to collect and store metrics for server nodes func collectServerMetrics(networks []models.Network) { - if !ee.IsEnterprise() { + if !logic.Is_EE { return } if len(networks) > 0 { diff --git a/netclient/functions/mqpublish.go b/netclient/functions/mqpublish.go index 5f1fdae6..766c19ed 100644 --- a/netclient/functions/mqpublish.go +++ b/netclient/functions/mqpublish.go @@ -15,7 +15,7 @@ import ( "github.com/cloverstd/tcping/ping" "github.com/gravitl/netmaker/logger" - "github.com/gravitl/netmaker/logic/pro/metrics" + "github.com/gravitl/netmaker/logic/metrics" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/netclient/auth" "github.com/gravitl/netmaker/netclient/config"