NM-44: Device Approvals for Network Join (#3579)

* add pending hosts apis, migration logic for network auto join field

* fix pending hosts logic on join

* delete pending hosts on host delete

* ignore pedning device request if host in the network already

* add peer update on host approval
This commit is contained in:
Abhishek K
2025-08-12 09:16:51 +05:30
committed by GitHub
parent 062552170d
commit a8a0dd066c
11 changed files with 320 additions and 61 deletions

View File

@@ -1,6 +1,7 @@
package auth
import (
"context"
"encoding/json"
"fmt"
"log/slog"
@@ -9,12 +10,14 @@ import (
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/hostactions"
"github.com/gravitl/netmaker/logic/pro/netcache"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/mq"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/servercfg"
)
@@ -223,7 +226,7 @@ func SessionHandler(conn *websocket.Conn) {
if err = conn.WriteMessage(messageType, reponseData); err != nil {
logger.Log(0, "error during message writing:", err.Error())
}
go CheckNetRegAndHostUpdate(netsToAdd[:], &result.Host, uuid.Nil, []models.TagID{})
go CheckNetRegAndHostUpdate(models.EnrollmentKey{Networks: netsToAdd}, &result.Host, "")
case <-timeout: // the read from req.answerCh has timed out
logger.Log(0, "timeout signal recv,exiting oauth socket conn")
break
@@ -237,35 +240,79 @@ func SessionHandler(conn *websocket.Conn) {
}
// CheckNetRegAndHostUpdate - run through networks and send a host update
func CheckNetRegAndHostUpdate(networks []string, h *models.Host, relayNodeId uuid.UUID, tags []models.TagID) {
func CheckNetRegAndHostUpdate(key models.EnrollmentKey, h *models.Host, username string) {
// publish host update through MQ
for i := range networks {
network := networks[i]
if ok, _ := logic.NetworkExists(network); ok {
newNode, err := logic.UpdateHostNetwork(h, network, true)
for _, netID := range key.Networks {
if network, err := logic.GetNetwork(netID); err == nil {
if network.AutoJoin == "false" {
if logic.DoesHostExistinTheNetworkAlready(h, models.NetworkID(netID)) {
continue
}
if err := (&schema.PendingHost{
HostID: h.ID.String(),
Network: netID,
}).CheckIfPendingHostExists(db.WithContext(context.TODO())); err == nil {
continue
}
keyB, _ := json.Marshal(key)
// add host to pending host table
p := schema.PendingHost{
ID: uuid.NewString(),
HostID: h.ID.String(),
Hostname: h.Name,
Network: netID,
PublicKey: h.PublicKey.String(),
OS: h.OS,
Location: h.Location,
Version: h.Version,
EnrollmentKey: keyB,
RequestedAt: time.Now().UTC(),
}
p.Create(db.WithContext(context.TODO()))
continue
}
logic.LogEvent(&models.Event{
Action: models.JoinHostToNet,
Source: models.Subject{
ID: key.Value,
Name: key.Tags[0],
Type: models.EnrollmentKeySub,
},
TriggeredBy: username,
Target: models.Subject{
ID: h.ID.String(),
Name: h.Name,
Type: models.DeviceSub,
},
NetworkID: models.NetworkID(netID),
Origin: models.Dashboard,
})
newNode, err := logic.UpdateHostNetwork(h, netID, true)
if err == nil || strings.Contains(err.Error(), "host already part of network") {
if len(tags) > 0 {
if len(key.Groups) > 0 {
newNode.Tags = make(map[models.TagID]struct{})
for _, tagI := range tags {
for _, tagI := range key.Groups {
newNode.Tags[tagI] = struct{}{}
}
logic.UpsertNode(newNode)
}
if relayNodeId != uuid.Nil && !newNode.IsRelayed {
if key.Relay != uuid.Nil && !newNode.IsRelayed {
// check if relay node exists and acting as relay
relaynode, err := logic.GetNodeByID(relayNodeId.String())
relaynode, err := logic.GetNodeByID(key.Relay.String())
if err == nil && relaynode.IsGw && relaynode.Network == newNode.Network {
slog.Error(fmt.Sprintf("adding relayed node %s to relay %s on network %s", newNode.ID.String(), relayNodeId.String(), network))
slog.Error(fmt.Sprintf("adding relayed node %s to relay %s on network %s", newNode.ID.String(), key.Relay.String(), netID))
newNode.IsRelayed = true
newNode.RelayedBy = relayNodeId.String()
newNode.RelayedBy = key.Relay.String()
updatedRelayNode := relaynode
updatedRelayNode.RelayedNodes = append(updatedRelayNode.RelayedNodes, newNode.ID.String())
logic.UpdateRelayed(&relaynode, &updatedRelayNode)
if err := logic.UpsertNode(&updatedRelayNode); err != nil {
slog.Error("failed to update node", "nodeid", relayNodeId.String())
slog.Error("failed to update node", "nodeid", key.Relay.String())
}
if err := logic.UpsertNode(newNode); err != nil {
slog.Error("failed to update node", "nodeid", relayNodeId.String())
slog.Error("failed to update node", "nodeid", key.Relay.String())
}
} else {
slog.Error("failed to relay node. maybe specified relay node is actually not a relay? Or the relayed node is not in the same network with relay?", "err", err)
@@ -275,7 +322,7 @@ func CheckNetRegAndHostUpdate(networks []string, h *models.Host, relayNodeId uui
continue
}
} else {
logger.Log(0, "failed to add host to network:", h.ID.String(), h.Name, network, err.Error())
logger.Log(0, "failed to add host to network:", h.ID.String(), h.Name, netID, err.Error())
continue
}
logger.Log(1, "added new node", newNode.ID.String(), "to host", h.Name)
@@ -288,10 +335,10 @@ func CheckNetRegAndHostUpdate(networks []string, h *models.Host, relayNodeId uui
// make host failover
logic.CreateFailOver(*newNode)
// make host remote access gateway
logic.CreateIngressGateway(network, newNode.ID.String(), models.IngressRequest{})
logic.CreateIngressGateway(netID, newNode.ID.String(), models.IngressRequest{})
logic.CreateRelay(models.RelayRequest{
NodeID: newNode.ID.String(),
NetID: network,
NetID: netID,
})
}
}

View File

@@ -414,28 +414,10 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
ServerConf: server,
RequestedHost: *host,
}
for _, netID := range enrollmentKey.Networks {
logic.LogEvent(&models.Event{
Action: models.JoinHostToNet,
Source: models.Subject{
ID: enrollmentKey.Value,
Name: enrollmentKey.Tags[0],
Type: models.EnrollmentKeySub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: newHost.ID.String(),
Name: newHost.Name,
Type: models.DeviceSub,
},
NetworkID: models.NetworkID(netID),
Origin: models.Dashboard,
})
}
logger.Log(0, host.Name, host.ID.String(), "registered with Netmaker")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(&response)
// notify host of changes, peer and node updates
go auth.CheckNetRegAndHostUpdate(enrollmentKey.Networks, host, enrollmentKey.Relay, enrollmentKey.Groups)
go auth.CheckNetRegAndHostUpdate(*enrollmentKey, host, r.Header.Get("user"))
}

View File

@@ -10,10 +10,13 @@ import (
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/hostactions"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/mq"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/crypto/bcrypt"
"golang.org/x/exp/slog"
@@ -51,6 +54,12 @@ func hostHandlers(r *mux.Router) {
Methods(http.MethodPut)
r.HandleFunc("/api/v1/host/{hostid}/peer_info", Authorize(true, false, "host", http.HandlerFunc(getHostPeerInfo))).
Methods(http.MethodGet)
r.HandleFunc("/api/v1/pending_hosts", logic.SecurityCheck(true, http.HandlerFunc(getPendingHosts))).
Methods(http.MethodGet)
r.HandleFunc("/api/v1/pending_hosts/approve/{id}", logic.SecurityCheck(true, http.HandlerFunc(approvePendingHost))).
Methods(http.MethodPost)
r.HandleFunc("/api/v1/pending_hosts/reject/{id}", logic.SecurityCheck(true, http.HandlerFunc(rejectPendingHost))).
Methods(http.MethodPost)
r.HandleFunc("/api/emqx/hosts", logic.SecurityCheck(true, http.HandlerFunc(delEmqxHosts))).
Methods(http.MethodDelete)
r.HandleFunc("/api/v1/auth-register/host", socketHandler)
@@ -453,6 +462,10 @@ func deleteHost(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
// delete if any pending reqs
(&schema.PendingHost{
HostID: currHost.ID.String(),
}).DeleteAllPendingHosts(db.WithContext(r.Context()))
logic.LogEvent(&models.Event{
Action: models.Delete,
Source: models.Subject{
@@ -1144,3 +1157,141 @@ func getHostPeerInfo(w http.ResponseWriter, r *http.Request) {
}
logic.ReturnSuccessResponseWithJson(w, r, peerInfo, "fetched host peer info")
}
// @Summary List pending hosts in a network
// @Router /api/v1/pending_hosts [get]
// @Tags Hosts
// @Security oauth
// @Success 200 {array} schema.PendingHost
// @Failure 500 {object} models.ErrorResponse
func getPendingHosts(w http.ResponseWriter, r *http.Request) {
netID := r.URL.Query().Get("network")
if netID == "" {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("network id param is missing"), "badrequest"))
return
}
pendingHosts, err := (&schema.PendingHost{
Network: netID,
}).List(db.WithContext(r.Context()))
if err != nil {
logic.ReturnErrorResponse(w, r, models.ErrorResponse{
Code: http.StatusBadRequest,
Message: err.Error(),
})
return
}
logger.Log(2, r.Header.Get("user"), "fetched all hosts")
logic.ReturnSuccessResponseWithJson(w, r, pendingHosts, "returned pending hosts in "+netID)
}
// @Summary approve pending hosts in a network
// @Router /api/v1/pending_hosts/approve/{id} [post]
// @Tags Hosts
// @Security oauth
// @Success 200 {array} models.ApiNode
// @Failure 500 {object} models.ErrorResponse
func approvePendingHost(w http.ResponseWriter, r *http.Request) {
id := mux.Vars(r)["id"]
p := &schema.PendingHost{ID: id}
err := p.Get(db.WithContext(r.Context()))
if err != nil {
logic.ReturnErrorResponse(w, r, models.ErrorResponse{
Code: http.StatusBadRequest,
Message: err.Error(),
})
return
}
h, err := logic.GetHost(p.HostID)
if err != nil {
logic.ReturnErrorResponse(w, r, models.ErrorResponse{
Code: http.StatusBadRequest,
Message: err.Error(),
})
return
}
key := models.EnrollmentKey{}
json.Unmarshal(p.EnrollmentKey, &key)
newNode, err := logic.UpdateHostNetwork(h, p.Network, true)
if err != nil {
logic.ReturnErrorResponse(w, r, models.ErrorResponse{
Code: http.StatusBadRequest,
Message: err.Error(),
})
return
}
if len(key.Groups) > 0 {
newNode.Tags = make(map[models.TagID]struct{})
for _, tagI := range key.Groups {
newNode.Tags[tagI] = struct{}{}
}
logic.UpsertNode(newNode)
}
if key.Relay != uuid.Nil && !newNode.IsRelayed {
// check if relay node exists and acting as relay
relaynode, err := logic.GetNodeByID(key.Relay.String())
if err == nil && relaynode.IsGw && relaynode.Network == newNode.Network {
slog.Error(fmt.Sprintf("adding relayed node %s to relay %s on network %s", newNode.ID.String(), key.Relay.String(), p.Network))
newNode.IsRelayed = true
newNode.RelayedBy = key.Relay.String()
updatedRelayNode := relaynode
updatedRelayNode.RelayedNodes = append(updatedRelayNode.RelayedNodes, newNode.ID.String())
logic.UpdateRelayed(&relaynode, &updatedRelayNode)
if err := logic.UpsertNode(&updatedRelayNode); err != nil {
slog.Error("failed to update node", "nodeid", key.Relay.String())
}
if err := logic.UpsertNode(newNode); err != nil {
slog.Error("failed to update node", "nodeid", key.Relay.String())
}
} else {
slog.Error("failed to relay node. maybe specified relay node is actually not a relay? Or the relayed node is not in the same network with relay?", "err", err)
}
}
logger.Log(1, "added new node", newNode.ID.String(), "to host", h.Name)
hostactions.AddAction(models.HostUpdate{
Action: models.JoinHostToNetwork,
Host: *h,
Node: *newNode,
})
if h.IsDefault {
// make host failover
logic.CreateFailOver(*newNode)
// make host remote access gateway
logic.CreateIngressGateway(p.Network, newNode.ID.String(), models.IngressRequest{})
logic.CreateRelay(models.RelayRequest{
NodeID: newNode.ID.String(),
NetID: p.Network,
})
}
p.Delete(db.WithContext(r.Context()))
go mq.PublishPeerUpdate(false)
logic.ReturnSuccessResponseWithJson(w, r, newNode.ConvertToAPINode(), "added pending host to "+p.Network)
}
// @Summary reject pending hosts in a network
// @Router /api/v1/pending_hosts/reject/{id} [post]
// @Tags Hosts
// @Security oauth
// @Success 200 {array} models.ApiNode
// @Failure 500 {object} models.ErrorResponse
func rejectPendingHost(w http.ResponseWriter, r *http.Request) {
id := mux.Vars(r)["id"]
p := &schema.PendingHost{ID: id}
err := p.Get(db.WithContext(r.Context()))
if err != nil {
logic.ReturnErrorResponse(w, r, models.ErrorResponse{
Code: http.StatusBadRequest,
Message: err.Error(),
})
return
}
err = p.Delete(db.WithContext(r.Context()))
if err != nil {
logic.ReturnErrorResponse(w, r, models.ErrorResponse{
Code: http.StatusBadRequest,
Message: err.Error(),
})
return
}
logic.ReturnSuccessResponseWithJson(w, r, p, "deleted pending host from "+p.Network)
}

View File

@@ -701,10 +701,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
netNew := netOld
netNew.NameServers = payload.NameServers
netNew.DefaultACL = payload.DefaultACL
_, _, _, err = logic.UpdateNetwork(&netOld, &netNew)
err = logic.UpdateNetwork(&netOld, &payload)
if err != nil {
slog.Info("failed to update network", "user", r.Header.Get("user"), "err", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))

View File

@@ -175,6 +175,18 @@ func GetHostsMap() (map[string]models.Host, error) {
return currHostMap, nil
}
func DoesHostExistinTheNetworkAlready(h *models.Host, network models.NetworkID) bool {
if len(h.Nodes) > 0 {
for _, nodeID := range h.Nodes {
node, err := GetNodeByID(nodeID)
if err == nil && node.Network == network.String() {
return true
}
}
}
return false
}
// GetHost - gets a host from db given id
func GetHost(hostid string) (*models.Host, error) {
if servercfg.CacheEnabled() {

View File

@@ -629,30 +629,41 @@ func IsNetworkNameUnique(network *models.Network) (bool, error) {
return isunique, nil
}
func UpsertNetwork(network models.Network) error {
netData, err := json.Marshal(network)
if err != nil {
return err
}
err = database.Insert(network.NetID, string(netData), database.NETWORKS_TABLE_NAME)
if err != nil {
return err
}
return nil
}
// UpdateNetwork - updates a network with another network's fields
func UpdateNetwork(currentNetwork *models.Network, newNetwork *models.Network) (bool, bool, bool, error) {
func UpdateNetwork(currentNetwork *models.Network, newNetwork *models.Network) error {
if err := ValidateNetwork(newNetwork, true); err != nil {
return false, false, false, err
return err
}
if newNetwork.NetID == currentNetwork.NetID {
hasrangeupdate4 := newNetwork.AddressRange != currentNetwork.AddressRange
hasrangeupdate6 := newNetwork.AddressRange6 != currentNetwork.AddressRange6
hasholepunchupdate := newNetwork.DefaultUDPHolePunch != currentNetwork.DefaultUDPHolePunch
data, err := json.Marshal(newNetwork)
if err != nil {
return false, false, false, err
}
newNetwork.SetNetworkLastModified()
err = database.Insert(newNetwork.NetID, string(data), database.NETWORKS_TABLE_NAME)
if err == nil {
if servercfg.CacheEnabled() {
storeNetworkInCache(newNetwork.NetID, *newNetwork)
}
}
return hasrangeupdate4, hasrangeupdate6, hasholepunchupdate, err
if newNetwork.NetID != currentNetwork.NetID {
return errors.New("failed to update network " + newNetwork.NetID + ", cannot change netid.")
}
// copy values
return false, false, false, errors.New("failed to update network " + newNetwork.NetID + ", cannot change netid.")
currentNetwork.AutoJoin = newNetwork.AutoJoin
currentNetwork.DefaultACL = newNetwork.DefaultACL
currentNetwork.NameServers = newNetwork.NameServers
data, err := json.Marshal(currentNetwork)
if err != nil {
return err
}
newNetwork.SetNetworkLastModified()
err = database.Insert(currentNetwork.NetID, string(data), database.NETWORKS_TABLE_NAME)
if err == nil {
if servercfg.CacheEnabled() {
storeNetworkInCache(newNetwork.NetID, *currentNetwork)
}
}
return err
}
// GetNetwork - gets a network from database

View File

@@ -38,9 +38,20 @@ func Run() {
updateNewAcls()
logic.MigrateToGws()
migrateToEgressV1()
updateNetworks()
resync()
}
func updateNetworks() {
nets, _ := logic.GetNetworks()
for _, netI := range nets {
if netI.AutoJoin == "" {
netI.AutoJoin = "true"
logic.UpsertNetwork(netI)
}
}
}
// removes if any stale configurations from previous run.
func resync() {

View File

@@ -25,6 +25,7 @@ type Network struct {
DefaultMTU int32 `json:"defaultmtu" bson:"defaultmtu"`
DefaultACL string `json:"defaultacl" bson:"defaultacl" yaml:"defaultacl" validate:"checkyesorno"`
NameServers []string `json:"dns_nameservers"`
AutoJoin string `json:"auto_join"`
}
// SaveData - sensitive fields of a network that should be kept the same

View File

@@ -7,5 +7,6 @@ func ListModels() []interface{} {
&Egress{},
&UserAccessToken{},
&Event{},
&PendingHost{},
}
}

46
schema/pending_hosts.go Normal file
View File

@@ -0,0 +1,46 @@
package schema
import (
"context"
"time"
"github.com/gravitl/netmaker/db"
"gorm.io/datatypes"
)
type PendingHost struct {
ID string `gorm:"id" json:"id"`
HostID string `gorm:"host_id" json:"host_id"`
Hostname string `gorm:"host_name" json:"host_name"`
Network string `gorm:"network" json:"network"`
PublicKey string `gorm:"public_key" json:"public_key"`
EnrollmentKey datatypes.JSON `gorm:"enrollment_key_id" json:"enrollment_key_id"`
OS string `gorm:"os" json:"os"`
Version string `gorm:"version" json:"version"`
Location string `gorm:"location" json:"location"` // Format: "lat,lon"
RequestedAt time.Time `gorm:"requested_at" json:"requested_at"`
}
func (p *PendingHost) Get(ctx context.Context) error {
return db.FromContext(ctx).Model(&PendingHost{}).First(&p).Where("id = ?", p.ID).Error
}
func (p *PendingHost) Create(ctx context.Context) error {
return db.FromContext(ctx).Model(&PendingHost{}).Create(&p).Error
}
func (p *PendingHost) List(ctx context.Context) (pendingHosts []PendingHost, err error) {
err = db.FromContext(ctx).Model(&PendingHost{}).Find(&pendingHosts).Error
return
}
func (p *PendingHost) Delete(ctx context.Context) error {
return db.FromContext(ctx).Model(&PendingHost{}).Where("id = ?", p.ID).Delete(&p).Error
}
func (p *PendingHost) CheckIfPendingHostExists(ctx context.Context) error {
return db.FromContext(ctx).Model(&PendingHost{}).Where("host_id = ? AND network = ?", p.HostID, p.Network).First(&p).Error
}
func (p *PendingHost) DeleteAllPendingHosts(ctx context.Context) error {
return db.FromContext(ctx).Model(&PendingHost{}).Where("host_id = ?", p.HostID).Delete(&p).Error
}

View File

@@ -34,7 +34,7 @@ EXPORTER_API_PORT=8085
CORS_ALLOWED_ORIGIN=*
# Show keys permanently in UI (until deleted) as opposed to 1-time display.
DISPLAY_KEYS=on
# Database to use - sqlite, postgres, or rqlite
# Database to use - sqlite, postgres
DATABASE=sqlite
# The address of the mq server. If running from docker compose it will be "mq". Otherwise, need to input address.
# If using "host networking", it will find and detect the IP of the mq container.