Merge branch 'develop' into bugfix_v0.16.0_swagger_sections

This commit is contained in:
Alex Feiszli
2022-09-19 12:55:32 -04:00
committed by GitHub
90 changed files with 4623 additions and 521 deletions

View File

@@ -1,11 +1,14 @@
#first stage - builder #first stage - builder
FROM gravitl/go-builder as builder FROM gravitl/go-builder as builder
ARG version ARG version
ARG tags
WORKDIR /app WORKDIR /app
COPY . . COPY . .
ENV GO111MODULE=auto ENV GO111MODULE=auto
RUN GOOS=linux CGO_ENABLED=1 go build -ldflags="-s -X 'main.version=${version}'" -o netmaker main.go RUN apk add git
RUN GOOS=linux CGO_ENABLED=1 go build ${tags} -ldflags="-s -X 'main.version=${version}'" .
# RUN go build -tags=ee . -o netmaker main.go
FROM alpine:3.15.2 FROM alpine:3.15.2
# add a c lib # add a c lib

14
Dockerfile-quick Normal file
View File

@@ -0,0 +1,14 @@
#first stage - builder
FROM alpine:3.15.2
ARG version
WORKDIR /app
COPY ./netmaker /root/netmaker
ENV GO111MODULE=auto
# add a c lib
RUN apk add gcompat iptables wireguard-tools
# set the working directory
WORKDIR /root/
RUN mkdir -p /etc/netclient/config
EXPOSE 8081
ENTRYPOINT ["./netmaker"]

View File

@@ -9,6 +9,7 @@ import (
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/pro/netcache"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/servercfg"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
@@ -27,8 +28,19 @@ const (
oidc_provider_name = "oidc" oidc_provider_name = "oidc"
verify_user = "verifyuser" verify_user = "verifyuser"
auth_key = "netmaker_auth" auth_key = "netmaker_auth"
user_signin_length = 16
node_signin_length = 64
) )
// OAuthUser - generic OAuth strategy user
type OAuthUser struct {
Name string `json:"name" bson:"name"`
Email string `json:"email" bson:"email"`
Login string `json:"login" bson:"login"`
UserPrincipalName string `json:"userPrincipalName" bson:"userPrincipalName"`
AccessToken string `json:"accesstoken" bson:"accesstoken"`
}
var auth_provider *oauth2.Config var auth_provider *oauth2.Config
func getCurrentAuthFunctions() map[string]interface{} { func getCurrentAuthFunctions() map[string]interface{} {
@@ -94,7 +106,14 @@ func HandleAuthCallback(w http.ResponseWriter, r *http.Request) {
if functions == nil { if functions == nil {
return return
} }
state, _ := getStateAndCode(r)
_, err := netcache.Get(state) // if in netcache proceeed with node registration login
if err == nil || len(state) == node_signin_length || (err != nil && strings.Contains(err.Error(), "expired")) {
logger.Log(0, "proceeding with node SSO callback")
HandleNodeSSOCallback(w, r)
} else { // handle normal login
functions[handle_callback].(func(http.ResponseWriter, *http.Request))(w, r) functions[handle_callback].(func(http.ResponseWriter, *http.Request))(w, r)
}
} }
// swagger:route GET /api/oauth/login nodes HandleAuthLogin // swagger:route GET /api/oauth/login nodes HandleAuthLogin
@@ -197,3 +216,35 @@ func fetchPassValue(newValue string) (string, error) {
} }
return string(b64CurrentValue), nil return string(b64CurrentValue), nil
} }
func getStateAndCode(r *http.Request) (string, string) {
var state, code string
if r.FormValue("state") != "" && r.FormValue("code") != "" {
state = r.FormValue("state")
code = r.FormValue("code")
} else if r.URL.Query().Get("state") != "" && r.URL.Query().Get("code") != "" {
state = r.URL.Query().Get("state")
code = r.URL.Query().Get("code")
}
return state, code
}
func (user *OAuthUser) getUserName() string {
var userName string
if user.Email != "" {
userName = user.Email
} else if user.Login != "" {
userName = user.Login
} else if user.UserPrincipalName != "" {
userName = user.UserPrincipalName
} else if user.Name != "" {
userName = user.Name
}
return userName
}
func isStateCached(state string) bool {
_, err := netcache.Get(state)
return err == nil || strings.Contains(err.Error(), "expired")
}

View File

@@ -23,11 +23,6 @@ var azure_ad_functions = map[string]interface{}{
verify_user: verifyAzureUser, verify_user: verifyAzureUser,
} }
type azureOauthUser struct {
UserPrincipalName string `json:"userPrincipalName" bson:"userPrincipalName"`
AccessToken string `json:"accesstoken" bson:"accesstoken"`
}
// == handle azure ad authentication here == // == handle azure ad authentication here ==
func initAzureAD(redirectURL string, clientID string, clientSecret string) { func initAzureAD(redirectURL string, clientID string, clientSecret string) {
@@ -41,7 +36,7 @@ func initAzureAD(redirectURL string, clientID string, clientSecret string) {
} }
func handleAzureLogin(w http.ResponseWriter, r *http.Request) { func handleAzureLogin(w http.ResponseWriter, r *http.Request) {
var oauth_state_string = logic.RandomString(16) var oauth_state_string = logic.RandomString(user_signin_length)
if auth_provider == nil && servercfg.GetFrontendURL() != "" { if auth_provider == nil && servercfg.GetFrontendURL() != "" {
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
return return
@@ -61,7 +56,8 @@ func handleAzureLogin(w http.ResponseWriter, r *http.Request) {
func handleAzureCallback(w http.ResponseWriter, r *http.Request) { func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
var content, err = getAzureUserInfo(r.FormValue("state"), r.FormValue("code")) var rState, rCode = getStateAndCode(r)
var content, err = getAzureUserInfo(rState, rCode)
if err != nil { if err != nil {
logger.Log(1, "error when getting user info from azure:", err.Error()) logger.Log(1, "error when getting user info from azure:", err.Error())
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
@@ -93,9 +89,9 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?login="+jwt+"&user="+content.UserPrincipalName, http.StatusPermanentRedirect) http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?login="+jwt+"&user="+content.UserPrincipalName, http.StatusPermanentRedirect)
} }
func getAzureUserInfo(state string, code string) (*azureOauthUser, error) { func getAzureUserInfo(state string, code string) (*OAuthUser, error) {
oauth_state_string, isValid := logic.IsStateValid(state) oauth_state_string, isValid := logic.IsStateValid(state)
if !isValid || state != oauth_state_string { if (!isValid || state != oauth_state_string) && !isStateCached(state) {
return nil, fmt.Errorf("invalid oauth state") return nil, fmt.Errorf("invalid oauth state")
} }
var token, err = auth_provider.Exchange(context.Background(), code) var token, err = auth_provider.Exchange(context.Background(), code)
@@ -121,7 +117,7 @@ func getAzureUserInfo(state string, code string) (*azureOauthUser, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("failed reading response body: %s", err.Error()) return nil, fmt.Errorf("failed reading response body: %s", err.Error())
} }
var userInfo = &azureOauthUser{} var userInfo = &OAuthUser{}
if err = json.Unmarshal(contents, userInfo); err != nil { if err = json.Unmarshal(contents, userInfo); err != nil {
return nil, fmt.Errorf("failed parsing email from response data: %s", err.Error()) return nil, fmt.Errorf("failed parsing email from response data: %s", err.Error())
} }

View File

@@ -23,11 +23,6 @@ var github_functions = map[string]interface{}{
verify_user: verifyGithubUser, verify_user: verifyGithubUser,
} }
type githubOauthUser struct {
Login string `json:"login" bson:"login"`
AccessToken string `json:"accesstoken" bson:"accesstoken"`
}
// == handle github authentication here == // == handle github authentication here ==
func initGithub(redirectURL string, clientID string, clientSecret string) { func initGithub(redirectURL string, clientID string, clientSecret string) {
@@ -41,7 +36,7 @@ func initGithub(redirectURL string, clientID string, clientSecret string) {
} }
func handleGithubLogin(w http.ResponseWriter, r *http.Request) { func handleGithubLogin(w http.ResponseWriter, r *http.Request) {
var oauth_state_string = logic.RandomString(16) var oauth_state_string = logic.RandomString(user_signin_length)
if auth_provider == nil && servercfg.GetFrontendURL() != "" { if auth_provider == nil && servercfg.GetFrontendURL() != "" {
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
return return
@@ -61,7 +56,8 @@ func handleGithubLogin(w http.ResponseWriter, r *http.Request) {
func handleGithubCallback(w http.ResponseWriter, r *http.Request) { func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
var content, err = getGithubUserInfo(r.URL.Query().Get("state"), r.URL.Query().Get("code")) var rState, rCode = getStateAndCode(r)
var content, err = getGithubUserInfo(rState, rCode)
if err != nil { if err != nil {
logger.Log(1, "error when getting user info from github:", err.Error()) logger.Log(1, "error when getting user info from github:", err.Error())
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
@@ -93,10 +89,10 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?login="+jwt+"&user="+content.Login, http.StatusPermanentRedirect) http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?login="+jwt+"&user="+content.Login, http.StatusPermanentRedirect)
} }
func getGithubUserInfo(state string, code string) (*githubOauthUser, error) { func getGithubUserInfo(state string, code string) (*OAuthUser, error) {
oauth_state_string, isValid := logic.IsStateValid(state) oauth_state_string, isValid := logic.IsStateValid(state)
if !isValid || state != oauth_state_string { if (!isValid || state != oauth_state_string) && !isStateCached(state) {
return nil, fmt.Errorf("invalid OAuth state") return nil, fmt.Errorf("invalid oauth state")
} }
var token, err = auth_provider.Exchange(context.Background(), code) var token, err = auth_provider.Exchange(context.Background(), code)
if err != nil { if err != nil {
@@ -125,7 +121,7 @@ func getGithubUserInfo(state string, code string) (*githubOauthUser, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("failed reading response body: %s", err.Error()) return nil, fmt.Errorf("failed reading response body: %s", err.Error())
} }
var userInfo = &githubOauthUser{} var userInfo = &OAuthUser{}
if err = json.Unmarshal(contents, userInfo); err != nil { if err = json.Unmarshal(contents, userInfo); err != nil {
return nil, fmt.Errorf("failed parsing email from response data: %s", err.Error()) return nil, fmt.Errorf("failed parsing email from response data: %s", err.Error())
} }

View File

@@ -24,11 +24,6 @@ var google_functions = map[string]interface{}{
verify_user: verifyGoogleUser, verify_user: verifyGoogleUser,
} }
type googleOauthUser struct {
Email string `json:"email" bson:"email"`
AccessToken string `json:"accesstoken" bson:"accesstoken"`
}
// == handle google authentication here == // == handle google authentication here ==
func initGoogle(redirectURL string, clientID string, clientSecret string) { func initGoogle(redirectURL string, clientID string, clientSecret string) {
@@ -42,7 +37,7 @@ func initGoogle(redirectURL string, clientID string, clientSecret string) {
} }
func handleGoogleLogin(w http.ResponseWriter, r *http.Request) { func handleGoogleLogin(w http.ResponseWriter, r *http.Request) {
var oauth_state_string = logic.RandomString(16) var oauth_state_string = logic.RandomString(user_signin_length)
if auth_provider == nil && servercfg.GetFrontendURL() != "" { if auth_provider == nil && servercfg.GetFrontendURL() != "" {
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
return return
@@ -62,7 +57,9 @@ func handleGoogleLogin(w http.ResponseWriter, r *http.Request) {
func handleGoogleCallback(w http.ResponseWriter, r *http.Request) { func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
var content, err = getGoogleUserInfo(r.FormValue("state"), r.FormValue("code")) var rState, rCode = getStateAndCode(r)
var content, err = getGoogleUserInfo(rState, rCode)
if err != nil { if err != nil {
logger.Log(1, "error when getting user info from google:", err.Error()) logger.Log(1, "error when getting user info from google:", err.Error())
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
@@ -91,13 +88,13 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
} }
logger.Log(1, "completed google OAuth sigin in for", content.Email) logger.Log(1, "completed google OAuth sigin in for", content.Email)
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?login="+jwt+"&user="+content.Email, http.StatusPermanentRedirect) http.Redirect(w, r, fmt.Sprintf("%s/login?login=%s&user=%s", servercfg.GetFrontendURL(), jwt, content.Email), http.StatusPermanentRedirect)
} }
func getGoogleUserInfo(state string, code string) (*googleOauthUser, error) { func getGoogleUserInfo(state string, code string) (*OAuthUser, error) {
oauth_state_string, isValid := logic.IsStateValid(state) oauth_state_string, isValid := logic.IsStateValid(state)
if !isValid || state != oauth_state_string { if (!isValid || state != oauth_state_string) && !isStateCached(state) {
return nil, fmt.Errorf("invalid OAuth state") return nil, fmt.Errorf("invalid oauth state")
} }
var token, err = auth_provider.Exchange(context.Background(), code) var token, err = auth_provider.Exchange(context.Background(), code)
if err != nil { if err != nil {
@@ -120,7 +117,7 @@ func getGoogleUserInfo(state string, code string) (*googleOauthUser, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("failed reading response body: %s", err.Error()) return nil, fmt.Errorf("failed reading response body: %s", err.Error())
} }
var userInfo = &googleOauthUser{} var userInfo = &OAuthUser{}
if err = json.Unmarshal(contents, userInfo); err != nil { if err = json.Unmarshal(contents, userInfo); err != nil {
return nil, fmt.Errorf("failed parsing email from response data: %s", err.Error()) return nil, fmt.Errorf("failed parsing email from response data: %s", err.Error())
} }

259
auth/nodecallback.go Normal file
View File

@@ -0,0 +1,259 @@
package auth
import (
"bytes"
"fmt"
"net/http"
"time"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/pro"
"github.com/gravitl/netmaker/logic/pro/netcache"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/models/promodels"
"github.com/gravitl/netmaker/servercfg"
)
var (
redirectUrl string
)
// HandleNodeSSOCallback handles the callback from the sso endpoint
// It is the analogue of auth.handleNodeSSOCallback but takes care of the end point flow
// Retrieves the mkey from the state cache and adds the machine to the users email namespace
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
// TODO: Add groups information from OIDC tokens into machine HostInfo
// Listens in /oidc/callback.
func HandleNodeSSOCallback(w http.ResponseWriter, r *http.Request) {
var functions = getCurrentAuthFunctions()
if functions == nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("bad conf"))
logger.Log(0, "Missing Oauth config in HandleNodeSSOCallback")
return
}
state, code := getStateAndCode(r)
var userClaims, err = functions[get_user_info].(func(string, string) (*OAuthUser, error))(state, code)
if err != nil {
logger.Log(0, "error when getting user info from callback:", err.Error())
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
return
}
if code == "" || state == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("Wrong params"))
logger.Log(0, "Missing params in HandleSSOCallback")
return
}
// all responses should be in html format from here on out
w.Header().Add("content-type", "text/html; charset=utf-8")
// retrieve machinekey from state cache
reqKeyIf, machineKeyFoundErr := netcache.Get(state)
if machineKeyFoundErr != nil {
logger.Log(0, "requested machine state key expired before authorisation completed -", machineKeyFoundErr.Error())
reqKeyIf = &netcache.CValue{
Network: "invalid",
Value: state,
Pass: "",
User: "netmaker",
Expiration: time.Now(),
}
response := returnErrTemplate("", "requested machine state key expired before authorisation completed", state, reqKeyIf)
w.WriteHeader(http.StatusInternalServerError)
w.Write(response)
return
}
user, err := isUserIsAllowed(userClaims.getUserName(), reqKeyIf.Network, true)
if err != nil {
logger.Log(0, "error occurred during SSO node join for user", userClaims.getUserName(), "on network", reqKeyIf.Network, "-", err.Error())
response := returnErrTemplate(user.UserName, err.Error(), state, reqKeyIf)
w.WriteHeader(http.StatusNotAcceptable)
w.Write(response)
return
}
logger.Log(1, "registering new node for user:", user.UserName, "on network", reqKeyIf.Network)
// Send OK to user in the browser
var response bytes.Buffer
if err := ssoCallbackTemplate.Execute(&response, ssoCallbackTemplateConfig{
User: userClaims.getUserName(),
Verb: "Authenticated",
}); err != nil {
logger.Log(0, "Could not render SSO callback template ", err.Error())
response := returnErrTemplate(user.UserName, "Could not render SSO callback template", state, reqKeyIf)
w.WriteHeader(http.StatusInternalServerError)
w.Write(response)
} else {
w.WriteHeader(http.StatusOK)
w.Write(response.Bytes())
}
// Need to send access key to the client
logger.Log(1, "Handling new machine addition to network",
reqKeyIf.Network, "with key",
reqKeyIf.Value, " identity:", userClaims.getUserName(), "claims:", fmt.Sprintf("%+v", userClaims))
var answer string
// The registation logic is starting here:
// we request access key with 1 use for the required network
accessToken, err := requestAccessKey(reqKeyIf.Network, 1, userClaims.getUserName())
if err != nil {
answer = fmt.Sprintf("Error from the netmaker controller %s", err.Error())
} else {
answer = fmt.Sprintf("AccessToken: %s", accessToken)
}
logger.Log(0, "Updating the token for the client request ... ")
// Give the user the access token via Pass in the DB
reqKeyIf.Pass = answer
if err = netcache.Set(state, reqKeyIf); err != nil {
logger.Log(0, "machine failed to complete join on network,", reqKeyIf.Network, "-", err.Error())
return
}
}
func setNetcache(ncache *netcache.CValue, state string) error {
if ncache == nil {
return fmt.Errorf("cache miss")
}
var err error
if err = netcache.Set(state, ncache); err != nil {
logger.Log(0, "machine failed to complete join on network,", ncache.Network, "-", err.Error())
}
return err
}
func returnErrTemplate(uname, message, state string, ncache *netcache.CValue) []byte {
var response bytes.Buffer
ncache.Pass = message
err := ssoErrCallbackTemplate.Execute(&response, ssoCallbackTemplateConfig{
User: uname,
Verb: message,
})
if err != nil {
return []byte(err.Error())
}
err = setNetcache(ncache, state)
if err != nil {
return []byte(err.Error())
}
return response.Bytes()
}
// RegisterNodeSSO redirects to the IDP for authentication
// Puts machine key in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:regKey.
func RegisterNodeSSO(w http.ResponseWriter, r *http.Request) {
logger.Log(1, "RegisterNodeSSO\n")
vars := mux.Vars(r)
// machineKeyStr this is not key but state
machineKeyStr := vars["regKey"]
logger.Log(1, "requested key:", machineKeyStr)
if machineKeyStr == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("Wrong params"))
logger.Log(0, "Wrong params ", machineKeyStr)
return
}
// machineKeyStr this not key but state
authURL := auth_provider.AuthCodeURL(machineKeyStr)
//authURL = authURL + "&connector_id=" + "google"
logger.Log(0, "Redirecting to ", authURL, " for authentication")
http.Redirect(w, r, authURL, http.StatusSeeOther)
}
// == private ==
// API to create an access key for a given network with a given name
func requestAccessKey(network string, uses int, name string) (accessKey string, err error) {
var sAccessKey models.AccessKey
var sNetwork models.Network
sNetwork, err = logic.GetParentNetwork(network)
if err != nil {
logger.Log(0, "err calling GetParentNetwork API=%s", err.Error())
return "", fmt.Errorf("internal controller error %s", err.Error())
}
// If a key already exists, we recreate it.
// @TODO Is that a preferred handling ? We could also trying to re-use.
// can happen if user started log in but did not finish
for _, currentkey := range sNetwork.AccessKeys {
if currentkey.Name == name {
logger.Log(0, "erasing existing AccessKey for: ", name)
err = logic.DeleteKey(currentkey.Name, network)
if err != nil {
logger.Log(0, "err calling CreateAccessKey API ", err.Error())
return "", fmt.Errorf("key already exists. Contact admin to resolve")
}
break
}
}
// Only one usage is needed - for the next time new access key will be required
// it will be created next time after another IdP approval
sAccessKey.Uses = 1
sAccessKey.Name = name
accessToken, err := logic.CreateAccessKey(sAccessKey, sNetwork)
if err != nil {
logger.Log(0, "err calling CreateAccessKey API ", err.Error())
return "", fmt.Errorf("error from the netmaker controller %s", err.Error())
} else {
logger.Log(1, "created access key", sAccessKey.Name, "on", network)
}
return accessToken.AccessString, nil
}
func isUserIsAllowed(username, network string, shouldAddUser bool) (*models.User, error) {
user, err := logic.GetUser(username)
if err != nil && shouldAddUser { // user must not exist, so try to make one
if err = addUser(username); err != nil {
logger.Log(0, "failed to add user", username, "during a node SSO network join on network", network)
// response := returnErrTemplate(user.UserName, "failed to add user", state, reqKeyIf)
// w.WriteHeader(http.StatusInternalServerError)
// w.Write(response)
return nil, fmt.Errorf("failed to add user to system")
}
logger.Log(0, "user", username, "was added during a node SSO network join on network", network)
user, _ = logic.GetUser(username)
}
if !user.IsAdmin { // perform check to see if user is allowed to join a node to network
netUser, err := pro.GetNetworkUser(network, promodels.NetworkUserID(user.UserName))
if err != nil {
logger.Log(0, "failed to get net user details for user", user.UserName, "during node SSO")
return nil, fmt.Errorf("failed to verify network user")
}
if netUser.AccessLevel != pro.NET_ADMIN { // if user is a net admin on network, good to go
// otherwise, check if they have node access + haven't reached node limit on network
if netUser.AccessLevel == pro.NODE_ACCESS {
if len(netUser.Nodes) >= netUser.NodeLimit {
logger.Log(0, "user", user.UserName, "has reached their node limit on network", network)
return nil, fmt.Errorf("user node limit exceeded")
}
} else {
logger.Log(0, "user", user.UserName, "attempted to access network", network, "via node SSO")
return nil, fmt.Errorf("network user not allowed")
}
}
}
return &user, nil
}

154
auth/nodesession.go Normal file
View File

@@ -0,0 +1,154 @@
package auth
import (
"encoding/hex"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/gorilla/websocket"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/pro/netcache"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/models/promodels"
"github.com/gravitl/netmaker/servercfg"
)
// SessionHandler - called by the HTTP router when user
// is calling netclient with --login-server parameter in order to authenticate
// via SSO mechanism by OAuth2 protocol flow.
// This triggers a session start and it is managed by the flow implemented here and callback
// When this method finishes - the auth flow has finished either OK or by timeout or any other error occured
func SessionHandler(conn *websocket.Conn) {
defer conn.Close()
logger.Log(1, "Running sessionHandler")
// If reached here we have a session from user to handle...
messageType, message, err := conn.ReadMessage()
if err != nil {
logger.Log(0, "Error during message reading:", err.Error())
return
}
var loginMessage promodels.LoginMsg
err = json.Unmarshal(message, &loginMessage)
if err != nil {
logger.Log(0, "Failed to unmarshall data err=", err.Error())
return
}
logger.Log(1, "SSO node join attempted with info network:", loginMessage.Network, "node identifier:", loginMessage.Mac, "user:", loginMessage.User)
req := new(netcache.CValue)
req.Value = string(loginMessage.Mac)
req.Network = loginMessage.Network
req.Pass = ""
req.User = ""
// Add any extra parameter provided in the configuration to the Authorize Endpoint request??
stateStr := hex.EncodeToString([]byte(logic.RandomString(node_signin_length)))
if err := netcache.Set(stateStr, req); err != nil {
logger.Log(0, "Failed to process sso request -", err.Error())
return
}
// Wait for the user to finish his auth flow...
// TBD: what should be the timeout here ?
timeout := make(chan bool, 1)
answer := make(chan string, 1)
defer close(answer)
defer close(timeout)
if loginMessage.User != "" { // handle basic auth
// verify that server supports basic auth, then authorize the request with given credentials
// check if user is allowed to join via node sso
// i.e. user is admin or user has network permissions
if !servercfg.IsBasicAuthEnabled() {
err = conn.WriteMessage(messageType, []byte("Basic Auth Disabled"))
if err != nil {
logger.Log(0, "error during message writing:", err.Error())
}
}
_, err := logic.VerifyAuthRequest(models.UserAuthParams{
UserName: loginMessage.User,
Password: loginMessage.Password,
})
if err != nil {
err = conn.WriteMessage(messageType, []byte(fmt.Sprintf("Failed to authenticate, %s.", loginMessage.User)))
if err != nil {
logger.Log(0, "error during message writing:", err.Error())
}
return
}
user, err := isUserIsAllowed(loginMessage.User, loginMessage.Network, false)
if err != nil {
err = conn.WriteMessage(messageType, []byte(fmt.Sprintf("%s lacks permission to join.", loginMessage.User)))
if err != nil {
logger.Log(0, "error during message writing:", err.Error())
}
return
}
accessToken, err := requestAccessKey(loginMessage.Network, 1, user.UserName)
if err != nil {
req.Pass = fmt.Sprintf("Error from the netmaker controller %s", err.Error())
} else {
req.Pass = fmt.Sprintf("AccessToken: %s", accessToken)
}
// Give the user the access token via Pass in the DB
if err = netcache.Set(stateStr, req); err != nil {
logger.Log(0, "machine failed to complete join on network,", loginMessage.Network, "-", err.Error())
return
}
} else { // handle SSO / OAuth
redirectUrl = fmt.Sprintf("https://%s/api/oauth/register/%s", servercfg.GetAPIConnString(), stateStr)
err = conn.WriteMessage(messageType, []byte(redirectUrl))
if err != nil {
logger.Log(0, "error during message writing:", err.Error())
}
}
go func() {
for {
cachedReq, err := netcache.Get(stateStr)
if err != nil {
if strings.Contains(err.Error(), "expired") {
logger.Log(0, "timeout occurred while waiting for SSO on network", loginMessage.Network)
timeout <- true
break
}
continue
} else if cachedReq.Pass != "" {
logger.Log(0, "node SSO process completed for user", cachedReq.User, "on network", loginMessage.Network)
answer <- cachedReq.Pass
break
}
time.Sleep(500) // try it 2 times per second to see if auth is completed
}
}()
select {
case result := <-answer:
// a read from req.answerCh has occurred
err = conn.WriteMessage(messageType, []byte(result))
if err != nil {
logger.Log(0, "Error during message writing:", err.Error())
}
case <-timeout:
logger.Log(0, "Authentication server time out for a node on network", loginMessage.Network)
// the read from req.answerCh has timed out
err = conn.WriteMessage(messageType, []byte("Authentication server time out"))
if err != nil {
logger.Log(0, "Error during message writing:", err.Error())
}
}
// The entry is not needed anymore, but we will let the producer to close it to avoid panic cases
if err = netcache.Del(stateStr); err != nil {
logger.Log(0, "failed to remove node SSO cache entry", err.Error())
}
// Cleanly close the connection by sending a close message and then
// waiting (with timeout) for the server to close the connection.
err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
logger.Log(0, "write close:", err.Error())
return
}
}

View File

@@ -26,11 +26,6 @@ var oidc_functions = map[string]interface{}{
var oidc_verifier *oidc.IDTokenVerifier var oidc_verifier *oidc.IDTokenVerifier
type OIDCUser struct {
Name string `json:"name" bson:"name"`
Email string `json:"email" bson:"email"`
}
// == handle OIDC authentication here == // == handle OIDC authentication here ==
func initOIDC(redirectURL string, clientID string, clientSecret string, issuer string) { func initOIDC(redirectURL string, clientID string, clientSecret string, issuer string) {
@@ -54,7 +49,7 @@ func initOIDC(redirectURL string, clientID string, clientSecret string, issuer s
} }
func handleOIDCLogin(w http.ResponseWriter, r *http.Request) { func handleOIDCLogin(w http.ResponseWriter, r *http.Request) {
var oauth_state_string = logic.RandomString(16) var oauth_state_string = logic.RandomString(user_signin_length)
if auth_provider == nil && servercfg.GetFrontendURL() != "" { if auth_provider == nil && servercfg.GetFrontendURL() != "" {
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
return return
@@ -67,14 +62,15 @@ func handleOIDCLogin(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
return return
} }
var url = auth_provider.AuthCodeURL(oauth_state_string) var url = auth_provider.AuthCodeURL(oauth_state_string)
http.Redirect(w, r, url, http.StatusTemporaryRedirect) http.Redirect(w, r, url, http.StatusTemporaryRedirect)
} }
func handleOIDCCallback(w http.ResponseWriter, r *http.Request) { func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
var content, err = getOIDCUserInfo(r.FormValue("state"), r.FormValue("code")) var rState, rCode = getStateAndCode(r)
var content, err = getOIDCUserInfo(rState, rCode)
if err != nil { if err != nil {
logger.Log(1, "error when getting user info from callback:", err.Error()) logger.Log(1, "error when getting user info from callback:", err.Error())
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
@@ -98,7 +94,7 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
var jwt, jwtErr = logic.VerifyAuthRequest(authRequest) var jwt, jwtErr = logic.VerifyAuthRequest(authRequest)
if jwtErr != nil { if jwtErr != nil {
logger.Log(1, "could not parse jwt for user", authRequest.UserName) logger.Log(1, "could not parse jwt for user", authRequest.UserName, jwtErr.Error())
return return
} }
@@ -106,10 +102,12 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?login="+jwt+"&user="+content.Email, http.StatusPermanentRedirect) http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?login="+jwt+"&user="+content.Email, http.StatusPermanentRedirect)
} }
func getOIDCUserInfo(state string, code string) (u *OIDCUser, e error) { func getOIDCUserInfo(state string, code string) (u *OAuthUser, e error) {
oauth_state_string, isValid := logic.IsStateValid(state) oauth_state_string, isValid := logic.IsStateValid(state)
if !isValid || state != oauth_state_string { logger.Log(3, "using oauth state string:,", oauth_state_string)
return nil, fmt.Errorf("invalid OAuth state") logger.Log(3, " state string:,", state)
if (!isValid || state != oauth_state_string) && !isStateCached(state) {
return nil, fmt.Errorf("invalid oauth state")
} }
defer func() { defer func() {
@@ -136,7 +134,7 @@ func getOIDCUserInfo(state string, code string) (u *OIDCUser, e error) {
return nil, fmt.Errorf("failed to verify raw id_token: \"%s\"", err.Error()) return nil, fmt.Errorf("failed to verify raw id_token: \"%s\"", err.Error())
} }
u = &OIDCUser{} u = &OAuthUser{}
if err := idToken.Claims(u); err != nil { if err := idToken.Claims(u); err != nil {
e = fmt.Errorf("error when claiming OIDCUser: \"%s\"", err.Error()) e = fmt.Errorf("error when claiming OIDCUser: \"%s\"", err.Error())
} }

81
auth/templates.go Normal file
View File

@@ -0,0 +1,81 @@
package auth
import "html/template"
type ssoCallbackTemplateConfig struct {
User string
Verb string
}
var ssoCallbackTemplate = template.Must(
template.New("ssocallback").Parse(`<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta http-equiv="X-UA-Compatible" content="ie=edge">
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@4.3.1/dist/css/bootstrap.min.css"
integrity="sha384-ggOyR0iXCbMQv3Xipma34MD+dH/1fQ784/j6cY/iJTQUOhcWr7x9JvoRxT2MZw1T" crossorigin="anonymous">
<title>Netmaker</title>
</head>
<style>
.text-responsive {
font-size: calc(100% + 1vw + 1vh);
}
</style>
<body>
<div class="container">
<div class="row justify-content-center mt-5 p-5 align-items-center text-center">
<a href="https://netmaker.io">
<img src="https://raw.githubusercontent.com/gravitl/netmaker/master/img/netmaker-teal.png" alt="Netmaker"
width="75%" height="25%" class="img-fluid">
</a>
</div>
<div class="row justify-content-center mt-5 p-3 text-center">
<div class="col">
<h2 class="text-responsive">{{.User}} has been successfully {{.Verb}}</h2>
<br />
<h2 class="text-responsive">You may now close this window.</h2>
</div>
</div>
</div>
</body>
</html>`),
)
var ssoErrCallbackTemplate = template.Must(
template.New("ssocallback").Parse(`<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta http-equiv="X-UA-Compatible" content="ie=edge">
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@4.3.1/dist/css/bootstrap.min.css"
integrity="sha384-ggOyR0iXCbMQv3Xipma34MD+dH/1fQ784/j6cY/iJTQUOhcWr7x9JvoRxT2MZw1T" crossorigin="anonymous">
<title>Netmaker</title>
</head>
<style>
.text-responsive {
font-size: calc(100% + 1vw + 1vh);
color: red;
}
</style>
<body>
<div class="container">
<div class="row justify-content-center mt-5 p-5 align-items-center text-center">
<a href="https://netmaker.io">
<img src="https://raw.githubusercontent.com/gravitl/netmaker/master/img/netmaker-teal.png" alt="Netmaker"
width="75%" height="25%" class="img-fluid">
</a>
</div>
<div class="row justify-content-center mt-5 p-3 text-center">
<div class="col">
<h2 class="text-responsive">{{.User}} unable to join network: {{.Verb}}</h2>
<br />
<h2 class="text-responsive">If you feel this is a mistake, please contact your network administrator.</h2>
</div>
</div>
</div>
</body>
</html>`),
)

View File

@@ -39,6 +39,7 @@ services:
VERBOSITY: "1" VERBOSITY: "1"
MANAGE_IPTABLES: "on" MANAGE_IPTABLES: "on"
PORT_FORWARD_SERVICES: "dns" PORT_FORWARD_SERVICES: "dns"
METRICS_EXPORTER: "on"
ports: ports:
- "51821-51830:51821-51830/udp" - "51821-51830:51821-51830/udp"
expose: expose:
@@ -111,6 +112,7 @@ services:
restart: unless-stopped restart: unless-stopped
volumes: volumes:
- /root/mosquitto.conf:/mosquitto/config/mosquitto.conf - /root/mosquitto.conf:/mosquitto/config/mosquitto.conf
- /root/mosquitto.passwords:/etc/mosquitto.passwords
- mosquitto_data:/mosquitto/data - mosquitto_data:/mosquitto/data
- mosquitto_logs:/mosquitto/log - mosquitto_logs:/mosquitto/log
- shared_certs:/mosquitto/certs - shared_certs:/mosquitto/certs
@@ -123,6 +125,66 @@ services:
- traefik.tcp.services.mqtts-svc.loadbalancer.server.port=8883 - traefik.tcp.services.mqtts-svc.loadbalancer.server.port=8883
- traefik.tcp.routers.mqtts.service=mqtts-svc - traefik.tcp.routers.mqtts.service=mqtts-svc
- traefik.tcp.routers.mqtts.entrypoints=websecure - traefik.tcp.routers.mqtts.entrypoints=websecure
prometheus:
container_name: prometheus
image: gravitl/netmaker-prometheus:latest
environment:
NETMAKER_METRICS_TARGET: "netmaker-exporter.NETMAKER_BASE_DOMAIN"
labels:
- traefik.enable=true
- traefik.http.routers.prometheus.entrypoints=websecure
- traefik.http.routers.prometheus.rule=Host(`prometheus.NETMAKER_BASE_DOMAIN`)
- traefik.http.services.prometheus.loadbalancer.server.port=9090
- traefik.http.routers.prometheus.service=prometheus
restart: always
volumes:
- prometheus_data:/prometheus
depends_on:
- netmaker
ports:
- 9090:9090
grafana:
container_name: grafana
image: gravitl/netmaker-grafana:latest
labels:
- traefik.enable=true
- traefik.http.routers.grafana.entrypoints=websecure
- traefik.http.routers.grafana.rule=Host(`grafana.NETMAKER_BASE_DOMAIN`)
- traefik.http.services.grafana.loadbalancer.server.port=3000
- traefik.http.routers.grafana.service=grafana
environment:
PROMETHEUS_HOST: "prometheus.NETMAKER_BASE_DOMAIN"
NETMAKER_METRICS_TARGET: "netmaker-exporter.NETMAKER_BASE_DOMAIN"
ports:
- 3000:3000
restart: always
links:
- prometheus
depends_on:
- prometheus
- netmaker
netmaker-exporter:
container_name: netmaker-exporter
image: gravitl/netmaker-exporter:latest
labels:
- traefik.enable=true
- traefik.http.routers.netmaker-exporter.entrypoints=websecure
- traefik.http.routers.netmaker-exporter.rule=Host(`netmaker-exporter.NETMAKER_BASE_DOMAIN`)
- traefik.http.services.netmaker-exporter.loadbalancer.server.port=8085
- traefik.http.routers.netmaker-exporter.service=netmaker-exporter
restart: always
depends_on:
- netmaker
environment:
MQ_HOST: "mq"
MQ_PORT: "443"
MQ_SERVER_PORT: "1884"
PROMETHEUS: "on"
VERBOSITY: "1"
API_PORT: "8085"
PROMETHEUS_HOST: https://prometheus.NETMAKER_BASE_DOMAIN
expose:
- "8085"
volumes: volumes:
traefik_certs: {} traefik_certs: {}
shared_certs: {} shared_certs: {}
@@ -130,3 +192,4 @@ volumes:
dnsconfig: {} dnsconfig: {}
mosquitto_data: {} mosquitto_data: {}
mosquitto_logs: {} mosquitto_logs: {}
prometheus_data: {}

View File

@@ -70,6 +70,11 @@ type ServerConfig struct {
MQServerPort string `yaml:"mqserverport"` MQServerPort string `yaml:"mqserverport"`
Server string `yaml:"server"` Server string `yaml:"server"`
PublicIPService string `yaml:"publicipservice"` PublicIPService string `yaml:"publicipservice"`
MetricsExporter string `yaml:"metrics_exporter"`
BasicAuth string `yaml:"basic_auth"`
LicenseValue string `yaml:"license_value"`
NetmakerAccountID string `yaml:"netmaker_account_id"`
IsEE string `yaml:"is_ee"`
} }
// SQLConfig - Generic SQL Config // SQLConfig - Generic SQL Config

View File

@@ -25,6 +25,9 @@ var HttpHandlers = []interface{}{
serverHandlers, serverHandlers,
extClientHandlers, extClientHandlers,
ipHandlers, ipHandlers,
loggerHandlers,
userGroupsHandlers,
networkUsersHandlers,
} }
// HandleRESTRequests - handles the rest requests // HandleRESTRequests - handles the rest requests

View File

@@ -16,13 +16,13 @@ import (
func dnsHandlers(r *mux.Router) { func dnsHandlers(r *mux.Router) {
r.HandleFunc("/api/dns", securityCheck(true, http.HandlerFunc(getAllDNS))).Methods("GET") r.HandleFunc("/api/dns", logic.SecurityCheck(true, http.HandlerFunc(getAllDNS))).Methods("GET")
r.HandleFunc("/api/dns/adm/{network}/nodes", securityCheck(false, http.HandlerFunc(getNodeDNS))).Methods("GET") r.HandleFunc("/api/dns/adm/{network}/nodes", logic.SecurityCheck(false, http.HandlerFunc(getNodeDNS))).Methods("GET")
r.HandleFunc("/api/dns/adm/{network}/custom", securityCheck(false, http.HandlerFunc(getCustomDNS))).Methods("GET") r.HandleFunc("/api/dns/adm/{network}/custom", logic.SecurityCheck(false, http.HandlerFunc(getCustomDNS))).Methods("GET")
r.HandleFunc("/api/dns/adm/{network}", securityCheck(false, http.HandlerFunc(getDNS))).Methods("GET") r.HandleFunc("/api/dns/adm/{network}", logic.SecurityCheck(false, http.HandlerFunc(getDNS))).Methods("GET")
r.HandleFunc("/api/dns/{network}", securityCheck(false, http.HandlerFunc(createDNS))).Methods("POST") r.HandleFunc("/api/dns/{network}", logic.SecurityCheck(false, http.HandlerFunc(createDNS))).Methods("POST")
r.HandleFunc("/api/dns/adm/pushdns", securityCheck(false, http.HandlerFunc(pushDNS))).Methods("POST") r.HandleFunc("/api/dns/adm/pushdns", logic.SecurityCheck(false, http.HandlerFunc(pushDNS))).Methods("POST")
r.HandleFunc("/api/dns/{network}/{domain}", securityCheck(false, http.HandlerFunc(deleteDNS))).Methods("DELETE") r.HandleFunc("/api/dns/{network}/{domain}", logic.SecurityCheck(false, http.HandlerFunc(deleteDNS))).Methods("DELETE")
} }
// swagger:route GET /api/dns/adm/{network}/nodes dns getNodeDNS // swagger:route GET /api/dns/adm/{network}/nodes dns getNodeDNS
@@ -44,7 +44,7 @@ func getNodeDNS(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get node DNS entries for network [%s]: %v", network, err)) fmt.Sprintf("failed to get node DNS entries for network [%s]: %v", network, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -68,7 +68,7 @@ func getAllDNS(w http.ResponseWriter, r *http.Request) {
dns, err := logic.GetAllDNS() dns, err := logic.GetAllDNS()
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to get all DNS entries: ", err.Error()) logger.Log(0, r.Header.Get("user"), "failed to get all DNS entries: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -98,7 +98,7 @@ func getCustomDNS(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get custom DNS entries for network [%s]: %v", network, err.Error())) fmt.Sprintf("failed to get custom DNS entries for network [%s]: %v", network, err.Error()))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -128,7 +128,7 @@ func getDNS(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get all DNS entries for network [%s]: %v", network, err.Error())) fmt.Sprintf("failed to get all DNS entries for network [%s]: %v", network, err.Error()))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -160,7 +160,7 @@ func createDNS(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("invalid DNS entry %+v: %v", entry, err)) fmt.Sprintf("invalid DNS entry %+v: %v", entry, err))
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
@@ -168,14 +168,14 @@ func createDNS(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("Failed to create DNS entry %+v: %v", entry, err)) fmt.Sprintf("Failed to create DNS entry %+v: %v", entry, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
err = logic.SetDNS() err = logic.SetDNS()
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("Failed to set DNS entries on file: %v", err)) fmt.Sprintf("Failed to set DNS entries on file: %v", err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
logger.Log(1, "new DNS record added:", entry.Name) logger.Log(1, "new DNS record added:", entry.Name)
@@ -221,7 +221,7 @@ func deleteDNS(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, "failed to delete dns entry: ", entrytext) logger.Log(0, "failed to delete dns entry: ", entrytext)
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
logger.Log(1, "deleted dns entry: ", entrytext) logger.Log(1, "deleted dns entry: ", entrytext)
@@ -229,7 +229,7 @@ func deleteDNS(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("Failed to set DNS entries on file: %v", err)) fmt.Sprintf("Failed to set DNS entries on file: %v", err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
json.NewEncoder(w).Encode(entrytext + " deleted.") json.NewEncoder(w).Encode(entrytext + " deleted.")
@@ -287,7 +287,7 @@ func pushDNS(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("Failed to set DNS entries on file: %v", err)) fmt.Sprintf("Failed to set DNS entries on file: %v", err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
logger.Log(1, r.Header.Get("user"), "pushed DNS updates to nameserver") logger.Log(1, r.Header.Get("user"), "pushed DNS updates to nameserver")

View File

@@ -12,20 +12,22 @@ import (
"github.com/gravitl/netmaker/functions" "github.com/gravitl/netmaker/functions"
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/pro"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/models/promodels"
"github.com/gravitl/netmaker/mq" "github.com/gravitl/netmaker/mq"
"github.com/skip2/go-qrcode" "github.com/skip2/go-qrcode"
) )
func extClientHandlers(r *mux.Router) { func extClientHandlers(r *mux.Router) {
r.HandleFunc("/api/extclients", securityCheck(false, http.HandlerFunc(getAllExtClients))).Methods("GET") r.HandleFunc("/api/extclients", logic.SecurityCheck(false, http.HandlerFunc(getAllExtClients))).Methods("GET")
r.HandleFunc("/api/extclients/{network}", securityCheck(false, http.HandlerFunc(getNetworkExtClients))).Methods("GET") r.HandleFunc("/api/extclients/{network}", logic.SecurityCheck(false, http.HandlerFunc(getNetworkExtClients))).Methods("GET")
r.HandleFunc("/api/extclients/{network}/{clientid}", securityCheck(false, http.HandlerFunc(getExtClient))).Methods("GET") r.HandleFunc("/api/extclients/{network}/{clientid}", logic.SecurityCheck(false, http.HandlerFunc(getExtClient))).Methods("GET")
r.HandleFunc("/api/extclients/{network}/{clientid}/{type}", securityCheck(false, http.HandlerFunc(getExtClientConf))).Methods("GET") r.HandleFunc("/api/extclients/{network}/{clientid}/{type}", logic.NetUserSecurityCheck(false, true, http.HandlerFunc(getExtClientConf))).Methods("GET")
r.HandleFunc("/api/extclients/{network}/{clientid}", securityCheck(false, http.HandlerFunc(updateExtClient))).Methods("PUT") r.HandleFunc("/api/extclients/{network}/{clientid}", logic.NetUserSecurityCheck(false, true, http.HandlerFunc(updateExtClient))).Methods("PUT")
r.HandleFunc("/api/extclients/{network}/{clientid}", securityCheck(false, http.HandlerFunc(deleteExtClient))).Methods("DELETE") r.HandleFunc("/api/extclients/{network}/{clientid}", logic.NetUserSecurityCheck(false, true, http.HandlerFunc(deleteExtClient))).Methods("DELETE")
r.HandleFunc("/api/extclients/{network}/{nodeid}", securityCheck(false, http.HandlerFunc(createExtClient))).Methods("POST") r.HandleFunc("/api/extclients/{network}/{nodeid}", logic.NetUserSecurityCheck(false, true, checkFreeTierLimits(clients_l, http.HandlerFunc(createExtClient)))).Methods("POST")
} }
func checkIngressExists(nodeID string) bool { func checkIngressExists(nodeID string) bool {
@@ -60,7 +62,7 @@ func getNetworkExtClients(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get ext clients for network [%s]: %v", network, err)) fmt.Sprintf("failed to get ext clients for network [%s]: %v", network, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -94,16 +96,16 @@ func getAllExtClients(w http.ResponseWriter, r *http.Request) {
if marshalErr != nil { if marshalErr != nil {
logger.Log(0, "error unmarshalling networks: ", logger.Log(0, "error unmarshalling networks: ",
marshalErr.Error()) marshalErr.Error())
returnErrorResponse(w, r, formatError(marshalErr, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(marshalErr, "internal"))
return return
} }
clients := []models.ExtClient{} clients := []models.ExtClient{}
var err error var err error
if networksSlice[0] == ALL_NETWORK_ACCESS { if networksSlice[0] == logic.ALL_NETWORK_ACCESS {
clients, err = functions.GetAllExtClients() clients, err = functions.GetAllExtClients()
if err != nil && !database.IsEmptyRecord(err) { if err != nil && !database.IsEmptyRecord(err) {
logger.Log(0, "failed to get all extclients: ", err.Error()) logger.Log(0, "failed to get all extclients: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
} else { } else {
@@ -144,7 +146,7 @@ func getExtClient(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get extclient for [%s] on network [%s]: %v", logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get extclient for [%s] on network [%s]: %v",
clientid, network, err)) clientid, network, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -175,7 +177,7 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get extclient for [%s] on network [%s]: %v", logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get extclient for [%s] on network [%s]: %v",
clientid, networkid, err)) clientid, networkid, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -183,14 +185,14 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get ingress gateway node [%s] info: %v", client.IngressGatewayID, err)) fmt.Sprintf("failed to get ingress gateway node [%s] info: %v", client.IngressGatewayID, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
network, err := logic.GetParentNetwork(client.Network) network, err := logic.GetParentNetwork(client.Network)
if err != nil { if err != nil {
logger.Log(1, r.Header.Get("user"), "Could not retrieve Ingress Gateway Network", client.Network) logger.Log(1, r.Header.Get("user"), "Could not retrieve Ingress Gateway Network", client.Network)
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -256,7 +258,7 @@ Endpoint = %s
bytes, err := qrcode.Encode(config, qrcode.Medium, 220) bytes, err := qrcode.Encode(config, qrcode.Medium, 220)
if err != nil { if err != nil {
logger.Log(1, r.Header.Get("user"), "failed to encode qr code: ", err.Error()) logger.Log(1, r.Header.Get("user"), "failed to encode qr code: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
w.Header().Set("Content-Type", "image/png") w.Header().Set("Content-Type", "image/png")
@@ -264,7 +266,7 @@ Endpoint = %s
_, err = w.Write(bytes) _, err = w.Write(bytes)
if err != nil { if err != nil {
logger.Log(1, r.Header.Get("user"), "response writer error (qr) ", err.Error()) logger.Log(1, r.Header.Get("user"), "response writer error (qr) ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
return return
@@ -278,7 +280,7 @@ Endpoint = %s
_, err := fmt.Fprint(w, config) _, err := fmt.Fprint(w, config)
if err != nil { if err != nil {
logger.Log(1, r.Header.Get("user"), "response writer error (file) ", err.Error()) logger.Log(1, r.Header.Get("user"), "response writer error (file) ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
} }
return return
} }
@@ -308,7 +310,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
err := errors.New("ingress does not exist") err := errors.New("ingress does not exist")
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to create extclient on network [%s]: %v", networkName, err)) fmt.Sprintf("failed to create extclient on network [%s]: %v", networkName, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -327,7 +329,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get ingress gateway node [%s] info: %v", nodeid, err)) fmt.Sprintf("failed to get ingress gateway node [%s] info: %v", nodeid, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
extclient.IngressGatewayEndpoint = node.Endpoint + ":" + strconv.FormatInt(int64(node.ListenPort), 10) extclient.IngressGatewayEndpoint = node.Endpoint + ":" + strconv.FormatInt(int64(node.ListenPort), 10)
@@ -337,13 +339,36 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
if err == nil { // check if parent network default ACL is enabled (yes) or not (no) if err == nil { // check if parent network default ACL is enabled (yes) or not (no)
extclient.Enabled = parentNetwork.DefaultACL == "yes" extclient.Enabled = parentNetwork.DefaultACL == "yes"
} }
// check pro settings
err = logic.CreateExtClient(&extclient) err = logic.CreateExtClient(&extclient)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to create new ext client on network [%s]: %v", networkName, err)) fmt.Sprintf("failed to create new ext client on network [%s]: %v", networkName, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
var isAdmin bool
if r.Header.Get("ismaster") != "yes" {
userID := r.Header.Get("user")
if isAdmin, err = checkProClientAccess(userID, extclient.ClientID, &parentNetwork); err != nil {
logger.Log(0, userID, "attempted to create a client on network", networkName, "but they lack access")
logic.DeleteExtClient(networkName, extclient.ClientID)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
if !isAdmin {
if err = pro.AssociateNetworkUserClient(userID, networkName, extclient.ClientID); err != nil {
logger.Log(0, "failed to associate client", extclient.ClientID, "to user", userID)
}
extclient.OwnerID = userID
if _, err := logic.UpdateExtClient(extclient.ClientID, extclient.Network, extclient.Enabled, &extclient); err != nil {
logger.Log(0, "failed to add owner id", userID, "to client", extclient.ClientID)
}
}
}
logger.Log(0, r.Header.Get("user"), "created new ext client on network", networkName) logger.Log(0, r.Header.Get("user"), "created new ext client on network", networkName)
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
err = mq.PublishExtPeerUpdate(&node) err = mq.PublishExtPeerUpdate(&node)
@@ -375,7 +400,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ", logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
clientid := params["clientid"] clientid := params["clientid"]
@@ -385,7 +410,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get record key for client [%s], network [%s]: %v", fmt.Sprintf("failed to get record key for client [%s], network [%s]: %v",
clientid, network, err)) clientid, network, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
data, err := database.FetchRecord(database.EXT_CLIENT_TABLE_NAME, key) data, err := database.FetchRecord(database.EXT_CLIENT_TABLE_NAME, key)
@@ -393,22 +418,46 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to fetch ext client record key [%s] from db for client [%s], network [%s]: %v", fmt.Sprintf("failed to fetch ext client record key [%s] from db for client [%s], network [%s]: %v",
key, clientid, network, err)) key, clientid, network, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
if err = json.Unmarshal([]byte(data), &oldExtClient); err != nil { if err = json.Unmarshal([]byte(data), &oldExtClient); err != nil {
logger.Log(0, "error unmarshalling extclient: ", logger.Log(0, "error unmarshalling extclient: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
// == PRO ==
networkName := params["network"]
var changedID = newExtClient.ClientID != oldExtClient.ClientID
if r.Header.Get("ismaster") != "yes" {
userID := r.Header.Get("user")
_, doesOwn := doesUserOwnClient(userID, params["clientid"], networkName)
if !doesOwn {
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("user not permitted"), "internal"))
return
}
}
if changedID && oldExtClient.OwnerID != "" {
if err := pro.DissociateNetworkUserClient(oldExtClient.OwnerID, networkName, oldExtClient.ClientID); err != nil {
logger.Log(0, "failed to dissociate client", oldExtClient.ClientID, "from user", oldExtClient.OwnerID)
}
if err := pro.AssociateNetworkUserClient(oldExtClient.OwnerID, networkName, newExtClient.ClientID); err != nil {
logger.Log(0, "failed to associate client", newExtClient.ClientID, "to user", oldExtClient.OwnerID)
}
}
// == END PRO ==
var changedEnabled = newExtClient.Enabled != oldExtClient.Enabled // indicates there was a change in enablement var changedEnabled = newExtClient.Enabled != oldExtClient.Enabled // indicates there was a change in enablement
newclient, err := logic.UpdateExtClient(newExtClient.ClientID, params["network"], newExtClient.Enabled, &oldExtClient) newclient, err := logic.UpdateExtClient(newExtClient.ClientID, params["network"], newExtClient.Enabled, &oldExtClient)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to update ext client [%s], network [%s]: %v", fmt.Sprintf("failed to update ext client [%s], network [%s]: %v",
clientid, network, err)) clientid, network, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
logger.Log(0, r.Header.Get("user"), "updated ext client", newExtClient.ClientID) logger.Log(0, r.Header.Get("user"), "updated ext client", newExtClient.ClientID)
@@ -448,23 +497,41 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) {
err = errors.New("Could not delete extclient " + params["clientid"]) err = errors.New("Could not delete extclient " + params["clientid"])
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to delete extclient [%s],network [%s]: %v", clientid, network, err)) fmt.Sprintf("failed to delete extclient [%s],network [%s]: %v", clientid, network, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
ingressnode, err := logic.GetNodeByID(extclient.IngressGatewayID) ingressnode, err := logic.GetNodeByID(extclient.IngressGatewayID)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get ingress gateway node [%s] info: %v", extclient.IngressGatewayID, err)) fmt.Sprintf("failed to get ingress gateway node [%s] info: %v", extclient.IngressGatewayID, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
// == PRO ==
if r.Header.Get("ismaster") != "yes" {
userID, clientID, networkName := r.Header.Get("user"), params["clientid"], params["network"]
_, doesOwn := doesUserOwnClient(userID, clientID, networkName)
if !doesOwn {
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("user not permitted"), "internal"))
return
}
}
if extclient.OwnerID != "" {
if err = pro.DissociateNetworkUserClient(extclient.OwnerID, extclient.Network, extclient.ClientID); err != nil {
logger.Log(0, "failed to dissociate client", extclient.ClientID, "from user", extclient.OwnerID)
}
}
// == END PRO ==
err = logic.DeleteExtClient(params["network"], params["clientid"]) err = logic.DeleteExtClient(params["network"], params["clientid"])
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to delete extclient [%s],network [%s]: %v", clientid, network, err)) fmt.Sprintf("failed to delete extclient [%s],network [%s]: %v", clientid, network, err))
err = errors.New("Could not delete extclient " + params["clientid"]) err = errors.New("Could not delete extclient " + params["clientid"])
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -472,7 +539,65 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(1, "error setting ext peers on "+ingressnode.ID+": "+err.Error()) logger.Log(1, "error setting ext peers on "+ingressnode.ID+": "+err.Error())
} }
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
"Deleted extclient client", params["clientid"], "from network", params["network"]) "Deleted extclient client", params["clientid"], "from network", params["network"])
returnSuccessResponse(w, r, params["clientid"]+" deleted.") logic.ReturnSuccessResponse(w, r, params["clientid"]+" deleted.")
}
func checkProClientAccess(username, clientID string, network *models.Network) (bool, error) {
u, err := logic.GetUser(username)
if err != nil {
return false, err
}
if u.IsAdmin {
return true, nil
}
netUser, err := pro.GetNetworkUser(network.NetID, promodels.NetworkUserID(u.UserName))
if err != nil {
return false, err
}
if netUser.AccessLevel == pro.NET_ADMIN {
return false, nil
}
if netUser.AccessLevel == pro.NO_ACCESS {
return false, fmt.Errorf("user does not have access")
}
if !(len(netUser.Clients) < netUser.ClientLimit) {
return false, fmt.Errorf("user can not create more clients")
}
if netUser.AccessLevel < pro.NO_ACCESS {
netUser.Clients = append(netUser.Clients, clientID)
if err = pro.UpdateNetworkUser(network.NetID, netUser); err != nil {
return false, err
}
}
return false, nil
}
// checks if net user owns an ext client or is an admin
func doesUserOwnClient(username, clientID, network string) (bool, bool) {
u, err := logic.GetUser(username)
if err != nil {
return false, false
}
if u.IsAdmin {
return true, true
}
netUser, err := pro.GetNetworkUser(network, promodels.NetworkUserID(u.UserName))
if err != nil {
return false, false
}
if netUser.AccessLevel == pro.NET_ADMIN {
return false, true
}
return false, logic.StringSliceContains(netUser.Clients, clientID)
} }

58
controllers/limits.go Normal file
View File

@@ -0,0 +1,58 @@
package controller
import (
"net/http"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
)
// limit consts
const (
node_l = 0
networks_l = 1
users_l = 2
clients_l = 3
)
func checkFreeTierLimits(limit_choice int, next http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var errorResponse = models.ErrorResponse{
Code: http.StatusUnauthorized, Message: "free tier limits exceeded on networks",
}
if logic.Free_Tier && logic.Is_EE { // check that free tier limits not exceeded
if limit_choice == networks_l {
currentNetworks, err := logic.GetNetworks()
if (err != nil && !database.IsEmptyRecord(err)) || len(currentNetworks) >= logic.Networks_Limit {
logic.ReturnErrorResponse(w, r, errorResponse)
return
}
} else if limit_choice == node_l {
nodes, err := logic.GetAllNodes()
if (err != nil && !database.IsEmptyRecord(err)) || len(nodes) >= logic.Node_Limit {
errorResponse.Message = "free tier limits exceeded on nodes"
logic.ReturnErrorResponse(w, r, errorResponse)
return
}
} else if limit_choice == users_l {
users, err := logic.GetUsers()
if (err != nil && !database.IsEmptyRecord(err)) || len(users) >= logic.Users_Limit {
errorResponse.Message = "free tier limits exceeded on users"
logic.ReturnErrorResponse(w, r, errorResponse)
return
}
} else if limit_choice == clients_l {
clients, err := logic.GetAllExtClients()
if (err != nil && !database.IsEmptyRecord(err)) || len(clients) >= logic.Clients_Limit {
errorResponse.Message = "free tier limits exceeded on external clients"
logic.ReturnErrorResponse(w, r, errorResponse)
return
}
}
}
next.ServeHTTP(w, r)
}
}

23
controllers/logger.go Normal file
View File

@@ -0,0 +1,23 @@
package controller
import (
"fmt"
"net/http"
"time"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
)
func loggerHandlers(r *mux.Router) {
r.HandleFunc("/api/logs", logic.SecurityCheck(true, http.HandlerFunc(getLogs))).Methods("GET")
}
func getLogs(w http.ResponseWriter, r *http.Request) {
var currentTime = time.Now().Format(logger.TimeFormatDay)
var currentFilePath = fmt.Sprintf("data/netmaker.log.%s", currentTime)
logger.DumpFile(currentFilePath)
w.WriteHeader(http.StatusOK)
w.Write([]byte(logger.Retrieve(currentFilePath)))
}

View File

@@ -17,26 +17,20 @@ import (
"github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/servercfg"
) )
// ALL_NETWORK_ACCESS - represents all networks
const ALL_NETWORK_ACCESS = "THIS_USER_HAS_ALL"
// NO_NETWORKS_PRESENT - represents no networks
const NO_NETWORKS_PRESENT = "THIS_USER_HAS_NONE"
func networkHandlers(r *mux.Router) { func networkHandlers(r *mux.Router) {
r.HandleFunc("/api/networks", securityCheck(false, http.HandlerFunc(getNetworks))).Methods("GET") r.HandleFunc("/api/networks", logic.SecurityCheck(false, http.HandlerFunc(getNetworks))).Methods("GET")
r.HandleFunc("/api/networks", securityCheck(true, http.HandlerFunc(createNetwork))).Methods("POST") r.HandleFunc("/api/networks", logic.SecurityCheck(true, checkFreeTierLimits(networks_l, http.HandlerFunc(createNetwork)))).Methods("POST")
r.HandleFunc("/api/networks/{networkname}", securityCheck(false, http.HandlerFunc(getNetwork))).Methods("GET") r.HandleFunc("/api/networks/{networkname}", logic.SecurityCheck(false, http.HandlerFunc(getNetwork))).Methods("GET")
r.HandleFunc("/api/networks/{networkname}", securityCheck(false, http.HandlerFunc(updateNetwork))).Methods("PUT") r.HandleFunc("/api/networks/{networkname}", logic.SecurityCheck(false, http.HandlerFunc(updateNetwork))).Methods("PUT")
r.HandleFunc("/api/networks/{networkname}/nodelimit", securityCheck(true, http.HandlerFunc(updateNetworkNodeLimit))).Methods("PUT") r.HandleFunc("/api/networks/{networkname}/nodelimit", logic.SecurityCheck(true, http.HandlerFunc(updateNetworkNodeLimit))).Methods("PUT")
r.HandleFunc("/api/networks/{networkname}", securityCheck(true, http.HandlerFunc(deleteNetwork))).Methods("DELETE") r.HandleFunc("/api/networks/{networkname}", logic.SecurityCheck(true, http.HandlerFunc(deleteNetwork))).Methods("DELETE")
r.HandleFunc("/api/networks/{networkname}/keyupdate", securityCheck(true, http.HandlerFunc(keyUpdate))).Methods("POST") r.HandleFunc("/api/networks/{networkname}/keyupdate", logic.SecurityCheck(true, http.HandlerFunc(keyUpdate))).Methods("POST")
r.HandleFunc("/api/networks/{networkname}/keys", securityCheck(false, http.HandlerFunc(createAccessKey))).Methods("POST") r.HandleFunc("/api/networks/{networkname}/keys", logic.SecurityCheck(false, http.HandlerFunc(createAccessKey))).Methods("POST")
r.HandleFunc("/api/networks/{networkname}/keys", securityCheck(false, http.HandlerFunc(getAccessKeys))).Methods("GET") r.HandleFunc("/api/networks/{networkname}/keys", logic.SecurityCheck(false, http.HandlerFunc(getAccessKeys))).Methods("GET")
r.HandleFunc("/api/networks/{networkname}/keys/{name}", securityCheck(false, http.HandlerFunc(deleteAccessKey))).Methods("DELETE") r.HandleFunc("/api/networks/{networkname}/keys/{name}", logic.SecurityCheck(false, http.HandlerFunc(deleteAccessKey))).Methods("DELETE")
// ACLs // ACLs
r.HandleFunc("/api/networks/{networkname}/acls", securityCheck(true, http.HandlerFunc(updateNetworkACL))).Methods("PUT") r.HandleFunc("/api/networks/{networkname}/acls", logic.SecurityCheck(true, http.HandlerFunc(updateNetworkACL))).Methods("PUT")
r.HandleFunc("/api/networks/{networkname}/acls", securityCheck(true, http.HandlerFunc(getNetworkACL))).Methods("GET") r.HandleFunc("/api/networks/{networkname}/acls", logic.SecurityCheck(true, http.HandlerFunc(getNetworkACL))).Methods("GET")
} }
// swagger:route GET /api/networks networks getNetworks // swagger:route GET /api/networks networks getNetworks
@@ -58,16 +52,16 @@ func getNetworks(w http.ResponseWriter, r *http.Request) {
if marshalErr != nil { if marshalErr != nil {
logger.Log(0, r.Header.Get("user"), "error unmarshalling networks: ", logger.Log(0, r.Header.Get("user"), "error unmarshalling networks: ",
marshalErr.Error()) marshalErr.Error())
returnErrorResponse(w, r, formatError(marshalErr, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(marshalErr, "badrequest"))
return return
} }
allnetworks := []models.Network{} allnetworks := []models.Network{}
var err error var err error
if networksSlice[0] == ALL_NETWORK_ACCESS { if networksSlice[0] == logic.ALL_NETWORK_ACCESS {
allnetworks, err = logic.GetNetworks() allnetworks, err = logic.GetNetworks()
if err != nil && !database.IsEmptyRecord(err) { if err != nil && !database.IsEmptyRecord(err) {
logger.Log(0, r.Header.Get("user"), "failed to fetch networks: ", err.Error()) logger.Log(0, r.Header.Get("user"), "failed to fetch networks: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
} else { } else {
@@ -110,7 +104,7 @@ func getNetwork(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to fetch network [%s] info: %v", logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to fetch network [%s] info: %v",
netname, err)) netname, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
if !servercfg.IsDisplayKeys() { if !servercfg.IsDisplayKeys() {
@@ -140,7 +134,7 @@ func keyUpdate(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to update keys for network [%s]: %v", logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to update keys for network [%s]: %v",
netname, err)) netname, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
logger.Log(2, r.Header.Get("user"), "updated key on network", netname) logger.Log(2, r.Header.Get("user"), "updated key on network", netname)
@@ -182,7 +176,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to get network info: ", logger.Log(0, r.Header.Get("user"), "failed to get network info: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
var newNetwork models.Network var newNetwork models.Network
@@ -190,7 +184,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ", logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
@@ -199,21 +193,39 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
newNetwork.DefaultPostUp = network.DefaultPostUp newNetwork.DefaultPostUp = network.DefaultPostUp
} }
rangeupdate4, rangeupdate6, localrangeupdate, holepunchupdate, err := logic.UpdateNetwork(&network, &newNetwork) rangeupdate4, rangeupdate6, localrangeupdate, holepunchupdate, groupsDelta, userDelta, err := logic.UpdateNetwork(&network, &newNetwork)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to update network: ", logger.Log(0, r.Header.Get("user"), "failed to update network: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
if len(groupsDelta) > 0 {
for _, g := range groupsDelta {
users, err := logic.GetGroupUsers(g)
if err == nil {
for _, user := range users {
logic.AdjustNetworkUserPermissions(&user, &newNetwork)
}
}
}
}
if len(userDelta) > 0 {
for _, uname := range userDelta {
user, err := logic.GetReturnUser(uname)
if err == nil {
logic.AdjustNetworkUserPermissions(&user, &newNetwork)
}
}
}
if rangeupdate4 { if rangeupdate4 {
err = logic.UpdateNetworkNodeAddresses(network.NetID) err = logic.UpdateNetworkNodeAddresses(network.NetID)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to update network [%s] ipv4 addresses: %v", fmt.Sprintf("failed to update network [%s] ipv4 addresses: %v",
network.NetID, err.Error())) network.NetID, err.Error()))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
} }
@@ -223,7 +235,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to update network [%s] ipv6 addresses: %v", fmt.Sprintf("failed to update network [%s] ipv6 addresses: %v",
network.NetID, err.Error())) network.NetID, err.Error()))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
} }
@@ -233,7 +245,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to update network [%s] local addresses: %v", fmt.Sprintf("failed to update network [%s] local addresses: %v",
network.NetID, err.Error())) network.NetID, err.Error()))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
} }
@@ -243,7 +255,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to update network [%s] hole punching: %v", fmt.Sprintf("failed to update network [%s] hole punching: %v",
network.NetID, err.Error())) network.NetID, err.Error()))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
} }
@@ -253,7 +265,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get network [%s] nodes: %v", fmt.Sprintf("failed to get network [%s] nodes: %v",
network.NetID, err.Error())) network.NetID, err.Error()))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
for _, node := range nodes { for _, node := range nodes {
@@ -287,7 +299,7 @@ func updateNetworkNodeLimit(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get network [%s] nodes: %v", fmt.Sprintf("failed to get network [%s] nodes: %v",
network.NetID, err.Error())) network.NetID, err.Error()))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -297,7 +309,7 @@ func updateNetworkNodeLimit(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ", logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
if networkChange.NodeLimit != 0 { if networkChange.NodeLimit != 0 {
@@ -306,7 +318,7 @@ func updateNetworkNodeLimit(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
"error marshalling resp: ", err.Error()) "error marshalling resp: ", err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME) database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME)
@@ -336,21 +348,21 @@ func updateNetworkACL(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to fetch ACLs for network [%s]: %v", netname, err)) fmt.Sprintf("failed to fetch ACLs for network [%s]: %v", netname, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
err = json.NewDecoder(r.Body).Decode(&networkACLChange) err = json.NewDecoder(r.Body).Decode(&networkACLChange)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ", logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
newNetACL, err := networkACLChange.Save(acls.ContainerID(netname)) newNetACL, err := networkACLChange.Save(acls.ContainerID(netname))
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to update ACLs for network [%s]: %v", netname, err)) fmt.Sprintf("failed to update ACLs for network [%s]: %v", netname, err))
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
logger.Log(1, r.Header.Get("user"), "updated ACLs for network", netname) logger.Log(1, r.Header.Get("user"), "updated ACLs for network", netname)
@@ -394,7 +406,7 @@ func getNetworkACL(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to fetch ACLs for network [%s]: %v", netname, err)) fmt.Sprintf("failed to fetch ACLs for network [%s]: %v", netname, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
logger.Log(2, r.Header.Get("user"), "fetched acl for network", netname) logger.Log(2, r.Header.Get("user"), "fetched acl for network", netname)
@@ -427,7 +439,7 @@ func deleteNetwork(w http.ResponseWriter, r *http.Request) {
} }
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to delete network [%s]: %v", network, err)) fmt.Sprintf("failed to delete network [%s]: %v", network, err))
returnErrorResponse(w, r, formatError(err, errtype)) logic.ReturnErrorResponse(w, r, logic.FormatError(err, errtype))
return return
} }
logger.Log(1, r.Header.Get("user"), "deleted network", network) logger.Log(1, r.Header.Get("user"), "deleted network", network)
@@ -457,7 +469,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ", logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
@@ -465,7 +477,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
err := errors.New("IPv4 or IPv6 CIDR required") err := errors.New("IPv4 or IPv6 CIDR required")
logger.Log(0, r.Header.Get("user"), "failed to create network: ", logger.Log(0, r.Header.Get("user"), "failed to create network: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
@@ -473,7 +485,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to create network: ", logger.Log(0, r.Header.Get("user"), "failed to create network: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
@@ -486,7 +498,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
} }
logger.Log(0, r.Header.Get("user"), "failed to create network: ", logger.Log(0, r.Header.Get("user"), "failed to create network: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
} }
@@ -519,23 +531,32 @@ func createAccessKey(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to get network info: ", logger.Log(0, r.Header.Get("user"), "failed to get network info: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
err = json.NewDecoder(r.Body).Decode(&accesskey) err = json.NewDecoder(r.Body).Decode(&accesskey)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ", logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
key, err := logic.CreateAccessKey(accesskey, network) key, err := logic.CreateAccessKey(accesskey, network)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to create access key: ", logger.Log(0, r.Header.Get("user"), "failed to create access key: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
// do not allow access key creations view API with user names
if _, err = logic.GetUser(key.Name); err == nil {
logger.Log(0, "access key creation with invalid name attempted by", r.Header.Get("user"))
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("cannot create access key with user name"), "badrequest"))
logic.DeleteKey(key.Name, network.NetID)
return
}
logger.Log(1, r.Header.Get("user"), "created access key", accesskey.Name, "on", netname) logger.Log(1, r.Header.Get("user"), "created access key", accesskey.Name, "on", netname)
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(key) json.NewEncoder(w).Encode(key)
@@ -560,7 +581,7 @@ func getAccessKeys(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get keys for network [%s]: %v", logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get keys for network [%s]: %v",
network, err)) network, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
if !servercfg.IsDisplayKeys() { if !servercfg.IsDisplayKeys() {
@@ -594,7 +615,7 @@ func deleteAccessKey(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to delete key [%s] for network [%s]: %v", logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to delete key [%s] for network [%s]: %v",
keyname, netname, err)) keyname, netname, err))
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
logger.Log(1, r.Header.Get("user"), "deleted access key", keyname, "on network,", netname) logger.Log(1, r.Header.Get("user"), "deleted access key", keyname, "on network,", netname)

View File

@@ -17,7 +17,7 @@ type NetworkValidationTestCase struct {
} }
func TestCreateNetwork(t *testing.T) { func TestCreateNetwork(t *testing.T) {
database.InitializeDatabase() initialize()
deleteAllNetworks() deleteAllNetworks()
var network models.Network var network models.Network
@@ -30,7 +30,7 @@ func TestCreateNetwork(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
} }
func TestGetNetwork(t *testing.T) { func TestGetNetwork(t *testing.T) {
database.InitializeDatabase() initialize()
createNet() createNet()
t.Run("GetExistingNetwork", func(t *testing.T) { t.Run("GetExistingNetwork", func(t *testing.T) {
@@ -46,7 +46,7 @@ func TestGetNetwork(t *testing.T) {
} }
func TestDeleteNetwork(t *testing.T) { func TestDeleteNetwork(t *testing.T) {
database.InitializeDatabase() initialize()
createNet() createNet()
//create nodes //create nodes
t.Run("NetworkwithNodes", func(t *testing.T) { t.Run("NetworkwithNodes", func(t *testing.T) {
@@ -62,7 +62,7 @@ func TestDeleteNetwork(t *testing.T) {
} }
func TestCreateKey(t *testing.T) { func TestCreateKey(t *testing.T) {
database.InitializeDatabase() initialize()
createNet() createNet()
keys, _ := logic.GetKeys("skynet") keys, _ := logic.GetKeys("skynet")
for _, key := range keys { for _, key := range keys {
@@ -74,7 +74,7 @@ func TestCreateKey(t *testing.T) {
t.Run("NameTooLong", func(t *testing.T) { t.Run("NameTooLong", func(t *testing.T) {
network, err := logic.GetNetwork("skynet") network, err := logic.GetNetwork("skynet")
assert.Nil(t, err) assert.Nil(t, err)
accesskey.Name = "Thisisareallylongkeynamethatwillfail" accesskey.Name = "ThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfail"
_, err = logic.CreateAccessKey(accesskey, network) _, err = logic.CreateAccessKey(accesskey, network)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Contains(t, err.Error(), "Field validation for 'Name' failed on the 'max' tag") assert.Contains(t, err.Error(), "Field validation for 'Name' failed on the 'max' tag")
@@ -134,7 +134,7 @@ func TestCreateKey(t *testing.T) {
} }
func TestGetKeys(t *testing.T) { func TestGetKeys(t *testing.T) {
database.InitializeDatabase() initialize()
deleteAllNetworks() deleteAllNetworks()
createNet() createNet()
network, err := logic.GetNetwork("skynet") network, err := logic.GetNetwork("skynet")
@@ -157,7 +157,7 @@ func TestGetKeys(t *testing.T) {
}) })
} }
func TestDeleteKey(t *testing.T) { func TestDeleteKey(t *testing.T) {
database.InitializeDatabase() initialize()
createNet() createNet()
network, err := logic.GetNetwork("skynet") network, err := logic.GetNetwork("skynet")
assert.Nil(t, err) assert.Nil(t, err)
@@ -179,27 +179,27 @@ func TestDeleteKey(t *testing.T) {
func TestSecurityCheck(t *testing.T) { func TestSecurityCheck(t *testing.T) {
//these seem to work but not sure it the tests are really testing the functionality //these seem to work but not sure it the tests are really testing the functionality
database.InitializeDatabase() initialize()
os.Setenv("MASTER_KEY", "secretkey") os.Setenv("MASTER_KEY", "secretkey")
t.Run("NoNetwork", func(t *testing.T) { t.Run("NoNetwork", func(t *testing.T) {
networks, username, err := SecurityCheck(false, "", "Bearer secretkey") networks, username, err := logic.UserPermissions(false, "", "Bearer secretkey")
assert.Nil(t, err) assert.Nil(t, err)
t.Log(networks, username) t.Log(networks, username)
}) })
t.Run("WithNetwork", func(t *testing.T) { t.Run("WithNetwork", func(t *testing.T) {
networks, username, err := SecurityCheck(false, "skynet", "Bearer secretkey") networks, username, err := logic.UserPermissions(false, "skynet", "Bearer secretkey")
assert.Nil(t, err) assert.Nil(t, err)
t.Log(networks, username) t.Log(networks, username)
}) })
t.Run("BadNet", func(t *testing.T) { t.Run("BadNet", func(t *testing.T) {
t.Skip() t.Skip()
networks, username, err := SecurityCheck(false, "badnet", "Bearer secretkey") networks, username, err := logic.UserPermissions(false, "badnet", "Bearer secretkey")
assert.NotNil(t, err) assert.NotNil(t, err)
t.Log(err) t.Log(err)
t.Log(networks, username) t.Log(networks, username)
}) })
t.Run("BadToken", func(t *testing.T) { t.Run("BadToken", func(t *testing.T) {
networks, username, err := SecurityCheck(false, "skynet", "Bearer badkey") networks, username, err := logic.UserPermissions(false, "skynet", "Bearer badkey")
assert.NotNil(t, err) assert.NotNil(t, err)
t.Log(err) t.Log(err)
t.Log(networks, username) t.Log(networks, username)
@@ -210,7 +210,7 @@ func TestValidateNetwork(t *testing.T) {
//t.Skip() //t.Skip()
//This functions is not called by anyone //This functions is not called by anyone
//it panics as validation function 'display_name_valid' is not defined //it panics as validation function 'display_name_valid' is not defined
database.InitializeDatabase() initialize()
//yes := true //yes := true
//no := false //no := false
//deleteNet(t) //deleteNet(t)
@@ -295,7 +295,7 @@ func TestValidateNetwork(t *testing.T) {
func TestIpv6Network(t *testing.T) { func TestIpv6Network(t *testing.T) {
//these seem to work but not sure it the tests are really testing the functionality //these seem to work but not sure it the tests are really testing the functionality
database.InitializeDatabase() initialize()
os.Setenv("MASTER_KEY", "secretkey") os.Setenv("MASTER_KEY", "secretkey")
deleteAllNetworks() deleteAllNetworks()
createNet() createNet()
@@ -321,6 +321,21 @@ func deleteAllNetworks() {
} }
} }
func initialize() {
database.InitializeDatabase()
createAdminUser()
}
func createAdminUser() {
logic.CreateAdmin(models.User{
UserName: "admin",
Password: "password",
IsAdmin: true,
Networks: []string{},
Groups: []string{},
})
}
func createNet() { func createNet() {
var network models.Network var network models.Network
network.NetID = "skynet" network.NetID = "skynet"

365
controllers/networkusers.go Normal file
View File

@@ -0,0 +1,365 @@
package controller
import (
"encoding/json"
"errors"
"net/http"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/pro"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/models/promodels"
)
func networkUsersHandlers(r *mux.Router) {
r.HandleFunc("/api/networkusers", logic.SecurityCheck(true, http.HandlerFunc(getAllNetworkUsers))).Methods("GET")
r.HandleFunc("/api/networkusers/{network}", logic.SecurityCheck(true, http.HandlerFunc(getNetworkUsers))).Methods("GET")
r.HandleFunc("/api/networkusers/{network}/{networkuser}", logic.SecurityCheck(true, http.HandlerFunc(getNetworkUser))).Methods("GET")
r.HandleFunc("/api/networkusers/{network}", logic.SecurityCheck(true, http.HandlerFunc(createNetworkUser))).Methods("POST")
r.HandleFunc("/api/networkusers/{network}", logic.SecurityCheck(true, http.HandlerFunc(updateNetworkUser))).Methods("PUT")
r.HandleFunc("/api/networkusers/data/{networkuser}/me", logic.NetUserSecurityCheck(false, false, http.HandlerFunc(getNetworkUserData))).Methods("GET")
r.HandleFunc("/api/networkusers/{network}/{networkuser}", logic.SecurityCheck(true, http.HandlerFunc(deleteNetworkUser))).Methods("DELETE")
}
// == RETURN TYPES ==
// NetworkName - represents a network name/ID
type NetworkName string
// NetworkUserDataMap - map of all data per network for a user
type NetworkUserDataMap map[NetworkName]NetworkUserData
// NetworkUserData - data struct for network users
type NetworkUserData struct {
Nodes []models.Node `json:"nodes" bson:"nodes" yaml:"nodes"`
Clients []models.ExtClient `json:"clients" bson:"clients" yaml:"clients"`
Vpn []models.Node `json:"vpns" bson:"vpns" yaml:"vpns"`
Networks []models.Network `json:"networks" bson:"networks" yaml:"networks"`
User promodels.NetworkUser `json:"user" bson:"user" yaml:"user"`
}
// == END RETURN TYPES ==
// returns a map of a network user's data across all networks
func getNetworkUserData(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
networkUserName := params["networkuser"]
logger.Log(1, r.Header.Get("user"), "requested fetching network user data for user", networkUserName)
networks, err := logic.GetNetworks()
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
if networkUserName == "" {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("netuserToGet"), "badrequest"))
return
}
u, err := logic.GetUser(networkUserName)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("could not find user"), "badrequest"))
return
}
// initialize the return data of network users
returnData := make(NetworkUserDataMap)
// go through each network and get that user's data
// if user has no access, give no data
// if user is a net admin, give all nodes
// if user has node access, give user's nodes if any
// if user has client access, git user's clients if any
for i := range networks {
netID := networks[i].NetID
newData := NetworkUserData{
Nodes: []models.Node{},
Clients: []models.ExtClient{},
Vpn: []models.Node{},
Networks: []models.Network{},
}
netUser, err := pro.GetNetworkUser(netID, promodels.NetworkUserID(networkUserName))
// check if user has access
if err == nil && netUser.AccessLevel != pro.NO_ACCESS {
newData.User = promodels.NetworkUser{
AccessLevel: netUser.AccessLevel,
ClientLimit: netUser.ClientLimit,
NodeLimit: netUser.NodeLimit,
Nodes: netUser.Nodes,
Clients: netUser.Clients,
}
newData.User.SetDefaults()
// check network level permissions
if doesNetworkAllow := pro.IsUserAllowed(&networks[i], networkUserName, u.Groups); doesNetworkAllow || netUser.AccessLevel == pro.NET_ADMIN {
netNodes, err := logic.GetNetworkNodes(netID)
if err != nil {
if database.IsEmptyRecord(err) && netUser.AccessLevel == pro.NET_ADMIN {
newData.Networks = append(newData.Networks, networks[i])
} else {
logger.Log(0, "failed to retrieve nodes on network", netID, "for user", string(netUser.ID))
}
} else {
if netUser.AccessLevel <= pro.NODE_ACCESS { // handle nodes
// if access level is NODE_ACCESS, filter nodes
if netUser.AccessLevel == pro.NODE_ACCESS {
for i := range netNodes {
if logic.StringSliceContains(netUser.Nodes, netNodes[i].ID) {
newData.Nodes = append(newData.Nodes, netNodes[i])
}
}
} else { // net admin so, get all nodes and ext clients on network...
newData.Nodes = netNodes
for i := range netNodes {
if netNodes[i].IsIngressGateway == "yes" {
newData.Vpn = append(newData.Vpn, netNodes[i])
if clients, err := logic.GetExtClientsByID(netNodes[i].ID, netID); err == nil {
newData.Clients = append(newData.Clients, clients...)
}
}
}
newData.Networks = append(newData.Networks, networks[i])
}
}
if netUser.AccessLevel <= pro.CLIENT_ACCESS && netUser.AccessLevel != pro.NET_ADMIN {
for _, c := range netUser.Clients {
if client, err := logic.GetExtClient(c, netID); err == nil {
newData.Clients = append(newData.Clients, client)
}
}
for i := range netNodes {
if netNodes[i].IsIngressGateway == "yes" {
newData.Vpn = append(newData.Vpn, netNodes[i])
}
}
}
}
}
returnData[NetworkName(netID)] = newData
}
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(returnData)
}
// returns a map of all network users mapped to each network
func getAllNetworkUsers(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
logger.Log(1, r.Header.Get("user"), "requested fetching all network users")
type allNetworkUsers = map[string][]promodels.NetworkUser
networks, err := logic.GetNetworks()
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
var allNetUsers = make(allNetworkUsers, len(networks))
for i := range networks {
netusers, err := pro.GetNetworkUsers(networks[i].NetID)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
for _, v := range netusers {
allNetUsers[networks[i].NetID] = append(allNetUsers[networks[i].NetID], v)
}
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(allNetUsers)
}
func getNetworkUsers(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
netname := params["network"]
logger.Log(1, r.Header.Get("user"), "requested fetching network users for network", netname)
_, err := logic.GetNetwork(netname)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
netusers, err := pro.GetNetworkUsers(netname)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(netusers)
}
func getNetworkUser(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
netname := params["network"]
logger.Log(1, r.Header.Get("user"), "requested fetching network user", params["networkuser"], "on network", netname)
_, err := logic.GetNetwork(netname)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
netuserToGet := params["networkuser"]
if netuserToGet == "" {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("netuserToGet"), "badrequest"))
return
}
netuser, err := pro.GetNetworkUser(netname, promodels.NetworkUserID(netuserToGet))
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(netuser)
}
func createNetworkUser(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
netname := params["network"]
logger.Log(1, r.Header.Get("user"), "requested creating a network user on network", netname)
network, err := logic.GetNetwork(netname)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
var networkuser promodels.NetworkUser
// we decode our body request params
err = json.NewDecoder(r.Body).Decode(&networkuser)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
err = pro.CreateNetworkUser(&network, &networkuser)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
w.WriteHeader(http.StatusOK)
}
func updateNetworkUser(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
netname := params["network"]
logger.Log(1, r.Header.Get("user"), "requested updating a network user on network", netname)
network, err := logic.GetNetwork(netname)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
var networkuser promodels.NetworkUser
// we decode our body request params
err = json.NewDecoder(r.Body).Decode(&networkuser)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
if networkuser.ID == "" || !pro.DoesNetworkUserExist(netname, networkuser.ID) {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid user "+string(networkuser.ID)), "badrequest"))
return
}
if networkuser.AccessLevel < pro.NET_ADMIN || networkuser.AccessLevel > pro.NO_ACCESS {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid user access level provided"), "badrequest"))
return
}
if networkuser.ClientLimit < 0 || networkuser.NodeLimit < 0 {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("negative user limit provided"), "badrequest"))
return
}
u, err := logic.GetUser(string(networkuser.ID))
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid user "+string(networkuser.ID)), "badrequest"))
return
}
if !pro.IsUserAllowed(&network, u.UserName, u.Groups) {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("user must be in allowed groups or users"), "badrequest"))
return
}
if networkuser.AccessLevel == pro.NET_ADMIN {
currentUser, err := logic.GetUser(string(networkuser.ID))
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("user model not found for "+string(networkuser.ID)), "badrequest"))
return
}
if !logic.StringSliceContains(currentUser.Networks, netname) {
// append network name to user model to conform to old model
if err = logic.UpdateUserNetworks(
append(currentUser.Networks, netname),
currentUser.Groups,
currentUser.IsAdmin,
&models.ReturnUser{
Groups: currentUser.Groups,
IsAdmin: currentUser.IsAdmin,
Networks: currentUser.Networks,
UserName: currentUser.UserName,
},
); err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("user model failed net admin update "+string(networkuser.ID)+" (are they an admin?"), "badrequest"))
return
}
}
}
err = pro.UpdateNetworkUser(netname, &networkuser)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
w.WriteHeader(http.StatusOK)
}
func deleteNetworkUser(w http.ResponseWriter, r *http.Request) {
var params = mux.Vars(r)
netname := params["network"]
logger.Log(1, r.Header.Get("user"), "requested deleting network user", params["networkuser"], "on network", netname)
_, err := logic.GetNetwork(netname)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
netuserToDelete := params["networkuser"]
if netuserToDelete == "" {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("no group name provided"), "badrequest"))
return
}
if err := pro.DeleteNetworkUser(netname, netuserToDelete); err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
w.WriteHeader(http.StatusOK)
}

View File

@@ -8,10 +8,11 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/functions"
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/pro"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/models/promodels"
"github.com/gravitl/netmaker/mq" "github.com/gravitl/netmaker/mq"
"github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/servercfg"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
@@ -28,10 +29,10 @@ func nodeHandlers(r *mux.Router) {
r.HandleFunc("/api/nodes/{network}/{nodeid}/deleterelay", authorize(false, true, "user", http.HandlerFunc(deleteRelay))).Methods("DELETE") r.HandleFunc("/api/nodes/{network}/{nodeid}/deleterelay", authorize(false, true, "user", http.HandlerFunc(deleteRelay))).Methods("DELETE")
r.HandleFunc("/api/nodes/{network}/{nodeid}/creategateway", authorize(false, true, "user", http.HandlerFunc(createEgressGateway))).Methods("POST") r.HandleFunc("/api/nodes/{network}/{nodeid}/creategateway", authorize(false, true, "user", http.HandlerFunc(createEgressGateway))).Methods("POST")
r.HandleFunc("/api/nodes/{network}/{nodeid}/deletegateway", authorize(false, true, "user", http.HandlerFunc(deleteEgressGateway))).Methods("DELETE") r.HandleFunc("/api/nodes/{network}/{nodeid}/deletegateway", authorize(false, true, "user", http.HandlerFunc(deleteEgressGateway))).Methods("DELETE")
r.HandleFunc("/api/nodes/{network}/{nodeid}/createingress", securityCheck(false, http.HandlerFunc(createIngressGateway))).Methods("POST") r.HandleFunc("/api/nodes/{network}/{nodeid}/createingress", logic.SecurityCheck(false, http.HandlerFunc(createIngressGateway))).Methods("POST")
r.HandleFunc("/api/nodes/{network}/{nodeid}/deleteingress", securityCheck(false, http.HandlerFunc(deleteIngressGateway))).Methods("DELETE") r.HandleFunc("/api/nodes/{network}/{nodeid}/deleteingress", logic.SecurityCheck(false, http.HandlerFunc(deleteIngressGateway))).Methods("DELETE")
r.HandleFunc("/api/nodes/{network}/{nodeid}/approve", authorize(false, true, "user", http.HandlerFunc(uncordonNode))).Methods("POST") r.HandleFunc("/api/nodes/{network}/{nodeid}/approve", authorize(false, true, "user", http.HandlerFunc(uncordonNode))).Methods("POST")
r.HandleFunc("/api/nodes/{network}", nodeauth(http.HandlerFunc(createNode))).Methods("POST") r.HandleFunc("/api/nodes/{network}", nodeauth(checkFreeTierLimits(node_l, http.HandlerFunc(createNode)))).Methods("POST")
r.HandleFunc("/api/nodes/adm/{network}/lastmodified", authorize(false, true, "network", http.HandlerFunc(getLastModified))).Methods("GET") r.HandleFunc("/api/nodes/adm/{network}/lastmodified", authorize(false, true, "network", http.HandlerFunc(getLastModified))).Methods("GET")
r.HandleFunc("/api/nodes/adm/{network}/authenticate", authenticate).Methods("POST") r.HandleFunc("/api/nodes/adm/{network}/authenticate", authenticate).Methods("POST")
} }
@@ -64,19 +65,19 @@ func authenticate(response http.ResponseWriter, request *http.Request) {
errorResponse.Message = decoderErr.Error() errorResponse.Message = decoderErr.Error()
logger.Log(0, request.Header.Get("user"), "error decoding request body: ", logger.Log(0, request.Header.Get("user"), "error decoding request body: ",
decoderErr.Error()) decoderErr.Error())
returnErrorResponse(response, request, errorResponse) logic.ReturnErrorResponse(response, request, errorResponse)
return return
} else { } else {
errorResponse.Code = http.StatusBadRequest errorResponse.Code = http.StatusBadRequest
if authRequest.ID == "" { if authRequest.ID == "" {
errorResponse.Message = "W1R3: ID can't be empty" errorResponse.Message = "W1R3: ID can't be empty"
logger.Log(0, request.Header.Get("user"), errorResponse.Message) logger.Log(0, request.Header.Get("user"), errorResponse.Message)
returnErrorResponse(response, request, errorResponse) logic.ReturnErrorResponse(response, request, errorResponse)
return return
} else if authRequest.Password == "" { } else if authRequest.Password == "" {
errorResponse.Message = "W1R3: Password can't be empty" errorResponse.Message = "W1R3: Password can't be empty"
logger.Log(0, request.Header.Get("user"), errorResponse.Message) logger.Log(0, request.Header.Get("user"), errorResponse.Message)
returnErrorResponse(response, request, errorResponse) logic.ReturnErrorResponse(response, request, errorResponse)
return return
} else { } else {
var err error var err error
@@ -87,7 +88,7 @@ func authenticate(response http.ResponseWriter, request *http.Request) {
errorResponse.Message = err.Error() errorResponse.Message = err.Error()
logger.Log(0, request.Header.Get("user"), logger.Log(0, request.Header.Get("user"),
fmt.Sprintf("failed to get node info [%s]: %v", authRequest.ID, err)) fmt.Sprintf("failed to get node info [%s]: %v", authRequest.ID, err))
returnErrorResponse(response, request, errorResponse) logic.ReturnErrorResponse(response, request, errorResponse)
return return
} }
@@ -97,7 +98,7 @@ func authenticate(response http.ResponseWriter, request *http.Request) {
errorResponse.Message = err.Error() errorResponse.Message = err.Error()
logger.Log(0, request.Header.Get("user"), logger.Log(0, request.Header.Get("user"),
"error validating user password: ", err.Error()) "error validating user password: ", err.Error())
returnErrorResponse(response, request, errorResponse) logic.ReturnErrorResponse(response, request, errorResponse)
return return
} else { } else {
tokenString, err := logic.CreateJWT(authRequest.ID, authRequest.MacAddress, result.Network) tokenString, err := logic.CreateJWT(authRequest.ID, authRequest.MacAddress, result.Network)
@@ -107,7 +108,7 @@ func authenticate(response http.ResponseWriter, request *http.Request) {
errorResponse.Message = "Could not create Token" errorResponse.Message = "Could not create Token"
logger.Log(0, request.Header.Get("user"), logger.Log(0, request.Header.Get("user"),
fmt.Sprintf("%s: %v", errorResponse.Message, err)) fmt.Sprintf("%s: %v", errorResponse.Message, err))
returnErrorResponse(response, request, errorResponse) logic.ReturnErrorResponse(response, request, errorResponse)
return return
} }
@@ -126,7 +127,7 @@ func authenticate(response http.ResponseWriter, request *http.Request) {
errorResponse.Message = err.Error() errorResponse.Message = err.Error()
logger.Log(0, request.Header.Get("user"), logger.Log(0, request.Header.Get("user"),
"error marshalling resp: ", err.Error()) "error marshalling resp: ", err.Error())
returnErrorResponse(response, request, errorResponse) logic.ReturnErrorResponse(response, request, errorResponse)
return return
} }
response.WriteHeader(http.StatusOK) response.WriteHeader(http.StatusOK)
@@ -147,7 +148,7 @@ func nodeauth(next http.Handler) http.HandlerFunc {
errorResponse := models.ErrorResponse{ errorResponse := models.ErrorResponse{
Code: http.StatusUnauthorized, Message: "W1R3: You are unauthorized to access this endpoint.", Code: http.StatusUnauthorized, Message: "W1R3: You are unauthorized to access this endpoint.",
} }
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} else { } else {
token = tokenSplit[1] token = tokenSplit[1]
@@ -159,7 +160,7 @@ func nodeauth(next http.Handler) http.HandlerFunc {
errorResponse := models.ErrorResponse{ errorResponse := models.ErrorResponse{
Code: http.StatusNotFound, Message: "no networks", Code: http.StatusNotFound, Message: "no networks",
} }
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} }
for _, network := range networks { for _, network := range networks {
@@ -175,7 +176,7 @@ func nodeauth(next http.Handler) http.HandlerFunc {
errorResponse := models.ErrorResponse{ errorResponse := models.ErrorResponse{
Code: http.StatusUnauthorized, Message: "You are unauthorized to access this endpoint.", Code: http.StatusUnauthorized, Message: "You are unauthorized to access this endpoint.",
} }
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} }
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
@@ -192,16 +193,16 @@ func nodeauth(next http.Handler) http.HandlerFunc {
func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Handler) http.HandlerFunc { func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var errorResponse = models.ErrorResponse{ var errorResponse = models.ErrorResponse{
Code: http.StatusUnauthorized, Message: unauthorized_msg, Code: http.StatusUnauthorized, Message: logic.Unauthorized_Msg,
} }
var params = mux.Vars(r) var params = mux.Vars(r)
networkexists, _ := functions.NetworkExists(params["network"]) networkexists, _ := logic.NetworkExists(params["network"])
//check that the request is for a valid network //check that the request is for a valid network
//if (networkCheck && !networkexists) || err != nil { //if (networkCheck && !networkexists) || err != nil {
if networkCheck && !networkexists { if networkCheck && !networkexists {
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} else { } else {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
@@ -218,7 +219,7 @@ func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Ha
if len(tokenSplit) > 1 { if len(tokenSplit) > 1 {
authToken = tokenSplit[1] authToken = tokenSplit[1]
} else { } else {
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} }
//check if node instead of user //check if node instead of user
@@ -234,9 +235,10 @@ func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Ha
var nodeID = "" var nodeID = ""
username, networks, isadmin, errN := logic.VerifyUserToken(authToken) username, networks, isadmin, errN := logic.VerifyUserToken(authToken)
if errN != nil { if errN != nil {
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} }
isnetadmin := isadmin isnetadmin := isadmin
if errN == nil && isadmin { if errN == nil && isadmin {
nodeID = "mastermac" nodeID = "mastermac"
@@ -244,7 +246,7 @@ func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Ha
r.Header.Set("ismasterkey", "yes") r.Header.Set("ismasterkey", "yes")
} }
if !isadmin && params["network"] != "" { if !isadmin && params["network"] != "" {
if logic.StringSliceContains(networks, params["network"]) { if logic.StringSliceContains(networks, params["network"]) && pro.IsUserNetAdmin(params["network"], username) {
isnetadmin = true isnetadmin = true
} }
} }
@@ -266,7 +268,7 @@ func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Ha
} else { } else {
node, err := logic.GetNodeByID(nodeID) node, err := logic.GetNodeByID(nodeID)
if err != nil { if err != nil {
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} }
isAuthorized = (node.Network == params["network"]) isAuthorized = (node.Network == params["network"])
@@ -284,7 +286,7 @@ func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Ha
} }
} }
if !isAuthorized { if !isAuthorized {
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} else { } else {
//If authorized, this function passes along it's request and output to the appropriate route function. //If authorized, this function passes along it's request and output to the appropriate route function.
@@ -321,7 +323,7 @@ func getNetworkNodes(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("error fetching nodes on network %s: %v", networkName, err)) fmt.Sprintf("error fetching nodes on network %s: %v", networkName, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -355,7 +357,7 @@ func getAllNodes(w http.ResponseWriter, r *http.Request) {
if err != nil && r.Header.Get("ismasterkey") != "yes" { if err != nil && r.Header.Get("ismasterkey") != "yes" {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
"error fetching user info: ", err.Error()) "error fetching user info: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
var nodes []models.Node var nodes []models.Node
@@ -363,7 +365,7 @@ func getAllNodes(w http.ResponseWriter, r *http.Request) {
nodes, err = logic.GetAllNodes() nodes, err = logic.GetAllNodes()
if err != nil { if err != nil {
logger.Log(0, "error fetching all nodes info: ", err.Error()) logger.Log(0, "error fetching all nodes info: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
} else { } else {
@@ -371,7 +373,7 @@ func getAllNodes(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
"error fetching nodes: ", err.Error()) "error fetching nodes: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
} }
@@ -415,7 +417,7 @@ func getNode(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("error fetching node [ %s ] info: %v", nodeid, err)) fmt.Sprintf("error fetching node [ %s ] info: %v", nodeid, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -423,7 +425,7 @@ func getNode(w http.ResponseWriter, r *http.Request) {
if err != nil && !database.IsEmptyRecord(err) { if err != nil && !database.IsEmptyRecord(err) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("error fetching wg peers config for node [ %s ]: %v", nodeid, err)) fmt.Sprintf("error fetching wg peers config for node [ %s ]: %v", nodeid, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -435,6 +437,7 @@ func getNode(w http.ResponseWriter, r *http.Request) {
Node: node, Node: node,
Peers: peerUpdate.Peers, Peers: peerUpdate.Peers,
ServerConfig: servercfg.GetServerInfo(), ServerConfig: servercfg.GetServerInfo(),
PeerIDs: peerUpdate.PeerIDs,
} }
logger.Log(2, r.Header.Get("user"), "fetched node", params["nodeid"]) logger.Log(2, r.Header.Get("user"), "fetched node", params["nodeid"])
@@ -466,7 +469,7 @@ func getLastModified(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("error fetching network [%s] info: %v", networkName, err)) fmt.Sprintf("error fetching network [%s] info: %v", networkName, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
logger.Log(2, r.Header.Get("user"), "called last modified") logger.Log(2, r.Header.Get("user"), "called last modified")
@@ -494,12 +497,12 @@ func createNode(w http.ResponseWriter, r *http.Request) {
Code: http.StatusInternalServerError, Message: "W1R3: It's not you it's me.", Code: http.StatusInternalServerError, Message: "W1R3: It's not you it's me.",
} }
networkName := params["network"] networkName := params["network"]
networkexists, err := functions.NetworkExists(networkName) networkexists, err := logic.NetworkExists(networkName)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to fetch network [%s] info: %v", networkName, err)) fmt.Sprintf("failed to fetch network [%s] info: %v", networkName, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} else if !networkexists { } else if !networkexists {
errorResponse = models.ErrorResponse{ errorResponse = models.ErrorResponse{
@@ -507,7 +510,7 @@ func createNode(w http.ResponseWriter, r *http.Request) {
} }
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("network [%s] does not exist", networkName)) fmt.Sprintf("network [%s] does not exist", networkName))
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} }
@@ -517,7 +520,7 @@ func createNode(w http.ResponseWriter, r *http.Request) {
err = json.NewDecoder(r.Body).Decode(&node) err = json.NewDecoder(r.Body).Decode(&node)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
@@ -527,17 +530,17 @@ func createNode(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get network [%s] info: %v", node.Network, err)) fmt.Sprintf("failed to get network [%s] info: %v", node.Network, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
node.NetworkSettings, err = logic.GetNetworkSettings(node.Network) node.NetworkSettings, err = logic.GetNetworkSettings(node.Network)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get network [%s] settings: %v", node.Network, err)) fmt.Sprintf("failed to get network [%s] settings: %v", node.Network, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
validKey := logic.IsKeyValid(networkName, node.AccessKey) keyName, validKey := logic.IsKeyValid(networkName, node.AccessKey)
if !validKey { if !validKey {
// Check to see if network will allow manual sign up // Check to see if network will allow manual sign up
// may want to switch this up with the valid key check and avoid a DB call that way. // may want to switch this up with the valid key check and avoid a DB call that way.
@@ -550,24 +553,32 @@ func createNode(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to create node on network [%s]: %s", fmt.Sprintf("failed to create node on network [%s]: %s",
node.Network, errorResponse.Message)) node.Network, errorResponse.Message))
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} }
} }
user, err := pro.GetNetworkUser(networkName, promodels.NetworkUserID(keyName))
if err == nil {
if user.ID != "" {
logger.Log(1, "associating new node with user", keyName)
node.OwnerID = string(user.ID)
}
}
key, keyErr := logic.RetrievePublicTrafficKey() key, keyErr := logic.RetrievePublicTrafficKey()
if keyErr != nil { if keyErr != nil {
logger.Log(0, "error retrieving key: ", keyErr.Error()) logger.Log(0, "error retrieving key: ", keyErr.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
if key == nil { if key == nil {
logger.Log(0, "error: server traffic key is nil") logger.Log(0, "error: server traffic key is nil")
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
if node.TrafficKeys.Mine == nil { if node.TrafficKeys.Mine == nil {
logger.Log(0, "error: node traffic key is nil") logger.Log(0, "error: node traffic key is nil")
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
node.TrafficKeys = models.TrafficKeys{ node.TrafficKeys = models.TrafficKeys{
@@ -580,15 +591,33 @@ func createNode(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to create node on network [%s]: %s", fmt.Sprintf("failed to create node on network [%s]: %s",
node.Network, err)) node.Network, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
// check if key belongs to a user
// if so add to their netuser data
// if it fails remove the node and fail request
if user != nil {
var updatedUserNode bool
user.Nodes = append(user.Nodes, node.ID) // add new node to user
if err = pro.UpdateNetworkUser(networkName, user); err == nil {
logger.Log(1, "added node", node.ID, node.Name, "to user", string(user.ID))
updatedUserNode = true
}
if !updatedUserNode { // user was found but not updated, so delete node
logger.Log(0, "failed to add node to user", keyName)
logic.DeleteNodeByID(&node, true)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
}
peerUpdate, err := logic.GetPeerUpdate(&node) peerUpdate, err := logic.GetPeerUpdate(&node)
if err != nil && !database.IsEmptyRecord(err) { if err != nil && !database.IsEmptyRecord(err) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("error fetching wg peers config for node [ %s ]: %v", node.ID, err)) fmt.Sprintf("error fetching wg peers config for node [ %s ]: %v", node.ID, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -596,6 +625,7 @@ func createNode(w http.ResponseWriter, r *http.Request) {
Node: node, Node: node,
Peers: peerUpdate.Peers, Peers: peerUpdate.Peers,
ServerConfig: servercfg.GetServerInfo(), ServerConfig: servercfg.GetServerInfo(),
PeerIDs: peerUpdate.PeerIDs,
} }
logger.Log(1, r.Header.Get("user"), "created new node", node.Name, "on network", node.Network) logger.Log(1, r.Header.Get("user"), "created new node", node.Name, "on network", node.Network)
@@ -625,7 +655,7 @@ func uncordonNode(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to uncordon node [%s]: %v", node.Name, err)) fmt.Sprintf("failed to uncordon node [%s]: %v", node.Name, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
logger.Log(1, r.Header.Get("user"), "uncordoned node", node.Name) logger.Log(1, r.Header.Get("user"), "uncordoned node", node.Name)
@@ -655,7 +685,7 @@ func createEgressGateway(w http.ResponseWriter, r *http.Request) {
err := json.NewDecoder(r.Body).Decode(&gateway) err := json.NewDecoder(r.Body).Decode(&gateway)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
gateway.NetID = params["network"] gateway.NetID = params["network"]
@@ -665,7 +695,7 @@ func createEgressGateway(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to create egress gateway on node [%s] on network [%s]: %v", fmt.Sprintf("failed to create egress gateway on node [%s] on network [%s]: %v",
gateway.NodeID, gateway.NetID, err)) gateway.NodeID, gateway.NetID, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -697,7 +727,7 @@ func deleteEgressGateway(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to delete egress gateway on node [%s] on network [%s]: %v", fmt.Sprintf("failed to delete egress gateway on node [%s] on network [%s]: %v",
nodeid, netid, err)) nodeid, netid, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -731,7 +761,7 @@ func createIngressGateway(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to create ingress gateway on node [%s] on network [%s]: %v", fmt.Sprintf("failed to create ingress gateway on node [%s] on network [%s]: %v",
nodeid, netid, err)) nodeid, netid, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -763,7 +793,7 @@ func deleteIngressGateway(w http.ResponseWriter, r *http.Request) {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to delete ingress gateway on node [%s] on network [%s]: %v", fmt.Sprintf("failed to delete ingress gateway on node [%s] on network [%s]: %v",
nodeid, netid, err)) nodeid, netid, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -797,7 +827,7 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("error fetching node [ %s ] info: %v", nodeid, err)) fmt.Sprintf("error fetching node [ %s ] info: %v", nodeid, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -806,7 +836,7 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
err = json.NewDecoder(r.Body).Decode(&newNode) err = json.NewDecoder(r.Body).Decode(&newNode)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
relayupdate := false relayupdate := false
@@ -854,7 +884,7 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to update node info [ %s ] info: %v", nodeid, err)) fmt.Sprintf("failed to update node info [ %s ] info: %v", nodeid, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
if relayupdate { if relayupdate {
@@ -901,25 +931,33 @@ func deleteNode(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("error fetching node [ %s ] info: %v", nodeid, err)) fmt.Sprintf("error fetching node [ %s ] info: %v", nodeid, err))
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
if isServer(&node) { if isServer(&node) {
err := fmt.Errorf("cannot delete server node") err := fmt.Errorf("cannot delete server node")
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to delete node [ %s ]: %v", nodeid, err)) fmt.Sprintf("failed to delete node [ %s ]: %v", nodeid, err))
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
if r.Header.Get("ismaster") != "yes" {
username := r.Header.Get("user")
if username != "" && !doesUserOwnNode(username, params["network"], nodeid) {
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("user not permitted"), "badrequest"))
return
}
}
//send update to node to be deleted before deleting on server otherwise message cannot be sent //send update to node to be deleted before deleting on server otherwise message cannot be sent
node.Action = models.NODE_DELETE node.Action = models.NODE_DELETE
err = logic.DeleteNodeByID(&node, false) err = logic.DeleteNodeByID(&node, false)
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
returnSuccessResponse(w, r, nodeid+" deleted.")
logic.ReturnSuccessResponse(w, r, nodeid+" deleted.")
logger.Log(1, r.Header.Get("user"), "Deleted node", nodeid, "from network", params["network"]) logger.Log(1, r.Header.Get("user"), "Deleted node", nodeid, "from network", params["network"])
runUpdates(&node, false) runUpdates(&node, false)
@@ -1006,3 +1044,24 @@ func updateRelay(oldnode, newnode *models.Node) {
} }
logic.UpdateNode(relay, newrelay) logic.UpdateNode(relay, newrelay)
} }
func doesUserOwnNode(username, network, nodeID string) bool {
u, err := logic.GetUser(username)
if err != nil {
return false
}
if u.IsAdmin {
return true
}
netUser, err := pro.GetNetworkUser(network, promodels.NetworkUserID(u.UserName))
if err != nil {
return false
}
if netUser.AccessLevel == pro.NET_ADMIN {
return true
}
return logic.StringSliceContains(netUser.Nodes, nodeID)
}

View File

@@ -30,7 +30,7 @@ func createRelay(w http.ResponseWriter, r *http.Request) {
err := json.NewDecoder(r.Body).Decode(&relay) err := json.NewDecoder(r.Body).Decode(&relay)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
relay.NetID = params["network"] relay.NetID = params["network"]
@@ -39,7 +39,7 @@ func createRelay(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to create relay on node [%s] on network [%s]: %v", relay.NodeID, relay.NetID, err)) fmt.Sprintf("failed to create relay on node [%s] on network [%s]: %v", relay.NodeID, relay.NetID, err))
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
logger.Log(1, r.Header.Get("user"), "created relay on node", relay.NodeID, "on network", relay.NetID) logger.Log(1, r.Header.Get("user"), "created relay on node", relay.NodeID, "on network", relay.NetID)
@@ -73,7 +73,7 @@ func deleteRelay(w http.ResponseWriter, r *http.Request) {
updatenodes, node, err := logic.DeleteRelay(netid, nodeid) updatenodes, node, err := logic.DeleteRelay(netid, nodeid)
if err != nil { if err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
logger.Log(1, r.Header.Get("user"), "deleted relay server", nodeid, "on network", netid) logger.Log(1, r.Header.Get("user"), "deleted relay server", nodeid, "on network", netid)

View File

@@ -7,12 +7,13 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestFormatError(t *testing.T) { func TestFormatError(t *testing.T) {
response := formatError(errors.New("this is a sample error"), "badrequest") response := logic.FormatError(errors.New("this is a sample error"), "badrequest")
assert.Equal(t, http.StatusBadRequest, response.Code) assert.Equal(t, http.StatusBadRequest, response.Code)
assert.Equal(t, "this is a sample error", response.Message) assert.Equal(t, "this is a sample error", response.Message)
} }
@@ -20,7 +21,7 @@ func TestFormatError(t *testing.T) {
func TestReturnSuccessResponse(t *testing.T) { func TestReturnSuccessResponse(t *testing.T) {
var response models.SuccessResponse var response models.SuccessResponse
handler := func(rw http.ResponseWriter, r *http.Request) { handler := func(rw http.ResponseWriter, r *http.Request) {
returnSuccessResponse(rw, r, "This is a test message") logic.ReturnSuccessResponse(rw, r, "This is a test message")
} }
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -42,7 +43,7 @@ func TestReturnErrorResponse(t *testing.T) {
errMessage.Code = http.StatusUnauthorized errMessage.Code = http.StatusUnauthorized
errMessage.Message = "You are not authorized to access this endpoint" errMessage.Message = "You are not authorized to access this endpoint"
handler := func(rw http.ResponseWriter, r *http.Request) { handler := func(rw http.ResponseWriter, r *http.Request) {
returnErrorResponse(rw, r, errMessage) logic.ReturnErrorResponse(rw, r, errMessage)
} }
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()

View File

@@ -1,129 +0,0 @@
package controller
import (
"encoding/json"
"net/http"
"strings"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/functions"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/servercfg"
)
const (
master_uname = "masteradministrator"
unauthorized_msg = "unauthorized"
unauthorized_err = models.Error(unauthorized_msg)
)
func securityCheck(reqAdmin bool, next http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var errorResponse = models.ErrorResponse{
Code: http.StatusUnauthorized, Message: unauthorized_msg,
}
var params = mux.Vars(r)
bearerToken := r.Header.Get("Authorization")
// to have a custom DNS service adding entries
// we should refactor this, but is for the special case of an external service to query the DNS api
if strings.Contains(r.RequestURI, "/dns") && strings.ToUpper(r.Method) == "GET" && authenticateDNSToken(bearerToken) {
// do dns stuff
r.Header.Set("user", "nameserver")
networks, _ := json.Marshal([]string{ALL_NETWORK_ACCESS})
r.Header.Set("networks", string(networks))
next.ServeHTTP(w, r)
return
}
var networkName = params["networkname"]
if len(networkName) == 0 {
networkName = params["network"]
}
networks, username, err := SecurityCheck(reqAdmin, networkName, bearerToken)
if err != nil {
returnErrorResponse(w, r, errorResponse)
return
}
networksJson, err := json.Marshal(&networks)
if err != nil {
returnErrorResponse(w, r, errorResponse)
return
}
r.Header.Set("user", username)
r.Header.Set("networks", string(networksJson))
next.ServeHTTP(w, r)
}
}
// SecurityCheck - checks token stuff
func SecurityCheck(reqAdmin bool, netname string, token string) ([]string, string, error) {
var tokenSplit = strings.Split(token, " ")
var authToken = ""
userNetworks := []string{}
if len(tokenSplit) < 2 {
return userNetworks, "", unauthorized_err
} else {
authToken = tokenSplit[1]
}
//all endpoints here require master so not as complicated
if authenticateMaster(authToken) {
return []string{ALL_NETWORK_ACCESS}, master_uname, nil
}
username, networks, isadmin, err := logic.VerifyUserToken(authToken)
if err != nil {
return nil, username, unauthorized_err
}
if !isadmin && reqAdmin {
return nil, username, unauthorized_err
}
userNetworks = networks
if isadmin {
return []string{ALL_NETWORK_ACCESS}, username, nil
}
// check network admin access
if len(netname) > 0 && (!authenticateNetworkUser(netname, userNetworks) || len(userNetworks) == 0) {
return nil, username, unauthorized_err
}
return userNetworks, username, nil
}
// Consider a more secure way of setting master key
func authenticateMaster(tokenString string) bool {
return tokenString == servercfg.GetMasterKey() && servercfg.GetMasterKey() != ""
}
func authenticateNetworkUser(network string, userNetworks []string) bool {
networkexists, err := functions.NetworkExists(network)
if (err != nil && !database.IsEmptyRecord(err)) || !networkexists {
return false
}
return logic.StringSliceContains(userNetworks, network)
}
//Consider a more secure way of setting master key
func authenticateDNSToken(tokenString string) bool {
tokens := strings.Split(tokenString, " ")
if len(tokens) < 2 {
return false
}
return tokens[1] == servercfg.GetDNSKey()
}
func continueIfUserMatch(next http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var errorResponse = models.ErrorResponse{
Code: http.StatusUnauthorized, Message: unauthorized_msg,
}
var params = mux.Vars(r)
var requestedUser = params["username"]
if requestedUser != r.Header.Get("user") {
returnErrorResponse(w, r, errorResponse)
return
}
next.ServeHTTP(w, r)
}
}

View File

@@ -21,46 +21,29 @@ import (
func serverHandlers(r *mux.Router) { func serverHandlers(r *mux.Router) {
// r.HandleFunc("/api/server/addnetwork/{network}", securityCheckServer(true, http.HandlerFunc(addNetwork))).Methods("POST") // r.HandleFunc("/api/server/addnetwork/{network}", securityCheckServer(true, http.HandlerFunc(addNetwork))).Methods("POST")
r.HandleFunc("/api/server/getconfig", securityCheckServer(false, http.HandlerFunc(getConfig))).Methods("GET") r.HandleFunc("/api/server/getconfig", allowUsers(http.HandlerFunc(getConfig))).Methods("GET")
r.HandleFunc("/api/server/removenetwork/{network}", securityCheckServer(true, http.HandlerFunc(removeNetwork))).Methods("DELETE")
r.HandleFunc("/api/server/register", authorize(true, false, "node", http.HandlerFunc(register))).Methods("POST") r.HandleFunc("/api/server/register", authorize(true, false, "node", http.HandlerFunc(register))).Methods("POST")
r.HandleFunc("/api/server/getserverinfo", authorize(true, false, "node", http.HandlerFunc(getServerInfo))).Methods("GET") r.HandleFunc("/api/server/getserverinfo", authorize(true, false, "node", http.HandlerFunc(getServerInfo))).Methods("GET")
} }
//Security check is middleware for every function and just checks to make sure that its the master calling // allowUsers - allow all authenticated (valid) users - only used by getConfig, may be able to remove during refactor
//Only admin should have access to all these network-level actions func allowUsers(next http.Handler) http.HandlerFunc {
//or maybe some Users once implemented
func securityCheckServer(adminonly bool, next http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var errorResponse = models.ErrorResponse{ var errorResponse = models.ErrorResponse{
Code: http.StatusInternalServerError, Message: "W1R3: It's not you it's me.", Code: http.StatusInternalServerError, Message: logic.Unauthorized_Msg,
} }
bearerToken := r.Header.Get("Authorization") bearerToken := r.Header.Get("Authorization")
var tokenSplit = strings.Split(bearerToken, " ") var tokenSplit = strings.Split(bearerToken, " ")
var authToken = "" var authToken = ""
if len(tokenSplit) < 2 { if len(tokenSplit) < 2 {
errorResponse = models.ErrorResponse{ logic.ReturnErrorResponse(w, r, errorResponse)
Code: http.StatusUnauthorized, Message: "W1R3: You are unauthorized to access this endpoint.",
}
returnErrorResponse(w, r, errorResponse)
return return
} else { } else {
authToken = tokenSplit[1] authToken = tokenSplit[1]
} }
//all endpoints here require master so not as complicated user, _, _, err := logic.VerifyUserToken(authToken)
//still might not be a good way of doing this if err != nil || user == "" {
user, _, isadmin, err := logic.VerifyUserToken(authToken) logic.ReturnErrorResponse(w, r, errorResponse)
errorResponse = models.ErrorResponse{
Code: http.StatusUnauthorized, Message: "W1R3: You are unauthorized to access this endpoint.",
}
if !adminonly && (err != nil || user == "") {
returnErrorResponse(w, r, errorResponse)
return
}
if adminonly && !isadmin && !authenticateMaster(authToken) {
returnErrorResponse(w, r, errorResponse)
return return
} }
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
@@ -136,6 +119,10 @@ func getConfig(w http.ResponseWriter, r *http.Request) {
// get params // get params
scfg := servercfg.GetServerConfig() scfg := servercfg.GetServerConfig()
scfg.IsEE = "no"
if logic.Is_EE {
scfg.IsEE = "yes"
}
json.NewEncoder(w).Encode(scfg) json.NewEncoder(w).Encode(scfg)
//w.WriteHeader(http.StatusOK) //w.WriteHeader(http.StatusOK)
} }
@@ -161,7 +148,7 @@ func register(w http.ResponseWriter, r *http.Request) {
errorResponse := models.ErrorResponse{ errorResponse := models.ErrorResponse{
Code: http.StatusBadRequest, Message: err.Error(), Code: http.StatusBadRequest, Message: err.Error(),
} }
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} }
cert, ca, err := genCerts(&request.Key, &request.CommonName) cert, ca, err := genCerts(&request.Key, &request.CommonName)
@@ -170,7 +157,7 @@ func register(w http.ResponseWriter, r *http.Request) {
errorResponse := models.ErrorResponse{ errorResponse := models.ErrorResponse{
Code: http.StatusNotFound, Message: err.Error(), Code: http.StatusNotFound, Message: err.Error(),
} }
returnErrorResponse(w, r, errorResponse) logic.ReturnErrorResponse(w, r, errorResponse)
return return
} }
//x509.Certificate.PublicKey is an interface therefore json encoding/decoding result in a string value rather than a []byte //x509.Certificate.PublicKey is an interface therefore json encoding/decoding result in a string value rather than a []byte

View File

@@ -7,11 +7,17 @@ import (
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/gorilla/websocket"
"github.com/gravitl/netmaker/auth" "github.com/gravitl/netmaker/auth"
"github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/servercfg"
)
var (
upgrader = websocket.Upgrader{}
) )
func userHandlers(r *mux.Router) { func userHandlers(r *mux.Router) {
@@ -19,15 +25,17 @@ func userHandlers(r *mux.Router) {
r.HandleFunc("/api/users/adm/hasadmin", hasAdmin).Methods("GET") r.HandleFunc("/api/users/adm/hasadmin", hasAdmin).Methods("GET")
r.HandleFunc("/api/users/adm/createadmin", createAdmin).Methods("POST") r.HandleFunc("/api/users/adm/createadmin", createAdmin).Methods("POST")
r.HandleFunc("/api/users/adm/authenticate", authenticateUser).Methods("POST") r.HandleFunc("/api/users/adm/authenticate", authenticateUser).Methods("POST")
r.HandleFunc("/api/users/{username}", securityCheck(false, continueIfUserMatch(http.HandlerFunc(updateUser)))).Methods("PUT") r.HandleFunc("/api/users/{username}", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(updateUser)))).Methods("PUT")
r.HandleFunc("/api/users/networks/{username}", securityCheck(true, http.HandlerFunc(updateUserNetworks))).Methods("PUT") r.HandleFunc("/api/users/networks/{username}", logic.SecurityCheck(true, http.HandlerFunc(updateUserNetworks))).Methods("PUT")
r.HandleFunc("/api/users/{username}/adm", securityCheck(true, http.HandlerFunc(updateUserAdm))).Methods("PUT") r.HandleFunc("/api/users/{username}/adm", logic.SecurityCheck(true, http.HandlerFunc(updateUserAdm))).Methods("PUT")
r.HandleFunc("/api/users/{username}", securityCheck(true, http.HandlerFunc(createUser))).Methods("POST") r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, checkFreeTierLimits(users_l, http.HandlerFunc(createUser)))).Methods("POST")
r.HandleFunc("/api/users/{username}", securityCheck(true, http.HandlerFunc(deleteUser))).Methods("DELETE") r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, http.HandlerFunc(deleteUser))).Methods("DELETE")
r.HandleFunc("/api/users/{username}", securityCheck(false, continueIfUserMatch(http.HandlerFunc(getUser)))).Methods("GET") r.HandleFunc("/api/users/{username}", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUser)))).Methods("GET")
r.HandleFunc("/api/users", securityCheck(true, http.HandlerFunc(getUsers))).Methods("GET") r.HandleFunc("/api/users", logic.SecurityCheck(true, http.HandlerFunc(getUsers))).Methods("GET")
r.HandleFunc("/api/oauth/login", auth.HandleAuthLogin).Methods("GET") r.HandleFunc("/api/oauth/login", auth.HandleAuthLogin).Methods("GET")
r.HandleFunc("/api/oauth/callback", auth.HandleAuthCallback).Methods("GET") r.HandleFunc("/api/oauth/callback", auth.HandleAuthCallback).Methods("GET")
r.HandleFunc("/api/oauth/node-handler", socketHandler)
r.HandleFunc("/api/oauth/register/{regKey}", auth.RegisterNodeSSO).Methods("GET")
} }
// swagger:route POST /api/users/adm/authenticate user authenticateUser // swagger:route POST /api/users/adm/authenticate user authenticateUser
@@ -50,13 +58,18 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
Code: http.StatusInternalServerError, Message: "W1R3: It's not you it's me.", Code: http.StatusInternalServerError, Message: "W1R3: It's not you it's me.",
} }
if !servercfg.IsBasicAuthEnabled() {
logic.ReturnErrorResponse(response, request, logic.FormatError(fmt.Errorf("basic auth is disabled"), "badrequest"))
return
}
decoder := json.NewDecoder(request.Body) decoder := json.NewDecoder(request.Body)
decoderErr := decoder.Decode(&authRequest) decoderErr := decoder.Decode(&authRequest)
defer request.Body.Close() defer request.Body.Close()
if decoderErr != nil { if decoderErr != nil {
logger.Log(0, "error decoding request body: ", logger.Log(0, "error decoding request body: ",
decoderErr.Error()) decoderErr.Error())
returnErrorResponse(response, request, errorResponse) logic.ReturnErrorResponse(response, request, errorResponse)
return return
} }
username := authRequest.UserName username := authRequest.UserName
@@ -64,14 +77,14 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
if err != nil { if err != nil {
logger.Log(0, username, "user validation failed: ", logger.Log(0, username, "user validation failed: ",
err.Error()) err.Error())
returnErrorResponse(response, request, formatError(err, "badrequest")) logic.ReturnErrorResponse(response, request, logic.FormatError(err, "badrequest"))
return return
} }
if jwt == "" { if jwt == "" {
// very unlikely that err is !nil and no jwt returned, but handle it anyways. // very unlikely that err is !nil and no jwt returned, but handle it anyways.
logger.Log(0, username, "jwt token is empty") logger.Log(0, username, "jwt token is empty")
returnErrorResponse(response, request, formatError(errors.New("no token returned"), "internal")) logic.ReturnErrorResponse(response, request, logic.FormatError(errors.New("no token returned"), "internal"))
return return
} }
@@ -89,7 +102,7 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
if jsonError != nil { if jsonError != nil {
logger.Log(0, username, logger.Log(0, username,
"error marshalling resp: ", err.Error()) "error marshalling resp: ", err.Error())
returnErrorResponse(response, request, errorResponse) logic.ReturnErrorResponse(response, request, errorResponse)
return return
} }
logger.Log(2, username, "was authenticated") logger.Log(2, username, "was authenticated")
@@ -115,7 +128,7 @@ func hasAdmin(w http.ResponseWriter, r *http.Request) {
hasadmin, err := logic.HasAdmin() hasadmin, err := logic.HasAdmin()
if err != nil { if err != nil {
logger.Log(0, "failed to check for admin: ", err.Error()) logger.Log(0, "failed to check for admin: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -158,7 +171,7 @@ func getUser(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, usernameFetched, "failed to fetch user: ", err.Error()) logger.Log(0, usernameFetched, "failed to fetch user: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
logger.Log(2, r.Header.Get("user"), "fetched user", usernameFetched) logger.Log(2, r.Header.Get("user"), "fetched user", usernameFetched)
@@ -184,7 +197,7 @@ func getUsers(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, "failed to fetch users: ", err.Error()) logger.Log(0, "failed to fetch users: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
@@ -213,17 +226,23 @@ func createAdmin(w http.ResponseWriter, r *http.Request) {
logger.Log(0, admin.UserName, "error decoding request body: ", logger.Log(0, admin.UserName, "error decoding request body: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
admin, err = logic.CreateAdmin(admin)
if !servercfg.IsBasicAuthEnabled() {
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("basic auth is disabled"), "badrequest"))
return
}
admin, err = logic.CreateAdmin(admin)
if err != nil { if err != nil {
logger.Log(0, admin.UserName, "failed to create admin: ", logger.Log(0, admin.UserName, "failed to create admin: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
logger.Log(1, admin.UserName, "was made a new admin") logger.Log(1, admin.UserName, "was made a new admin")
json.NewEncoder(w).Encode(admin) json.NewEncoder(w).Encode(admin)
} }
@@ -247,14 +266,15 @@ func createUser(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, user.UserName, "error decoding request body: ", logger.Log(0, user.UserName, "error decoding request body: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
user, err = logic.CreateUser(user) user, err = logic.CreateUser(user)
if err != nil { if err != nil {
logger.Log(0, user.UserName, "error creating new user: ", logger.Log(0, user.UserName, "error creating new user: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
logger.Log(1, user.UserName, "was created") logger.Log(1, user.UserName, "was created")
@@ -282,7 +302,7 @@ func updateUserNetworks(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, username, logger.Log(0, username,
"failed to update user networks: ", err.Error()) "failed to update user networks: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
var userchange models.User var userchange models.User
@@ -291,14 +311,20 @@ func updateUserNetworks(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, username, "error decoding request body: ", logger.Log(0, username, "error decoding request body: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
err = logic.UpdateUserNetworks(userchange.Networks, userchange.IsAdmin, &user) err = logic.UpdateUserNetworks(userchange.Networks, userchange.Groups, userchange.IsAdmin, &models.ReturnUser{
Groups: user.Groups,
IsAdmin: user.IsAdmin,
Networks: user.Networks,
UserName: user.UserName,
})
if err != nil { if err != nil {
logger.Log(0, username, logger.Log(0, username,
"failed to update user networks: ", err.Error()) "failed to update user networks: ", err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
logger.Log(1, username, "status was updated") logger.Log(1, username, "status was updated")
@@ -326,13 +352,13 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, username, logger.Log(0, username,
"failed to update user info: ", err.Error()) "failed to update user info: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
if auth.IsOauthUser(&user) == nil { if auth.IsOauthUser(&user) == nil {
err := fmt.Errorf("cannot update user info for oauth user %s", username) err := fmt.Errorf("cannot update user info for oauth user %s", username)
logger.Log(0, err.Error()) logger.Log(0, err.Error())
returnErrorResponse(w, r, formatError(err, "forbidden")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "forbidden"))
return return
} }
var userchange models.User var userchange models.User
@@ -341,7 +367,7 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, username, "error decoding request body: ", logger.Log(0, username, "error decoding request body: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
userchange.Networks = nil userchange.Networks = nil
@@ -349,7 +375,7 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, username, logger.Log(0, username,
"failed to update user info: ", err.Error()) "failed to update user info: ", err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
logger.Log(1, username, "was updated") logger.Log(1, username, "was updated")
@@ -375,13 +401,13 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) {
username := params["username"] username := params["username"]
user, err := GetUserInternal(username) user, err := GetUserInternal(username)
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} }
if auth.IsOauthUser(&user) != nil { if auth.IsOauthUser(&user) != nil {
err := fmt.Errorf("cannot update user info for oauth user %s", username) err := fmt.Errorf("cannot update user info for oauth user %s", username)
logger.Log(0, err.Error()) logger.Log(0, err.Error())
returnErrorResponse(w, r, formatError(err, "forbidden")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "forbidden"))
return return
} }
var userchange models.User var userchange models.User
@@ -390,18 +416,18 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, username, "error decoding request body: ", logger.Log(0, username, "error decoding request body: ",
err.Error()) err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
if !user.IsAdmin { if !user.IsAdmin {
logger.Log(0, username, "not an admin user") logger.Log(0, username, "not an admin user")
returnErrorResponse(w, r, formatError(errors.New("not a admin user"), "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("not a admin user"), "badrequest"))
} }
user, err = logic.UpdateUser(userchange, user) user, err = logic.UpdateUser(userchange, user)
if err != nil { if err != nil {
logger.Log(0, username, logger.Log(0, username,
"failed to update user (admin) info: ", err.Error()) "failed to update user (admin) info: ", err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
logger.Log(1, username, "was updated (admin)") logger.Log(1, username, "was updated (admin)")
@@ -432,15 +458,31 @@ func deleteUser(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
logger.Log(0, username, logger.Log(0, username,
"failed to delete user: ", err.Error()) "failed to delete user: ", err.Error())
returnErrorResponse(w, r, formatError(err, "internal")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return return
} else if !success { } else if !success {
err := errors.New("delete unsuccessful") err := errors.New("delete unsuccessful")
logger.Log(0, username, err.Error()) logger.Log(0, username, err.Error())
returnErrorResponse(w, r, formatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
logger.Log(1, username, "was deleted") logger.Log(1, username, "was deleted")
json.NewEncoder(w).Encode(params["username"] + " deleted.") json.NewEncoder(w).Encode(params["username"] + " deleted.")
} }
// Called when vpn client dials in to start the auth flow and first stage is to get register URL itself
func socketHandler(w http.ResponseWriter, r *http.Request) {
// Upgrade our raw HTTP connection to a websocket based one
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logger.Log(0, "error during connection upgrade for node sign-in:", err.Error())
return
}
if conn == nil {
logger.Log(0, "failed to establish web-socket connection during node sign-in")
return
}
// Start handling the session
go auth.SessionHandler(conn)
}

View File

@@ -31,7 +31,7 @@ func TestHasAdmin(t *testing.T) {
assert.False(t, found) assert.False(t, found)
}) })
t.Run("No admin user", func(t *testing.T) { t.Run("No admin user", func(t *testing.T) {
var user = models.User{"noadmin", "password", nil, false} var user = models.User{"noadmin", "password", nil, false, nil}
_, err := logic.CreateUser(user) _, err := logic.CreateUser(user)
assert.Nil(t, err) assert.Nil(t, err)
found, err := logic.HasAdmin() found, err := logic.HasAdmin()
@@ -39,7 +39,7 @@ func TestHasAdmin(t *testing.T) {
assert.False(t, found) assert.False(t, found)
}) })
t.Run("admin user", func(t *testing.T) { t.Run("admin user", func(t *testing.T) {
var user = models.User{"admin", "password", nil, true} var user = models.User{"admin", "password", nil, true, nil}
_, err := logic.CreateUser(user) _, err := logic.CreateUser(user)
assert.Nil(t, err) assert.Nil(t, err)
found, err := logic.HasAdmin() found, err := logic.HasAdmin()
@@ -47,7 +47,7 @@ func TestHasAdmin(t *testing.T) {
assert.True(t, found) assert.True(t, found)
}) })
t.Run("multiple admins", func(t *testing.T) { t.Run("multiple admins", func(t *testing.T) {
var user = models.User{"admin1", "password", nil, true} var user = models.User{"admin1", "password", nil, true, nil}
_, err := logic.CreateUser(user) _, err := logic.CreateUser(user)
assert.Nil(t, err) assert.Nil(t, err)
found, err := logic.HasAdmin() found, err := logic.HasAdmin()
@@ -59,7 +59,7 @@ func TestHasAdmin(t *testing.T) {
func TestCreateUser(t *testing.T) { func TestCreateUser(t *testing.T) {
database.InitializeDatabase() database.InitializeDatabase()
deleteAllUsers() deleteAllUsers()
user := models.User{"admin", "password", nil, true} user := models.User{"admin", "password", nil, true, nil}
t.Run("NoUser", func(t *testing.T) { t.Run("NoUser", func(t *testing.T) {
admin, err := logic.CreateUser(user) admin, err := logic.CreateUser(user)
assert.Nil(t, err) assert.Nil(t, err)
@@ -101,7 +101,7 @@ func TestDeleteUser(t *testing.T) {
assert.False(t, deleted) assert.False(t, deleted)
}) })
t.Run("Existing User", func(t *testing.T) { t.Run("Existing User", func(t *testing.T) {
user := models.User{"admin", "password", nil, true} user := models.User{"admin", "password", nil, true, nil}
logic.CreateUser(user) logic.CreateUser(user)
deleted, err := logic.DeleteUser("admin") deleted, err := logic.DeleteUser("admin")
assert.Nil(t, err) assert.Nil(t, err)
@@ -166,7 +166,7 @@ func TestGetUser(t *testing.T) {
assert.Equal(t, "", admin.UserName) assert.Equal(t, "", admin.UserName)
}) })
t.Run("UserExisits", func(t *testing.T) { t.Run("UserExisits", func(t *testing.T) {
user := models.User{"admin", "password", nil, true} user := models.User{"admin", "password", nil, true, nil}
logic.CreateUser(user) logic.CreateUser(user)
admin, err := logic.GetUser("admin") admin, err := logic.GetUser("admin")
assert.Nil(t, err) assert.Nil(t, err)
@@ -183,7 +183,7 @@ func TestGetUserInternal(t *testing.T) {
assert.Equal(t, "", admin.UserName) assert.Equal(t, "", admin.UserName)
}) })
t.Run("UserExisits", func(t *testing.T) { t.Run("UserExisits", func(t *testing.T) {
user := models.User{"admin", "password", nil, true} user := models.User{"admin", "password", nil, true, nil}
logic.CreateUser(user) logic.CreateUser(user)
admin, err := GetUserInternal("admin") admin, err := GetUserInternal("admin")
assert.Nil(t, err) assert.Nil(t, err)
@@ -200,14 +200,14 @@ func TestGetUsers(t *testing.T) {
assert.Equal(t, []models.ReturnUser(nil), admin) assert.Equal(t, []models.ReturnUser(nil), admin)
}) })
t.Run("UserExisits", func(t *testing.T) { t.Run("UserExisits", func(t *testing.T) {
user := models.User{"admin", "password", nil, true} user := models.User{"admin", "password", nil, true, nil}
logic.CreateUser(user) logic.CreateUser(user)
admins, err := logic.GetUsers() admins, err := logic.GetUsers()
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, user.UserName, admins[0].UserName) assert.Equal(t, user.UserName, admins[0].UserName)
}) })
t.Run("MulipleUsers", func(t *testing.T) { t.Run("MulipleUsers", func(t *testing.T) {
user := models.User{"user", "password", nil, true} user := models.User{"user", "password", nil, true, nil}
logic.CreateUser(user) logic.CreateUser(user)
admins, err := logic.GetUsers() admins, err := logic.GetUsers()
assert.Nil(t, err) assert.Nil(t, err)
@@ -225,8 +225,8 @@ func TestGetUsers(t *testing.T) {
func TestUpdateUser(t *testing.T) { func TestUpdateUser(t *testing.T) {
database.InitializeDatabase() database.InitializeDatabase()
deleteAllUsers() deleteAllUsers()
user := models.User{"admin", "password", nil, true} user := models.User{"admin", "password", nil, true, nil}
newuser := models.User{"hello", "world", []string{"wirecat, netmaker"}, true} newuser := models.User{"hello", "world", []string{"wirecat, netmaker"}, true, []string{}}
t.Run("NonExistantUser", func(t *testing.T) { t.Run("NonExistantUser", func(t *testing.T) {
admin, err := logic.UpdateUser(newuser, user) admin, err := logic.UpdateUser(newuser, user)
assert.EqualError(t, err, "could not find any records") assert.EqualError(t, err, "could not find any records")
@@ -288,10 +288,10 @@ func TestVerifyAuthRequest(t *testing.T) {
authRequest.Password = "password" authRequest.Password = "password"
jwt, err := logic.VerifyAuthRequest(authRequest) jwt, err := logic.VerifyAuthRequest(authRequest)
assert.Equal(t, "", jwt) assert.Equal(t, "", jwt)
assert.EqualError(t, err, "incorrect credentials") assert.EqualError(t, err, "error retrieving user from db: could not find any records")
}) })
t.Run("Non-Admin", func(t *testing.T) { t.Run("Non-Admin", func(t *testing.T) {
user := models.User{"nonadmin", "somepass", nil, false} user := models.User{"nonadmin", "somepass", nil, false, []string{}}
logic.CreateUser(user) logic.CreateUser(user)
authRequest := models.UserAuthParams{"nonadmin", "somepass"} authRequest := models.UserAuthParams{"nonadmin", "somepass"}
jwt, err := logic.VerifyAuthRequest(authRequest) jwt, err := logic.VerifyAuthRequest(authRequest)
@@ -299,7 +299,7 @@ func TestVerifyAuthRequest(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
}) })
t.Run("WrongPassword", func(t *testing.T) { t.Run("WrongPassword", func(t *testing.T) {
user := models.User{"admin", "password", nil, false} user := models.User{"admin", "password", nil, false, []string{}}
logic.CreateUser(user) logic.CreateUser(user)
authRequest := models.UserAuthParams{"admin", "badpass"} authRequest := models.UserAuthParams{"admin", "badpass"}
jwt, err := logic.VerifyAuthRequest(authRequest) jwt, err := logic.VerifyAuthRequest(authRequest)

73
controllers/usergroups.go Normal file
View File

@@ -0,0 +1,73 @@
package controller
import (
"encoding/json"
"errors"
"net/http"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/logic/pro"
"github.com/gravitl/netmaker/models/promodels"
)
func userGroupsHandlers(r *mux.Router) {
r.HandleFunc("/api/usergroups", logic.SecurityCheck(true, http.HandlerFunc(getUserGroups))).Methods("GET")
r.HandleFunc("/api/usergroups/{usergroup}", logic.SecurityCheck(true, http.HandlerFunc(createUserGroup))).Methods("POST")
r.HandleFunc("/api/usergroups/{usergroup}", logic.SecurityCheck(true, http.HandlerFunc(deleteUserGroup))).Methods("DELETE")
}
func getUserGroups(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
logger.Log(1, r.Header.Get("user"), "requested fetching user groups")
userGroups, err := pro.GetUserGroups()
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
// Returns all the groups in JSON format
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(userGroups)
}
func createUserGroup(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
newGroup := params["usergroup"]
logger.Log(1, r.Header.Get("user"), "requested creating user group", newGroup)
if newGroup == "" {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("no group name provided"), "badrequest"))
return
}
err := pro.InsertUserGroup(promodels.UserGroupName(newGroup))
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
w.WriteHeader(http.StatusOK)
}
func deleteUserGroup(w http.ResponseWriter, r *http.Request) {
var params = mux.Vars(r)
groupToDelete := params["usergroup"]
logger.Log(1, r.Header.Get("user"), "requested deleting user group", groupToDelete)
if groupToDelete == "" {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("no group name provided"), "badrequest"))
return
}
if err := pro.DeleteUserGroup(promodels.UserGroupName(groupToDelete)); err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
w.WriteHeader(http.StatusOK)
}

View File

@@ -59,6 +59,18 @@ const NODE_ACLS_TABLE_NAME = "nodeacls"
// SSO_STATE_CACHE - holds sso session information for OAuth2 sign-ins // SSO_STATE_CACHE - holds sso session information for OAuth2 sign-ins
const SSO_STATE_CACHE = "ssostatecache" const SSO_STATE_CACHE = "ssostatecache"
// METRICS_TABLE_NAME - stores network metrics
const METRICS_TABLE_NAME = "metrics"
// NETWORK_USER_TABLE_NAME - network user table tracks stats for a network user per network
const NETWORK_USER_TABLE_NAME = "networkusers"
// USER_GROUPS_TABLE_NAME - table for storing usergroups
const USER_GROUPS_TABLE_NAME = "usergroups"
// CACHE_TABLE_NAME - caching table
const CACHE_TABLE_NAME = "cache"
// == ERROR CONSTS == // == ERROR CONSTS ==
// NO_RECORD - no singular result found // NO_RECORD - no singular result found
@@ -139,6 +151,10 @@ func createTables() {
createTable(GENERATED_TABLE_NAME) createTable(GENERATED_TABLE_NAME)
createTable(NODE_ACLS_TABLE_NAME) createTable(NODE_ACLS_TABLE_NAME)
createTable(SSO_STATE_CACHE) createTable(SSO_STATE_CACHE)
createTable(METRICS_TABLE_NAME)
createTable(NETWORK_USER_TABLE_NAME)
createTable(USER_GROUPS_TABLE_NAME)
createTable(CACHE_TABLE_NAME)
} }
func createTable(tableName string) error { func createTable(tableName string) error {

View File

@@ -10,3 +10,7 @@ keyfile /mosquitto/certs/server.key
listener 1883 listener 1883
allow_anonymous true allow_anonymous true
listener 1884
allow_anonymous false
password_file /etc/mosquitto.passwords

View File

@@ -0,0 +1 @@
netmaker-exporter:$7$101$9kcXwXP+nUMh06gm$MND2YjtRSvcZTXjMn7xYKoqUFQxG6NOgqWmXIcxxxZksM9cA8732URQWOsPHqpGEvVF9mSVagM1MBEMIKwZm2A==

10
ee/LICENSE Normal file
View File

@@ -0,0 +1,10 @@
The Netmaker Enterprise license (the “Enterprise License”)
Copyright (c) 2022 Netmaker, Inc.
With regard to the Netmaker Software:
This software and associated documentation files (the "Software") may only be used in production, if you (and any entity that you represent) have agreed to, and are in compliance with, the Netmaker Subscription Terms of Service, available at https://netmaker.io/terms (the “Enterprise Terms”), or other agreement governing the use of the Software, as agreed by you and Netmaker, and otherwise have a valid Netmaker Enterprise license for the correct number of users, networks, nodes, servers, and external clients. Subject to the foregoing sentence, you are free to modify this Software and publish patches to the Software. You agree that Netmaker and/or its licensors (as applicable) retain all right, title and interest in and to all such modifications and/or patches, and all such modifications and/or patches may only be used, copied, modified, displayed, distributed, or otherwise exploited with a valid Netmaker Enterprise license for the correct number of users, networks, nodes, servers, and external clients as allocated by the license. Notwithstanding the foregoing, you may copy and modify the Software for development and testing purposes, without requiring a subscription. You agree that Netmaker and/or its licensors (as applicable) retain all right, title and interest in and to all such modifications. You are not granted any other rights beyond what is expressly stated herein. Subject to the foregoing, it is forbidden to copy, merge, publish, distribute, sublicense, and/or sell the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
For all third party components incorporated into the Netmaker Software, those components are licensed under the original license provided by the owner of the applicable component.

View File

@@ -0,0 +1,103 @@
package ee_controllers
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
)
// MetricHandlers - How we handle EE Metrics
func MetricHandlers(r *mux.Router) {
r.HandleFunc("/api/metrics/{network}/{nodeid}", logic.SecurityCheck(true, http.HandlerFunc(getNodeMetrics))).Methods("GET")
r.HandleFunc("/api/metrics/{network}", logic.SecurityCheck(true, http.HandlerFunc(getNetworkNodesMetrics))).Methods("GET")
r.HandleFunc("/api/metrics", logic.SecurityCheck(true, http.HandlerFunc(getAllMetrics))).Methods("GET")
}
// get the metrics of a given node
func getNodeMetrics(w http.ResponseWriter, r *http.Request) {
// set header.
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
nodeID := params["nodeid"]
logger.Log(1, r.Header.Get("user"), "requested fetching metrics for node", nodeID, "on network", params["network"])
metrics, err := logic.GetMetrics(nodeID)
if err != nil {
logger.Log(1, r.Header.Get("user"), "failed to fetch metrics of node", nodeID, err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
logger.Log(1, r.Header.Get("user"), "fetched metrics for node", params["nodeid"])
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(metrics)
}
// get the metrics of all nodes in given network
func getNetworkNodesMetrics(w http.ResponseWriter, r *http.Request) {
// set header.
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
network := params["network"]
logger.Log(1, r.Header.Get("user"), "requested fetching network node metrics on network", network)
networkNodes, err := logic.GetNetworkNodes(network)
if err != nil {
logger.Log(1, r.Header.Get("user"), "failed to fetch metrics of all nodes in network", network, err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
networkMetrics := models.NetworkMetrics{}
networkMetrics.Nodes = make(models.MetricsMap)
for i := range networkNodes {
id := networkNodes[i].ID
metrics, err := logic.GetMetrics(id)
if err != nil {
logger.Log(1, r.Header.Get("user"), "failed to append metrics of node", id, "during network metrics fetch", err.Error())
continue
}
networkMetrics.Nodes[id] = *metrics
}
logger.Log(1, r.Header.Get("user"), "fetched metrics for network", network)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(networkMetrics)
}
// get Metrics of all nodes on server, lots of data
func getAllMetrics(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
logger.Log(1, r.Header.Get("user"), "requested fetching all metrics")
allNodes, err := logic.GetAllNodes()
if err != nil {
logger.Log(1, r.Header.Get("user"), "failed to fetch metrics of all nodes on server", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
networkMetrics := models.NetworkMetrics{}
networkMetrics.Nodes = make(models.MetricsMap)
for i := range allNodes {
id := allNodes[i].ID
metrics, err := logic.GetMetrics(id)
if err != nil {
logger.Log(1, r.Header.Get("user"), "failed to append metrics of node", id, "during all nodes metrics fetch", err.Error())
continue
}
networkMetrics.Nodes[id] = *metrics
}
logger.Log(1, r.Header.Get("user"), "fetched metrics for all nodes on server")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(networkMetrics)
}

54
ee/initialize.go Normal file
View File

@@ -0,0 +1,54 @@
//go:build ee
// +build ee
package ee
import (
controller "github.com/gravitl/netmaker/controllers"
"github.com/gravitl/netmaker/ee/ee_controllers"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
)
// InitEE - Initialize EE Logic
func InitEE() {
setIsEnterprise()
models.SetLogo(retrieveEELogo())
controller.HttpHandlers = append(controller.HttpHandlers, ee_controllers.MetricHandlers)
logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() {
// == License Handling ==
ValidateLicense()
if Limits.FreeTier {
logger.Log(0, "proceeding with Free Tier license")
} else {
logger.Log(0, "proceeding with Paid Tier license")
}
// == End License Handling ==
AddLicenseHooks()
})
}
func setControllerLimits() {
logic.Node_Limit = Limits.Nodes
logic.Users_Limit = Limits.Users
logic.Clients_Limit = Limits.Clients
logic.Free_Tier = Limits.FreeTier
logic.Is_EE = true
}
func retrieveEELogo() string {
return `
__ __ ______ ______ __ __ ______ __ __ ______ ______
/\ "-.\ \ /\ ___\ /\__ _\ /\ "-./ \ /\ __ \ /\ \/ / /\ ___\ /\ == \
\ \ \-. \ \ \ __\ \/_/\ \/ \ \ \-./\ \ \ \ __ \ \ \ _"-. \ \ __\ \ \ __<
\ \_\\"\_\ \ \_____\ \ \_\ \ \_\ \ \_\ \ \_\ \_\ \ \_\ \_\ \ \_____\ \ \_\ \_\
\/_/ \/_/ \/_____/ \/_/ \/_/ \/_/ \/_/\/_/ \/_/\/_/ \/_____/ \/_/ /_/
___ ___ ____
____ ____ ____ / _ \ / _ \ / __ \ ____ ____ ____
/___/ /___/ /___/ / ___/ / , _// /_/ / /___/ /___/ /___/
/___/ /___/ /___/ /_/ /_/|_| \____/ /___/ /___/ /___/
`
}

251
ee/license.go Normal file
View File

@@ -0,0 +1,251 @@
//go:build ee
// +build ee
package ee
import (
"bytes"
"crypto/rand"
"encoding/json"
"fmt"
"io/ioutil"
"math"
"net/http"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/netclient/ncutils"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/crypto/nacl/box"
)
const (
db_license_key = "netmaker-id-key-pair"
)
type apiServerConf struct {
PrivateKey []byte `json:"private_key" binding:"required"`
PublicKey []byte `json:"public_key" binding:"required"`
}
// AddLicenseHooks - adds the validation and cache clear hooks
func AddLicenseHooks() {
logic.AddHook(ValidateLicense)
logic.AddHook(ClearLicenseCache)
}
// ValidateLicense - the initial license check for netmaker server
// checks if a license is valid + limits are not exceeded
// if license is free_tier and limits exceeds, then server should terminate
// if license is not valid, server should terminate
func ValidateLicense() error {
licenseKeyValue := servercfg.GetLicenseKey()
netmakerAccountID := servercfg.GetNetmakerAccountID()
logger.Log(0, "proceeding with Netmaker license validation...")
if len(licenseKeyValue) == 0 || len(netmakerAccountID) == 0 {
logger.FatalLog(errValidation.Error())
}
apiPublicKey, err := getLicensePublicKey(licenseKeyValue)
if err != nil {
logger.FatalLog(errValidation.Error())
}
tempPubKey, tempPrivKey, err := FetchApiServerKeys()
if err != nil {
logger.FatalLog(errValidation.Error())
}
licenseSecret := LicenseSecret{
UserID: netmakerAccountID,
Limits: getCurrentServerLimit(),
}
secretData, err := json.Marshal(&licenseSecret)
if err != nil {
logger.FatalLog(errValidation.Error())
}
encryptedData, err := ncutils.BoxEncrypt(secretData, apiPublicKey, tempPrivKey)
if err != nil {
logger.FatalLog(errValidation.Error())
}
validationResponse, err := validateLicenseKey(encryptedData, tempPubKey)
if err != nil || len(validationResponse) == 0 {
logger.FatalLog(errValidation.Error())
}
var licenseResponse ValidatedLicense
if err = json.Unmarshal(validationResponse, &licenseResponse); err != nil {
logger.FatalLog(errValidation.Error())
}
respData, err := ncutils.BoxDecrypt(base64decode(licenseResponse.EncryptedLicense), apiPublicKey, tempPrivKey)
if err != nil {
logger.FatalLog(errValidation.Error())
}
license := LicenseKey{}
if err = json.Unmarshal(respData, &license); err != nil {
logger.FatalLog(errValidation.Error())
}
Limits.Networks = math.MaxInt
Limits.FreeTier = license.FreeTier == "yes"
Limits.Clients = license.LimitClients
Limits.Nodes = license.LimitNodes
Limits.Servers = license.LimitServers
Limits.Users = license.LimitUsers
if Limits.FreeTier {
Limits.Networks = 3
}
setControllerLimits()
logger.Log(0, "License validation succeeded!")
return nil
}
// FetchApiServerKeys - fetches netmaker license keys for identification
// as well as secure communication with API
// if none present, it generates a new pair
func FetchApiServerKeys() (pub *[32]byte, priv *[32]byte, err error) {
var returnData = apiServerConf{}
currentData, err := database.FetchRecord(database.SERVERCONF_TABLE_NAME, db_license_key)
if err != nil && !database.IsEmptyRecord(err) {
return nil, nil, err
} else if database.IsEmptyRecord(err) { // need to generate a new identifier pair
pub, priv, err = box.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, err
}
pubBytes, err := ncutils.ConvertKeyToBytes(pub)
if err != nil {
return nil, nil, err
}
privBytes, err := ncutils.ConvertKeyToBytes(priv)
if err != nil {
return nil, nil, err
}
returnData.PrivateKey = privBytes
returnData.PublicKey = pubBytes
record, err := json.Marshal(&returnData)
if err != nil {
return nil, nil, err
}
if err = database.Insert(db_license_key, string(record), database.SERVERCONF_TABLE_NAME); err != nil {
return nil, nil, err
}
} else {
if err = json.Unmarshal([]byte(currentData), &returnData); err != nil {
return nil, nil, err
}
priv, err = ncutils.ConvertBytesToKey(returnData.PrivateKey)
if err != nil {
return nil, nil, err
}
pub, err = ncutils.ConvertBytesToKey(returnData.PublicKey)
if err != nil {
return nil, nil, err
}
}
return pub, priv, nil
}
func getLicensePublicKey(licensePubKeyEncoded string) (*[32]byte, error) {
decodedPubKey := base64decode(licensePubKeyEncoded)
return ncutils.ConvertBytesToKey(decodedPubKey)
}
func validateLicenseKey(encryptedData []byte, publicKey *[32]byte) ([]byte, error) {
publicKeyBytes, err := ncutils.ConvertKeyToBytes(publicKey)
if err != nil {
return nil, err
}
msg := ValidateLicenseRequest{
NmServerPubKey: base64encode(publicKeyBytes),
EncryptedPart: base64encode(encryptedData),
}
requestBody, err := json.Marshal(msg)
if err != nil {
return nil, err
}
req, err := http.NewRequest(http.MethodPost, api_endpoint, bytes.NewReader(requestBody))
if err != nil {
return nil, err
}
reqParams := req.URL.Query()
reqParams.Add("licensevalue", servercfg.GetLicenseKey())
req.URL.RawQuery = reqParams.Encode()
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
client := &http.Client{}
var body []byte
validateResponse, err := client.Do(req)
if err != nil { // check cache
body, err = getCachedResponse()
if err != nil {
return nil, err
}
logger.Log(3, "proceeding with cached response, Netmaker API may be down")
} else {
defer validateResponse.Body.Close()
if validateResponse.StatusCode != 200 {
return nil, fmt.Errorf("could not validate license")
} // if you received a 200 cache the response locally
body, err = ioutil.ReadAll(validateResponse.Body)
if err != nil {
return nil, err
}
cacheResponse(body)
}
return body, err
}
func cacheResponse(response []byte) error {
var lrc = licenseResponseCache{
Body: response,
}
record, err := json.Marshal(&lrc)
if err != nil {
return err
}
return database.Insert(license_cache_key, string(record), database.CACHE_TABLE_NAME)
}
func getCachedResponse() ([]byte, error) {
var lrc licenseResponseCache
record, err := database.FetchRecord(database.CACHE_TABLE_NAME, license_cache_key)
if err != nil {
return nil, err
}
if err = json.Unmarshal([]byte(record), &lrc); err != nil {
return nil, err
}
return lrc.Body, nil
}
// ClearLicenseCache - clears the cached validate response
func ClearLicenseCache() error {
return database.DeleteRecord(database.CACHE_TABLE_NAME, license_cache_key)
}
func getServerCount() int {
if record, err := database.FetchRecord(database.SERVERCONF_TABLE_NAME, server_id_key); err == nil {
currentServerIDs := serverIDs{}
if err = json.Unmarshal([]byte(record), &currentServerIDs); err == nil {
return len(currentServerIDs.ServerIDs)
}
}
return 1
}

87
ee/types.go Normal file
View File

@@ -0,0 +1,87 @@
package ee
import "fmt"
const (
api_endpoint = "https://api.controller.netmaker.io/api/v1/license/validate"
license_cache_key = "license_response_cache"
license_validation_err_msg = "invalid license"
server_id_key = "nm-server-id"
)
var errValidation = fmt.Errorf(license_validation_err_msg)
// Limits - limits to be referenced throughout server
var Limits = GlobalLimits{
Servers: 0,
Users: 0,
Nodes: 0,
Clients: 0,
FreeTier: false,
}
// GlobalLimits - struct for holding global limits on this netmaker server in memory
type GlobalLimits struct {
Servers int
Users int
Nodes int
Clients int
FreeTier bool
Networks int
}
// LicenseKey - the license key struct representation with associated data
type LicenseKey struct {
LicenseValue string `json:"license_value"` // actual (public) key and the unique value for the key
Expiration int64 `json:"expiration"`
LimitServers int `json:"limit_servers"`
LimitUsers int `json:"limit_users"`
LimitNodes int `json:"limit_nodes"`
LimitClients int `json:"limit_clients"`
Metadata string `json:"metadata"`
SubscriptionID string `json:"subscription_id"` // for a paid subscription (non-free-tier license)
FreeTier string `json:"free_tier"` // yes if free tier
IsActive string `json:"is_active"` // yes if active
}
// ValidatedLicense - the validated license struct
type ValidatedLicense struct {
LicenseValue string `json:"license_value" binding:"required"` // license that validation is being requested for
EncryptedLicense string `json:"encrypted_license" binding:"required"` // to be decrypted by Netmaker using Netmaker server's private key
}
// LicenseSecret - the encrypted struct for sending user-id
type LicenseSecret struct {
UserID string `json:"user_id" binding:"required"` // UUID for user foreign key to User table
Limits LicenseLimits `json:"limits" binding:"required"`
}
// LicenseLimits - struct license limits
type LicenseLimits struct {
Servers int `json:"servers" binding:"required"`
Users int `json:"users" binding:"required"`
Nodes int `json:"nodes" binding:"required"`
Clients int `json:"clients" binding:"required"`
}
// LicenseLimits.SetDefaults - sets the default values for limits
func (l *LicenseLimits) SetDefaults() {
l.Clients = 0
l.Servers = 1
l.Nodes = 0
l.Users = 1
}
// ValidateLicenseRequest - used for request to validate license endpoint
type ValidateLicenseRequest struct {
NmServerPubKey string `json:"nm_server_pub_key" binding:"required"` // Netmaker server public key used to send data back to Netmaker for the Netmaker server to decrypt (eg output from validating license)
EncryptedPart string `json:"secret" binding:"required"`
}
type licenseResponseCache struct {
Body []byte `json:"body" binding:"required"`
}
type serverIDs struct {
ServerIDs []string `json:"server_ids" binding:"required"`
}

54
ee/util.go Normal file
View File

@@ -0,0 +1,54 @@
package ee
import (
"encoding/base64"
"github.com/gravitl/netmaker/logic"
)
var isEnterprise bool
// IsEnterprise - checks if enterprise binary or not
func IsEnterprise() bool {
return isEnterprise
}
// setIsEnterprise - sets server to use enterprise features
func setIsEnterprise() {
isEnterprise = true
}
// base64encode - base64 encode helper function
func base64encode(input []byte) string {
return base64.StdEncoding.EncodeToString(input)
}
// base64decode - base64 decode helper function
func base64decode(input string) []byte {
bytes, err := base64.StdEncoding.DecodeString(input)
if err != nil {
return nil
}
return bytes
}
func getCurrentServerLimit() (limits LicenseLimits) {
limits.SetDefaults()
nodes, err := logic.GetAllNodes()
if err == nil {
limits.Nodes = len(nodes)
}
clients, err := logic.GetAllExtClients()
if err == nil {
limits.Clients = len(clients)
}
users, err := logic.GetUsers()
if err == nil {
limits.Users = len(users)
}
limits.Servers = logic.GetServerCount()
return
}

View File

@@ -8,17 +8,6 @@ import (
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
) )
// NetworkExists - check if network exists
func NetworkExists(name string) (bool, error) {
var network string
var err error
if network, err = database.FetchRecord(database.NETWORKS_TABLE_NAME, name); err != nil {
return false, err
}
return len(network) > 0, nil
}
// NameInDNSCharSet - name in dns char set // NameInDNSCharSet - name in dns char set
func NameInDNSCharSet(name string) bool { func NameInDNSCharSet(name string) bool {

View File

@@ -26,7 +26,7 @@ func TestNetworkExists(t *testing.T) {
} }
database.DeleteRecord(database.NETWORKS_TABLE_NAME, testNetwork.NetID) database.DeleteRecord(database.NETWORKS_TABLE_NAME, testNetwork.NetID)
defer database.CloseDB() defer database.CloseDB()
exists, err := NetworkExists(testNetwork.NetID) exists, err := logic.NetworkExists(testNetwork.NetID)
if err == nil { if err == nil {
t.Fatalf("expected error, received nil") t.Fatalf("expected error, received nil")
} }
@@ -38,7 +38,7 @@ func TestNetworkExists(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to save test network in databse: %s", err) t.Fatalf("failed to save test network in databse: %s", err)
} }
exists, err = NetworkExists(testNetwork.NetID) exists, err = logic.NetworkExists(testNetwork.NetID)
if err != nil { if err != nil {
t.Fatalf("expected nil, received err: %s", err) t.Fatalf("expected nil, received err: %s", err)
} }

10
go.mod
View File

@@ -17,7 +17,10 @@ require (
github.com/txn2/txeh v1.3.0 github.com/txn2/txeh v1.3.0
github.com/urfave/cli/v2 v2.16.3 github.com/urfave/cli/v2 v2.16.3
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b // indirect
golang.org/x/oauth2 v0.0.0-20220822191816-0ebed06d0094 golang.org/x/oauth2 v0.0.0-20220822191816-0ebed06d0094
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect
golang.org/x/text v0.3.7 // indirect
golang.zx2c4.com/wireguard v0.0.0-20220318042302-193cf8d6a5d6 // indirect golang.zx2c4.com/wireguard v0.0.0-20220318042302-193cf8d6a5d6 // indirect
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20220324164955-056925b7df31 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20220324164955-056925b7df31
google.golang.org/protobuf v1.28.1 // indirect google.golang.org/protobuf v1.28.1 // indirect
@@ -30,14 +33,17 @@ require (
fyne.io/fyne/v2 v2.2.3 fyne.io/fyne/v2 v2.2.3
github.com/c-robinson/iplib v1.0.3 github.com/c-robinson/iplib v1.0.3
github.com/cloverstd/tcping v0.1.1 github.com/cloverstd/tcping v0.1.1
github.com/go-ping/ping v1.1.0
github.com/guumaster/hostctl v1.1.3 github.com/guumaster/hostctl v1.1.3
github.com/kr/pretty v0.3.0 github.com/kr/pretty v0.3.0
github.com/posthog/posthog-go v0.0.0-20211028072449-93c17c49e2b0 github.com/posthog/posthog-go v0.0.0-20211028072449-93c17c49e2b0
) )
require ( require (
github.com/gorilla/websocket v1.4.2
github.com/coreos/go-oidc/v3 v3.4.0 github.com/coreos/go-oidc/v3 v3.4.0
golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e
golang.org/x/term v0.0.0-20220722155259-a9ba230a4035
) )
require ( require (
@@ -67,7 +73,6 @@ require (
github.com/golang/protobuf v1.5.2 // indirect github.com/golang/protobuf v1.5.2 // indirect
github.com/google/go-cmp v0.5.8 // indirect github.com/google/go-cmp v0.5.8 // indirect
github.com/gopherjs/gopherjs v1.17.2 // indirect github.com/gopherjs/gopherjs v1.17.2 // indirect
github.com/gorilla/websocket v1.4.2 // indirect
github.com/josharian/native v1.0.0 // indirect github.com/josharian/native v1.0.0 // indirect
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect
github.com/kr/text v0.2.0 // indirect github.com/kr/text v0.2.0 // indirect
@@ -91,10 +96,7 @@ require (
github.com/yuin/goldmark v1.4.13 // indirect github.com/yuin/goldmark v1.4.13 // indirect
golang.org/x/image v0.0.0-20220601225756-64ec528b34cd // indirect golang.org/x/image v0.0.0-20220601225756-64ec528b34cd // indirect
golang.org/x/mobile v0.0.0-20211207041440-4e6c2922fdee // indirect golang.org/x/mobile v0.0.0-20211207041440-4e6c2922fdee // indirect
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b // indirect
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f // indirect golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f // indirect
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect
golang.org/x/text v0.3.7 // indirect
google.golang.org/appengine v1.6.7 // indirect google.golang.org/appengine v1.6.7 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect

6
go.sum
View File

@@ -163,6 +163,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211213063430-748e38ca8aec h1:3FLiRYO6PlQFDpUU7OEFlWgjGD1jnBIVSJ5SYRWk+9c= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211213063430-748e38ca8aec h1:3FLiRYO6PlQFDpUU7OEFlWgjGD1jnBIVSJ5SYRWk+9c=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211213063430-748e38ca8aec/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211213063430-748e38ca8aec/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-ping/ping v1.1.0 h1:3MCGhVX4fyEUuhsfwPrsEdQw6xspHkv5zHsiSoDFZYw=
github.com/go-ping/ping v1.1.0/go.mod h1:xIFjORFzTxqIV/tDVGO4eDy/bLuSyawEeojSm3GfRGk=
github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU= github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU=
@@ -251,6 +253,7 @@ github.com/google/pprof v0.0.0-20210609004039-a478d1d731e9/go.mod h1:kpwsk12EmLe
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.0.0-20220520183353-fd19c99a87aa/go.mod h1:17drOmN3MwGY7t0e+Ei9b45FFGA3fBs3x36SsCg1hq8= github.com/googleapis/enterprise-certificate-proxy v0.0.0-20220520183353-fd19c99a87aa/go.mod h1:17drOmN3MwGY7t0e+Ei9b45FFGA3fBs3x36SsCg1hq8=
@@ -664,6 +667,7 @@ golang.org/x/sys v0.0.0-20210220050731-9a76102bfb43/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210225134936-a50acf3fe073/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210225134936-a50acf3fe073/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210305230114-8fe3ee5dd75b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210305230114-8fe3ee5dd75b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -700,6 +704,8 @@ golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdp
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.0.0-20220722155259-a9ba230a4035 h1:Q5284mrmYTpACcm+eAKjKJH48BBwSyfJqmmGDTtT8Vc=
golang.org/x/term v0.0.0-20220722155259-a9ba230a4035/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View File

@@ -161,11 +161,11 @@ func DecrimentKey(networkName string, keyvalue string) {
} }
// IsKeyValid - check if key is valid // IsKeyValid - check if key is valid
func IsKeyValid(networkname string, keyvalue string) bool { func IsKeyValid(networkname string, keyvalue string) (string, bool) {
network, err := GetParentNetwork(networkname) network, err := GetParentNetwork(networkname)
if err != nil { if err != nil {
return false return "", false
} }
accesskeys := network.AccessKeys accesskeys := network.AccessKeys
@@ -185,7 +185,7 @@ func IsKeyValid(networkname string, keyvalue string) bool {
isvalid = true isvalid = true
} }
} }
return isvalid return key.Name, isvalid
} }
// RemoveKeySensitiveInfo - remove sensitive key info // RemoveKeySensitiveInfo - remove sensitive key info

View File

@@ -9,7 +9,9 @@ import (
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
"github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic/pro"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/models/promodels"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@@ -95,21 +97,61 @@ func CreateUser(user models.User) (models.User, error) {
// set password to encrypted password // set password to encrypted password
user.Password = string(hash) user.Password = string(hash)
tokenString, _ := CreateUserJWT(user.UserName, user.Networks, user.IsAdmin) tokenString, _ := CreateProUserJWT(user.UserName, user.Networks, user.Groups, user.IsAdmin)
if tokenString == "" { if tokenString == "" {
// returnErrorResponse(w, r, errorResponse) // logic.ReturnErrorResponse(w, r, errorResponse)
return user, err return user, err
} }
SetUserDefaults(&user)
// connect db // connect db
data, err := json.Marshal(&user) data, err := json.Marshal(&user)
if err != nil { if err != nil {
return user, err return user, err
} }
err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME) err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME)
if err != nil {
return user, err return user, err
}
// == PRO == Add user to every network as network user ==
currentNets, err := GetNetworks()
if err != nil {
currentNets = []models.Network{}
}
for i := range currentNets {
newUser := promodels.NetworkUser{
ID: promodels.NetworkUserID(user.UserName),
Clients: []string{},
Nodes: []string{},
}
pro.AddProNetDefaults(&currentNets[i])
if pro.IsUserAllowed(&currentNets[i], user.UserName, user.Groups) {
newUser.AccessLevel = currentNets[i].ProSettings.DefaultAccessLevel
newUser.ClientLimit = currentNets[i].ProSettings.DefaultUserClientLimit
newUser.NodeLimit = currentNets[i].ProSettings.DefaultUserNodeLimit
} else {
newUser.AccessLevel = pro.NO_ACCESS
newUser.ClientLimit = 0
newUser.NodeLimit = 0
}
// legacy
if StringSliceContains(user.Networks, currentNets[i].NetID) {
if !Is_EE {
newUser.AccessLevel = pro.NET_ADMIN
}
}
userErr := pro.CreateNetworkUser(&currentNets[i], &newUser)
if userErr != nil {
logger.Log(0, "failed to add network user data on network", currentNets[i].NetID, "for user", user.UserName)
}
}
// == END PRO ==
return user, nil
} }
// CreateAdmin - creates an admin user // CreateAdmin - creates an admin user
@@ -136,10 +178,10 @@ func VerifyAuthRequest(authRequest models.UserAuthParams) (string, error) {
//Search DB for node with Mac Address. Ignore pending nodes (they should not be able to authenticate with API until approved). //Search DB for node with Mac Address. Ignore pending nodes (they should not be able to authenticate with API until approved).
record, err := database.FetchRecord(database.USERS_TABLE_NAME, authRequest.UserName) record, err := database.FetchRecord(database.USERS_TABLE_NAME, authRequest.UserName)
if err != nil { if err != nil {
return "", errors.New("incorrect credentials") return "", errors.New("error retrieving user from db: " + err.Error())
} }
if err = json.Unmarshal([]byte(record), &result); err != nil { if err = json.Unmarshal([]byte(record), &result); err != nil {
return "", errors.New("incorrect credentials") return "", errors.New("error unmarshalling user json: " + err.Error())
} }
// compare password from request to stored password in database // compare password from request to stored password in database
@@ -150,14 +192,15 @@ func VerifyAuthRequest(authRequest models.UserAuthParams) (string, error) {
} }
//Create a new JWT for the node //Create a new JWT for the node
tokenString, _ := CreateUserJWT(authRequest.UserName, result.Networks, result.IsAdmin) tokenString, _ := CreateProUserJWT(authRequest.UserName, result.Networks, result.Groups, result.IsAdmin)
return tokenString, nil return tokenString, nil
} }
// UpdateUserNetworks - updates the networks of a given user // UpdateUserNetworks - updates the networks of a given user
func UpdateUserNetworks(newNetworks []string, isadmin bool, currentUser *models.User) error { func UpdateUserNetworks(newNetworks, newGroups []string, isadmin bool, currentUser *models.ReturnUser) error {
// check if user exists // check if user exists
if returnedUser, err := GetUser(currentUser.UserName); err != nil { returnedUser, err := GetUser(currentUser.UserName)
if err != nil {
return err return err
} else if returnedUser.IsAdmin { } else if returnedUser.IsAdmin {
return fmt.Errorf("can not make changes to an admin user, attempted to change %s", returnedUser.UserName) return fmt.Errorf("can not make changes to an admin user, attempted to change %s", returnedUser.UserName)
@@ -166,18 +209,46 @@ func UpdateUserNetworks(newNetworks []string, isadmin bool, currentUser *models.
currentUser.IsAdmin = true currentUser.IsAdmin = true
currentUser.Networks = nil currentUser.Networks = nil
} else { } else {
// == PRO ==
currentUser.Groups = newGroups
for _, n := range newNetworks {
if !StringSliceContains(currentUser.Networks, n) {
// make net admin of any network not previously assigned
pro.MakeNetAdmin(n, currentUser.UserName)
}
}
// Compare networks, find networks not in previous
for _, n := range currentUser.Networks {
if !StringSliceContains(newNetworks, n) {
// if user was removed from a network, re-assign access to net default level
if network, err := GetNetwork(n); err == nil {
if network.ProSettings != nil {
ok := pro.AssignAccessLvl(n, currentUser.UserName, network.ProSettings.DefaultAccessLevel)
if ok {
logger.Log(0, "changed", currentUser.UserName, "access level on network", network.NetID, "to", fmt.Sprintf("%d", network.ProSettings.DefaultAccessLevel))
}
}
}
}
}
if err := AdjustGroupPermissions(currentUser); err != nil {
logger.Log(0, "failed to update user", currentUser.UserName, "after group update", err.Error())
}
// == END PRO ==
currentUser.Networks = newNetworks currentUser.Networks = newNetworks
} }
data, err := json.Marshal(currentUser) _, err = UpdateUser(models.User{
if err != nil { UserName: currentUser.UserName,
return err Networks: currentUser.Networks,
} IsAdmin: currentUser.IsAdmin,
if err = database.Insert(currentUser.UserName, string(data), database.USERS_TABLE_NAME); err != nil { Password: "",
return err Groups: currentUser.Groups,
} }, returnedUser)
return nil return err
} }
// UpdateUser - updates a given user // UpdateUser - updates a given user
@@ -187,11 +258,6 @@ func UpdateUser(userchange models.User, user models.User) (models.User, error) {
return models.User{}, err return models.User{}, err
} }
err := ValidateUser(userchange)
if err != nil {
return models.User{}, err
}
queryUser := user.UserName queryUser := user.UserName
if userchange.UserName != "" { if userchange.UserName != "" {
@@ -200,6 +266,9 @@ func UpdateUser(userchange models.User, user models.User) (models.User, error) {
if len(userchange.Networks) > 0 { if len(userchange.Networks) > 0 {
user.Networks = userchange.Networks user.Networks = userchange.Networks
} }
if len(userchange.Groups) > 0 {
user.Groups = userchange.Groups
}
if userchange.Password != "" { if userchange.Password != "" {
// encrypt that password so we never see it again // encrypt that password so we never see it again
hash, err := bcrypt.GenerateFromPassword([]byte(userchange.Password), 5) hash, err := bcrypt.GenerateFromPassword([]byte(userchange.Password), 5)
@@ -212,6 +281,12 @@ func UpdateUser(userchange models.User, user models.User) (models.User, error) {
user.Password = userchange.Password user.Password = userchange.Password
} }
err := ValidateUser(user)
if err != nil {
return models.User{}, err
}
if err = database.DeleteRecord(database.USERS_TABLE_NAME, queryUser); err != nil { if err = database.DeleteRecord(database.USERS_TABLE_NAME, queryUser); err != nil {
return models.User{}, err return models.User{}, err
} }
@@ -256,6 +331,20 @@ func DeleteUser(user string) (bool, error) {
if err != nil { if err != nil {
return false, err return false, err
} }
// == pro - remove user from all network user instances ==
currentNets, err := GetNetworks()
if err != nil {
return true, err
}
for i := range currentNets {
netID := currentNets[i].NetID
if err = pro.DeleteNetworkUser(netID, user); err != nil {
logger.Log(0, "failed to remove", user, "as network user from network", netID, err.Error())
}
}
return true, nil return true, nil
} }
@@ -313,6 +402,9 @@ func IsStateValid(state string) (string, bool) {
if s.Value != "" { if s.Value != "" {
delState(state) delState(state)
} }
if err != nil {
logger.Log(2, "error retrieving oauth state:", err.Error())
}
return s.Value, err == nil return s.Value, err == nil
} }
@@ -320,3 +412,51 @@ func IsStateValid(state string) (string, bool) {
func delState(state string) error { func delState(state string) error {
return database.DeleteRecord(database.SSO_STATE_CACHE, state) return database.DeleteRecord(database.SSO_STATE_CACHE, state)
} }
// PRO
// AdjustGroupPermissions - adjusts a given user's network access based on group changes
func AdjustGroupPermissions(user *models.ReturnUser) error {
networks, err := GetNetworks()
if err != nil {
return err
}
// UPDATE
// go through all networks and see if new group is in
// if access level of current user is greater (value) than network's default
// assign network's default
// DELETE
// if user not allowed on network a
for i := range networks {
AdjustNetworkUserPermissions(user, &networks[i])
}
return nil
}
// AdjustGroupPermissions - adjusts a given user's network access based on group changes
func AdjustNetworkUserPermissions(user *models.ReturnUser, network *models.Network) error {
networkUser, err := pro.GetNetworkUser(
network.NetID,
promodels.NetworkUserID(user.UserName),
)
if err == nil && network.ProSettings != nil {
if pro.IsUserAllowed(network, user.UserName, user.Groups) {
if networkUser.AccessLevel > network.ProSettings.DefaultAccessLevel {
networkUser.AccessLevel = network.ProSettings.DefaultAccessLevel
}
if networkUser.NodeLimit < network.ProSettings.DefaultUserNodeLimit {
networkUser.NodeLimit = network.ProSettings.DefaultUserNodeLimit
}
if networkUser.ClientLimit < network.ProSettings.DefaultUserClientLimit {
networkUser.ClientLimit = network.ProSettings.DefaultUserClientLimit
}
} else {
networkUser.AccessLevel = pro.NO_ACCESS
networkUser.NodeLimit = 0
networkUser.ClientLimit = 0
}
pro.UpdateNetworkUser(network.NetID, networkUser)
}
return err
}

View File

@@ -1,4 +1,4 @@
package controller package logic
import ( import (
"encoding/json" "encoding/json"
@@ -8,7 +8,8 @@ import (
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
) )
func formatError(err error, errType string) models.ErrorResponse { // FormatError - takes ErrorResponse and uses correct code
func FormatError(err error, errType string) models.ErrorResponse {
var status = http.StatusInternalServerError var status = http.StatusInternalServerError
switch errType { switch errType {
@@ -33,7 +34,8 @@ func formatError(err error, errType string) models.ErrorResponse {
return response return response
} }
func returnSuccessResponse(response http.ResponseWriter, request *http.Request, message string) { // ReturnSuccessResponse - processes message and adds header
func ReturnSuccessResponse(response http.ResponseWriter, request *http.Request, message string) {
var httpResponse models.SuccessResponse var httpResponse models.SuccessResponse
httpResponse.Code = http.StatusOK httpResponse.Code = http.StatusOK
httpResponse.Message = message httpResponse.Message = message
@@ -42,7 +44,8 @@ func returnSuccessResponse(response http.ResponseWriter, request *http.Request,
json.NewEncoder(response).Encode(httpResponse) json.NewEncoder(response).Encode(httpResponse)
} }
func returnErrorResponse(response http.ResponseWriter, request *http.Request, errorMessage models.ErrorResponse) { // ReturnErrorResponse - processes error and adds header
func ReturnErrorResponse(response http.ResponseWriter, request *http.Request, errorMessage models.ErrorResponse) {
httpResponse := &models.ErrorResponse{Code: errorMessage.Code, Message: errorMessage.Message} httpResponse := &models.ErrorResponse{Code: errorMessage.Code, Message: errorMessage.Message}
jsonResponse, err := json.Marshal(httpResponse) jsonResponse, err := json.Marshal(httpResponse)
if err != nil { if err != nil {

View File

@@ -183,3 +183,40 @@ func UpdateExtClient(newclientid string, network string, enabled bool, client *m
CreateExtClient(client) CreateExtClient(client)
return client, err return client, err
} }
// GetExtClientsByID - gets the clients of attached gateway
func GetExtClientsByID(nodeid, network string) ([]models.ExtClient, error) {
var result []models.ExtClient
currentClients, err := GetNetworkExtClients(network)
if err != nil {
return result, err
}
for i := range currentClients {
if currentClients[i].IngressGatewayID == nodeid {
result = append(result, currentClients[i])
}
}
return result, nil
}
// GetAllExtClients - gets all ext clients from DB
func GetAllExtClients() ([]models.ExtClient, error) {
var clients = []models.ExtClient{}
currentNetworks, err := GetNetworks()
if err != nil && database.IsEmptyRecord(err) {
return clients, nil
} else if err != nil {
return clients, err
}
for i := range currentNetworks {
netName := currentNetworks[i].NetID
netClients, err := GetNetworkExtClients(netName)
if err != nil {
continue
}
clients = append(clients, netClients...)
}
return clients, nil
}

View File

@@ -53,6 +53,30 @@ func CreateJWT(uuid string, macAddress string, network string) (response string,
return "", err return "", err
} }
// CreateProUserJWT - creates a user jwt token
func CreateProUserJWT(username string, networks, groups []string, isadmin bool) (response string, err error) {
expirationTime := time.Now().Add(60 * 12 * time.Minute)
claims := &models.UserClaims{
UserName: username,
Networks: networks,
IsAdmin: isadmin,
Groups: groups,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "Netmaker",
Subject: fmt.Sprintf("user|%s", username),
IssuedAt: jwt.NewNumericDate(time.Now()),
ExpiresAt: jwt.NewNumericDate(expirationTime),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString(jwtSecretKey)
if err == nil {
return tokenString, nil
}
return "", err
}
// CreateUserJWT - creates a user jwt token // CreateUserJWT - creates a user jwt token
func CreateUserJWT(username string, networks []string, isadmin bool) (response string, err error) { func CreateUserJWT(username string, networks []string, isadmin bool) (response string, err error) {
expirationTime := time.Now().Add(60 * 12 * time.Minute) expirationTime := time.Now().Add(60 * 12 * time.Minute)

65
logic/metrics.go Normal file
View File

@@ -0,0 +1,65 @@
package logic
import (
"encoding/json"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/models"
)
// GetMetrics - gets the metrics
func GetMetrics(nodeid string) (*models.Metrics, error) {
var metrics models.Metrics
record, err := database.FetchRecord(database.METRICS_TABLE_NAME, nodeid)
if err != nil {
if database.IsEmptyRecord(err) {
return &metrics, nil
}
return &metrics, err
}
err = json.Unmarshal([]byte(record), &metrics)
if err != nil {
return &metrics, err
}
return &metrics, nil
}
// UpdateMetrics - updates the metrics of a given client
func UpdateMetrics(nodeid string, metrics *models.Metrics) error {
data, err := json.Marshal(metrics)
if err != nil {
return err
}
return database.Insert(nodeid, string(data), database.METRICS_TABLE_NAME)
}
// DeleteMetrics - deletes metrics of a given node
func DeleteMetrics(nodeid string) error {
return database.DeleteRecord(database.METRICS_TABLE_NAME, nodeid)
}
// CollectServerMetrics - collects metrics for given server node
func CollectServerMetrics(serverID string, networkNodes []models.Node) *models.Metrics {
newServerMetrics := models.Metrics{}
newServerMetrics.Connectivity = make(map[string]models.Metric)
for i := range networkNodes {
currNodeID := networkNodes[i].ID
if currNodeID == serverID {
continue
}
if currMetrics, err := GetMetrics(currNodeID); err == nil {
if currMetrics.Connectivity != nil && currMetrics.Connectivity[serverID].Connected {
metrics := currMetrics.Connectivity[serverID]
metrics.NodeName = networkNodes[i].Name
metrics.IsServer = "no"
newServerMetrics.Connectivity[currNodeID] = metrics
}
} else {
newServerMetrics.Connectivity[currNodeID] = models.Metric{
Connected: false,
Latency: 999,
}
}
}
return &newServerMetrics
}

121
logic/metrics/metrics.go Normal file
View File

@@ -0,0 +1,121 @@
package metrics
import (
"github.com/go-ping/ping"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"golang.zx2c4.com/wireguard/wgctrl"
)
// Collect - collects metrics
func Collect(iface string, peerMap models.PeerMap) (*models.Metrics, error) {
var metrics models.Metrics
metrics.Connectivity = make(map[string]models.Metric)
var wgclient, err = wgctrl.New()
if err != nil {
fillUnconnectedData(&metrics, peerMap)
return &metrics, err
}
defer wgclient.Close()
device, err := wgclient.Device(iface)
if err != nil {
fillUnconnectedData(&metrics, peerMap)
return &metrics, err
}
// TODO handle freebsd??
for i := range device.Peers {
currPeer := device.Peers[i]
id := peerMap[currPeer.PublicKey.String()].ID
address := peerMap[currPeer.PublicKey.String()].Address
if id == "" || address == "" {
logger.Log(0, "attempted to parse metrics for invalid peer from server", id, address)
continue
}
var newMetric = models.Metric{
NodeName: peerMap[currPeer.PublicKey.String()].Name,
IsServer: peerMap[currPeer.PublicKey.String()].IsServer,
}
logger.Log(2, "collecting metrics for peer", address)
newMetric.TotalReceived = currPeer.ReceiveBytes
newMetric.TotalSent = currPeer.TransmitBytes
// get latency
pinger, err := ping.NewPinger(address)
if err != nil {
logger.Log(0, "could not initiliaze ping for metrics on peer address", address, err.Error())
newMetric.Connected = false
newMetric.Latency = 999
} else {
pinger.Count = 1
err = pinger.Run()
if err != nil {
logger.Log(0, "failed ping for metrics on peer address", address, err.Error())
newMetric.Connected = false
newMetric.Latency = 999
} else {
pingStats := pinger.Statistics()
newMetric.Uptime = 1
newMetric.Connected = true
newMetric.Latency = pingStats.AvgRtt.Milliseconds()
}
}
newMetric.TotalTime = 1
metrics.Connectivity[id] = newMetric
}
fillUnconnectedData(&metrics, peerMap)
return &metrics, nil
}
// GetExchangedBytesForNode - get exchanged bytes for current node peers
func GetExchangedBytesForNode(node *models.Node, metrics *models.Metrics) error {
peers, err := logic.GetPeerUpdate(node)
if err != nil {
logger.Log(0, "Failed to get peers: ", err.Error())
return err
}
wgclient, err := wgctrl.New()
if err != nil {
return err
}
defer wgclient.Close()
device, err := wgclient.Device(node.Interface)
if err != nil {
return err
}
for _, currPeer := range device.Peers {
id := peers.PeerIDs[currPeer.PublicKey.String()].ID
address := peers.PeerIDs[currPeer.PublicKey.String()].Address
if id == "" || address == "" {
logger.Log(0, "attempted to parse metrics for invalid peer from server", id, address)
continue
}
logger.Log(2, "collecting exchanged bytes info for peer: ", address)
peerMetric := metrics.Connectivity[id]
peerMetric.TotalReceived = currPeer.ReceiveBytes
peerMetric.TotalSent = currPeer.TransmitBytes
metrics.Connectivity[id] = peerMetric
}
return nil
}
// == used to fill zero value data for non connected peers ==
func fillUnconnectedData(metrics *models.Metrics, peerMap models.PeerMap) {
for r := range peerMap {
id := peerMap[r].ID
if !metrics.Connectivity[id].Connected {
newMetric := models.Metric{
NodeName: peerMap[r].Name,
IsServer: peerMap[r].IsServer,
Uptime: 0,
TotalTime: 1,
Connected: false,
Latency: 999,
PercentUp: 0,
}
metrics.Connectivity[id] = newMetric
}
}
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic/acls/nodeacls" "github.com/gravitl/netmaker/logic/acls/nodeacls"
"github.com/gravitl/netmaker/logic/pro"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/netclient/ncutils"
"github.com/gravitl/netmaker/validation" "github.com/gravitl/netmaker/validation"
@@ -62,6 +63,9 @@ func DeleteNetwork(network string) error {
} else { } else {
logger.Log(1, "could not remove servers before deleting network", network) logger.Log(1, "could not remove servers before deleting network", network)
} }
if err = pro.RemoveAllNetworkUsers(network); err != nil {
logger.Log(0, "failed to remove network users on network delete for network", network, err.Error())
}
return database.DeleteRecord(database.NETWORKS_TABLE_NAME, network) return database.DeleteRecord(database.NETWORKS_TABLE_NAME, network)
} }
return errors.New("node check failed. All nodes must be deleted before deleting network") return errors.New("node check failed. All nodes must be deleted before deleting network")
@@ -84,13 +88,24 @@ func CreateNetwork(network models.Network) (models.Network, error) {
} }
network.AddressRange6 = normalizedRange network.AddressRange6 = normalizedRange
} }
network.SetDefaults() network.SetDefaults()
network.SetNodesLastModified() network.SetNodesLastModified()
network.SetNetworkLastModified() network.SetNetworkLastModified()
pro.AddProNetDefaults(&network)
if len(network.ProSettings.AllowedGroups) == 0 {
network.ProSettings.AllowedGroups = []string{pro.DEFAULT_ALLOWED_GROUPS}
}
err := ValidateNetwork(&network, false) err := ValidateNetwork(&network, false)
if err != nil { if err != nil {
//returnErrorResponse(w, r, formatError(err, "badrequest")) //logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return models.Network{}, err
}
if err = pro.InitializeNetworkUsers(network.NetID); err != nil {
return models.Network{}, err return models.Network{}, err
} }
@@ -98,10 +113,16 @@ func CreateNetwork(network models.Network) (models.Network, error) {
if err != nil { if err != nil {
return models.Network{}, err return models.Network{}, err
} }
if err = database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil { if err = database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
return models.Network{}, err return models.Network{}, err
} }
// == add all current users to network as network users ==
if err = InitializeNetUsers(&network); err != nil {
return network, err
}
return network, nil return network, nil
} }
@@ -526,25 +547,29 @@ func IsNetworkNameUnique(network *models.Network) (bool, error) {
} }
// UpdateNetwork - updates a network with another network's fields // UpdateNetwork - updates a network with another network's fields
func UpdateNetwork(currentNetwork *models.Network, newNetwork *models.Network) (bool, bool, bool, bool, error) { func UpdateNetwork(currentNetwork *models.Network, newNetwork *models.Network) (bool, bool, bool, bool, []string, []string, error) {
if err := ValidateNetwork(newNetwork, true); err != nil { if err := ValidateNetwork(newNetwork, true); err != nil {
return false, false, false, false, err return false, false, false, false, nil, nil, err
} }
if newNetwork.NetID == currentNetwork.NetID { if newNetwork.NetID == currentNetwork.NetID {
hasrangeupdate4 := newNetwork.AddressRange != currentNetwork.AddressRange hasrangeupdate4 := newNetwork.AddressRange != currentNetwork.AddressRange
hasrangeupdate6 := newNetwork.AddressRange6 != currentNetwork.AddressRange6 hasrangeupdate6 := newNetwork.AddressRange6 != currentNetwork.AddressRange6
localrangeupdate := newNetwork.LocalRange != currentNetwork.LocalRange localrangeupdate := newNetwork.LocalRange != currentNetwork.LocalRange
hasholepunchupdate := newNetwork.DefaultUDPHolePunch != currentNetwork.DefaultUDPHolePunch hasholepunchupdate := newNetwork.DefaultUDPHolePunch != currentNetwork.DefaultUDPHolePunch
groupDelta := append(StringDifference(newNetwork.ProSettings.AllowedGroups, currentNetwork.ProSettings.AllowedGroups),
StringDifference(currentNetwork.ProSettings.AllowedGroups, newNetwork.ProSettings.AllowedGroups)...)
userDelta := append(StringDifference(newNetwork.ProSettings.AllowedUsers, currentNetwork.ProSettings.AllowedUsers),
StringDifference(currentNetwork.ProSettings.AllowedUsers, newNetwork.ProSettings.AllowedUsers)...)
data, err := json.Marshal(newNetwork) data, err := json.Marshal(newNetwork)
if err != nil { if err != nil {
return false, false, false, false, err return false, false, false, false, nil, nil, err
} }
newNetwork.SetNetworkLastModified() newNetwork.SetNetworkLastModified()
err = database.Insert(newNetwork.NetID, string(data), database.NETWORKS_TABLE_NAME) err = database.Insert(newNetwork.NetID, string(data), database.NETWORKS_TABLE_NAME)
return hasrangeupdate4, hasrangeupdate6, localrangeupdate, hasholepunchupdate, err return hasrangeupdate4, hasrangeupdate6, localrangeupdate, hasholepunchupdate, groupDelta, userDelta, err
} }
// copy values // copy values
return false, false, false, false, errors.New("failed to update network " + newNetwork.NetID + ", cannot change netid.") return false, false, false, false, nil, nil, errors.New("failed to update network " + newNetwork.NetID + ", cannot change netid.")
} }
// GetNetwork - gets a network from database // GetNetwork - gets a network from database
@@ -596,6 +621,15 @@ func ValidateNetwork(network *models.Network, isUpdate bool) error {
} }
} }
if network.ProSettings != nil {
if network.ProSettings.DefaultAccessLevel < pro.NET_ADMIN || network.ProSettings.DefaultAccessLevel > pro.NO_ACCESS {
return fmt.Errorf("invalid access level")
}
if network.ProSettings.DefaultUserClientLimit < 0 || network.ProSettings.DefaultUserNodeLimit < 0 {
return fmt.Errorf("invalid node/client limit provided")
}
}
return err return err
} }
@@ -627,6 +661,17 @@ func SaveNetwork(network *models.Network) error {
return nil return nil
} }
// NetworkExists - check if network exists
func NetworkExists(name string) (bool, error) {
var network string
var err error
if network, err = database.FetchRecord(database.NETWORKS_TABLE_NAME, name); err != nil {
return false, err
}
return len(network) > 0, nil
}
// == Private == // == Private ==
func networkNodesUpdateAction(networkName string, action string) error { func networkNodesUpdateAction(networkName string, action string) error {

View File

@@ -13,6 +13,8 @@ import (
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic/acls" "github.com/gravitl/netmaker/logic/acls"
"github.com/gravitl/netmaker/logic/acls/nodeacls" "github.com/gravitl/netmaker/logic/acls/nodeacls"
"github.com/gravitl/netmaker/logic/pro"
"github.com/gravitl/netmaker/logic/pro/proacls"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/netclient/ncutils"
"github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/servercfg"
@@ -128,6 +130,7 @@ func UpdateNode(currentNode *models.Node, newNode *models.Node) error {
} }
} }
} }
nodeACLDelta := currentNode.DefaultACL != newNode.DefaultACL
newNode.Fill(currentNode) newNode.Fill(currentNode)
if currentNode.IsServer == "yes" && !validateServer(currentNode, newNode) { if currentNode.IsServer == "yes" && !validateServer(currentNode, newNode) {
@@ -137,7 +140,15 @@ func UpdateNode(currentNode *models.Node, newNode *models.Node) error {
if err := ValidateNode(newNode, true); err != nil { if err := ValidateNode(newNode, true); err != nil {
return err return err
} }
if newNode.ID == currentNode.ID { if newNode.ID == currentNode.ID {
if nodeACLDelta {
if err := updateProNodeACLS(newNode); err != nil {
logger.Log(1, "failed to apply node level ACLs during creation of node", newNode.ID, "-", err.Error())
return err
}
}
newNode.SetLastModified() newNode.SetLastModified()
if data, err := json.Marshal(newNode); err != nil { if data, err := json.Marshal(newNode); err != nil {
return err return err
@@ -145,6 +156,7 @@ func UpdateNode(currentNode *models.Node, newNode *models.Node) error {
return database.Insert(newNode.ID, string(data), database.NODES_TABLE_NAME) return database.Insert(newNode.ID, string(data), database.NODES_TABLE_NAME)
} }
} }
return fmt.Errorf("failed to update node " + currentNode.ID + ", cannot change ID.") return fmt.Errorf("failed to update node " + currentNode.ID + ", cannot change ID.")
} }
@@ -176,9 +188,16 @@ func DeleteNodeByID(node *models.Node, exterminate bool) error {
if err = database.DeleteRecord(database.NODES_TABLE_NAME, key); err != nil { if err = database.DeleteRecord(database.NODES_TABLE_NAME, key); err != nil {
return err return err
} }
if servercfg.IsDNSMode() { if servercfg.IsDNSMode() {
SetDNS() SetDNS()
} }
if node.OwnerID != "" {
err = pro.DissociateNetworkUserNode(node.OwnerID, node.Network, node.ID)
if err != nil {
logger.Log(0, "failed to dissasociate", node.OwnerID, "from node", node.ID, ":", err.Error())
}
}
_, err = nodeacls.RemoveNodeACL(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID)) _, err = nodeacls.RemoveNodeACL(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID))
if err != nil { if err != nil {
@@ -186,6 +205,10 @@ func DeleteNodeByID(node *models.Node, exterminate bool) error {
logger.Log(2, "attempted to remove node ACL for node", node.Name, node.ID) logger.Log(2, "attempted to remove node ACL for node", node.Name, node.ID)
} }
// removeZombie <- node.ID // removeZombie <- node.ID
if err = DeleteMetrics(node.ID); err != nil {
logger.Log(1, "unable to remove metrics from DB for node", node.ID, err.Error())
}
if node.IsServer == "yes" { if node.IsServer == "yes" {
return removeLocalServer(node) return removeLocalServer(node)
} }
@@ -219,6 +242,9 @@ func ValidateNode(node *models.Node, isUpdate bool) error {
_ = v.RegisterValidation("checkyesorno", func(fl validator.FieldLevel) bool { _ = v.RegisterValidation("checkyesorno", func(fl validator.FieldLevel) bool {
return validation.CheckYesOrNo(fl) return validation.CheckYesOrNo(fl)
}) })
_ = v.RegisterValidation("checkyesornoorunset", func(fl validator.FieldLevel) bool {
return validation.CheckYesOrNoOrUnset(fl)
})
err := v.Struct(node) err := v.Struct(node)
return err return err
@@ -255,6 +281,10 @@ func CreateNode(node *models.Node) error {
} }
} }
if node.DefaultACL == "" {
node.DefaultACL = "unset"
}
reverse := node.IsServer == "yes" reverse := node.IsServer == "yes"
if node.Address == "" { if node.Address == "" {
if parentNetwork.IsIPv4 == "yes" { if parentNetwork.IsIPv4 == "yes" {
@@ -281,7 +311,7 @@ func CreateNode(node *models.Node) error {
//Create a JWT for the node //Create a JWT for the node
tokenString, _ := CreateJWT(node.ID, node.MacAddress, node.Network) tokenString, _ := CreateJWT(node.ID, node.MacAddress, node.Network)
if tokenString == "" { if tokenString == "" {
//returnErrorResponse(w, r, errorResponse) //logic.ReturnErrorResponse(w, r, errorResponse)
return err return err
} }
err = ValidateNode(node, false) err = ValidateNode(node, false)
@@ -305,9 +335,19 @@ func CreateNode(node *models.Node) error {
return err return err
} }
if err = updateProNodeACLS(node); err != nil {
logger.Log(1, "failed to apply node level ACLs during creation of node", node.ID, "-", err.Error())
return err
}
if node.IsPending != "yes" { if node.IsPending != "yes" {
DecrimentKey(node.Network, node.AccessKey) DecrimentKey(node.Network, node.AccessKey)
} }
if err = UpdateMetrics(node.ID, &models.Metrics{Connectivity: make(map[string]models.Metric)}); err != nil {
logger.Log(1, "failed to initialize metrics for node", node.Name, node.ID, err.Error())
}
SetNetworkNodesLastModified(node.Network) SetNetworkNodesLastModified(node.Network)
if servercfg.IsDNSMode() { if servercfg.IsDNSMode() {
err = SetDNS() err = SetDNS()
@@ -435,6 +475,7 @@ func SetNodeDefaults(node *models.Node) {
node.SetDefaultIsK8S() node.SetDefaultIsK8S()
node.SetDefaultIsHub() node.SetDefaultIsHub()
node.SetDefaultConnected() node.SetDefaultConnected()
node.SetDefaultACL()
} }
// GetRecordKey - get record key // GetRecordKey - get record key
@@ -677,3 +718,19 @@ func findNode(ip string) (*models.Node, error) {
} }
return nil, errors.New("node not found") return nil, errors.New("node not found")
} }
// == PRO ==
func updateProNodeACLS(node *models.Node) error {
// == PRO node ACLs ==
networkNodes, err := GetNetworkNodes(node.Network)
if err != nil {
return err
}
if err = proacls.AdjustNodeAcls(node, networkNodes[:]); err != nil {
return err
}
return nil
}
// == END PRO ==

View File

@@ -31,6 +31,7 @@ func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) {
} else if network.IsPointToSite == "yes" && node.IsHub != "yes" { } else if network.IsPointToSite == "yes" && node.IsHub != "yes" {
isP2S = true isP2S = true
} }
var peerMap = make(models.PeerMap)
// udppeers = the peers parsed from the local interface // udppeers = the peers parsed from the local interface
// gives us correct port to reach // gives us correct port to reach
@@ -149,14 +150,24 @@ func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) {
} }
peers = append(peers, peerData) peers = append(peers, peerData)
peerMap[peer.PublicKey] = models.IDandAddr{
Name: peer.Name,
ID: peer.ID,
Address: peer.PrimaryAddress(),
IsServer: peer.IsServer,
}
if peer.IsServer == "yes" { if peer.IsServer == "yes" {
serverNodeAddresses = append(serverNodeAddresses, models.ServerAddr{IsLeader: IsLeader(&peer), Address: peer.Address}) serverNodeAddresses = append(serverNodeAddresses, models.ServerAddr{IsLeader: IsLeader(&peer), Address: peer.Address})
} }
} }
if node.IsIngressGateway == "yes" { if node.IsIngressGateway == "yes" {
extPeers, err := getExtPeers(node) extPeers, idsAndAddr, err := getExtPeers(node)
if err == nil { if err == nil {
peers = append(peers, extPeers...) peers = append(peers, extPeers...)
for i := range idsAndAddr {
peerMap[idsAndAddr[i].ID] = idsAndAddr[i]
}
} else if !database.IsEmptyRecord(err) { } else if !database.IsEmptyRecord(err) {
logger.Log(1, "error retrieving external clients:", err.Error()) logger.Log(1, "error retrieving external clients:", err.Error())
} }
@@ -167,14 +178,16 @@ func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) {
peerUpdate.Peers = peers peerUpdate.Peers = peers
peerUpdate.ServerAddrs = serverNodeAddresses peerUpdate.ServerAddrs = serverNodeAddresses
peerUpdate.DNS = getPeerDNS(node.Network) peerUpdate.DNS = getPeerDNS(node.Network)
peerUpdate.PeerIDs = peerMap
return peerUpdate, nil return peerUpdate, nil
} }
func getExtPeers(node *models.Node) ([]wgtypes.PeerConfig, error) { func getExtPeers(node *models.Node) ([]wgtypes.PeerConfig, []models.IDandAddr, error) {
var peers []wgtypes.PeerConfig var peers []wgtypes.PeerConfig
var idsAndAddr []models.IDandAddr
extPeers, err := GetExtPeersList(node) extPeers, err := GetExtPeersList(node)
if err != nil { if err != nil {
return peers, err return peers, idsAndAddr, err
} }
for _, extPeer := range extPeers { for _, extPeer := range extPeers {
pubkey, err := wgtypes.ParseKey(extPeer.PublicKey) pubkey, err := wgtypes.ParseKey(extPeer.PublicKey)
@@ -208,14 +221,24 @@ func getExtPeers(node *models.Node) ([]wgtypes.PeerConfig, error) {
allowedips = append(allowedips, addr6) allowedips = append(allowedips, addr6)
} }
} }
primaryAddr := extPeer.Address
if primaryAddr == "" {
primaryAddr = extPeer.Address6
}
peer = wgtypes.PeerConfig{ peer = wgtypes.PeerConfig{
PublicKey: pubkey, PublicKey: pubkey,
ReplaceAllowedIPs: true, ReplaceAllowedIPs: true,
AllowedIPs: allowedips, AllowedIPs: allowedips,
} }
peers = append(peers, peer) peers = append(peers, peer)
idsAndAddr = append(idsAndAddr, models.IDandAddr{
ID: peer.PublicKey.String(),
Address: primaryAddr,
})
} }
return peers, nil return peers, idsAndAddr, nil
} }
@@ -281,7 +304,7 @@ func GetAllowedIPs(node, peer *models.Node) []net.IPNet {
// handle ingress gateway peers // handle ingress gateway peers
if peer.IsIngressGateway == "yes" { if peer.IsIngressGateway == "yes" {
extPeers, err := getExtPeers(peer) extPeers, _, err := getExtPeers(peer)
if err != nil { if err != nil {
logger.Log(2, "could not retrieve ext peers for ", peer.Name, err.Error()) logger.Log(2, "could not retrieve ext peers for ", peer.Name, err.Error())
} }
@@ -333,7 +356,7 @@ func GetAllowedIPs(node, peer *models.Node) []net.IPNet {
allowedips = append(allowedips, extAllowedIPs...) allowedips = append(allowedips, extAllowedIPs...)
} }
if relayedNode.IsIngressGateway == "yes" { if relayedNode.IsIngressGateway == "yes" {
extPeers, err := getExtPeers(relayedNode) extPeers, _, err := getExtPeers(relayedNode)
if err == nil { if err == nil {
for _, extPeer := range extPeers { for _, extPeer := range extPeers {
allowedips = append(allowedips, extPeer.AllowedIPs...) allowedips = append(allowedips, extPeer.AllowedIPs...)
@@ -486,7 +509,7 @@ func GetPeerUpdateForRelayedNode(node *models.Node, udppeers map[string]string)
} }
//if ingress add extclients //if ingress add extclients
if node.IsIngressGateway == "yes" { if node.IsIngressGateway == "yes" {
extPeers, err := getExtPeers(node) extPeers, _, err := getExtPeers(node)
if err == nil { if err == nil {
peers = append(peers, extPeers...) peers = append(peers, extPeers...)
} else { } else {

View File

@@ -0,0 +1,57 @@
package netcache
import (
"encoding/json"
"fmt"
"time"
"github.com/gravitl/netmaker/database"
)
const (
expirationTime = time.Minute * 5
)
// CValue - the cache object for a network
type CValue struct {
Network string `json:"network"`
Value string `json:"value"`
Pass string `json:"pass"`
User string `json:"user"`
Expiration time.Time `json:"expiration"`
}
var errExpired = fmt.Errorf("expired")
// Set - sets a value to a key in db
func Set(k string, newValue *CValue) error {
newValue.Expiration = time.Now().Add(expirationTime)
newData, err := json.Marshal(newValue)
if err != nil {
return err
}
return database.Insert(k, string(newData), database.CACHE_TABLE_NAME)
}
// Get - gets a value from db, if expired, return err
func Get(k string) (*CValue, error) {
record, err := database.FetchRecord(database.CACHE_TABLE_NAME, k)
if err != nil {
return nil, err
}
var entry CValue
if err := json.Unmarshal([]byte(record), &entry); err != nil {
return nil, err
}
if time.Now().After(entry.Expiration) {
return nil, errExpired
}
return &entry, nil
}
// Del - deletes a value from db
func Del(k string) error {
return database.DeleteRecord(database.CACHE_TABLE_NAME, k)
}

68
logic/pro/networks.go Normal file
View File

@@ -0,0 +1,68 @@
package pro
import (
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/models/promodels"
)
// AddProNetDefaults - adds default values to a network model
func AddProNetDefaults(network *models.Network) {
if network.ProSettings == nil {
newProSettings := promodels.ProNetwork{
DefaultAccessLevel: NO_ACCESS,
DefaultUserNodeLimit: 0,
DefaultUserClientLimit: 0,
AllowedUsers: []string{},
AllowedGroups: []string{DEFAULT_ALLOWED_GROUPS},
}
network.ProSettings = &newProSettings
}
if network.ProSettings.AllowedUsers == nil {
network.ProSettings.AllowedUsers = []string{}
}
if network.ProSettings.AllowedGroups == nil {
network.ProSettings.AllowedGroups = []string{DEFAULT_ALLOWED_GROUPS}
}
}
// isUserGroupAllowed - checks if a user group is allowed on a network
func isUserGroupAllowed(network *models.Network, groupName string) bool {
if network.ProSettings != nil {
if len(network.ProSettings.AllowedGroups) > 0 {
for i := range network.ProSettings.AllowedGroups {
currentGroup := network.ProSettings.AllowedGroups[i]
if currentGroup == DEFAULT_ALLOWED_GROUPS || currentGroup == groupName {
return true
}
}
}
}
return false
}
func isUserInAllowedUsers(network *models.Network, userName string) bool {
if network.ProSettings != nil {
if len(network.ProSettings.AllowedUsers) > 0 {
for i := range network.ProSettings.AllowedUsers {
currentUser := network.ProSettings.AllowedUsers[i]
if currentUser == DEFAULT_ALLOWED_USERS || currentUser == userName {
return true
}
}
}
}
return false
}
// IsUserAllowed - checks if given username + groups if a user is allowed on network
func IsUserAllowed(network *models.Network, userName string, groups []string) bool {
isGroupAllowed := false
for _, g := range groups {
if isUserGroupAllowed(network, g) {
isGroupAllowed = true
break
}
}
return isUserInAllowedUsers(network, userName) || isGroupAllowed
}

View File

@@ -0,0 +1,64 @@
package pro
import (
"testing"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/models/promodels"
"github.com/stretchr/testify/assert"
)
func TestNetworkProSettings(t *testing.T) {
t.Run("Uninitialized with pro", func(t *testing.T) {
network := models.Network{
NetID: "helloworld",
}
assert.Nil(t, network.ProSettings)
})
t.Run("Initialized with pro", func(t *testing.T) {
network := models.Network{
NetID: "helloworld",
}
AddProNetDefaults(&network)
assert.NotNil(t, network.ProSettings)
})
t.Run("Net Zero Defaults set correctly with Pro", func(t *testing.T) {
network := models.Network{
NetID: "helloworld",
}
AddProNetDefaults(&network)
assert.NotNil(t, network.ProSettings)
assert.Equal(t, NO_ACCESS, network.ProSettings.DefaultAccessLevel)
assert.Equal(t, 0, network.ProSettings.DefaultUserClientLimit)
assert.Equal(t, 0, network.ProSettings.DefaultUserNodeLimit)
})
t.Run("Net Defaults set correctly with Pro", func(t *testing.T) {
network := models.Network{
NetID: "helloworld",
ProSettings: &promodels.ProNetwork{
DefaultAccessLevel: NET_ADMIN,
DefaultUserNodeLimit: 10,
DefaultUserClientLimit: 25,
},
}
AddProNetDefaults(&network)
assert.NotNil(t, network.ProSettings)
assert.Equal(t, NET_ADMIN, network.ProSettings.DefaultAccessLevel)
assert.Equal(t, 25, network.ProSettings.DefaultUserClientLimit)
assert.Equal(t, 10, network.ProSettings.DefaultUserNodeLimit)
})
t.Run("Net Defaults set to allow all groups/users", func(t *testing.T) {
network := models.Network{
NetID: "helloworld",
ProSettings: &promodels.ProNetwork{
DefaultAccessLevel: NET_ADMIN,
DefaultUserNodeLimit: 10,
DefaultUserClientLimit: 25,
},
}
AddProNetDefaults(&network)
assert.NotNil(t, network.ProSettings)
assert.Equal(t, len(network.ProSettings.AllowedGroups), 1)
assert.Equal(t, len(network.ProSettings.AllowedUsers), 0)
})
}

252
logic/pro/networkuser.go Normal file
View File

@@ -0,0 +1,252 @@
package pro
import (
"encoding/json"
"fmt"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/models/promodels"
)
// InitializeNetworkUsers - intializes network users for a given network
func InitializeNetworkUsers(network string) error {
_, err := database.FetchRecord(database.NETWORK_USER_TABLE_NAME, network)
if err != nil && database.IsEmptyRecord(err) {
newNetUserMap := make(promodels.NetworkUserMap)
netUserData, err := json.Marshal(newNetUserMap)
if err != nil {
return err
}
return database.Insert(network, string(netUserData), database.NETWORK_USER_TABLE_NAME)
}
return err
}
// GetNetworkUsers - gets the network users table
func GetNetworkUsers(network string) (promodels.NetworkUserMap, error) {
currentUsers, err := database.FetchRecord(database.NETWORK_USER_TABLE_NAME, network)
if err != nil {
return nil, err
}
var userMap promodels.NetworkUserMap
if err = json.Unmarshal([]byte(currentUsers), &userMap); err != nil {
return nil, err
}
return userMap, nil
}
// CreateNetworkUser - adds a network user to db
func CreateNetworkUser(network *models.Network, user *promodels.NetworkUser) error {
if DoesNetworkUserExist(network.NetID, user.ID) {
return nil
}
currentUsers, err := GetNetworkUsers(network.NetID)
if err != nil {
return err
}
user.SetDefaults()
currentUsers.Add(user)
data, err := json.Marshal(currentUsers)
if err != nil {
return err
}
return database.Insert(network.NetID, string(data), database.NETWORK_USER_TABLE_NAME)
}
// DeleteNetworkUser - deletes a network user and removes from all networks
func DeleteNetworkUser(network, userid string) error {
currentUsers, err := GetNetworkUsers(network)
if err != nil {
return err
}
currentUsers.Delete(promodels.NetworkUserID(userid))
data, err := json.Marshal(currentUsers)
if err != nil {
return err
}
return database.Insert(network, string(data), database.NETWORK_USER_TABLE_NAME)
}
// DissociateNetworkUserNode - removes a node from a given user's node list
func DissociateNetworkUserNode(userid, networkid, nodeid string) error {
nuser, err := GetNetworkUser(networkid, promodels.NetworkUserID(userid))
if err != nil {
return err
}
for i, n := range nuser.Nodes {
if n == nodeid {
nuser.Nodes = removeStringIndex(nuser.Nodes, i)
break
}
}
return UpdateNetworkUser(networkid, nuser)
}
// DissociateNetworkUserClient - removes a client from a given user's client list
func DissociateNetworkUserClient(userid, networkid, clientid string) error {
nuser, err := GetNetworkUser(networkid, promodels.NetworkUserID(userid))
if err != nil {
return err
}
for i, n := range nuser.Clients {
if n == clientid {
nuser.Clients = removeStringIndex(nuser.Clients, i)
break
}
}
return UpdateNetworkUser(networkid, nuser)
}
// AssociateNetworkUserClient - removes a client from a given user's client list
func AssociateNetworkUserClient(userid, networkid, clientid string) error {
nuser, err := GetNetworkUser(networkid, promodels.NetworkUserID(userid))
if err != nil {
return err
}
var found bool
for _, n := range nuser.Clients {
if n == clientid {
found = true
break
}
}
if found {
return nil
} else {
nuser.Clients = append(nuser.Clients, clientid)
}
return UpdateNetworkUser(networkid, nuser)
}
func removeStringIndex(s []string, index int) []string {
ret := make([]string, 0)
ret = append(ret, s[:index]...)
return append(ret, s[index+1:]...)
}
// GetNetworkUser - fetches a network user from a given network
func GetNetworkUser(network string, userID promodels.NetworkUserID) (*promodels.NetworkUser, error) {
currentUsers, err := GetNetworkUsers(network)
if err != nil {
return nil, err
}
if currentUsers[userID].ID == "" {
return nil, fmt.Errorf("user %s does not exist", userID)
}
currentNetUser := currentUsers[userID]
return &currentNetUser, nil
}
// DoesNetworkUserExist - check if networkuser exists
func DoesNetworkUserExist(network string, userID promodels.NetworkUserID) bool {
_, err := GetNetworkUser(network, userID)
return err == nil
}
// UpdateNetworkUser - gets a network user from given network
func UpdateNetworkUser(network string, newUser *promodels.NetworkUser) error {
currentUsers, err := GetNetworkUsers(network)
if err != nil {
return err
}
currentUsers[newUser.ID] = *newUser
newUsersData, err := json.Marshal(&currentUsers)
if err != nil {
return err
}
return database.Insert(network, string(newUsersData), database.NETWORK_USER_TABLE_NAME)
}
// RemoveAllNetworkUsers - removes all network users from given network
func RemoveAllNetworkUsers(network string) error {
return database.DeleteRecord(database.NETWORK_USER_TABLE_NAME, network)
}
// IsUserNodeAllowed - given a list of nodes, determine if the user's node is allowed based on ID
// Checks if node is in given nodes list as well as being in user's list
func IsUserNodeAllowed(nodes []models.Node, network, userID, nodeID string) bool {
netUser, err := GetNetworkUser(network, promodels.NetworkUserID(userID))
if err != nil {
return false
}
for i := range nodes {
if nodes[i].ID == nodeID {
for j := range netUser.Nodes {
if netUser.Nodes[j] == nodeID {
return true
}
}
}
}
return false
}
// IsUserClientAllowed - given a list of clients, determine if the user's client is allowed based on ID
// Checks if client is in given ext client list as well as being in user's list
func IsUserClientAllowed(clients []models.ExtClient, network, userID, clientID string) bool {
netUser, err := GetNetworkUser(network, promodels.NetworkUserID(userID))
if err != nil {
return false
}
for i := range clients {
if clients[i].ClientID == clientID {
for j := range netUser.Clients {
if netUser.Clients[j] == clientID {
return true
}
}
}
}
return false
}
// IsUserNetAdmin - checks if a user is a net admin or not
func IsUserNetAdmin(network, userID string) bool {
var isAdmin bool
user, err := GetNetworkUser(network, promodels.NetworkUserID(userID))
if err != nil {
return isAdmin
}
return user.AccessLevel == NET_ADMIN
}
// MakeNetAdmin - makes a given user a network admin on given network
func MakeNetAdmin(network, userID string) (ok bool) {
user, err := GetNetworkUser(network, promodels.NetworkUserID(userID))
if err != nil {
return ok
}
user.AccessLevel = NET_ADMIN
if err = UpdateNetworkUser(network, user); err != nil {
return ok
}
return true
}
// AssignAccessLvl - gives a user a specified access level
func AssignAccessLvl(network, userID string, accesslvl int) (ok bool) {
user, err := GetNetworkUser(network, promodels.NetworkUserID(userID))
if err != nil {
return ok
}
user.AccessLevel = accesslvl
if err = UpdateNetworkUser(network, user); err != nil {
return ok
}
return true
}

View File

@@ -0,0 +1,98 @@
package pro
import (
"testing"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/models/promodels"
"github.com/stretchr/testify/assert"
)
func TestNetworkUserLogic(t *testing.T) {
database.InitializeDatabase()
networkUser := promodels.NetworkUser{
ID: "helloworld",
}
network := models.Network{
NetID: "skynet",
AddressRange: "192.168.0.0/24",
}
nodes := []models.Node{
models.Node{ID: "coolnode"},
}
clients := []models.ExtClient{
models.ExtClient{
ClientID: "coolclient",
},
}
AddProNetDefaults(&network)
t.Run("Net Users initialized successfully", func(t *testing.T) {
err := InitializeNetworkUsers(network.NetID)
assert.Nil(t, err)
})
t.Run("Error when no network users", func(t *testing.T) {
user, err := GetNetworkUser(network.NetID, networkUser.ID)
assert.Nil(t, user)
assert.NotNil(t, err)
})
t.Run("Successful net user create", func(t *testing.T) {
DeleteNetworkUser(network.NetID, string(networkUser.ID))
err := CreateNetworkUser(&network, &networkUser)
assert.Nil(t, err)
user, err := GetNetworkUser(network.NetID, networkUser.ID)
assert.NotNil(t, user)
assert.Nil(t, err)
assert.Equal(t, 0, user.AccessLevel)
assert.Equal(t, 0, user.ClientLimit)
})
t.Run("Successful net user update", func(t *testing.T) {
networkUser.AccessLevel = 0
networkUser.ClientLimit = 1
err := UpdateNetworkUser(network.NetID, &networkUser)
assert.Nil(t, err)
user, err := GetNetworkUser(network.NetID, networkUser.ID)
assert.NotNil(t, user)
assert.Nil(t, err)
assert.Equal(t, 0, user.AccessLevel)
assert.Equal(t, 1, user.ClientLimit)
})
t.Run("Successful net user node isallowed", func(t *testing.T) {
networkUser.Nodes = append(networkUser.Nodes, "coolnode")
err := UpdateNetworkUser(network.NetID, &networkUser)
assert.Nil(t, err)
isUserNodeAllowed := IsUserNodeAllowed(nodes[:], network.NetID, string(networkUser.ID), "coolnode")
assert.True(t, isUserNodeAllowed)
})
t.Run("Successful net user node not allowed", func(t *testing.T) {
isUserNodeAllowed := IsUserNodeAllowed(nodes[:], network.NetID, string(networkUser.ID), "notanode")
assert.False(t, isUserNodeAllowed)
})
t.Run("Successful net user client isallowed", func(t *testing.T) {
networkUser.Clients = append(networkUser.Clients, "coolclient")
err := UpdateNetworkUser(network.NetID, &networkUser)
assert.Nil(t, err)
isUserClientAllowed := IsUserClientAllowed(clients[:], network.NetID, string(networkUser.ID), "coolclient")
assert.True(t, isUserClientAllowed)
})
t.Run("Successful net user client not allowed", func(t *testing.T) {
isUserClientAllowed := IsUserClientAllowed(clients[:], network.NetID, string(networkUser.ID), "notaclient")
assert.False(t, isUserClientAllowed)
})
t.Run("Successful net user delete", func(t *testing.T) {
err := DeleteNetworkUser(network.NetID, string(networkUser.ID))
assert.Nil(t, err)
user, err := GetNetworkUser(network.NetID, networkUser.ID)
assert.Nil(t, user)
assert.NotNil(t, err)
})
}

View File

@@ -0,0 +1,35 @@
package proacls
import (
"github.com/gravitl/netmaker/logic/acls"
"github.com/gravitl/netmaker/logic/acls/nodeacls"
"github.com/gravitl/netmaker/models"
)
// AdjustNodeAcls - adjusts ACLs based on a node's default value
func AdjustNodeAcls(node *models.Node, networkNodes []models.Node) error {
networkID := nodeacls.NetworkID(node.Network)
nodeID := nodeacls.NodeID(node.ID)
currentACLs, err := nodeacls.FetchAllACLs(networkID)
if err != nil {
return err
}
for i := range networkNodes {
currentNodeID := nodeacls.NodeID(networkNodes[i].ID)
if currentNodeID == nodeID {
continue
}
// 2 cases
// both allow - allow
// either 1 denies - deny
if node.DoesACLDeny() || networkNodes[i].DoesACLDeny() {
currentACLs.ChangeAccess(acls.AclID(nodeID), acls.AclID(currentNodeID), acls.NotAllowed)
} else if node.DoesACLAllow() || networkNodes[i].DoesACLAllow() {
currentACLs.ChangeAccess(acls.AclID(nodeID), acls.AclID(currentNodeID), acls.Allowed)
}
}
_, err = currentACLs.Save(acls.ContainerID(node.Network))
return err
}

20
logic/pro/types.go Normal file
View File

@@ -0,0 +1,20 @@
package pro
const (
// == NET ACCESS END == indicates access for system admin (control of netmaker)
// NET_ADMIN - indicates access for network admin (control of network)
NET_ADMIN = 0
// NODE_ACCESS - indicates access for
NODE_ACCESS = 1
// CLIENT_ACCESS - indicates access for network user (limited to nodes + ext clients)
CLIENT_ACCESS = 2
// NO_ACCESS - indicates user has no access to network
NO_ACCESS = 3
// == NET ACCESS END ==
// DEFAULT_ALLOWED_GROUPS - default user group for all networks
DEFAULT_ALLOWED_GROUPS = "*"
// DEFAULT_ALLOWED_USERS - default allowed users for a network
DEFAULT_ALLOWED_USERS = "*"
// DB_GROUPS_KEY - represents db groups
DB_GROUPS_KEY = "netmaker-groups"
)

80
logic/pro/usergroups.go Normal file
View File

@@ -0,0 +1,80 @@
package pro
import (
"encoding/json"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/models/promodels"
)
// InitializeGroups - initialize groups data structure if not present in the DB
func InitializeGroups() error {
if !DoesUserGroupExist(DEFAULT_ALLOWED_GROUPS) {
return InsertUserGroup(DEFAULT_ALLOWED_GROUPS)
}
return nil
}
// InsertUserGroup - inserts a group into the
func InsertUserGroup(groupName promodels.UserGroupName) error {
currentGroups, err := GetUserGroups()
if err != nil {
return err
}
currentGroups[groupName] = promodels.Void{}
newData, err := json.Marshal(&currentGroups)
if err != nil {
return err
}
return database.Insert(DB_GROUPS_KEY, string(newData), database.USER_GROUPS_TABLE_NAME)
}
// DeleteUserGroup - deletes a group from database
func DeleteUserGroup(groupName promodels.UserGroupName) error {
var newGroups promodels.UserGroups
currentGroupRecords, err := database.FetchRecord(database.USER_GROUPS_TABLE_NAME, DB_GROUPS_KEY)
if err != nil && !database.IsEmptyRecord(err) {
return err
}
if err = json.Unmarshal([]byte(currentGroupRecords), &newGroups); err != nil {
return err
}
delete(newGroups, groupName)
newData, err := json.Marshal(&newGroups)
if err != nil {
return err
}
return database.Insert(DB_GROUPS_KEY, string(newData), database.USER_GROUPS_TABLE_NAME)
}
// GetUserGroups - get groups of users
func GetUserGroups() (promodels.UserGroups, error) {
var returnGroups promodels.UserGroups
groupsRecord, err := database.FetchRecord(database.USER_GROUPS_TABLE_NAME, DB_GROUPS_KEY)
if err != nil {
if database.IsEmptyRecord(err) {
return make(promodels.UserGroups, 1), nil
}
return returnGroups, err
}
if err = json.Unmarshal([]byte(groupsRecord), &returnGroups); err != nil {
return returnGroups, err
}
return returnGroups, nil
}
// DoesUserGroupExist - checks if a user group exists
func DoesUserGroupExist(group promodels.UserGroupName) bool {
currentGroups, err := GetUserGroups()
if err != nil {
return true
}
for k := range currentGroups {
if k == group {
return true
}
}
return false
}

View File

@@ -0,0 +1,43 @@
package pro
import (
"testing"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/models/promodels"
"github.com/stretchr/testify/assert"
)
func TestUserGroupLogic(t *testing.T) {
database.InitializeDatabase()
t.Run("User Groups initialized successfully", func(t *testing.T) {
err := InitializeGroups()
assert.Nil(t, err)
})
t.Run("Check for default group", func(t *testing.T) {
groups, err := GetUserGroups()
assert.Nil(t, err)
var hasdefault bool
for k := range groups {
if string(k) == DEFAULT_ALLOWED_GROUPS {
hasdefault = true
}
}
assert.True(t, hasdefault)
})
t.Run("User Groups created successfully", func(t *testing.T) {
err := InsertUserGroup(promodels.UserGroupName("group1"))
assert.Nil(t, err)
err = InsertUserGroup(promodels.UserGroupName("group2"))
assert.Nil(t, err)
})
t.Run("User Groups deleted successfully", func(t *testing.T) {
err := DeleteUserGroup(promodels.UserGroupName("group1"))
assert.Nil(t, err)
assert.False(t, DoesUserGroupExist(promodels.UserGroupName("group1")))
})
}

206
logic/security.go Normal file
View File

@@ -0,0 +1,206 @@
package logic
import (
"encoding/json"
"net/http"
"strings"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logic/pro"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/models/promodels"
"github.com/gravitl/netmaker/servercfg"
)
const (
// ALL_NETWORK_ACCESS - represents all networks
ALL_NETWORK_ACCESS = "THIS_USER_HAS_ALL"
master_uname = "masteradministrator"
Unauthorized_Msg = "unauthorized"
Unauthorized_Err = models.Error(Unauthorized_Msg)
)
// SecurityCheck - Check if user has appropriate permissions
func SecurityCheck(reqAdmin bool, next http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var errorResponse = models.ErrorResponse{
Code: http.StatusUnauthorized, Message: Unauthorized_Msg,
}
var params = mux.Vars(r)
bearerToken := r.Header.Get("Authorization")
// to have a custom DNS service adding entries
// we should refactor this, but is for the special case of an external service to query the DNS api
if strings.Contains(r.RequestURI, "/dns") && strings.ToUpper(r.Method) == "GET" && authenticateDNSToken(bearerToken) {
// do dns stuff
r.Header.Set("user", "nameserver")
networks, _ := json.Marshal([]string{ALL_NETWORK_ACCESS})
r.Header.Set("networks", string(networks))
next.ServeHTTP(w, r)
return
}
var networkName = params["networkname"]
if len(networkName) == 0 {
networkName = params["network"]
}
networks, username, err := UserPermissions(reqAdmin, networkName, bearerToken)
if err != nil {
ReturnErrorResponse(w, r, errorResponse)
return
}
networksJson, err := json.Marshal(&networks)
if err != nil {
ReturnErrorResponse(w, r, errorResponse)
return
}
r.Header.Set("user", username)
r.Header.Set("networks", string(networksJson))
next.ServeHTTP(w, r)
}
}
// NetUserSecurityCheck - Check if network user has appropriate permissions
func NetUserSecurityCheck(isNodes, isClients bool, next http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var errorResponse = models.ErrorResponse{
Code: http.StatusUnauthorized, Message: "unauthorized",
}
r.Header.Set("ismaster", "no")
var params = mux.Vars(r)
var netUserName = params["networkuser"]
var network = params["network"]
bearerToken := r.Header.Get("Authorization")
var tokenSplit = strings.Split(bearerToken, " ")
var authToken = ""
if len(tokenSplit) < 2 {
ReturnErrorResponse(w, r, errorResponse)
return
} else {
authToken = tokenSplit[1]
}
isMasterAuthenticated := authenticateMaster(authToken)
if isMasterAuthenticated {
r.Header.Set("user", "master token user")
r.Header.Set("ismaster", "yes")
next.ServeHTTP(w, r)
return
}
userName, _, isadmin, err := VerifyUserToken(authToken)
if err != nil {
ReturnErrorResponse(w, r, errorResponse)
return
}
r.Header.Set("user", userName)
if isadmin {
next.ServeHTTP(w, r)
return
}
if isNodes || isClients {
necessaryAccess := pro.NET_ADMIN
if isClients {
necessaryAccess = pro.CLIENT_ACCESS
}
if isNodes {
necessaryAccess = pro.NODE_ACCESS
}
u, err := pro.GetNetworkUser(network, promodels.NetworkUserID(userName))
if err != nil {
ReturnErrorResponse(w, r, errorResponse)
return
}
if u.AccessLevel > necessaryAccess {
ReturnErrorResponse(w, r, errorResponse)
return
}
} else if netUserName != userName {
ReturnErrorResponse(w, r, errorResponse)
return
}
next.ServeHTTP(w, r)
}
}
// UserPermissions - checks token stuff
func UserPermissions(reqAdmin bool, netname string, token string) ([]string, string, error) {
var tokenSplit = strings.Split(token, " ")
var authToken = ""
userNetworks := []string{}
if len(tokenSplit) < 2 {
return userNetworks, "", Unauthorized_Err
} else {
authToken = tokenSplit[1]
}
//all endpoints here require master so not as complicated
if authenticateMaster(authToken) {
return []string{ALL_NETWORK_ACCESS}, master_uname, nil
}
username, networks, isadmin, err := VerifyUserToken(authToken)
if err != nil {
return nil, username, Unauthorized_Err
}
if !isadmin && reqAdmin {
return nil, username, Unauthorized_Err
}
userNetworks = networks
if isadmin {
return []string{ALL_NETWORK_ACCESS}, username, nil
}
// check network admin access
if len(netname) > 0 && (!authenticateNetworkUser(netname, userNetworks) || len(userNetworks) == 0) {
return nil, username, Unauthorized_Err
}
if !pro.IsUserNetAdmin(netname, username) {
return nil, "", Unauthorized_Err
}
return userNetworks, username, nil
}
// Consider a more secure way of setting master key
func authenticateMaster(tokenString string) bool {
return tokenString == servercfg.GetMasterKey() && servercfg.GetMasterKey() != ""
}
func authenticateNetworkUser(network string, userNetworks []string) bool {
networkexists, err := NetworkExists(network)
if (err != nil && !database.IsEmptyRecord(err)) || !networkexists {
return false
}
return StringSliceContains(userNetworks, network)
}
//Consider a more secure way of setting master key
func authenticateDNSToken(tokenString string) bool {
tokens := strings.Split(tokenString, " ")
if len(tokens) < 2 {
return false
}
return tokens[1] == servercfg.GetDNSKey()
}
func ContinueIfUserMatch(next http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var errorResponse = models.ErrorResponse{
Code: http.StatusUnauthorized, Message: Unauthorized_Msg,
}
var params = mux.Vars(r)
var requestedUser = params["username"]
if requestedUser != r.Header.Get("user") {
ReturnErrorResponse(w, r, errorResponse)
return
}
next.ServeHTTP(w, r)
}
}

View File

@@ -18,6 +18,8 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
var EnterpriseCheckFuncs []interface{}
// == Join, Checkin, and Leave for Server == // == Join, Checkin, and Leave for Server ==
// KUBERNETES_LISTEN_PORT - starting port for Kubernetes in order to use NodePort range // KUBERNETES_LISTEN_PORT - starting port for Kubernetes in order to use NodePort range
@@ -164,6 +166,13 @@ func ServerJoin(networkSettings *models.Network) (models.Node, error) {
return *node, nil return *node, nil
} }
// EnterpriseCheck - Runs enterprise functions if presented
func EnterpriseCheck() {
for _, check := range EnterpriseCheckFuncs {
check.(func())()
}
}
// ServerUpdate - updates the server // ServerUpdate - updates the server
// replaces legacy Checkin code // replaces legacy Checkin code
func ServerUpdate(serverNode *models.Node, ifaceDelta bool) error { func ServerUpdate(serverNode *models.Node, ifaceDelta bool) error {

View File

@@ -6,6 +6,21 @@ import (
"github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/database"
) )
var (
// Node_Limit - dummy var for community
Node_Limit = 1000000000
// Networks_Limit - dummy var for community
Networks_Limit = 1000000000
// Users_Limit - dummy var for community
Users_Limit = 1000000000
// Clients_Limit - dummy var for community
Clients_Limit = 1000000000
// Free_Tier - specifies if free tier
Free_Tier = false
// Is_EE - specifies if enterprise
Is_EE = false
)
// constant for database key for storing server ids // constant for database key for storing server ids
const server_id_key = "nm-server-id" const server_id_key = "nm-server-id"

View File

@@ -4,7 +4,10 @@ import (
"encoding/json" "encoding/json"
"github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic/pro"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/models/promodels"
) )
// GetUser - gets a user // GetUser - gets a user
@@ -20,3 +23,57 @@ func GetUser(username string) (models.User, error) {
} }
return user, err return user, err
} }
// GetGroupUsers - gets users in a group
func GetGroupUsers(group string) ([]models.ReturnUser, error) {
var returnUsers []models.ReturnUser
users, err := GetUsers()
if err != nil {
return returnUsers, err
}
for _, user := range users {
if StringSliceContains(user.Groups, group) {
users = append(users, user)
}
}
return users, err
}
// == PRO ==
// InitializeNetUsers - intializes network users for all users/networks
func InitializeNetUsers(network *models.Network) error {
// == add all current users to network as network users ==
currentUsers, err := GetUsers()
if err != nil {
return err
}
for i := range currentUsers { // add all users to given network
newUser := promodels.NetworkUser{
ID: promodels.NetworkUserID(currentUsers[i].UserName),
Clients: []string{},
Nodes: []string{},
AccessLevel: pro.NO_ACCESS,
ClientLimit: 0,
NodeLimit: 0,
}
if pro.IsUserAllowed(network, currentUsers[i].UserName, currentUsers[i].Groups) {
newUser.AccessLevel = network.ProSettings.DefaultAccessLevel
newUser.ClientLimit = network.ProSettings.DefaultUserClientLimit
newUser.NodeLimit = network.ProSettings.DefaultUserNodeLimit
}
if err = pro.CreateNetworkUser(network, &newUser); err != nil {
logger.Log(0, "failed to add network user settings to user", string(newUser.ID), "on network", network.NetID)
}
}
return nil
}
// SetUserDefaults - sets the defaults of a user to avoid empty fields
func SetUserDefaults(user *models.User) {
if user.Groups == nil {
user.Groups = []string{pro.DEFAULT_ALLOWED_GROUPS}
}
}

View File

@@ -203,3 +203,18 @@ func getNetworkProtocols(cidrs []string) (bool, bool) {
} }
return ipv4, ipv6 return ipv4, ipv6
} }
// StringDifference - returns the elements in `a` that aren't in `b`.
func StringDifference(a, b []string) []string {
mb := make(map[string]struct{}, len(b))
for _, x := range b {
mb[x] = struct{}{}
}
var diff []string
for _, x := range a {
if _, found := mb[x]; !found {
diff = append(diff, x)
}
}
return diff
}

15
main.go
View File

@@ -1,3 +1,4 @@
// -build ee
package main package main
import ( import (
@@ -22,6 +23,7 @@ import (
"github.com/gravitl/netmaker/functions" "github.com/gravitl/netmaker/functions"
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/pro"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/mq" "github.com/gravitl/netmaker/mq"
"github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/netclient/ncutils"
@@ -36,10 +38,10 @@ var version = "dev"
func main() { func main() {
absoluteConfigPath := flag.String("c", "", "absolute path to configuration file") absoluteConfigPath := flag.String("c", "", "absolute path to configuration file")
flag.Parse() flag.Parse()
setupConfig(*absoluteConfigPath) setupConfig(*absoluteConfigPath)
servercfg.SetVersion(version) servercfg.SetVersion(version)
fmt.Println(models.RetrieveLogo()) // print the logo fmt.Println(models.RetrieveLogo()) // print the logo
// fmt.Println(models.ProLogo())
initialize() // initial db and acls; gen cert if required initialize() // initial db and acls; gen cert if required
setGarbageCollection() setGarbageCollection()
setVerbosity() setVerbosity()
@@ -73,14 +75,23 @@ func initialize() { // Client Mode Prereq Check
logger.FatalLog("Error connecting to database") logger.FatalLog("Error connecting to database")
} }
logger.Log(0, "database successfully connected") logger.Log(0, "database successfully connected")
logic.SetJWTSecret()
if err = logic.AddServerIDIfNotPresent(); err != nil { if err = logic.AddServerIDIfNotPresent(); err != nil {
logger.Log(1, "failed to save server ID") logger.Log(1, "failed to save server ID")
} }
logic.SetJWTSecret()
if err = pro.InitializeGroups(); err != nil {
logger.Log(0, "could not initialize default user group, \"*\"")
}
err = logic.TimerCheckpoint() err = logic.TimerCheckpoint()
if err != nil { if err != nil {
logger.Log(1, "Timer error occurred: ", err.Error()) logger.Log(1, "Timer error occurred: ", err.Error())
} }
logic.EnterpriseCheck()
var authProvider = auth.InitializeAuthProvider() var authProvider = auth.InitializeAuthProvider()
if authProvider != "" { if authProvider != "" {
logger.Log(0, "OAuth provider,", authProvider+",", "initialized") logger.Log(0, "OAuth provider,", authProvider+",", "initialized")

12
main_ee.go Normal file
View File

@@ -0,0 +1,12 @@
//go:build ee
// +build ee
package main
import (
"github.com/gravitl/netmaker/ee"
)
func init() {
ee.InitEE()
}

View File

@@ -13,4 +13,5 @@ type ExtClient struct {
IngressGatewayEndpoint string `json:"ingressgatewayendpoint" bson:"ingressgatewayendpoint"` IngressGatewayEndpoint string `json:"ingressgatewayendpoint" bson:"ingressgatewayendpoint"`
LastModified int64 `json:"lastmodified" bson:"lastmodified"` LastModified int64 `json:"lastmodified" bson:"lastmodified"`
Enabled bool `json:"enabled" bson:"enabled"` Enabled bool `json:"enabled" bson:"enabled"`
OwnerID string `json:"ownerid" bson:"ownerid"`
} }

45
models/metrics.go Normal file
View File

@@ -0,0 +1,45 @@
package models
import "time"
// Metrics - metrics struct
type Metrics struct {
Network string `json:"network" bson:"network" yaml:"network"`
NodeID string `json:"node_id" bson:"node_id" yaml:"node_id"`
NodeName string `json:"node_name" bson:"node_name" yaml:"node_name"`
IsServer string `json:"isserver" bson:"isserver" yaml:"isserver" validate:"checkyesorno"`
Connectivity map[string]Metric `json:"connectivity" bson:"connectivity" yaml:"connectivity"`
}
// Metric - holds a metric for data between nodes
type Metric struct {
NodeName string `json:"node_name" bson:"node_name" yaml:"node_name"`
IsServer string `json:"isserver" bson:"isserver" yaml:"isserver" validate:"checkyesorno"`
Uptime int64 `json:"uptime" bson:"uptime" yaml:"uptime"`
TotalTime int64 `json:"totaltime" bson:"totaltime" yaml:"totaltime"`
Latency int64 `json:"latency" bson:"latency" yaml:"latency"`
TotalReceived int64 `json:"totalreceived" bson:"totalreceived" yaml:"totalreceived"`
TotalSent int64 `json:"totalsent" bson:"totalsent" yaml:"totalsent"`
ActualUptime time.Duration `json:"actualuptime" bson:"actualuptime" yaml:"actualuptime"`
PercentUp float64 `json:"percentup" bson:"percentup" yaml:"percentup"`
Connected bool `json:"connected" bson:"connected" yaml:"connected"`
}
// IDandAddr - struct to hold ID and primary Address
type IDandAddr struct {
ID string `json:"id" bson:"id" yaml:"id"`
Address string `json:"address" bson:"address" yaml:"address"`
Name string `json:"name" bson:"name" yaml:"name"`
IsServer string `json:"isserver" bson:"isserver" yaml:"isserver" validate:"checkyesorno"`
}
// PeerMap - peer map for ids and addresses in metrics
type PeerMap map[string]IDandAddr
// MetricsMap - map for holding multiple metrics in memory
type MetricsMap map[string]Metrics
// NetworkMetrics - metrics model for all nodes in a network
type NetworkMetrics struct {
Nodes MetricsMap `json:"nodes" bson:"nodes" yaml:"nodes"`
}

View File

@@ -9,6 +9,7 @@ type PeerUpdate struct {
ServerAddrs []ServerAddr `json:"serveraddrs" bson:"serveraddrs" yaml:"serveraddrs"` ServerAddrs []ServerAddr `json:"serveraddrs" bson:"serveraddrs" yaml:"serveraddrs"`
Peers []wgtypes.PeerConfig `json:"peers" bson:"peers" yaml:"peers"` Peers []wgtypes.PeerConfig `json:"peers" bson:"peers" yaml:"peers"`
DNS string `json:"dns" bson:"dns" yaml:"dns"` DNS string `json:"dns" bson:"dns" yaml:"dns"`
PeerIDs PeerMap `json:"peerids" bson:"peerids" yaml:"peerids"`
} }
// KeyUpdate - key update struct // KeyUpdate - key update struct

View File

@@ -231,6 +231,8 @@ var SMALL_NAMES = []string{
"cold", "cold",
} }
var logoString = retrieveLogo()
// GenerateNodeName - generates a random node name // GenerateNodeName - generates a random node name
func GenerateNodeName() string { func GenerateNodeName() string {
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
@@ -239,6 +241,15 @@ func GenerateNodeName() string {
// RetrieveLogo - retrieves the ascii art logo for Netmaker // RetrieveLogo - retrieves the ascii art logo for Netmaker
func RetrieveLogo() string { func RetrieveLogo() string {
return logoString
}
// SetLogo - sets the logo ascii art
func SetLogo(logo string) {
logoString = logo
}
func retrieveLogo() string {
return ` return `
__ __ ______ ______ __ __ ______ __ __ ______ ______ __ __ ______ ______ __ __ ______ __ __ ______ ______
/\ "-.\ \ /\ ___\ /\__ _\ /\ "-./ \ /\ __ \ /\ \/ / /\ ___\ /\ == \ /\ "-.\ \ /\ ___\ /\__ _\ /\ "-./ \ /\ __ \ /\ \/ / /\ ___\ /\ == \

View File

@@ -2,6 +2,8 @@ package models
import ( import (
"time" "time"
"github.com/gravitl/netmaker/models/promodels"
) )
// Network Struct - contains info for a given unique network // Network Struct - contains info for a given unique network
@@ -29,6 +31,7 @@ type Network struct {
DefaultExtClientDNS string `json:"defaultextclientdns" bson:"defaultextclientdns"` DefaultExtClientDNS string `json:"defaultextclientdns" bson:"defaultextclientdns"`
DefaultMTU int32 `json:"defaultmtu" bson:"defaultmtu"` DefaultMTU int32 `json:"defaultmtu" bson:"defaultmtu"`
DefaultACL string `json:"defaultacl" bson:"defaultacl" yaml:"defaultacl" validate:"checkyesorno"` DefaultACL string `json:"defaultacl" bson:"defaultacl" yaml:"defaultacl" validate:"checkyesorno"`
ProSettings *promodels.ProNetwork `json:"prosettings,omitempty" bson:"prosettings,omitempty" yaml:"prosettings,omitempty"`
} }
// SaveData - sensitive fields of a network that should be kept the same // SaveData - sensitive fields of a network that should be kept the same

View File

@@ -101,6 +101,9 @@ type Node struct {
FirewallInUse string `json:"firewallinuse" bson:"firewallinuse" yaml:"firewallinuse"` FirewallInUse string `json:"firewallinuse" bson:"firewallinuse" yaml:"firewallinuse"`
InternetGateway string `json:"internetgateway" bson:"internetgateway" yaml:"internetgateway"` InternetGateway string `json:"internetgateway" bson:"internetgateway" yaml:"internetgateway"`
Connected string `json:"connected" bson:"connected" yaml:"connected" validate:"checkyesorno"` Connected string `json:"connected" bson:"connected" yaml:"connected" validate:"checkyesorno"`
// == PRO ==
DefaultACL string `json:"defaultacl,omitempty" bson:"defaultacl,omitempty" yaml:"defaultacl,omitempty" validate:"checkyesornoorunset"`
OwnerID string `json:"ownerid,omitempty" bson:"ownerid,omitempty" yaml:"ownerid,omitempty"`
} }
// NodesArray - used for node sorting // NodesArray - used for node sorting
@@ -139,6 +142,13 @@ func (node *Node) SetDefaultConnected() {
} }
} }
// Node.SetDefaultACL
func (node *Node) SetDefaultACL() {
if node.DefaultACL == "" {
node.DefaultACL = "yes"
}
}
// Node.SetDefaultMTU - sets default MTU of a node // Node.SetDefaultMTU - sets default MTU of a node
func (node *Node) SetDefaultMTU() { func (node *Node) SetDefaultMTU() {
if node.MTU == 0 { if node.MTU == 0 {
@@ -438,6 +448,10 @@ func (newNode *Node) Fill(currentNode *Node) { // TODO add new field for nftable
if newNode.Connected == "" { if newNode.Connected == "" {
newNode.Connected = currentNode.Connected newNode.Connected = currentNode.Connected
} }
if newNode.DefaultACL == "" {
newNode.DefaultACL = currentNode.DefaultACL
}
newNode.TrafficKeys = currentNode.TrafficKeys newNode.TrafficKeys = currentNode.TrafficKeys
} }
@@ -469,3 +483,15 @@ func (node *Node) NameInNodeCharSet() bool {
} }
return true return true
} }
// == PRO ==
// Node.DoesACLAllow - checks if default ACL on node is "yes"
func (node *Node) DoesACLAllow() bool {
return node.DefaultACL == "yes"
}
// Node.DoesACLDeny - checks if default ACL on node is "no"
func (node *Node) DoesACLDeny() bool {
return node.DefaultACL == "no"
}

View File

@@ -0,0 +1,37 @@
package promodels
// NetworkUserID - ID field for a network user
type NetworkUserID string
// NetworkUser - holds fields for a network user
type NetworkUser struct {
AccessLevel int `json:"accesslevel" bson:"accesslevel" yaml:"accesslevel"`
ClientLimit int `json:"clientlimit" bson:"clientlimit" yaml:"clientlimit"`
NodeLimit int `json:"nodelimit" bson:"nodelimit" yaml:"nodelimit"`
ID NetworkUserID `json:"id" bson:"id" yaml:"id"`
Clients []string `json:"clients" bson:"clients" yaml:"clients"`
Nodes []string `json:"nodes" bson:"nodes" yaml:"nodes"`
}
// NetworkUserMap - map of network users
type NetworkUserMap map[NetworkUserID]NetworkUser
// NetworkUserMap.Delete - deletes a network user struct from a given map in memory
func (N NetworkUserMap) Delete(ID NetworkUserID) {
delete(N, ID)
}
// NetworkUserMap.Add - adds a network user struct to given network user map in memory
func (N NetworkUserMap) Add(User *NetworkUser) {
N[User.ID] = *User
}
// SetDefaults - adds the defaults to network user
func (U *NetworkUser) SetDefaults() {
if U.Clients == nil {
U.Clients = []string{}
}
if U.Nodes == nil {
U.Nodes = []string{}
}
}

19
models/promodels/pro.go Normal file
View File

@@ -0,0 +1,19 @@
package promodels
// ProNetwork - struct for all pro Network related fields
type ProNetwork struct {
DefaultAccessLevel int `json:"defaultaccesslevel" bson:"defaultaccesslevel" yaml:"defaultaccesslevel"`
DefaultUserNodeLimit int `json:"defaultusernodelimit" bson:"defaultusernodelimit" yaml:"defaultusernodelimit"`
DefaultUserClientLimit int `json:"defaultuserclientlimit" bson:"defaultuserclientlimit" yaml:"defaultuserclientlimit"`
AllowedUsers []string `json:"allowedusers" bson:"allowedusers" yaml:"allowedusers"`
AllowedGroups []string `json:"allowedgroups" bson:"allowedgroups" yaml:"allowedgroups"`
}
// LoginMsg - login message struct for nodes to join via SSO login
// Need to change mac to public key for tighter verification ?
type LoginMsg struct {
Mac string `json:"mac"`
Network string `json:"network"`
User string `json:"user,omitempty"`
Password string `json:"password,omitempty"`
}

View File

@@ -0,0 +1,9 @@
package promodels
type Void struct{}
// UserGroupName - string representing a group name
type UserGroupName string
// UserGroups - groups type, holds group names
type UserGroups map[UserGroupName]Void

View File

@@ -2,6 +2,7 @@ package models
import ( import (
"strings" "strings"
"time"
jwt "github.com/golang-jwt/jwt/v4" jwt "github.com/golang-jwt/jwt/v4"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -28,6 +29,7 @@ type User struct {
Password string `json:"password" bson:"password" validate:"required,min=5"` Password string `json:"password" bson:"password" validate:"required,min=5"`
Networks []string `json:"networks" bson:"networks"` Networks []string `json:"networks" bson:"networks"`
IsAdmin bool `json:"isadmin" bson:"isadmin"` IsAdmin bool `json:"isadmin" bson:"isadmin"`
Groups []string `json:"groups" bson:"groups" yaml:"groups"`
} }
// ReturnUser - return user struct // ReturnUser - return user struct
@@ -35,6 +37,7 @@ type ReturnUser struct {
UserName string `json:"username" bson:"username"` UserName string `json:"username" bson:"username"`
Networks []string `json:"networks" bson:"networks"` Networks []string `json:"networks" bson:"networks"`
IsAdmin bool `json:"isadmin" bson:"isadmin"` IsAdmin bool `json:"isadmin" bson:"isadmin"`
Groups []string `json:"groups" bson:"groups"`
} }
// UserAuthParams - user auth params struct // UserAuthParams - user auth params struct
@@ -48,6 +51,7 @@ type UserClaims struct {
IsAdmin bool IsAdmin bool
UserName string UserName string
Networks []string Networks []string
Groups []string
jwt.RegisteredClaims jwt.RegisteredClaims
} }
@@ -95,10 +99,11 @@ type SuccessResponse struct {
// AccessKey - access key struct // AccessKey - access key struct
type AccessKey struct { type AccessKey struct {
Name string `json:"name" bson:"name" validate:"omitempty,max=20"` Name string `json:"name" bson:"name" validate:"omitempty,max=345"`
Value string `json:"value" bson:"value" validate:"omitempty,alphanum,max=16"` Value string `json:"value" bson:"value" validate:"omitempty,alphanum,max=16"`
AccessString string `json:"accessstring" bson:"accessstring"` AccessString string `json:"accessstring" bson:"accessstring"`
Uses int `json:"uses" bson:"uses" validate:"numeric,min=0"` Uses int `json:"uses" bson:"uses" validate:"numeric,min=0"`
Expiration *time.Time `json:"expiration" bson:"expiration"`
} }
// DisplayKey - what is displayed for key // DisplayKey - what is displayed for key
@@ -200,6 +205,7 @@ type NodeGet struct {
Node Node `json:"node" bson:"node" yaml:"node"` Node Node `json:"node" bson:"node" yaml:"node"`
Peers []wgtypes.PeerConfig `json:"peers" bson:"peers" yaml:"peers"` Peers []wgtypes.PeerConfig `json:"peers" bson:"peers" yaml:"peers"`
ServerConfig ServerConfig `json:"serverconfig" bson:"serverconfig" yaml:"serverconfig"` ServerConfig ServerConfig `json:"serverconfig" bson:"serverconfig" yaml:"serverconfig"`
PeerIDs PeerMap `json:"peerids,omitempty" bson:"peerids,omitempty" yaml:"peerids,omitempty"`
} }
// ServerConfig - struct for dealing with the server information for a netclient // ServerConfig - struct for dealing with the server information for a netclient

View File

@@ -2,6 +2,8 @@ package mq
import ( import (
"encoding/json" "encoding/json"
"fmt"
"time"
mqtt "github.com/eclipse/paho.mqtt.golang" mqtt "github.com/eclipse/paho.mqtt.golang"
"github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/database"
@@ -9,6 +11,7 @@ import (
"github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/netclient/ncutils"
"github.com/gravitl/netmaker/servercfg"
) )
// DefaultHandler default message queue handler -- NOT USED // DefaultHandler default message queue handler -- NOT USED
@@ -93,6 +96,50 @@ func UpdateNode(client mqtt.Client, msg mqtt.Message) {
}() }()
} }
// UpdateMetrics message Handler -- handles updates from client nodes for metrics
func UpdateMetrics(client mqtt.Client, msg mqtt.Message) {
if logic.Is_EE {
go func() {
id, err := getID(msg.Topic())
if err != nil {
logger.Log(1, "error getting node.ID sent on ", msg.Topic(), err.Error())
return
}
currentNode, err := logic.GetNodeByID(id)
if err != nil {
logger.Log(1, "error getting node ", id, err.Error())
return
}
decrypted, decryptErr := decryptMsg(&currentNode, msg.Payload())
if decryptErr != nil {
logger.Log(1, "failed to decrypt message for node ", id, decryptErr.Error())
return
}
var newMetrics models.Metrics
if err := json.Unmarshal(decrypted, &newMetrics); err != nil {
logger.Log(1, "error unmarshaling payload ", err.Error())
return
}
updateNodeMetrics(&currentNode, &newMetrics)
if err = logic.UpdateMetrics(id, &newMetrics); err != nil {
logger.Log(1, "faield to update node metrics", id, currentNode.Name, err.Error())
return
}
if servercfg.IsMetricsExporter() {
if err := pushMetricsToExporter(newMetrics); err != nil {
logger.Log(2, fmt.Sprintf("failed to push node: [%s] metrics to exporter, err: %v",
currentNode.Name, err))
}
}
logger.Log(1, "updated node metrics", id, currentNode.Name)
}()
}
}
// ClientPeerUpdate message handler -- handles updating peers after signal from client nodes // ClientPeerUpdate message handler -- handles updating peers after signal from client nodes
func ClientPeerUpdate(client mqtt.Client, msg mqtt.Message) { func ClientPeerUpdate(client mqtt.Client, msg mqtt.Message) {
go func() { go func() {
@@ -146,3 +193,46 @@ func updateNodePeers(currentNode *models.Node) {
} }
} }
} }
func updateNodeMetrics(currentNode *models.Node, newMetrics *models.Metrics) {
oldMetrics, err := logic.GetMetrics(currentNode.ID)
if err != nil {
logger.Log(1, "error finding old metrics for node", currentNode.ID, currentNode.Name)
return
}
var attachedClients []models.ExtClient
if currentNode.IsIngressGateway == "yes" {
clients, err := logic.GetExtClientsByID(currentNode.ID, currentNode.Network)
if err == nil {
attachedClients = clients
}
}
if len(attachedClients) > 0 {
// associate ext clients with IDs
for i := range attachedClients {
extMetric := newMetrics.Connectivity[attachedClients[i].PublicKey]
delete(newMetrics.Connectivity, attachedClients[i].PublicKey)
if extMetric.Connected { // add ext client metrics
newMetrics.Connectivity[attachedClients[i].ClientID] = extMetric
}
}
}
// run through metrics for each peer
for k := range newMetrics.Connectivity {
currMetric := newMetrics.Connectivity[k]
oldMetric := oldMetrics.Connectivity[k]
currMetric.TotalTime += oldMetric.TotalTime
currMetric.Uptime += oldMetric.Uptime // get the total uptime for this connection
currMetric.PercentUp = 100.0 * (float64(currMetric.Uptime) / float64(currMetric.TotalTime))
totalUpMinutes := currMetric.Uptime * 5
currMetric.ActualUptime = time.Duration(totalUpMinutes) * time.Minute
delete(oldMetrics.Connectivity, k) // remove from old data
newMetrics.Connectivity[k] = currMetric
}
for k := range oldMetrics.Connectivity { // cleanup any left over data, self healing
delete(newMetrics.Connectivity, k)
}
}

View File

@@ -51,6 +51,10 @@ func SetupMQTT() {
client.Disconnect(240) client.Disconnect(240)
logger.Log(0, "node client subscription failed") logger.Log(0, "node client subscription failed")
} }
if token := client.Subscribe("metrics/#", 0, mqtt.MessageHandler(UpdateMetrics)); token.WaitTimeout(MQ_TIMEOUT*time.Second) && token.Error() != nil {
client.Disconnect(240)
logger.Log(0, "node metrics subscription failed")
}
opts.SetOrderMatters(true) opts.SetOrderMatters(true)
opts.SetResumeSubs(true) opts.SetResumeSubs(true)

View File

@@ -2,11 +2,13 @@ package mq
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"time"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/metrics"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/servercfg"
"github.com/gravitl/netmaker/serverctl" "github.com/gravitl/netmaker/serverctl"
@@ -105,6 +107,11 @@ func NodeUpdate(node *models.Node) error {
// sendPeers - retrieve networks, send peer ports to all peers // sendPeers - retrieve networks, send peer ports to all peers
func sendPeers() { func sendPeers() {
networks, err := logic.GetNetworks()
if err != nil {
logger.Log(1, "error retrieving networks for keepalive", err.Error())
}
var force bool var force bool
peer_force_send++ peer_force_send++
if peer_force_send == 5 { if peer_force_send == 5 {
@@ -121,10 +128,8 @@ func sendPeers() {
if err != nil { if err != nil {
logger.Log(3, "error occurred on timer,", err.Error()) logger.Log(3, "error occurred on timer,", err.Error())
} }
}
networks, err := logic.GetNetworks() collectServerMetrics(networks[:])
if err != nil && !database.IsEmptyRecord(err) {
logger.Log(1, "error retrieving networks for keepalive", err.Error())
} }
for _, network := range networks { for _, network := range networks {
@@ -176,3 +181,66 @@ func ServerStartNotify() error {
} }
return nil return nil
} }
// function to collect and store metrics for server nodes
func collectServerMetrics(networks []models.Network) {
if !logic.Is_EE {
return
}
if len(networks) > 0 {
for i := range networks {
currentNetworkNodes, err := logic.GetNetworkNodes(networks[i].NetID)
if err != nil {
continue
}
currentServerNodes := logic.GetServerNodes(networks[i].NetID)
if len(currentServerNodes) > 0 {
for i := range currentServerNodes {
if logic.IsLocalServer(&currentServerNodes[i]) {
serverMetrics := logic.CollectServerMetrics(currentServerNodes[i].ID, currentNetworkNodes)
if serverMetrics != nil {
serverMetrics.NodeName = currentServerNodes[i].Name
serverMetrics.NodeID = currentServerNodes[i].ID
serverMetrics.IsServer = "yes"
serverMetrics.Network = currentServerNodes[i].Network
if err = metrics.GetExchangedBytesForNode(&currentServerNodes[i], serverMetrics); err != nil {
logger.Log(1, fmt.Sprintf("failed to update exchanged bytes info for server: %s, err: %v",
currentServerNodes[i].Name, err))
}
updateNodeMetrics(&currentServerNodes[i], serverMetrics)
if err = logic.UpdateMetrics(currentServerNodes[i].ID, serverMetrics); err != nil {
logger.Log(1, "failed to update metrics for server node", currentServerNodes[i].ID)
}
if servercfg.IsMetricsExporter() {
logger.Log(2, "-------------> SERVER METRICS: ", fmt.Sprintf("%+v", serverMetrics))
if err := pushMetricsToExporter(*serverMetrics); err != nil {
logger.Log(2, "failed to push server metrics to exporter: ", err.Error())
}
}
}
}
}
}
}
}
}
func pushMetricsToExporter(metrics models.Metrics) error {
logger.Log(2, "----> Pushing metrics to exporter")
data, err := json.Marshal(metrics)
if err != nil {
return errors.New("failed to marshal metrics: " + err.Error())
}
if token := mqclient.Publish("metrics_exporter", 0, true, data); !token.WaitTimeout(MQ_TIMEOUT*time.Second) || token.Error() != nil {
var err error
if token.Error() == nil {
err = errors.New("connection timeout")
} else {
err = token.Error()
}
return err
}
return nil
}

View File

@@ -133,6 +133,20 @@ func GetFlags(hostname string) []cli.Flag {
Value: "", Value: "",
Usage: "Access Token for signing up machine with Netmaker server during initial 'add'.", Usage: "Access Token for signing up machine with Netmaker server during initial 'add'.",
}, },
&cli.StringFlag{
Name: "login-server",
Aliases: []string{"l"},
EnvVars: []string{"LOGIN_SERVER"},
Value: "",
Usage: "Login server URL, use it for the Single Sign-on along with the network parameter",
},
&cli.StringFlag{
Name: "user",
Aliases: []string{"u"},
EnvVars: []string{"USER_NAME"},
Value: "",
Usage: "User name provided upon joins if joining over basic auth is desired.",
},
&cli.StringFlag{ &cli.StringFlag{
Name: "localrange", Name: "localrange",
EnvVars: []string{"NETCLIENT_LOCALRANGE"}, EnvVars: []string{"NETCLIENT_LOCALRANGE"},

View File

@@ -3,6 +3,7 @@ package command
import ( import (
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
"errors"
"fmt" "fmt"
"strings" "strings"
@@ -18,6 +19,28 @@ import (
func Join(cfg *config.ClientConfig, privateKey string) error { func Join(cfg *config.ClientConfig, privateKey string) error {
var err error var err error
//join network //join network
if cfg.SsoServer != "" {
// User wants to get access key from the OIDC server
// Do that before the Joining Network flow by performing the end point auth flow
// if performed successfully an access key is obtained from the server and then we
// proceed with the usual flow 'pretending' that user is feeded us with an access token
if len(cfg.Network) == 0 || cfg.Network == "all" {
return fmt.Errorf("no network provided. Specify network with \"-n <net name>\"")
}
logger.Log(1, "Logging into %s via:", cfg.Network, cfg.SsoServer)
err = functions.JoinViaSSo(cfg, privateKey)
if err != nil {
logger.Log(0, "Join via OIDC failed: ", err.Error())
return err
}
if cfg.AccessKey == "" {
return errors.New("failed to get access key")
}
logger.Log(1, "Got an access key to ", cfg.Network, " via:", cfg.SsoServer)
}
logger.Log(1, "Joining network: ", cfg.Network)
err = functions.JoinNetwork(cfg, privateKey) err = functions.JoinNetwork(cfg, privateKey)
if err != nil { if err != nil {
if !strings.Contains(err.Error(), "ALREADY_INSTALLED") { if !strings.Contains(err.Error(), "ALREADY_INSTALLED") {

View File

@@ -32,6 +32,7 @@ type ClientConfig struct {
OperatingSystem string `yaml:"operatingsystem"` OperatingSystem string `yaml:"operatingsystem"`
AccessKey string `yaml:"accesskey"` AccessKey string `yaml:"accesskey"`
PublicIPService string `yaml:"publicipservice"` PublicIPService string `yaml:"publicipservice"`
SsoServer string `yaml:"sso"`
} }
// RegisterRequest - struct for registation with netmaker server // RegisterRequest - struct for registation with netmaker server
@@ -239,6 +240,11 @@ func GetCLIConfig(c *cli.Context) (ClientConfig, string, error) {
if c.String("apiserver") != "" { if c.String("apiserver") != "" {
cfg.Server.API = c.String("apiserver") cfg.Server.API = c.String("apiserver")
} }
} else if c.String("login-server") != "" {
cfg.SsoServer = c.String("login-server")
cfg.Network = c.String("network")
cfg.Node.Network = c.String("network")
global_settings.User = c.String("user")
} else { } else {
cfg.AccessKey = c.String("key") cfg.AccessKey = c.String("key")
cfg.Network = c.String("network") cfg.Network = c.String("network")

View File

@@ -8,21 +8,175 @@ import (
"io" "io"
"log" "log"
"net/http" "net/http"
"os"
"os/signal"
"runtime" "runtime"
"strings"
"syscall"
"github.com/gorilla/websocket"
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/models/promodels"
"github.com/gravitl/netmaker/netclient/auth" "github.com/gravitl/netmaker/netclient/auth"
"github.com/gravitl/netmaker/netclient/config" "github.com/gravitl/netmaker/netclient/config"
"github.com/gravitl/netmaker/netclient/daemon" "github.com/gravitl/netmaker/netclient/daemon"
"github.com/gravitl/netmaker/netclient/global_settings"
"github.com/gravitl/netmaker/netclient/local" "github.com/gravitl/netmaker/netclient/local"
"github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/netclient/ncutils"
"github.com/gravitl/netmaker/netclient/wireguard" "github.com/gravitl/netmaker/netclient/wireguard"
"golang.org/x/crypto/nacl/box" "golang.org/x/crypto/nacl/box"
"golang.org/x/term"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// JoinViaSso - Handles the Single Sign-On flow on the end point VPN client side
// Contacts the server provided by the user (and thus specified in cfg.SsoServer)
// get the URL to authenticate with a provider and shows the user the URL.
// Then waits for user to authenticate with the URL.
// Upon user successful auth flow finished - server should return access token to the requested network
// Otherwise the error message is sent which can be displayed to the user
func JoinViaSSo(cfg *config.ClientConfig, privateKey string) error {
// User must tell us which network he is joining
if cfg.Node.Network == "" {
return errors.New("no network provided")
}
// Prepare a channel for interrupt
// Channel to listen for interrupt signal to terminate gracefully
interrupt := make(chan os.Signal, 1)
// Notify the interrupt channel for SIGINT
signal.Notify(interrupt, os.Interrupt)
// Web Socket is used, construct the URL accordingly ...
socketUrl := fmt.Sprintf("wss://%s/api/oauth/node-handler", cfg.SsoServer)
// Dial the netmaker server controller
conn, _, err := websocket.DefaultDialer.Dial(socketUrl, nil)
if err != nil {
logger.Log(0, fmt.Sprintf("error connecting to %s : %s", cfg.Server.API, err.Error()))
return err
}
// Don't forget to close when finished
defer conn.Close()
// Find and set node MacAddress
if cfg.Node.MacAddress == "" {
macs, err := ncutils.GetMacAddr()
if err != nil {
//if macaddress can't be found set to random string
cfg.Node.MacAddress = ncutils.MakeRandomString(18)
} else {
cfg.Node.MacAddress = macs[0]
}
}
var loginMsg promodels.LoginMsg
loginMsg.Mac = cfg.Node.MacAddress
loginMsg.Network = cfg.Node.Network
if global_settings.User != "" {
fmt.Printf("Continuing with user, %s.\nPlease input password:\n", global_settings.User)
pass, err := term.ReadPassword(int(syscall.Stdin))
if err != nil || string(pass) == "" {
logger.FatalLog("no password provided, exiting")
}
loginMsg.User = global_settings.User
loginMsg.Password = string(pass)
}
msgTx, err := json.Marshal(loginMsg)
if err != nil {
logger.Log(0, fmt.Sprintf("failed to marshal message %+v", loginMsg))
return err
}
err = conn.WriteMessage(websocket.TextMessage, []byte(msgTx))
if err != nil {
logger.FatalLog("Error during writing to websocket:", err.Error())
return err
}
// if user provided, server will handle authentication
if loginMsg.User == "" {
// We are going to get instructions on how to authenticate
// Wait to receive something from server
_, msg, err := conn.ReadMessage()
if err != nil {
log.Println("Error in receive:", err)
return err
}
// Print message from the netmaker controller to the user
fmt.Printf("Please visit:\n %s \n to authenticate", string(msg))
}
// Now the user is authenticating and we need to block until received
// An answer from the server.
// Server waits ~5 min - If takes too long timeout will be triggered by the server
done := make(chan struct{})
defer close(done)
// Following code will run in a separate go routine
// it reads a message from the server which either contains 'AccessToken:' string or not
// if not - then it contains an Error to display.
// if yes - then AccessToken is to be used to proceed joining the network
go func() {
for {
msgType, msg, err := conn.ReadMessage()
if err != nil {
// Error reading a message from the server
if !strings.Contains(err.Error(), "normal") {
logger.Log(0, "read:", err.Error())
}
return
}
if msgType == websocket.CloseMessage {
logger.Log(1, "received close message from server")
done <- struct{}{}
return
}
// Get the access token from the response
if strings.Contains(string(msg), "AccessToken: ") {
// Access was granted
rxToken := strings.TrimPrefix(string(msg), "AccessToken: ")
accesstoken, err := config.ParseAccessToken(rxToken)
if err != nil {
logger.Log(0, fmt.Sprintf("failed to parse received access token %s,err=%s\n", accesstoken, err.Error()))
return
}
cfg.Network = accesstoken.ClientConfig.Network
cfg.Node.Network = accesstoken.ClientConfig.Network
cfg.AccessKey = accesstoken.ClientConfig.Key
cfg.Node.LocalRange = accesstoken.ClientConfig.LocalRange
//cfg.Server.Server = accesstoken.ServerConfig.Server
cfg.Server.API = accesstoken.APIConnString
} else {
// Access was not granted. Display a message from the server
logger.Log(0, "Message from server:", string(msg))
cfg.AccessKey = ""
return
}
}
}()
for {
select {
case <-done:
logger.Log(1, "finished")
return nil
case <-interrupt:
logger.Log(0, "interrupt received, closing connection")
// Cleanly close the connection by sending a close message and then
// waiting (with timeout) for the server to close the connection.
err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
logger.Log(0, "write close:", err.Error())
return err
}
return nil
}
}
}
// JoinNetwork - helps a client join a network // JoinNetwork - helps a client join a network
func JoinNetwork(cfg *config.ClientConfig, privateKey string) error { func JoinNetwork(cfg *config.ClientConfig, privateKey string) error {
if cfg.Node.Network == "" { if cfg.Node.Network == "" {

View File

@@ -5,7 +5,9 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"net/http"
"os" "os"
"strconv" "strconv"
"sync" "sync"
@@ -13,6 +15,7 @@ import (
"github.com/cloverstd/tcping/ping" "github.com/cloverstd/tcping/ping"
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic/metrics"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/netclient/auth" "github.com/gravitl/netmaker/netclient/auth"
"github.com/gravitl/netmaker/netclient/config" "github.com/gravitl/netmaker/netclient/config"
@@ -20,13 +23,16 @@ import (
"github.com/gravitl/netmaker/tls" "github.com/gravitl/netmaker/tls"
) )
var metricsCache = new(sync.Map)
// Checkin -- go routine that checks for public or local ip changes, publishes changes // Checkin -- go routine that checks for public or local ip changes, publishes changes
// //
// if there are no updates, simply "pings" the server as a checkin // if there are no updates, simply "pings" the server as a checkin
func Checkin(ctx context.Context, wg *sync.WaitGroup) { func Checkin(ctx context.Context, wg *sync.WaitGroup) {
logger.Log(2, "starting checkin goroutine") logger.Log(2, "starting checkin goroutine")
defer wg.Done() defer wg.Done()
checkin() currentRun := 0
checkin(currentRun)
ticker := time.NewTicker(time.Second * 60) ticker := time.NewTicker(time.Second * 60)
defer ticker.Stop() defer ticker.Stop()
for { for {
@@ -36,12 +42,16 @@ func Checkin(ctx context.Context, wg *sync.WaitGroup) {
return return
//delay should be configuraable -> use cfg.Node.NetworkSettings.DefaultCheckInInterval ?? //delay should be configuraable -> use cfg.Node.NetworkSettings.DefaultCheckInInterval ??
case <-ticker.C: case <-ticker.C:
checkin() currentRun++
checkin(currentRun)
if currentRun >= 5 {
currentRun = 0
}
} }
} }
} }
func checkin() { func checkin(currentRun int) {
networks, _ := ncutils.GetSystemNetworks() networks, _ := ncutils.GetSystemNetworks()
logger.Log(3, "checkin with server(s) for all networks") logger.Log(3, "checkin with server(s) for all networks")
for _, network := range networks { for _, network := range networks {
@@ -104,6 +114,10 @@ func checkin() {
} }
Hello(&nodeCfg) Hello(&nodeCfg)
checkCertExpiry(&nodeCfg) checkCertExpiry(&nodeCfg)
if currentRun >= 5 {
logger.Log(0, "collecting metrics for node", nodeCfg.Node.Name)
publishMetrics(&nodeCfg)
}
} }
} }
@@ -146,6 +160,78 @@ func Hello(nodeCfg *config.ClientConfig) {
} }
} }
// publishMetrics - publishes the metrics of a given nodecfg
func publishMetrics(nodeCfg *config.ClientConfig) {
token, err := Authenticate(nodeCfg)
if err != nil {
logger.Log(1, "failed to authenticate when publishing metrics", err.Error())
return
}
url := fmt.Sprintf("https://%s/api/nodes/%s/%s", nodeCfg.Server.API, nodeCfg.Network, nodeCfg.Node.ID)
response, err := API("", http.MethodGet, url, token)
if err != nil {
logger.Log(1, "failed to read from server during metrics publish", err.Error())
return
}
if response.StatusCode != http.StatusOK {
bytes, err := io.ReadAll(response.Body)
if err != nil {
fmt.Println(err)
}
logger.Log(0, fmt.Sprintf("%s %s", string(bytes), err.Error()))
return
}
defer response.Body.Close()
var nodeGET models.NodeGet
if err := json.NewDecoder(response.Body).Decode(&nodeGET); err != nil {
logger.Log(0, "failed to decode node when running metrics update", err.Error())
return
}
metrics, err := metrics.Collect(nodeCfg.Node.Interface, nodeGET.PeerIDs)
if err != nil {
logger.Log(0, "failed metric collection for node", nodeCfg.Node.Name, err.Error())
}
metrics.Network = nodeCfg.Node.Network
metrics.NodeName = nodeCfg.Node.Name
metrics.NodeID = nodeCfg.Node.ID
metrics.IsServer = "no"
data, err := json.Marshal(metrics)
if err != nil {
logger.Log(0, "something went wrong when marshalling metrics data for node", nodeCfg.Node.Name, err.Error())
}
if err = publish(nodeCfg, fmt.Sprintf("metrics/%s", nodeCfg.Node.ID), data, 1); err != nil {
logger.Log(0, "error occurred during publishing of metrics on node", nodeCfg.Node.Name, err.Error())
logger.Log(0, "aggregating metrics locally until broker connection re-established")
val, ok := metricsCache.Load(nodeCfg.Node.ID)
if !ok {
metricsCache.Store(nodeCfg.Node.ID, data)
} else {
var oldMetrics models.Metrics
err = json.Unmarshal(val.([]byte), &oldMetrics)
if err == nil {
for k := range oldMetrics.Connectivity {
currentMetric := metrics.Connectivity[k]
if currentMetric.Latency == 0 {
currentMetric.Latency = oldMetrics.Connectivity[k].Latency
}
currentMetric.Uptime += oldMetrics.Connectivity[k].Uptime
currentMetric.TotalTime += oldMetrics.Connectivity[k].TotalTime
metrics.Connectivity[k] = currentMetric
}
newData, err := json.Marshal(metrics)
if err == nil {
metricsCache.Store(nodeCfg.Node.ID, newData)
}
}
}
} else {
metricsCache.Delete(nodeCfg.Node.ID)
logger.Log(0, "published metrics for node", nodeCfg.Node.Name)
}
}
// node cfg is required in order to fetch the traffic keys of that node for encryption // node cfg is required in order to fetch the traffic keys of that node for encryption
func publish(nodeCfg *config.ClientConfig, dest string, msg []byte, qos byte) error { func publish(nodeCfg *config.ClientConfig, dest string, msg []byte, qos byte) error {
// setup the keys // setup the keys

View File

@@ -4,3 +4,6 @@ package global_settings
// PublicIPServices - the list of user-specified IP services to use to obtain the node's public IP // PublicIPServices - the list of user-specified IP services to use to obtain the node's public IP
var PublicIPServices map[string]string = make(map[string]string) var PublicIPServices map[string]string = make(map[string]string)
// User - holds a user string for joins when using basic auth
var User string

View File

@@ -187,6 +187,8 @@ EOF
echo "visit https://dashboard.$NETMAKER_BASE_DOMAIN to log in" echo "visit https://dashboard.$NETMAKER_BASE_DOMAIN to log in"
echo "visit https://grafana.$NETMAKER_BASE_DOMAIN to view metrics on grafana dashboard"
echo "visit https://prometheus.$NETMAKER_BASE_DOMAIN to view metrics on prometheus"
sleep 7 sleep 7
setup_mesh() {( set -e setup_mesh() {( set -e

View File

@@ -281,6 +281,21 @@ func IsRestBackend() bool {
return isrest return isrest
} }
// IsMetricsExporter - checks if metrics exporter is on or off
func IsMetricsExporter() bool {
export := false
if os.Getenv("METRICS_EXPORTER") != "" {
if os.Getenv("METRICS_EXPORTER") == "on" {
export = true
}
} else if config.Config.Server.MetricsExporter != "" {
if config.Config.Server.MetricsExporter == "on" {
export = true
}
}
return export
}
// IsAgentBackend - checks if agent backed is on or off // IsAgentBackend - checks if agent backed is on or off
func IsAgentBackend() bool { func IsAgentBackend() bool {
isagent := true isagent := true
@@ -600,3 +615,32 @@ func GetMQServerPort() string {
} }
return port return port
} }
// IsBasicAuthEnabled - checks if basic auth has been configured to be turned off
func IsBasicAuthEnabled() bool {
var enabled = true //default
if os.Getenv("BASIC_AUTH") != "" {
enabled = os.Getenv("BASIC_AUTH") == "yes"
} else if config.Config.Server.BasicAuth != "" {
enabled = config.Config.Server.BasicAuth == "yes"
}
return enabled
}
// GetLicenseKey - retrieves pro license value from env or conf files
func GetLicenseKey() string {
licenseKeyValue := os.Getenv("LICENSE_KEY")
if licenseKeyValue == "" {
licenseKeyValue = config.Config.Server.LicenseValue
}
return licenseKeyValue
}
// GetNetmakerAccountID - get's the associated, Netmaker, account ID to verify ownership
func GetNetmakerAccountID() string {
netmakerAccountID := os.Getenv("NETMAKER_ACCOUNT_ID")
if netmakerAccountID == "" {
netmakerAccountID = config.Config.Server.LicenseValue
}
return netmakerAccountID
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/acls" "github.com/gravitl/netmaker/logic/acls"
"github.com/gravitl/netmaker/logic/acls/nodeacls" "github.com/gravitl/netmaker/logic/acls/nodeacls"
"github.com/gravitl/netmaker/logic/pro"
"github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/netclient/ncutils"
"github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/servercfg"
) )
@@ -45,6 +46,9 @@ func InitServerNetclient() error {
logger.Log(1, "failed pull for network", network.NetID, ", on server node", currentServerNode.ID) logger.Log(1, "failed pull for network", network.NetID, ", on server node", currentServerNode.ID)
} }
} }
if err = logic.InitializeNetUsers(&network); err != nil {
logger.Log(0, "something went wrong syncing usrs on network", network.NetID, "-", err.Error())
}
} }
} }
@@ -86,6 +90,14 @@ func SetDefaults() error {
return err return err
} }
if err := setNetworkDefaults(); err != nil {
return err
}
if err := setUserDefaults(); err != nil {
return err
}
return nil return nil
} }
@@ -108,3 +120,42 @@ func setNodeDefaults() error {
} }
return nil return nil
} }
func setNetworkDefaults() error {
// upgraded systems will not have NetworkUsers's set, which is why we need this function
networks, err := logic.GetNetworks()
if err != nil && !database.IsEmptyRecord(err) {
return err
}
for _, net := range networks {
if err = pro.InitializeNetworkUsers(net.NetID); err != nil {
logger.Log(0, "could not initialize NetworkUsers on network", net.NetID)
}
pro.AddProNetDefaults(&net)
_, _, _, _, _, _, err = logic.UpdateNetwork(&net, &net)
if err != nil {
logger.Log(0, "could not set defaults on network", net.NetID)
}
}
return nil
}
func setUserDefaults() error {
users, err := logic.GetUsers()
if err != nil && !database.IsEmptyRecord(err) {
return err
}
for _, user := range users {
updateUser, err := logic.GetUser(user.UserName)
if err != nil {
logger.Log(0, "could not update user", updateUser.UserName)
}
logic.SetUserDefaults(&updateUser)
copyUser := updateUser
copyUser.Password = ""
if _, err = logic.UpdateUser(copyUser, updateUser); err != nil {
logger.Log(0, "could not update user", updateUser.UserName)
}
}
return nil
}

View File

@@ -11,6 +11,11 @@ func CheckYesOrNo(fl validator.FieldLevel) bool {
return fl.Field().String() == "yes" || fl.Field().String() == "no" return fl.Field().String() == "yes" || fl.Field().String() == "no"
} }
// CheckYesOrNoOrUnset - checks if a field is yes, no or unset
func CheckYesOrNoOrUnset(fl validator.FieldLevel) bool {
return CheckYesOrNo(fl) || fl.Field().String() == "unset"
}
// CheckRegex - check if a struct's field passes regex test // CheckRegex - check if a struct's field passes regex test
func CheckRegex(fl validator.FieldLevel) bool { func CheckRegex(fl validator.FieldLevel) bool {
re := regexp.MustCompile(fl.Param()) re := regexp.MustCompile(fl.Param())