From a8a0dd066cc29b3560606fc4dd0e3f4f6952d20a Mon Sep 17 00:00:00 2001 From: Abhishek K Date: Tue, 12 Aug 2025 09:16:51 +0530 Subject: [PATCH] NM-44: Device Approvals for Network Join (#3579) * add pending hosts apis, migration logic for network auto join field * fix pending hosts logic on join * delete pending hosts on host delete * ignore pedning device request if host in the network already * add peer update on host approval --- auth/host_session.go | 81 ++++++++++++++---- controllers/enrollmentkeys.go | 20 +---- controllers/hosts.go | 151 ++++++++++++++++++++++++++++++++++ controllers/network.go | 5 +- logic/hosts.go | 12 +++ logic/networks.go | 51 +++++++----- migrate/migrate.go | 11 +++ models/network.go | 1 + schema/models.go | 1 + schema/pending_hosts.go | 46 +++++++++++ scripts/netmaker.default.env | 2 +- 11 files changed, 320 insertions(+), 61 deletions(-) create mode 100644 schema/pending_hosts.go diff --git a/auth/host_session.go b/auth/host_session.go index 8e34cc53..9b175d90 100644 --- a/auth/host_session.go +++ b/auth/host_session.go @@ -1,6 +1,7 @@ package auth import ( + "context" "encoding/json" "fmt" "log/slog" @@ -9,12 +10,14 @@ import ( "github.com/google/uuid" "github.com/gorilla/websocket" + "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic/hostactions" "github.com/gravitl/netmaker/logic/pro/netcache" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" + "github.com/gravitl/netmaker/schema" "github.com/gravitl/netmaker/servercfg" ) @@ -223,7 +226,7 @@ func SessionHandler(conn *websocket.Conn) { if err = conn.WriteMessage(messageType, reponseData); err != nil { logger.Log(0, "error during message writing:", err.Error()) } - go CheckNetRegAndHostUpdate(netsToAdd[:], &result.Host, uuid.Nil, []models.TagID{}) + go CheckNetRegAndHostUpdate(models.EnrollmentKey{Networks: netsToAdd}, &result.Host, "") case <-timeout: // the read from req.answerCh has timed out logger.Log(0, "timeout signal recv,exiting oauth socket conn") break @@ -237,35 +240,79 @@ func SessionHandler(conn *websocket.Conn) { } // CheckNetRegAndHostUpdate - run through networks and send a host update -func CheckNetRegAndHostUpdate(networks []string, h *models.Host, relayNodeId uuid.UUID, tags []models.TagID) { +func CheckNetRegAndHostUpdate(key models.EnrollmentKey, h *models.Host, username string) { // publish host update through MQ - for i := range networks { - network := networks[i] - if ok, _ := logic.NetworkExists(network); ok { - newNode, err := logic.UpdateHostNetwork(h, network, true) + for _, netID := range key.Networks { + if network, err := logic.GetNetwork(netID); err == nil { + if network.AutoJoin == "false" { + if logic.DoesHostExistinTheNetworkAlready(h, models.NetworkID(netID)) { + continue + } + if err := (&schema.PendingHost{ + HostID: h.ID.String(), + Network: netID, + }).CheckIfPendingHostExists(db.WithContext(context.TODO())); err == nil { + continue + } + keyB, _ := json.Marshal(key) + // add host to pending host table + p := schema.PendingHost{ + ID: uuid.NewString(), + HostID: h.ID.String(), + Hostname: h.Name, + Network: netID, + PublicKey: h.PublicKey.String(), + OS: h.OS, + Location: h.Location, + Version: h.Version, + EnrollmentKey: keyB, + RequestedAt: time.Now().UTC(), + } + p.Create(db.WithContext(context.TODO())) + continue + } + + logic.LogEvent(&models.Event{ + Action: models.JoinHostToNet, + Source: models.Subject{ + ID: key.Value, + Name: key.Tags[0], + Type: models.EnrollmentKeySub, + }, + TriggeredBy: username, + Target: models.Subject{ + ID: h.ID.String(), + Name: h.Name, + Type: models.DeviceSub, + }, + NetworkID: models.NetworkID(netID), + Origin: models.Dashboard, + }) + + newNode, err := logic.UpdateHostNetwork(h, netID, true) if err == nil || strings.Contains(err.Error(), "host already part of network") { - if len(tags) > 0 { + if len(key.Groups) > 0 { newNode.Tags = make(map[models.TagID]struct{}) - for _, tagI := range tags { + for _, tagI := range key.Groups { newNode.Tags[tagI] = struct{}{} } logic.UpsertNode(newNode) } - if relayNodeId != uuid.Nil && !newNode.IsRelayed { + if key.Relay != uuid.Nil && !newNode.IsRelayed { // check if relay node exists and acting as relay - relaynode, err := logic.GetNodeByID(relayNodeId.String()) + relaynode, err := logic.GetNodeByID(key.Relay.String()) if err == nil && relaynode.IsGw && relaynode.Network == newNode.Network { - slog.Error(fmt.Sprintf("adding relayed node %s to relay %s on network %s", newNode.ID.String(), relayNodeId.String(), network)) + slog.Error(fmt.Sprintf("adding relayed node %s to relay %s on network %s", newNode.ID.String(), key.Relay.String(), netID)) newNode.IsRelayed = true - newNode.RelayedBy = relayNodeId.String() + newNode.RelayedBy = key.Relay.String() updatedRelayNode := relaynode updatedRelayNode.RelayedNodes = append(updatedRelayNode.RelayedNodes, newNode.ID.String()) logic.UpdateRelayed(&relaynode, &updatedRelayNode) if err := logic.UpsertNode(&updatedRelayNode); err != nil { - slog.Error("failed to update node", "nodeid", relayNodeId.String()) + slog.Error("failed to update node", "nodeid", key.Relay.String()) } if err := logic.UpsertNode(newNode); err != nil { - slog.Error("failed to update node", "nodeid", relayNodeId.String()) + slog.Error("failed to update node", "nodeid", key.Relay.String()) } } else { slog.Error("failed to relay node. maybe specified relay node is actually not a relay? Or the relayed node is not in the same network with relay?", "err", err) @@ -275,7 +322,7 @@ func CheckNetRegAndHostUpdate(networks []string, h *models.Host, relayNodeId uui continue } } else { - logger.Log(0, "failed to add host to network:", h.ID.String(), h.Name, network, err.Error()) + logger.Log(0, "failed to add host to network:", h.ID.String(), h.Name, netID, err.Error()) continue } logger.Log(1, "added new node", newNode.ID.String(), "to host", h.Name) @@ -288,10 +335,10 @@ func CheckNetRegAndHostUpdate(networks []string, h *models.Host, relayNodeId uui // make host failover logic.CreateFailOver(*newNode) // make host remote access gateway - logic.CreateIngressGateway(network, newNode.ID.String(), models.IngressRequest{}) + logic.CreateIngressGateway(netID, newNode.ID.String(), models.IngressRequest{}) logic.CreateRelay(models.RelayRequest{ NodeID: newNode.ID.String(), - NetID: network, + NetID: netID, }) } } diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index ece24111..1ec33b86 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -414,28 +414,10 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) { ServerConf: server, RequestedHost: *host, } - for _, netID := range enrollmentKey.Networks { - logic.LogEvent(&models.Event{ - Action: models.JoinHostToNet, - Source: models.Subject{ - ID: enrollmentKey.Value, - Name: enrollmentKey.Tags[0], - Type: models.EnrollmentKeySub, - }, - TriggeredBy: r.Header.Get("user"), - Target: models.Subject{ - ID: newHost.ID.String(), - Name: newHost.Name, - Type: models.DeviceSub, - }, - NetworkID: models.NetworkID(netID), - Origin: models.Dashboard, - }) - } logger.Log(0, host.Name, host.ID.String(), "registered with Netmaker") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(&response) // notify host of changes, peer and node updates - go auth.CheckNetRegAndHostUpdate(enrollmentKey.Networks, host, enrollmentKey.Relay, enrollmentKey.Groups) + go auth.CheckNetRegAndHostUpdate(*enrollmentKey, host, r.Header.Get("user")) } diff --git a/controllers/hosts.go b/controllers/hosts.go index 014840bb..7fb52e60 100644 --- a/controllers/hosts.go +++ b/controllers/hosts.go @@ -10,10 +10,13 @@ import ( "github.com/google/uuid" "github.com/gorilla/mux" "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/logic/hostactions" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" + "github.com/gravitl/netmaker/schema" "github.com/gravitl/netmaker/servercfg" "golang.org/x/crypto/bcrypt" "golang.org/x/exp/slog" @@ -51,6 +54,12 @@ func hostHandlers(r *mux.Router) { Methods(http.MethodPut) r.HandleFunc("/api/v1/host/{hostid}/peer_info", Authorize(true, false, "host", http.HandlerFunc(getHostPeerInfo))). Methods(http.MethodGet) + r.HandleFunc("/api/v1/pending_hosts", logic.SecurityCheck(true, http.HandlerFunc(getPendingHosts))). + Methods(http.MethodGet) + r.HandleFunc("/api/v1/pending_hosts/approve/{id}", logic.SecurityCheck(true, http.HandlerFunc(approvePendingHost))). + Methods(http.MethodPost) + r.HandleFunc("/api/v1/pending_hosts/reject/{id}", logic.SecurityCheck(true, http.HandlerFunc(rejectPendingHost))). + Methods(http.MethodPost) r.HandleFunc("/api/emqx/hosts", logic.SecurityCheck(true, http.HandlerFunc(delEmqxHosts))). Methods(http.MethodDelete) r.HandleFunc("/api/v1/auth-register/host", socketHandler) @@ -453,6 +462,10 @@ func deleteHost(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } + // delete if any pending reqs + (&schema.PendingHost{ + HostID: currHost.ID.String(), + }).DeleteAllPendingHosts(db.WithContext(r.Context())) logic.LogEvent(&models.Event{ Action: models.Delete, Source: models.Subject{ @@ -1144,3 +1157,141 @@ func getHostPeerInfo(w http.ResponseWriter, r *http.Request) { } logic.ReturnSuccessResponseWithJson(w, r, peerInfo, "fetched host peer info") } + +// @Summary List pending hosts in a network +// @Router /api/v1/pending_hosts [get] +// @Tags Hosts +// @Security oauth +// @Success 200 {array} schema.PendingHost +// @Failure 500 {object} models.ErrorResponse +func getPendingHosts(w http.ResponseWriter, r *http.Request) { + netID := r.URL.Query().Get("network") + if netID == "" { + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("network id param is missing"), "badrequest")) + return + } + pendingHosts, err := (&schema.PendingHost{ + Network: netID, + }).List(db.WithContext(r.Context())) + if err != nil { + logic.ReturnErrorResponse(w, r, models.ErrorResponse{ + Code: http.StatusBadRequest, + Message: err.Error(), + }) + return + } + logger.Log(2, r.Header.Get("user"), "fetched all hosts") + logic.ReturnSuccessResponseWithJson(w, r, pendingHosts, "returned pending hosts in "+netID) +} + +// @Summary approve pending hosts in a network +// @Router /api/v1/pending_hosts/approve/{id} [post] +// @Tags Hosts +// @Security oauth +// @Success 200 {array} models.ApiNode +// @Failure 500 {object} models.ErrorResponse +func approvePendingHost(w http.ResponseWriter, r *http.Request) { + id := mux.Vars(r)["id"] + p := &schema.PendingHost{ID: id} + err := p.Get(db.WithContext(r.Context())) + if err != nil { + logic.ReturnErrorResponse(w, r, models.ErrorResponse{ + Code: http.StatusBadRequest, + Message: err.Error(), + }) + return + } + h, err := logic.GetHost(p.HostID) + if err != nil { + logic.ReturnErrorResponse(w, r, models.ErrorResponse{ + Code: http.StatusBadRequest, + Message: err.Error(), + }) + return + } + key := models.EnrollmentKey{} + json.Unmarshal(p.EnrollmentKey, &key) + newNode, err := logic.UpdateHostNetwork(h, p.Network, true) + if err != nil { + logic.ReturnErrorResponse(w, r, models.ErrorResponse{ + Code: http.StatusBadRequest, + Message: err.Error(), + }) + return + } + if len(key.Groups) > 0 { + newNode.Tags = make(map[models.TagID]struct{}) + for _, tagI := range key.Groups { + newNode.Tags[tagI] = struct{}{} + } + logic.UpsertNode(newNode) + } + if key.Relay != uuid.Nil && !newNode.IsRelayed { + // check if relay node exists and acting as relay + relaynode, err := logic.GetNodeByID(key.Relay.String()) + if err == nil && relaynode.IsGw && relaynode.Network == newNode.Network { + slog.Error(fmt.Sprintf("adding relayed node %s to relay %s on network %s", newNode.ID.String(), key.Relay.String(), p.Network)) + newNode.IsRelayed = true + newNode.RelayedBy = key.Relay.String() + updatedRelayNode := relaynode + updatedRelayNode.RelayedNodes = append(updatedRelayNode.RelayedNodes, newNode.ID.String()) + logic.UpdateRelayed(&relaynode, &updatedRelayNode) + if err := logic.UpsertNode(&updatedRelayNode); err != nil { + slog.Error("failed to update node", "nodeid", key.Relay.String()) + } + if err := logic.UpsertNode(newNode); err != nil { + slog.Error("failed to update node", "nodeid", key.Relay.String()) + } + } else { + slog.Error("failed to relay node. maybe specified relay node is actually not a relay? Or the relayed node is not in the same network with relay?", "err", err) + } + } + + logger.Log(1, "added new node", newNode.ID.String(), "to host", h.Name) + hostactions.AddAction(models.HostUpdate{ + Action: models.JoinHostToNetwork, + Host: *h, + Node: *newNode, + }) + if h.IsDefault { + // make host failover + logic.CreateFailOver(*newNode) + // make host remote access gateway + logic.CreateIngressGateway(p.Network, newNode.ID.String(), models.IngressRequest{}) + logic.CreateRelay(models.RelayRequest{ + NodeID: newNode.ID.String(), + NetID: p.Network, + }) + } + p.Delete(db.WithContext(r.Context())) + go mq.PublishPeerUpdate(false) + logic.ReturnSuccessResponseWithJson(w, r, newNode.ConvertToAPINode(), "added pending host to "+p.Network) +} + +// @Summary reject pending hosts in a network +// @Router /api/v1/pending_hosts/reject/{id} [post] +// @Tags Hosts +// @Security oauth +// @Success 200 {array} models.ApiNode +// @Failure 500 {object} models.ErrorResponse +func rejectPendingHost(w http.ResponseWriter, r *http.Request) { + id := mux.Vars(r)["id"] + p := &schema.PendingHost{ID: id} + err := p.Get(db.WithContext(r.Context())) + if err != nil { + logic.ReturnErrorResponse(w, r, models.ErrorResponse{ + Code: http.StatusBadRequest, + Message: err.Error(), + }) + return + } + err = p.Delete(db.WithContext(r.Context())) + if err != nil { + logic.ReturnErrorResponse(w, r, models.ErrorResponse{ + Code: http.StatusBadRequest, + Message: err.Error(), + }) + return + } + logic.ReturnSuccessResponseWithJson(w, r, p, "deleted pending host from "+p.Network) +} diff --git a/controllers/network.go b/controllers/network.go index 9b216a88..de35efdd 100644 --- a/controllers/network.go +++ b/controllers/network.go @@ -701,10 +701,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } - netNew := netOld - netNew.NameServers = payload.NameServers - netNew.DefaultACL = payload.DefaultACL - _, _, _, err = logic.UpdateNetwork(&netOld, &netNew) + err = logic.UpdateNetwork(&netOld, &payload) if err != nil { slog.Info("failed to update network", "user", r.Header.Get("user"), "err", err) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) diff --git a/logic/hosts.go b/logic/hosts.go index 04fcd405..6d469057 100644 --- a/logic/hosts.go +++ b/logic/hosts.go @@ -175,6 +175,18 @@ func GetHostsMap() (map[string]models.Host, error) { return currHostMap, nil } +func DoesHostExistinTheNetworkAlready(h *models.Host, network models.NetworkID) bool { + if len(h.Nodes) > 0 { + for _, nodeID := range h.Nodes { + node, err := GetNodeByID(nodeID) + if err == nil && node.Network == network.String() { + return true + } + } + } + return false +} + // GetHost - gets a host from db given id func GetHost(hostid string) (*models.Host, error) { if servercfg.CacheEnabled() { diff --git a/logic/networks.go b/logic/networks.go index 5ef45b22..1c1e5c2e 100644 --- a/logic/networks.go +++ b/logic/networks.go @@ -629,30 +629,41 @@ func IsNetworkNameUnique(network *models.Network) (bool, error) { return isunique, nil } +func UpsertNetwork(network models.Network) error { + netData, err := json.Marshal(network) + if err != nil { + return err + } + err = database.Insert(network.NetID, string(netData), database.NETWORKS_TABLE_NAME) + if err != nil { + return err + } + return nil +} + // UpdateNetwork - updates a network with another network's fields -func UpdateNetwork(currentNetwork *models.Network, newNetwork *models.Network) (bool, bool, bool, error) { +func UpdateNetwork(currentNetwork *models.Network, newNetwork *models.Network) error { if err := ValidateNetwork(newNetwork, true); err != nil { - return false, false, false, err + return err } - if newNetwork.NetID == currentNetwork.NetID { - hasrangeupdate4 := newNetwork.AddressRange != currentNetwork.AddressRange - hasrangeupdate6 := newNetwork.AddressRange6 != currentNetwork.AddressRange6 - hasholepunchupdate := newNetwork.DefaultUDPHolePunch != currentNetwork.DefaultUDPHolePunch - data, err := json.Marshal(newNetwork) - if err != nil { - return false, false, false, err - } - newNetwork.SetNetworkLastModified() - err = database.Insert(newNetwork.NetID, string(data), database.NETWORKS_TABLE_NAME) - if err == nil { - if servercfg.CacheEnabled() { - storeNetworkInCache(newNetwork.NetID, *newNetwork) - } - } - return hasrangeupdate4, hasrangeupdate6, hasholepunchupdate, err + if newNetwork.NetID != currentNetwork.NetID { + return errors.New("failed to update network " + newNetwork.NetID + ", cannot change netid.") } - // copy values - return false, false, false, errors.New("failed to update network " + newNetwork.NetID + ", cannot change netid.") + currentNetwork.AutoJoin = newNetwork.AutoJoin + currentNetwork.DefaultACL = newNetwork.DefaultACL + currentNetwork.NameServers = newNetwork.NameServers + data, err := json.Marshal(currentNetwork) + if err != nil { + return err + } + newNetwork.SetNetworkLastModified() + err = database.Insert(currentNetwork.NetID, string(data), database.NETWORKS_TABLE_NAME) + if err == nil { + if servercfg.CacheEnabled() { + storeNetworkInCache(newNetwork.NetID, *currentNetwork) + } + } + return err } // GetNetwork - gets a network from database diff --git a/migrate/migrate.go b/migrate/migrate.go index f54742d5..812c1faa 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -38,9 +38,20 @@ func Run() { updateNewAcls() logic.MigrateToGws() migrateToEgressV1() + updateNetworks() resync() } +func updateNetworks() { + nets, _ := logic.GetNetworks() + for _, netI := range nets { + if netI.AutoJoin == "" { + netI.AutoJoin = "true" + logic.UpsertNetwork(netI) + } + } +} + // removes if any stale configurations from previous run. func resync() { diff --git a/models/network.go b/models/network.go index 7ea41cce..a96d9587 100644 --- a/models/network.go +++ b/models/network.go @@ -25,6 +25,7 @@ type Network struct { DefaultMTU int32 `json:"defaultmtu" bson:"defaultmtu"` DefaultACL string `json:"defaultacl" bson:"defaultacl" yaml:"defaultacl" validate:"checkyesorno"` NameServers []string `json:"dns_nameservers"` + AutoJoin string `json:"auto_join"` } // SaveData - sensitive fields of a network that should be kept the same diff --git a/schema/models.go b/schema/models.go index 2f30d1f2..3d74b9bb 100644 --- a/schema/models.go +++ b/schema/models.go @@ -7,5 +7,6 @@ func ListModels() []interface{} { &Egress{}, &UserAccessToken{}, &Event{}, + &PendingHost{}, } } diff --git a/schema/pending_hosts.go b/schema/pending_hosts.go new file mode 100644 index 00000000..5c11bba8 --- /dev/null +++ b/schema/pending_hosts.go @@ -0,0 +1,46 @@ +package schema + +import ( + "context" + "time" + + "github.com/gravitl/netmaker/db" + "gorm.io/datatypes" +) + +type PendingHost struct { + ID string `gorm:"id" json:"id"` + HostID string `gorm:"host_id" json:"host_id"` + Hostname string `gorm:"host_name" json:"host_name"` + Network string `gorm:"network" json:"network"` + PublicKey string `gorm:"public_key" json:"public_key"` + EnrollmentKey datatypes.JSON `gorm:"enrollment_key_id" json:"enrollment_key_id"` + OS string `gorm:"os" json:"os"` + Version string `gorm:"version" json:"version"` + Location string `gorm:"location" json:"location"` // Format: "lat,lon" + RequestedAt time.Time `gorm:"requested_at" json:"requested_at"` +} + +func (p *PendingHost) Get(ctx context.Context) error { + return db.FromContext(ctx).Model(&PendingHost{}).First(&p).Where("id = ?", p.ID).Error +} + +func (p *PendingHost) Create(ctx context.Context) error { + return db.FromContext(ctx).Model(&PendingHost{}).Create(&p).Error +} + +func (p *PendingHost) List(ctx context.Context) (pendingHosts []PendingHost, err error) { + err = db.FromContext(ctx).Model(&PendingHost{}).Find(&pendingHosts).Error + return +} + +func (p *PendingHost) Delete(ctx context.Context) error { + return db.FromContext(ctx).Model(&PendingHost{}).Where("id = ?", p.ID).Delete(&p).Error +} +func (p *PendingHost) CheckIfPendingHostExists(ctx context.Context) error { + return db.FromContext(ctx).Model(&PendingHost{}).Where("host_id = ? AND network = ?", p.HostID, p.Network).First(&p).Error +} + +func (p *PendingHost) DeleteAllPendingHosts(ctx context.Context) error { + return db.FromContext(ctx).Model(&PendingHost{}).Where("host_id = ?", p.HostID).Delete(&p).Error +} diff --git a/scripts/netmaker.default.env b/scripts/netmaker.default.env index 14778efe..8fbe5642 100644 --- a/scripts/netmaker.default.env +++ b/scripts/netmaker.default.env @@ -34,7 +34,7 @@ EXPORTER_API_PORT=8085 CORS_ALLOWED_ORIGIN=* # Show keys permanently in UI (until deleted) as opposed to 1-time display. DISPLAY_KEYS=on -# Database to use - sqlite, postgres, or rqlite +# Database to use - sqlite, postgres DATABASE=sqlite # The address of the mq server. If running from docker compose it will be "mq". Otherwise, need to input address. # If using "host networking", it will find and detect the IP of the mq container.