diff --git a/controllers/node.go b/controllers/node.go index d2e6b80a..d92a08c4 100644 --- a/controllers/node.go +++ b/controllers/node.go @@ -433,7 +433,7 @@ func uncordonNode(w http.ResponseWriter, r *http.Request) { if err := mq.NodeUpdate(&node); err != nil { logger.Log(1, "error publishing node update", err.Error()) } - if err := mq.UpdatePeers(&node); err != nil { + if err := mq.PublishPeerUpdate(&node); err != nil { logger.Log(1, "error publishing peer update ", err.Error()) } }() @@ -465,7 +465,7 @@ func createEgressGateway(w http.ResponseWriter, r *http.Request) { if err := mq.NodeUpdate(&node); err != nil { logger.Log(1, "error publishing node update", err.Error()) } - if err := mq.UpdatePeers(&node); err != nil { + if err := mq.PublishPeerUpdate(&node); err != nil { logger.Log(1, "error publishing peer update "+err.Error()) } }() @@ -491,7 +491,7 @@ func deleteEgressGateway(w http.ResponseWriter, r *http.Request) { if err := mq.NodeUpdate(&node); err != nil { logger.Log(1, "error publishing node update", err.Error()) } - if err := mq.UpdatePeers(&node); err != nil { + if err := mq.PublishPeerUpdate(&node); err != nil { logger.Log(1, "error publishing peer update ", err.Error()) } }() @@ -516,7 +516,7 @@ func createIngressGateway(w http.ResponseWriter, r *http.Request) { if err := mq.NodeUpdate(&node); err != nil { logger.Log(1, "error publishing node update", err.Error()) } - if err := mq.UpdatePeers(&node); err != nil { + if err := mq.PublishPeerUpdate(&node); err != nil { logger.Log(1, "error publishing peer update ", err.Error()) } }() @@ -538,7 +538,7 @@ func deleteIngressGateway(w http.ResponseWriter, r *http.Request) { if err := mq.NodeUpdate(&node); err != nil { logger.Log(1, "error publishing node update", err.Error()) } - if err := mq.UpdatePeers(&node); err != nil { + if err := mq.PublishPeerUpdate(&node); err != nil { logger.Log(1, "error publishing peer update ", err.Error()) } }() @@ -617,7 +617,7 @@ func updateNode(w http.ResponseWriter, r *http.Request) { logger.Log(1, "error publishing node update", err.Error()) } if logic.ShouldPeersUpdate(&node, &newNode) { - if err := mq.UpdatePeers(&newNode); err != nil { + if err := mq.PublishPeerUpdate(&newNode); err != nil { logger.Log(1, "error publishing peer update after node update", err.Error()) } } diff --git a/controllers/node_grpc.go b/controllers/node_grpc.go index 8edebf5f..f5406c3c 100644 --- a/controllers/node_grpc.go +++ b/controllers/node_grpc.go @@ -2,6 +2,8 @@ package controller import ( "context" + "crypto/rand" + "crypto/rsa" "encoding/json" "errors" "strings" @@ -11,6 +13,7 @@ import ( "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" + "github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/servercfg" ) @@ -75,7 +78,21 @@ func (s *NodeServiceServer) CreateNode(ctx context.Context, req *nodepb.Object) Address: server.Address, } } + // TODO consolidate functionality around files node.NetworkSettings.DefaultServerAddrs = serverAddrs + var rsaPrivKey, keyErr = rsa.GenerateKey(rand.Reader, ncutils.KEY_SIZE) + if keyErr != nil { + return nil, keyErr + } + err = logic.StoreTrafficKey(node.ID, (*rsaPrivKey)) + if err != nil { + return nil, err + } + + node.TrafficKeys = models.TrafficKeys{ + Mine: node.TrafficKeys.Mine, + Server: rsaPrivKey.PublicKey, + } err = logic.CreateNode(&node) if err != nil { @@ -103,7 +120,7 @@ func (s *NodeServiceServer) CreateNode(ctx context.Context, req *nodepb.Object) logger.Log(0, "new node,", node.Name, ", added on network,"+node.Network) // notify other nodes on network of new peer go func() { - if err := mq.UpdatePeers(&node); err != nil { + if err := mq.PublishPeerUpdate(&node); err != nil { logger.Log(0, "failed to inform peers of new node ", err.Error()) } }() @@ -170,7 +187,7 @@ func (s *NodeServiceServer) DeleteNode(ctx context.Context, req *nodepb.Object) } // notify other nodes on network of deleted peer go func() { - if err := mq.UpdatePeers(&node); err != nil { + if err := mq.PublishPeerUpdate(&node); err != nil { logger.Log(0, "failed to inform peers of deleted node ", err.Error()) } }() diff --git a/controllers/relay.go b/controllers/relay.go index 4b9fda8c..d5e784cf 100644 --- a/controllers/relay.go +++ b/controllers/relay.go @@ -34,7 +34,7 @@ func createRelay(w http.ResponseWriter, r *http.Request) { if err := mq.NodeUpdate(&node); err != nil { logger.Log(1, "error publishing node update", err.Error()) } - if err := mq.UpdatePeers(&node); err != nil { + if err := mq.PublishPeerUpdate(&node); err != nil { logger.Log(1, "error publishing peer update ", err.Error()) } }() @@ -60,7 +60,7 @@ func deleteRelay(w http.ResponseWriter, r *http.Request) { if err := mq.NodeUpdate(&node); err != nil { logger.Log(1, "error publishing node update", err.Error()) } - if err := mq.UpdatePeers(&node); err != nil { + if err := mq.PublishPeerUpdate(&node); err != nil { logger.Log(1, "error publishing peer update ", err.Error()) } }() diff --git a/database/database.go b/database/database.go index dece51b8..ab0a54de 100644 --- a/database/database.go +++ b/database/database.go @@ -36,13 +36,16 @@ const INT_CLIENTS_TABLE_NAME = "intclients" // PEERS_TABLE_NAME - peers table const PEERS_TABLE_NAME = "peers" -// SERVERCONF_TABLE_NAME +// SERVERCONF_TABLE_NAME - stores server conf const SERVERCONF_TABLE_NAME = "serverconf" -// SERVER_UUID_TABLE_NAME +// SERVER_UUID_TABLE_NAME - stores const SERVER_UUID_TABLE_NAME = "serveruuid" -// SERVER_UUID_RECORD_KEY +// TRAFFIC_TABLE_NAME - stores stuff to control traffic +const TRAFFIC_TABLE_NAME = "traffic-table" + +// SERVER_UUID_RECORD_KEY - telemetry thing const SERVER_UUID_RECORD_KEY = "serveruuid" // DATABASE_FILENAME - database file name @@ -130,6 +133,7 @@ func createTables() { createTable(SERVERCONF_TABLE_NAME) createTable(SERVER_UUID_TABLE_NAME) createTable(GENERATED_TABLE_NAME) + createTable(TRAFFIC_TABLE_NAME) } func createTable(tableName string) error { diff --git a/database/sqlite.go b/database/sqlite.go index bfea8e06..1b8af989 100644 --- a/database/sqlite.go +++ b/database/sqlite.go @@ -30,7 +30,7 @@ var SQLITE_FUNCTIONS = map[string]interface{}{ func initSqliteDB() error { // == create db file if not present == if _, err := os.Stat("data"); os.IsNotExist(err) { - os.Mkdir("data", 0744) + os.Mkdir("data", 0700) } dbFilePath := filepath.Join("data", dbFilename) if _, err := os.Stat(dbFilePath); os.IsNotExist(err) { diff --git a/logic/traffic.go b/logic/traffic.go new file mode 100644 index 00000000..4df13e8a --- /dev/null +++ b/logic/traffic.go @@ -0,0 +1,36 @@ +package logic + +import ( + "crypto/rsa" + "encoding/json" + + "github.com/gravitl/netmaker/database" +) + +type trafficKey struct { + Key rsa.PrivateKey `json:"key" bson:"key"` +} + +// RetrieveTrafficKey - retrieves key based on node +func RetrieveTrafficKey(nodeid string) (rsa.PrivateKey, error) { + var record, err = database.FetchRecord(database.TRAFFIC_TABLE_NAME, nodeid) + if err != nil { + return rsa.PrivateKey{}, err + } + var result trafficKey + if err = json.Unmarshal([]byte(record), &result); err != nil { + return rsa.PrivateKey{}, err + } + return result.Key, nil +} + +// StoreTrafficKey - stores key based on node +func StoreTrafficKey(nodeid string, key rsa.PrivateKey) error { + var data, err = json.Marshal(trafficKey{ + Key: key, + }) + if err != nil { + return err + } + return database.Insert(nodeid, string(data), database.TRAFFIC_TABLE_NAME) +} diff --git a/models/node.go b/models/node.go index 3cd629b2..77a06884 100644 --- a/models/node.go +++ b/models/node.go @@ -48,34 +48,35 @@ type Node struct { LastCheckIn int64 `json:"lastcheckin" bson:"lastcheckin" yaml:"lastcheckin"` MacAddress string `json:"macaddress" bson:"macaddress" yaml:"macaddress"` // checkin interval is depreciated at the network level. Set on server with CHECKIN_INTERVAL - CheckInInterval int32 `json:"checkininterval" bson:"checkininterval" yaml:"checkininterval"` - Password string `json:"password" bson:"password" yaml:"password" validate:"required,min=6"` - Network string `json:"network" bson:"network" yaml:"network" validate:"network_exists"` - IsRelayed string `json:"isrelayed" bson:"isrelayed" yaml:"isrelayed"` - IsPending string `json:"ispending" bson:"ispending" yaml:"ispending"` - IsRelay string `json:"isrelay" bson:"isrelay" yaml:"isrelay" validate:"checkyesorno"` - IsDocker string `json:"isdocker" bson:"isdocker" yaml:"isdocker" validate:"checkyesorno"` - IsK8S string `json:"isk8s" bson:"isk8s" yaml:"isk8s" validate:"checkyesorno"` - IsEgressGateway string `json:"isegressgateway" bson:"isegressgateway" yaml:"isegressgateway"` - IsIngressGateway string `json:"isingressgateway" bson:"isingressgateway" yaml:"isingressgateway"` - EgressGatewayRanges []string `json:"egressgatewayranges" bson:"egressgatewayranges" yaml:"egressgatewayranges"` - RelayAddrs []string `json:"relayaddrs" bson:"relayaddrs" yaml:"relayaddrs"` - IngressGatewayRange string `json:"ingressgatewayrange" bson:"ingressgatewayrange" yaml:"ingressgatewayrange"` - IsStatic string `json:"isstatic" bson:"isstatic" yaml:"isstatic" validate:"checkyesorno"` - UDPHolePunch string `json:"udpholepunch" bson:"udpholepunch" yaml:"udpholepunch" validate:"checkyesorno"` - PullChanges string `json:"pullchanges" bson:"pullchanges" yaml:"pullchanges" validate:"checkyesorno"` - DNSOn string `json:"dnson" bson:"dnson" yaml:"dnson" validate:"checkyesorno"` - IsDualStack string `json:"isdualstack" bson:"isdualstack" yaml:"isdualstack" validate:"checkyesorno"` - IsServer string `json:"isserver" bson:"isserver" yaml:"isserver" validate:"checkyesorno"` - Action string `json:"action" bson:"action" yaml:"action"` - IsLocal string `json:"islocal" bson:"islocal" yaml:"islocal" validate:"checkyesorno"` - LocalRange string `json:"localrange" bson:"localrange" yaml:"localrange"` - Roaming string `json:"roaming" bson:"roaming" yaml:"roaming" validate:"checkyesorno"` - IPForwarding string `json:"ipforwarding" bson:"ipforwarding" yaml:"ipforwarding" validate:"checkyesorno"` - OS string `json:"os" bson:"os" yaml:"os"` - MTU int32 `json:"mtu" bson:"mtu" yaml:"mtu"` - Version string `json:"version" bson:"version" yaml:"version"` - ExcludedAddrs []string `json:"excludedaddrs" bson:"excludedaddrs" yaml:"excludedaddrs"` + CheckInInterval int32 `json:"checkininterval" bson:"checkininterval" yaml:"checkininterval"` + Password string `json:"password" bson:"password" yaml:"password" validate:"required,min=6"` + Network string `json:"network" bson:"network" yaml:"network" validate:"network_exists"` + IsRelayed string `json:"isrelayed" bson:"isrelayed" yaml:"isrelayed"` + IsPending string `json:"ispending" bson:"ispending" yaml:"ispending"` + IsRelay string `json:"isrelay" bson:"isrelay" yaml:"isrelay" validate:"checkyesorno"` + IsDocker string `json:"isdocker" bson:"isdocker" yaml:"isdocker" validate:"checkyesorno"` + IsK8S string `json:"isk8s" bson:"isk8s" yaml:"isk8s" validate:"checkyesorno"` + IsEgressGateway string `json:"isegressgateway" bson:"isegressgateway" yaml:"isegressgateway"` + IsIngressGateway string `json:"isingressgateway" bson:"isingressgateway" yaml:"isingressgateway"` + EgressGatewayRanges []string `json:"egressgatewayranges" bson:"egressgatewayranges" yaml:"egressgatewayranges"` + RelayAddrs []string `json:"relayaddrs" bson:"relayaddrs" yaml:"relayaddrs"` + IngressGatewayRange string `json:"ingressgatewayrange" bson:"ingressgatewayrange" yaml:"ingressgatewayrange"` + IsStatic string `json:"isstatic" bson:"isstatic" yaml:"isstatic" validate:"checkyesorno"` + UDPHolePunch string `json:"udpholepunch" bson:"udpholepunch" yaml:"udpholepunch" validate:"checkyesorno"` + PullChanges string `json:"pullchanges" bson:"pullchanges" yaml:"pullchanges" validate:"checkyesorno"` + DNSOn string `json:"dnson" bson:"dnson" yaml:"dnson" validate:"checkyesorno"` + IsDualStack string `json:"isdualstack" bson:"isdualstack" yaml:"isdualstack" validate:"checkyesorno"` + IsServer string `json:"isserver" bson:"isserver" yaml:"isserver" validate:"checkyesorno"` + Action string `json:"action" bson:"action" yaml:"action"` + IsLocal string `json:"islocal" bson:"islocal" yaml:"islocal" validate:"checkyesorno"` + LocalRange string `json:"localrange" bson:"localrange" yaml:"localrange"` + Roaming string `json:"roaming" bson:"roaming" yaml:"roaming" validate:"checkyesorno"` + IPForwarding string `json:"ipforwarding" bson:"ipforwarding" yaml:"ipforwarding" validate:"checkyesorno"` + OS string `json:"os" bson:"os" yaml:"os"` + MTU int32 `json:"mtu" bson:"mtu" yaml:"mtu"` + Version string `json:"version" bson:"version" yaml:"version"` + ExcludedAddrs []string `json:"excludedaddrs" bson:"excludedaddrs" yaml:"excludedaddrs"` + TrafficKeys TrafficKeys `json:"traffickeys" bson:"traffickeys" yaml:"traffickeys"` } // NodesArray - used for node sorting diff --git a/models/structs.go b/models/structs.go index 3020c1bd..edea44bc 100644 --- a/models/structs.go +++ b/models/structs.go @@ -1,6 +1,10 @@ package models -import jwt "github.com/golang-jwt/jwt/v4" +import ( + "crypto/rsa" + + jwt "github.com/golang-jwt/jwt/v4" +) const PLACEHOLDER_KEY_TEXT = "ACCESS_KEY" const PLACEHOLDER_TOKEN_TEXT = "ACCESS_TOKEN" @@ -175,3 +179,9 @@ type ServerAddr struct { IsLeader bool `json:"isleader" bson:"isleader" yaml:"isleader"` Address string `json:"address" bson:"address" yaml:"address"` } + +// TrafficKeys - struct to hold public keys +type TrafficKeys struct { + Mine rsa.PublicKey `json:"mine" bson:"mine" yaml:"mine"` + Server rsa.PublicKey `json:"server" bson:"server" yaml:"server"` +} diff --git a/mq/mq.go b/mq/mq.go index a31d8254..63cf71a2 100644 --- a/mq/mq.go +++ b/mq/mq.go @@ -3,6 +3,7 @@ package mq import ( "encoding/json" "errors" + "fmt" "log" "strings" @@ -40,6 +41,11 @@ var Ping mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) { logger.Log(0, record) return } + _, decryptErr := decryptMsg(node.ID, msg.Payload()) + if decryptErr != nil { + logger.Log(0, "error updating node ", node.ID, err.Error()) + return + } node.SetLastCheckIn() if err := logic.UpdateNode(&node, &node); err != nil { logger.Log(0, "error updating node ", err.Error()) @@ -58,22 +64,28 @@ var UpdateNode mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) logger.Log(1, "error getting node.ID sent on ", msg.Topic(), err.Error()) return } + currentNode, err := logic.GetNodeByID(id) + if err != nil { + logger.Log(1, "error getting node ", id, err.Error()) + return + } + decrypted, decryptErr := decryptMsg(id, msg.Payload()) + if decryptErr != nil { + logger.Log(1, "failed to decrypt message for node ", id, decryptErr.Error()) + return + } logger.Log(1, "Update Node Handler", id) var newNode models.Node - if err := json.Unmarshal(msg.Payload(), &newNode); err != nil { + if err := json.Unmarshal(decrypted, &newNode); err != nil { logger.Log(1, "error unmarshaling payload ", err.Error()) return } - currentNode, err := logic.GetNodeByID(newNode.ID) - if err != nil { - logger.Log(1, "error getting node ", newNode.ID, err.Error()) - return - } + if err := logic.UpdateNode(¤tNode, &newNode); err != nil { logger.Log(1, "error saving node", err.Error()) } if logic.ShouldPeersUpdate(¤tNode, &newNode) { - if err := PublishPeerUpdate(client, &newNode); err != nil { + if err := PublishPeerUpdate(&newNode); err != nil { logger.Log(1, "error publishing peer update ", err.Error()) return } @@ -82,7 +94,10 @@ var UpdateNode mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) } // PublishPeerUpdate --- deterines and publishes a peer update to all the peers of a node -func PublishPeerUpdate(client mqtt.Client, newNode *models.Node) error { +func PublishPeerUpdate(newNode *models.Node) error { + if !servercfg.IsMessageQueueBackend() { + return nil + } networkNodes, err := logic.GetNetworkNodes(newNode.Network) if err != nil { logger.Log(1, "err getting Network Nodes", err.Error()) @@ -96,12 +111,11 @@ func PublishPeerUpdate(client mqtt.Client, newNode *models.Node) error { } data, err := json.Marshal(&peerUpdate) if err != nil { - logger.Log(2, "error marshaling peer update ", err.Error()) - return err + logger.Log(2, "error marshaling peer update for node", node.ID, err.Error()) + continue } - if token := client.Publish("update/peers/"+node.ID, 0, false, data); token.Wait() && token.Error() != nil { - logger.Log(2, "error publishing peer update to peer ", node.ID, token.Error().Error()) - return err + if err = publish(node.ID, fmt.Sprintf("peers/%s/%s", node.Network, node.ID), data); err != nil { + logger.Log(1, "failed to publish peer update for node", node.ID) } } return nil @@ -118,28 +132,19 @@ func GetID(topic string) (string, error) { return parts[count-1], nil } -// UpdateNode -- publishes a node update +// NodeUpdate -- publishes a node update func NodeUpdate(node *models.Node) error { + if !servercfg.IsMessageQueueBackend() { + return nil + } logger.Log(3, "publishing node update to "+node.Name) - client := SetupMQTT() - defer client.Disconnect(250) data, err := json.Marshal(node) if err != nil { logger.Log(2, "error marshalling node update ", err.Error()) return err } - if token := client.Publish("update/"+node.ID, 0, false, data); token.Wait() && token.Error() != nil { - logger.Log(2, "error publishing peer update to peer ", node.ID, token.Error().Error()) - return err - } - return nil -} - -// UpdatePeers -- publishes a peer update to all the peers of a node -func UpdatePeers(node *models.Node) error { - client := SetupMQTT() - defer client.Disconnect(250) - if err := PublishPeerUpdate(client, node); err != nil { + if err = publish(node.ID, fmt.Sprintf("update/%s/%s", node.Network, node.ID), data); err != nil { + logger.Log(2, "error publishing node update to peer ", node.ID, err.Error()) return err } return nil diff --git a/mq/util.go b/mq/util.go new file mode 100644 index 00000000..07e1c326 --- /dev/null +++ b/mq/util.go @@ -0,0 +1,39 @@ +package mq + +import ( + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/netclient/ncutils" +) + +func decryptMsg(nodeid string, msg []byte) ([]byte, error) { + trafficKey, trafficErr := logic.RetrieveTrafficKey(nodeid) + if trafficErr != nil { + return nil, trafficErr + } + return ncutils.DecryptWithPrivateKey(msg, &trafficKey), nil +} + +func encrypt(nodeid string, dest string, msg []byte) ([]byte, error) { + var node, err = logic.GetNodeByID(nodeid) + if err != nil { + return nil, err + } + encrypted, encryptErr := ncutils.EncryptWithPublicKey(msg, &node.TrafficKeys.Mine) + if encryptErr != nil { + return nil, encryptErr + } + return encrypted, nil +} + +func publish(nodeid string, dest string, msg []byte) error { + client := SetupMQTT() + defer client.Disconnect(250) + encrypted, encryptErr := encrypt(nodeid, dest, msg) + if encryptErr != nil { + return encryptErr + } + if token := client.Publish(dest, 0, false, encrypted); token.Wait() && token.Error() != nil { + return token.Error() + } + return nil +} diff --git a/netclient/auth/auth.go b/netclient/auth/auth.go index d17b82e6..18b2477f 100644 --- a/netclient/auth/auth.go +++ b/netclient/auth/auth.go @@ -73,7 +73,7 @@ func AutoLogin(client nodepb.NodeServiceClient, network string) error { return err } tokenstring := []byte(res.Data) - err = os.WriteFile(home+"nettoken-"+network, tokenstring, 0644) // TODO: Proper permissions? + err = os.WriteFile(home+"nettoken-"+network, tokenstring, 0600) // TODO: Proper permissions? if err != nil { return err } @@ -83,8 +83,7 @@ func AutoLogin(client nodepb.NodeServiceClient, network string) error { // StoreSecret - stores auth secret locally func StoreSecret(key string, network string) error { d1 := []byte(key) - err := os.WriteFile(ncutils.GetNetclientPathSpecific()+"secret-"+network, d1, 0644) - return err + return os.WriteFile(ncutils.GetNetclientPathSpecific()+"secret-"+network, d1, 0600) } // RetrieveSecret - fetches secret locally @@ -93,6 +92,17 @@ func RetrieveSecret(network string) (string, error) { return string(dat), err } +// StoreTrafficKey - stores traffic key +func StoreTrafficKey(key string, network string) error { + return os.WriteFile(ncutils.GetNetclientPathSpecific()+"traffic-"+network, []byte(key), 0600) +} + +// RetrieveTrafficKey - reads traffic file locally +func RetrieveTrafficKey(network string) (string, error) { + dat, err := os.ReadFile(ncutils.GetNetclientPathSpecific() + "traffic-" + network) + return string(dat), err +} + // Configuraion - struct for mac and pass type Configuration struct { MacAddress string diff --git a/netclient/functions/daemon.go b/netclient/functions/daemon.go index a8d31bb6..b122dbb9 100644 --- a/netclient/functions/daemon.go +++ b/netclient/functions/daemon.go @@ -2,18 +2,21 @@ package functions import ( "context" + "crypto/rsa" "encoding/json" "fmt" "log" "os" "os/signal" "runtime" + "strings" "sync" "syscall" "time" mqtt "github.com/eclipse/paho.mqtt.golang" "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/netclient/auth" "github.com/gravitl/netmaker/netclient/config" "github.com/gravitl/netmaker/netclient/local" "github.com/gravitl/netmaker/netclient/ncutils" @@ -88,17 +91,17 @@ func MessageQueue(ctx context.Context, network string) { } ncutils.Log("subscribed to all topics for debugging purposes") } - if token := client.Subscribe("update/"+cfg.Node.ID, 0, mqtt.MessageHandler(NodeUpdate)); token.Wait() && token.Error() != nil { + if token := client.Subscribe(fmt.Sprintf("update/%s/%s", cfg.Node.Network, cfg.Node.ID), 0, mqtt.MessageHandler(NodeUpdate)); token.Wait() && token.Error() != nil { log.Fatal(token.Error()) } if cfg.DebugOn { - ncutils.Log("subscribed to node updates for node " + cfg.Node.Name + " update/" + cfg.Node.ID) + ncutils.Log(fmt.Sprintf("subscribed to node updates for node %s update/%s/%s \n", cfg.Node.Name, cfg.Node.Network, cfg.Node.ID)) } - if token := client.Subscribe("update/peers/"+cfg.Node.ID, 0, mqtt.MessageHandler(UpdatePeers)); token.Wait() && token.Error() != nil { + if token := client.Subscribe(fmt.Sprintf("peers/%s/%s", cfg.Node.Network, cfg.Node.ID), 0, mqtt.MessageHandler(UpdatePeers)); token.Wait() && token.Error() != nil { log.Fatal(token.Error()) } if cfg.DebugOn { - ncutils.Log("subscribed to node updates for node " + cfg.Node.Name + " update/peers/" + cfg.Node.ID) + ncutils.Log(fmt.Sprintf("subscribed to peer updates for node %s peers/%s/%s \n", cfg.Node.Name, cfg.Node.Network, cfg.Node.ID)) } defer client.Disconnect(250) go Checkin(ctx, &cfg, network) @@ -119,20 +122,27 @@ func NodeUpdate(client mqtt.Client, msg mqtt.Message) { go func() { var newNode models.Node var cfg config.ClientConfig - err := json.Unmarshal(msg.Payload(), &newNode) + var network = parseNetworkFromTopic(msg.Topic()) + cfg.Network = network + cfg.ReadConfig() + + data, dataErr := decryptMsg(&cfg, msg.Payload()) + if dataErr != nil { + return + } + err := json.Unmarshal(data, &newNode) if err != nil { ncutils.Log("error unmarshalling node update data" + err.Error()) return } + ncutils.Log("received message to update node " + newNode.Name) // see if cache hit, if so skip var currentMessage = read(newNode.Network, lastNodeUpdate) - if currentMessage == string(msg.Payload()) { + if currentMessage == string(data) { return } - insert(newNode.Network, lastNodeUpdate, string(msg.Payload())) - cfg.Network = newNode.Network - cfg.ReadConfig() + insert(newNode.Network, lastNodeUpdate, string(data)) //check if interface name has changed if so delete. if cfg.Node.Interface != newNode.Interface { if err = wireguard.RemoveConf(cfg.Node.Interface, true); err != nil { @@ -201,21 +211,28 @@ func NodeUpdate(client mqtt.Client, msg mqtt.Message) { func UpdatePeers(client mqtt.Client, msg mqtt.Message) { go func() { var peerUpdate models.PeerUpdate - err := json.Unmarshal(msg.Payload(), &peerUpdate) + var network = parseNetworkFromTopic(msg.Topic()) + var cfg = config.ClientConfig{} + cfg.Network = network + cfg.ReadConfig() + + data, dataErr := decryptMsg(&cfg, msg.Payload()) + if dataErr != nil { + return + } + err := json.Unmarshal(data, &peerUpdate) if err != nil { ncutils.Log("error unmarshalling peer data") return } // see if cache hit, if so skip var currentMessage = read(peerUpdate.Network, lastPeerUpdate) - if currentMessage == string(msg.Payload()) { + if currentMessage == string(data) { return } - insert(peerUpdate.Network, lastPeerUpdate, string(msg.Payload())) + insert(peerUpdate.Network, lastPeerUpdate, string(data)) ncutils.Log("update peer handler") - var cfg config.ClientConfig - cfg.Network = peerUpdate.Network - cfg.ReadConfig() + var shouldReSub = shouldResub(cfg.Node.NetworkSettings.DefaultServerAddrs, peerUpdate.ServerAddrs) if shouldReSub { Resubscribe(client, &cfg) @@ -335,24 +352,49 @@ func PublishNodeUpdate(cfg *config.ClientConfig) { if err := config.Write(cfg, cfg.Network); err != nil { ncutils.Log("error saving configuration" + err.Error()) } - client := SetupMQTT(cfg) data, err := json.Marshal(cfg.Node) if err != nil { ncutils.Log("error marshling node update " + err.Error()) } - if token := client.Publish("update/"+cfg.Node.ID, 0, false, data); token.Wait() && token.Error() != nil { - ncutils.Log("error publishing endpoint update " + token.Error().Error()) + if err = publish(cfg, fmt.Sprintf("update/%s", cfg.Node.ID), data); err != nil { + ncutils.Log(fmt.Sprintf("error publishing endpoint update, %v \n", err)) } - client.Disconnect(250) } // Hello -- ping the broker to let server know node is alive and doing fine func Hello(cfg *config.ClientConfig, network string) { - client := SetupMQTT(cfg) - if token := client.Publish("ping/"+cfg.Node.ID, 2, false, "hello world!"); token.Wait() && token.Error() != nil { - ncutils.Log("error publishing ping " + token.Error().Error()) + if err := publish(cfg, fmt.Sprintf("ping/%s", cfg.Node.ID), []byte("hello world!")); err != nil { + ncutils.Log(fmt.Sprintf("error publishing ping, %v \n", err)) } - client.Disconnect(250) +} + +func publish(cfg *config.ClientConfig, dest string, msg []byte) error { + client := SetupMQTT(cfg) + defer client.Disconnect(250) + encrypted, encryptErr := ncutils.EncryptWithPublicKey(msg, &cfg.Node.TrafficKeys.Server) + if encryptErr != nil { + return encryptErr + } + if token := client.Publish(dest, 0, false, encrypted); token.Wait() && token.Error() != nil { + return token.Error() + } + return nil +} + +func parseNetworkFromTopic(topic string) string { + return strings.Split(topic, "/")[1] +} + +func decryptMsg(cfg *config.ClientConfig, msg []byte) ([]byte, error) { + diskKey, trafficErr := auth.RetrieveTrafficKey(cfg.Node.Network) + if trafficErr != nil { + return nil, trafficErr + } + var trafficKey rsa.PrivateKey + if err := json.Unmarshal([]byte(diskKey), &trafficKey); err != nil { + return nil, err + } + return ncutils.DecryptWithPrivateKey(msg, &trafficKey), nil } func shouldResub(currentServers, newServers []models.ServerAddr) bool { diff --git a/netclient/functions/join.go b/netclient/functions/join.go index c8451f42..5859bb93 100644 --- a/netclient/functions/join.go +++ b/netclient/functions/join.go @@ -2,6 +2,8 @@ package functions import ( "context" + "crypto/rand" + "crypto/rsa" "encoding/json" "errors" "fmt" @@ -30,22 +32,29 @@ func JoinNetwork(cfg config.ClientConfig, privateKey string) error { } var err error - if cfg.Node.IsServer != "yes" { - if local.HasNetwork(cfg.Network) { - err := errors.New("ALREADY_INSTALLED. Netclient appears to already be installed for " + cfg.Network + ". To re-install, please remove by executing 'sudo netclient leave -n " + cfg.Network + "'. Then re-run the install command.") - return err - } - - err = config.Write(&cfg, cfg.Network) - if err != nil { - return err - } - if cfg.Node.Password == "" { - cfg.Node.Password = ncutils.GenPass() - } - auth.StoreSecret(cfg.Node.Password, cfg.Node.Network) + if local.HasNetwork(cfg.Network) { + err := errors.New("ALREADY_INSTALLED. Netclient appears to already be installed for " + cfg.Network + ". To re-install, please remove by executing 'sudo netclient leave -n " + cfg.Network + "'. Then re-run the install command.") + return err } + err = config.Write(&cfg, cfg.Network) + if err != nil { + return err + } + if cfg.Node.Password == "" { + cfg.Node.Password = ncutils.GenPass() + } + var rsaPrivKey, errGen = rsa.GenerateKey(rand.Reader, ncutils.KEY_SIZE) + if errGen != nil { + return errGen + } + auth.StoreSecret(cfg.Node.Password, cfg.Node.Network) + var keyData, errKeyData = json.Marshal(&rsaPrivKey) + if errKeyData != nil { + return errKeyData + } + auth.StoreTrafficKey(string(keyData), cfg.Node.Network) + if cfg.Node.LocalRange != "" && cfg.Node.LocalAddress == "" { log.Println("local vpn, getting local address from range: " + cfg.Node.LocalRange) cfg.Node.LocalAddress = getLocalIP(cfg.Node) @@ -122,6 +131,10 @@ func JoinNetwork(cfg config.ClientConfig, privateKey string) error { Endpoint: cfg.Node.Endpoint, SaveConfig: cfg.Node.SaveConfig, UDPHolePunch: cfg.Node.UDPHolePunch, + TrafficKeys: models.TrafficKeys{ + Mine: rsaPrivKey.PublicKey, + Server: rsa.PublicKey{}, + }, } ncutils.Log("joining " + cfg.Network + " at " + cfg.Server.GRPCAddress) diff --git a/netclient/ncutils/netclientutils.go b/netclient/ncutils/netclientutils.go index 0f158229..339ba646 100644 --- a/netclient/ncutils/netclientutils.go +++ b/netclient/ncutils/netclientutils.go @@ -1,6 +1,9 @@ package ncutils import ( + crand "crypto/rand" + "crypto/rsa" + "crypto/sha512" "crypto/tls" "errors" "fmt" @@ -51,6 +54,9 @@ const NETCLIENT_DEFAULT_PORT = 51821 // DEFAULT_GC_PERCENT - garbage collection percent const DEFAULT_GC_PERCENT = 10 +// KEY_SIZE = ideal length for keys +const KEY_SIZE = 64 + // Log - logs a message func Log(message string) { log.SetFlags(log.Flags() &^ (log.Llongfile | log.Lshortfile)) @@ -543,3 +549,27 @@ func ServerAddrSliceContains(slice []models.ServerAddr, item models.ServerAddr) } return false } + +// EncryptWithPublicKey encrypts data with public key +func EncryptWithPublicKey(msg []byte, pub *rsa.PublicKey) ([]byte, error) { + if pub == nil { + return nil, errors.New("invalid public key when decrypting") + } + log.Printf("pub key size: %d \n", pub.Size()) + hash := sha512.New() + ciphertext, err := rsa.EncryptOAEP(hash, crand.Reader, pub, msg, nil) + if err != nil { + return nil, err + } + return ciphertext, nil +} + +// DecryptWithPrivateKey decrypts data with private key +func DecryptWithPrivateKey(ciphertext []byte, priv *rsa.PrivateKey) []byte { + hash := sha512.New() + plaintext, err := rsa.DecryptOAEP(hash, crand.Reader, priv, ciphertext, nil) + if err != nil { + return nil + } + return plaintext +}