Merge pull request #1266 from gravitl/bugfix_v0.14.5_static_checks

eliminate static check warnings
This commit is contained in:
dcarns
2022-06-30 10:24:09 -04:00
committed by GitHub
7 changed files with 58 additions and 78 deletions

View File

@@ -182,32 +182,32 @@ func TestSecurityCheck(t *testing.T) {
database.InitializeDatabase() database.InitializeDatabase()
os.Setenv("MASTER_KEY", "secretkey") os.Setenv("MASTER_KEY", "secretkey")
t.Run("NoNetwork", func(t *testing.T) { t.Run("NoNetwork", func(t *testing.T) {
err, networks, username := SecurityCheck(false, "", "Bearer secretkey") networks, username, err := SecurityCheck(false, "", "Bearer secretkey")
assert.Nil(t, err) assert.Nil(t, err)
t.Log(networks, username) t.Log(networks, username)
}) })
t.Run("WithNetwork", func(t *testing.T) { t.Run("WithNetwork", func(t *testing.T) {
err, networks, username := SecurityCheck(false, "skynet", "Bearer secretkey") networks, username, err := SecurityCheck(false, "skynet", "Bearer secretkey")
assert.Nil(t, err) assert.Nil(t, err)
t.Log(networks, username) t.Log(networks, username)
}) })
t.Run("BadNet", func(t *testing.T) { t.Run("BadNet", func(t *testing.T) {
t.Skip() t.Skip()
err, networks, username := SecurityCheck(false, "badnet", "Bearer secretkey") networks, username, err := SecurityCheck(false, "badnet", "Bearer secretkey")
assert.NotNil(t, err) assert.NotNil(t, err)
t.Log(err) t.Log(err)
t.Log(networks, username) t.Log(networks, username)
}) })
t.Run("BadToken", func(t *testing.T) { t.Run("BadToken", func(t *testing.T) {
err, networks, username := SecurityCheck(false, "skynet", "Bearer badkey") networks, username, err := SecurityCheck(false, "skynet", "Bearer badkey")
assert.NotNil(t, err) assert.NotNil(t, err)
t.Log(err) t.Log(err)
t.Log(networks, username) t.Log(networks, username)
}) })
} }
func TestValidateNetworkUpdate(t *testing.T) { func TestValidateNetwork(t *testing.T) {
t.Skip() //t.Skip()
//This functions is not called by anyone //This functions is not called by anyone
//it panics as validation function 'display_name_valid' is not defined //it panics as validation function 'display_name_valid' is not defined
database.InitializeDatabase() database.InitializeDatabase()
@@ -220,23 +220,25 @@ func TestValidateNetworkUpdate(t *testing.T) {
{ {
testname: "InvalidAddress", testname: "InvalidAddress",
network: models.Network{ network: models.Network{
NetID: "skynet",
AddressRange: "10.0.0.256", AddressRange: "10.0.0.256",
}, },
errMessage: "Field validation for 'AddressRange' failed on the 'cidr' tag", errMessage: "Field validation for 'AddressRange' failed on the 'cidr' tag",
}, },
{ //{
testname: "InvalidAddress6", // testname: "InvalidAddress6",
network: models.Network{ // network: models.Network{
AddressRange6: "2607::ag", // NetID: "skynet1",
}, // AddressRange6: "2607::ffff/130",
errMessage: "Field validation for 'AddressRange6' failed on the 'cidr' tag", // },
}, // errMessage: "Field validation for 'AddressRange6' failed on the 'cidr' tag",
//},
{ {
testname: "InvalidNetID", testname: "InvalidNetID",
network: models.Network{ network: models.Network{
NetID: "contains spaces", NetID: "with spaces",
}, },
errMessage: "Field validation for 'NetID' failed on the 'alphanum' tag", errMessage: "Field validation for 'NetID' failed on the 'netid_valid' tag",
}, },
{ {
testname: "NetIDTooLong", testname: "NetIDTooLong",
@@ -248,6 +250,7 @@ func TestValidateNetworkUpdate(t *testing.T) {
{ {
testname: "ListenPortTooLow", testname: "ListenPortTooLow",
network: models.Network{ network: models.Network{
NetID: "skynet",
DefaultListenPort: 1023, DefaultListenPort: 1023,
}, },
errMessage: "Field validation for 'DefaultListenPort' failed on the 'min' tag", errMessage: "Field validation for 'DefaultListenPort' failed on the 'min' tag",
@@ -255,6 +258,7 @@ func TestValidateNetworkUpdate(t *testing.T) {
{ {
testname: "ListenPortTooHigh", testname: "ListenPortTooHigh",
network: models.Network{ network: models.Network{
NetID: "skynet",
DefaultListenPort: 65536, DefaultListenPort: 65536,
}, },
errMessage: "Field validation for 'DefaultListenPort' failed on the 'max' tag", errMessage: "Field validation for 'DefaultListenPort' failed on the 'max' tag",
@@ -262,6 +266,7 @@ func TestValidateNetworkUpdate(t *testing.T) {
{ {
testname: "KeepAliveTooBig", testname: "KeepAliveTooBig",
network: models.Network{ network: models.Network{
NetID: "skynet",
DefaultKeepalive: 1010, DefaultKeepalive: 1010,
}, },
errMessage: "Field validation for 'DefaultKeepalive' failed on the 'max' tag", errMessage: "Field validation for 'DefaultKeepalive' failed on the 'max' tag",
@@ -269,6 +274,7 @@ func TestValidateNetworkUpdate(t *testing.T) {
{ {
testname: "InvalidLocalRange", testname: "InvalidLocalRange",
network: models.Network{ network: models.Network{
NetID: "skynet",
LocalRange: "192.168.0.1", LocalRange: "192.168.0.1",
}, },
errMessage: "Field validation for 'LocalRange' failed on the 'cidr' tag", errMessage: "Field validation for 'LocalRange' failed on the 'cidr' tag",
@@ -276,8 +282,10 @@ func TestValidateNetworkUpdate(t *testing.T) {
} }
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.testname, func(t *testing.T) { t.Run(tc.testname, func(t *testing.T) {
t.Log(tc.testname)
network := models.Network(tc.network) network := models.Network(tc.network)
err := logic.ValidateNetworkUpdate(network) network.SetDefaults()
err := logic.ValidateNetwork(&network, false)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Contains(t, err.Error(), tc.errMessage) assert.Contains(t, err.Error(), tc.errMessage)
}) })

View File

@@ -31,7 +31,7 @@ func securityCheck(reqAdmin bool, next http.Handler) http.HandlerFunc {
return return
} }
err, networks, username := SecurityCheck(reqAdmin, params["networkname"], bearerToken) networks, username, err := SecurityCheck(reqAdmin, params["networkname"], bearerToken)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "does not exist") { if strings.Contains(err.Error(), "does not exist") {
errorResponse.Code = http.StatusNotFound errorResponse.Code = http.StatusNotFound
@@ -53,7 +53,7 @@ func securityCheck(reqAdmin bool, next http.Handler) http.HandlerFunc {
} }
// SecurityCheck - checks token stuff // SecurityCheck - checks token stuff
func SecurityCheck(reqAdmin bool, netname string, token string) (error, []string, string) { func SecurityCheck(reqAdmin bool, netname string, token string) ([]string, string, error) {
var hasBearer = true var hasBearer = true
var tokenSplit = strings.Split(token, " ") var tokenSplit = strings.Split(token, " ")
@@ -72,10 +72,10 @@ func SecurityCheck(reqAdmin bool, netname string, token string) (error, []string
userName, networks, isadmin, err := logic.VerifyUserToken(authToken) userName, networks, isadmin, err := logic.VerifyUserToken(authToken)
username = userName username = userName
if err != nil { if err != nil {
return errors.New("error verifying user token"), nil, username return nil, username, errors.New("error verifying user token")
} }
if !isadmin && reqAdmin { if !isadmin && reqAdmin {
return errors.New("you are unauthorized to access this endpoint"), nil, username return nil, username, errors.New("you are unauthorized to access this endpoint")
} }
userNetworks = networks userNetworks = networks
if isadmin { if isadmin {
@@ -83,10 +83,10 @@ func SecurityCheck(reqAdmin bool, netname string, token string) (error, []string
} else { } else {
networkexists, err := functions.NetworkExists(netname) networkexists, err := functions.NetworkExists(netname)
if err != nil && !database.IsEmptyRecord(err) { if err != nil && !database.IsEmptyRecord(err) {
return err, nil, "" return nil, "", err
} }
if netname != "" && !networkexists { if netname != "" && !networkexists {
return errors.New("this network does not exist"), nil, "" return nil, "", errors.New("this network does not exist")
} }
} }
} else if isMasterAuthenticated { } else if isMasterAuthenticated {
@@ -95,7 +95,7 @@ func SecurityCheck(reqAdmin bool, netname string, token string) (error, []string
if len(userNetworks) == 0 { if len(userNetworks) == 0 {
userNetworks = append(userNetworks, NO_NETWORKS_PRESENT) userNetworks = append(userNetworks, NO_NETWORKS_PRESENT)
} }
return nil, userNetworks, username return userNetworks, username, nil
} }
// Consider a more secure way of setting master key // Consider a more secure way of setting master key

View File

@@ -83,17 +83,6 @@ func rqliteDeleteAllRecords(tableName string) error {
return nil return nil
} }
func rqliteFetchRecord(tableName string, key string) (string, error) {
results, err := FetchRecords(tableName)
if err != nil {
return "", err
}
if results[key] == "" {
return "", errors.New(NO_RECORD)
}
return results[key], nil
}
func rqliteFetchRecords(tableName string) (map[string]string, error) { func rqliteFetchRecords(tableName string) (map[string]string, error) {
row, err := RQliteDatabase.QueryOne("SELECT * FROM " + tableName + " ORDER BY key") row, err := RQliteDatabase.QueryOne("SELECT * FROM " + tableName + " ORDER BY key")
if err != nil { if err != nil {

View File

@@ -229,6 +229,10 @@ func UpdateUser(userchange models.User, user models.User) (models.User, error) {
func ValidateUser(user models.User) error { func ValidateUser(user models.User) error {
v := validator.New() v := validator.New()
_ = v.RegisterValidation("in_charset", func(fl validator.FieldLevel) bool {
isgood := user.NameInCharSet()
return isgood
})
err := v.Struct(user) err := v.Struct(user)
if err != nil { if err != nil {

View File

@@ -37,11 +37,11 @@ func CreateJWT(uuid string, macAddress string, network string) (response string,
ID: uuid, ID: uuid,
Network: network, Network: network,
MacAddress: macAddress, MacAddress: macAddress,
StandardClaims: jwt.StandardClaims{ RegisteredClaims: jwt.RegisteredClaims{
Issuer: "Netmaker", Issuer: "Netmaker",
Subject: fmt.Sprintf("node|%s", uuid), Subject: fmt.Sprintf("node|%s", uuid),
IssuedAt: time.Now().Unix(), IssuedAt: jwt.NewNumericDate(time.Now()),
ExpiresAt: expirationTime.Unix(), ExpiresAt: jwt.NewNumericDate(expirationTime),
}, },
} }
@@ -60,11 +60,11 @@ func CreateUserJWT(username string, networks []string, isadmin bool) (response s
UserName: username, UserName: username,
Networks: networks, Networks: networks,
IsAdmin: isadmin, IsAdmin: isadmin,
StandardClaims: jwt.StandardClaims{ RegisteredClaims: jwt.RegisteredClaims{
Issuer: "Netmaker", Issuer: "Netmaker",
IssuedAt: time.Now().Unix(),
Subject: fmt.Sprintf("user|%s", username), Subject: fmt.Sprintf("user|%s", username),
ExpiresAt: expirationTime.Unix(), IssuedAt: jwt.NewNumericDate(time.Now()),
ExpiresAt: jwt.NewNumericDate(expirationTime),
}, },
} }

View File

@@ -622,28 +622,6 @@ func ParseNetwork(value string) (models.Network, error) {
return network, err return network, err
} }
// ValidateNetworkUpdate - checks if network is valid to update
func ValidateNetworkUpdate(network models.Network) error {
v := validator.New()
_ = v.RegisterValidation("netid_valid", func(fl validator.FieldLevel) bool {
if fl.Field().String() == "" {
return true
}
inCharSet := nameInNetworkCharSet(fl.Field().String())
return inCharSet
})
err := v.Struct(network)
if err != nil {
for _, e := range err.(validator.ValidationErrors) {
logger.Log(1, "validator", e.Error())
}
}
return err
}
// KeyUpdate - updates keys on network // KeyUpdate - updates keys on network
func KeyUpdate(netname string) (models.Network, error) { func KeyUpdate(netname string) (models.Network, error) {
err := networkNodesUpdateAction(netname, models.NODE_UPDATE_KEY) err := networkNodesUpdateAction(netname, models.NODE_UPDATE_KEY)
@@ -699,18 +677,6 @@ func networkNodesUpdateAction(networkName string, action string) error {
return nil return nil
} }
func nameInNetworkCharSet(name string) bool {
charset := "abcdefghijklmnopqrstuvwxyz1234567890-_."
for _, char := range name {
if !strings.Contains(charset, strings.ToLower(string(char))) {
return false
}
}
return true
}
func deleteInterface(ifacename string, postdown string) error { func deleteInterface(ifacename string, postdown string) error {
var err error var err error
if !ncutils.IsKernel() { if !ncutils.IsKernel() {

View File

@@ -1,6 +1,8 @@
package models package models
import ( import (
"strings"
jwt "github.com/golang-jwt/jwt/v4" jwt "github.com/golang-jwt/jwt/v4"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@@ -17,7 +19,7 @@ type AuthParams struct {
// User struct - struct for Users // User struct - struct for Users
type User struct { type User struct {
UserName string `json:"username" bson:"username" validate:"min=3,max=40,regexp=^(([a-zA-Z,\-,\.]*)|([A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,4})){3,40}$"` UserName string `json:"username" bson:"username" validate:"min=3,max=40,in_charset|email"`
Password string `json:"password" bson:"password" validate:"required,min=5"` Password string `json:"password" bson:"password" validate:"required,min=5"`
Networks []string `json:"networks" bson:"networks"` Networks []string `json:"networks" bson:"networks"`
IsAdmin bool `json:"isadmin" bson:"isadmin"` IsAdmin bool `json:"isadmin" bson:"isadmin"`
@@ -25,7 +27,7 @@ type User struct {
// ReturnUser - return user struct // ReturnUser - return user struct
type ReturnUser struct { type ReturnUser struct {
UserName string `json:"username" bson:"username" validate:"min=3,max=40,regexp=^(([a-zA-Z,\-,\.]*)|([A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,4})){3,40}$"` UserName string `json:"username" bson:"username"`
Networks []string `json:"networks" bson:"networks"` Networks []string `json:"networks" bson:"networks"`
IsAdmin bool `json:"isadmin" bson:"isadmin"` IsAdmin bool `json:"isadmin" bson:"isadmin"`
} }
@@ -41,7 +43,7 @@ type UserClaims struct {
IsAdmin bool IsAdmin bool
UserName string UserName string
Networks []string Networks []string
jwt.StandardClaims jwt.RegisteredClaims
} }
// SuccessfulUserLoginResponse - successlogin struct // SuccessfulUserLoginResponse - successlogin struct
@@ -56,7 +58,7 @@ type Claims struct {
ID string ID string
MacAddress string MacAddress string
Network string Network string
jwt.StandardClaims jwt.RegisteredClaims
} }
// SuccessfulLoginResponse is struct to send the request response // SuccessfulLoginResponse is struct to send the request response
@@ -206,3 +208,14 @@ type ServerConfig struct {
MQPort string `yaml:"mqport"` MQPort string `yaml:"mqport"`
Server string `yaml:"server"` Server string `yaml:"server"`
} }
// User.NameInCharset - returns if name is in charset below or not
func (user *User) NameInCharSet() bool {
charset := "abcdefghijklmnopqrstuvwxyz1234567890-."
for _, char := range user.UserName {
if !strings.Contains(charset, strings.ToLower(string(char))) {
return false
}
}
return true
}