fixing sso error handling

This commit is contained in:
afeiszli
2022-09-19 15:37:00 -04:00
parent d7517dab1c
commit f63b88db73
10 changed files with 42 additions and 22 deletions

View File

@@ -155,8 +155,11 @@ func returnErrTemplate(uname, message, state string, ncache *netcache.CValue) []
// Listens in /oidc/register/:regKey. // Listens in /oidc/register/:regKey.
func RegisterNodeSSO(w http.ResponseWriter, r *http.Request) { func RegisterNodeSSO(w http.ResponseWriter, r *http.Request) {
logger.Log(1, "RegisterNodeSSO\n") if auth_provider == nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("invalid login attempt"))
return
}
vars := mux.Vars(r) vars := mux.Vars(r)
// machineKeyStr this is not key but state // machineKeyStr this is not key but state
@@ -165,8 +168,7 @@ func RegisterNodeSSO(w http.ResponseWriter, r *http.Request) {
if machineKeyStr == "" { if machineKeyStr == "" {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("Wrong params")) w.Write([]byte("invalid login attempt"))
logger.Log(0, "Wrong params ", machineKeyStr)
return return
} }

View File

@@ -58,12 +58,20 @@ func SessionHandler(conn *websocket.Conn) {
defer close(answer) defer close(answer)
defer close(timeout) defer close(timeout)
if _, err = logic.GetNetwork(loginMessage.Network); err != nil {
err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
logger.Log(0, "error during message writing:", err.Error())
}
return
}
if loginMessage.User != "" { // handle basic auth if loginMessage.User != "" { // handle basic auth
// verify that server supports basic auth, then authorize the request with given credentials // verify that server supports basic auth, then authorize the request with given credentials
// check if user is allowed to join via node sso // check if user is allowed to join via node sso
// i.e. user is admin or user has network permissions // i.e. user is admin or user has network permissions
if !servercfg.IsBasicAuthEnabled() { if !servercfg.IsBasicAuthEnabled() {
err = conn.WriteMessage(messageType, []byte("Basic Auth Disabled")) err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil { if err != nil {
logger.Log(0, "error during message writing:", err.Error()) logger.Log(0, "error during message writing:", err.Error())
} }
@@ -73,7 +81,7 @@ func SessionHandler(conn *websocket.Conn) {
Password: loginMessage.Password, Password: loginMessage.Password,
}) })
if err != nil { if err != nil {
err = conn.WriteMessage(messageType, []byte(fmt.Sprintf("Failed to authenticate, %s.", loginMessage.User))) err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil { if err != nil {
logger.Log(0, "error during message writing:", err.Error()) logger.Log(0, "error during message writing:", err.Error())
} }
@@ -81,7 +89,7 @@ func SessionHandler(conn *websocket.Conn) {
} }
user, err := isUserIsAllowed(loginMessage.User, loginMessage.Network, false) user, err := isUserIsAllowed(loginMessage.User, loginMessage.Network, false)
if err != nil { if err != nil {
err = conn.WriteMessage(messageType, []byte(fmt.Sprintf("%s lacks permission to join.", loginMessage.User))) err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil { if err != nil {
logger.Log(0, "error during message writing:", err.Error()) logger.Log(0, "error during message writing:", err.Error())
} }
@@ -99,6 +107,13 @@ func SessionHandler(conn *websocket.Conn) {
return return
} }
} else { // handle SSO / OAuth } else { // handle SSO / OAuth
if auth_provider == nil {
err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
logger.Log(0, "error during message writing:", err.Error())
}
return
}
redirectUrl = fmt.Sprintf("https://%s/api/oauth/register/%s", servercfg.GetAPIConnString(), stateStr) redirectUrl = fmt.Sprintf("https://%s/api/oauth/register/%s", servercfg.GetAPIConnString(), stateStr)
err = conn.WriteMessage(messageType, []byte(redirectUrl)) err = conn.WriteMessage(messageType, []byte(redirectUrl))
if err != nil { if err != nil {
@@ -135,7 +150,7 @@ func SessionHandler(conn *websocket.Conn) {
case <-timeout: case <-timeout:
logger.Log(0, "Authentication server time out for a node on network", loginMessage.Network) logger.Log(0, "Authentication server time out for a node on network", loginMessage.Network)
// the read from req.answerCh has timed out // the read from req.answerCh has timed out
err = conn.WriteMessage(messageType, []byte("Authentication server time out")) err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil { if err != nil {
logger.Log(0, "Error during message writing:", err.Error()) logger.Log(0, "Error during message writing:", err.Error())
} }

View File

@@ -134,7 +134,7 @@ func Retrieve(filePath string) string {
// FatalLog - exits os after logging // FatalLog - exits os after logging
func FatalLog(message ...string) { func FatalLog(message ...string) {
fmt.Printf("[netmaker] Fatal: %s \n", MakeString(" ", message...)) fmt.Printf("[%s] Fatal: %s \n", program, MakeString(" ", message...))
os.Exit(2) os.Exit(2)
} }

View File

@@ -30,14 +30,13 @@ func Join(cfg *config.ClientConfig, privateKey string) error {
logger.Log(1, "Logging into %s via:", cfg.Network, cfg.SsoServer) logger.Log(1, "Logging into %s via:", cfg.Network, cfg.SsoServer)
err = functions.JoinViaSSo(cfg, privateKey) err = functions.JoinViaSSo(cfg, privateKey)
if err != nil { if err != nil {
logger.Log(0, "Join via OIDC failed: ", err.Error()) logger.Log(0, "Join failed: ", err.Error())
return err return err
} }
if cfg.AccessKey == "" { if cfg.AccessKey == "" {
return errors.New("failed to get access key") return errors.New("login failed")
} }
logger.Log(1, "Got an access key to ", cfg.Network, " via:", cfg.SsoServer)
} }
logger.Log(1, "Joining network: ", cfg.Network) logger.Log(1, "Joining network: ", cfg.Network)

View File

@@ -28,7 +28,7 @@ func SetupFreebsdDaemon() error {
} }
err = ncutils.Copy(binarypath, EXEC_DIR+"netclient") err = ncutils.Copy(binarypath, EXEC_DIR+"netclient")
if err != nil { if err != nil {
log.Println(err) logger.Log(0, err.Error())
return err return err
} }

View File

@@ -25,7 +25,7 @@ func SetupMacDaemon() error {
} }
err = ncutils.Copy(binarypath, MAC_EXEC_DIR+"netclient") err = ncutils.Copy(binarypath, MAC_EXEC_DIR+"netclient")
if err != nil { if err != nil {
log.Println(err) logger.Log(0, err.Error())
return err return err
} }

View File

@@ -38,7 +38,7 @@ func SetupSystemDDaemon() error {
} }
err = ncutils.Copy(binarypath, EXEC_DIR+"netclient") err = ncutils.Copy(binarypath, EXEC_DIR+"netclient")
if err != nil { if err != nil {
log.Println(err) logger.Log(0, err.Error())
return err return err
} }
@@ -64,7 +64,7 @@ WantedBy=multi-user.target
if !ncutils.FileExists("/etc/systemd/system/netclient.service") { if !ncutils.FileExists("/etc/systemd/system/netclient.service") {
err = os.WriteFile("/etc/systemd/system/netclient.service", servicebytes, 0644) err = os.WriteFile("/etc/systemd/system/netclient.service", servicebytes, 0644)
if err != nil { if err != nil {
log.Println(err) logger.Log(0, err.Error())
return err return err
} }
} }
@@ -106,7 +106,7 @@ func RemoveSystemDServices() error {
var err error var err error
if !ncutils.IsWindows() && isOnlyService() { if !ncutils.IsWindows() && isOnlyService() {
if err != nil { if err != nil {
log.Println(err) logger.Log(0, err.Error())
} }
ncutils.RunCmd("systemctl disable netclient.service", false) ncutils.RunCmd("systemctl disable netclient.service", false)
ncutils.RunCmd("systemctl disable netclient.timer", false) ncutils.RunCmd("systemctl disable netclient.timer", false)

View File

@@ -301,8 +301,7 @@ func WipeLocal(cfg *config.ClientConfig) error {
if cfg.Node.Interface != "" { if cfg.Node.Interface != "" {
if ncutils.FileExists(dir + cfg.Node.Interface + ".conf") { if ncutils.FileExists(dir + cfg.Node.Interface + ".conf") {
if err := os.Remove(dir + cfg.Node.Interface + ".conf"); err != nil { if err := os.Remove(dir + cfg.Node.Interface + ".conf"); err != nil {
log.Println("error removing .conf:") logger.Log(0, err.Error())
log.Println(err.Error())
fail = true fail = true
} }
} }

View File

@@ -82,6 +82,7 @@ func JoinViaSSo(cfg *config.ClientConfig, privateKey string) error {
} }
loginMsg.User = global_settings.User loginMsg.User = global_settings.User
loginMsg.Password = string(pass) loginMsg.Password = string(pass)
fmt.Println("attempting login...")
} }
msgTx, err := json.Marshal(loginMsg) msgTx, err := json.Marshal(loginMsg)
@@ -101,7 +102,6 @@ func JoinViaSSo(cfg *config.ClientConfig, privateKey string) error {
// Wait to receive something from server // Wait to receive something from server
_, msg, err := conn.ReadMessage() _, msg, err := conn.ReadMessage()
if err != nil { if err != nil {
log.Println("Error in receive:", err)
return err return err
} }
// Print message from the netmaker controller to the user // Print message from the netmaker controller to the user
@@ -121,6 +121,11 @@ func JoinViaSSo(cfg *config.ClientConfig, privateKey string) error {
for { for {
msgType, msg, err := conn.ReadMessage() msgType, msg, err := conn.ReadMessage()
if err != nil { if err != nil {
if msgType < 0 {
logger.Log(1, "received close message from server")
done <- struct{}{}
return
}
// Error reading a message from the server // Error reading a message from the server
if !strings.Contains(err.Error(), "normal") { if !strings.Contains(err.Error(), "normal") {
logger.Log(0, "read:", err.Error()) logger.Log(0, "read:", err.Error())

View File

@@ -4,10 +4,10 @@
package main package main
import ( import (
"log"
"os" "os"
"runtime/debug" "runtime/debug"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/netclient/cli_options" "github.com/gravitl/netmaker/netclient/cli_options"
"github.com/gravitl/netmaker/netclient/config" "github.com/gravitl/netmaker/netclient/config"
"github.com/gravitl/netmaker/netclient/functions" "github.com/gravitl/netmaker/netclient/functions"
@@ -47,7 +47,7 @@ func main() {
} else { } else {
err := app.Run(os.Args) err := app.Run(os.Args)
if err != nil { if err != nil {
log.Fatal(err) logger.FatalLog(err.Error())
} }
} }
} }