[NET-404] Run in limited mode when ee checks fail (#2474)

* Add limited http handlers functionality to rest handler

* Export ee.errValidation (ee.ErrValidation)

* Export a fatal error handled by the hook manager

* Export a new status variable for unlicensed server

* Mark server as unlicensed when ee checks fail

* Handle license validation failures with a (re)boot in a limited state

* Revert "Export a fatal error handled by the hook manager"

This reverts commit 069c21974a8d36e889c73ad78023448d787d62a5.

* Revert "Export ee.errValidation (ee.ErrValidation)"

This reverts commit 59dbab8c79773ca5d879f28cbaf53f3dd4297b9b.

* Revert "Add limited http handlers functionality to rest handler"

This reverts commit e2f1f28facaca54713db76a588839cd2733cf673.

* Revert "Handle license validation failures with a (re)boot in a limited state"

This reverts commit 58cfbbaf522a1345aac1fa67964ebff0a6d60cd8.

* Revert "Mark server as unlicensed when ee checks fail"

This reverts commit 77c6dbdd3c9cfa6e7d6becedef6251e8617ae367.

* Handle license validation failures with a middleware

* Forbid responses if unlicensed ee and not in status api

* Remove unused func
This commit is contained in:
Gabriel de Souza Seibel
2023-08-03 03:46:58 -03:00
committed by GitHub
parent a021e2659e
commit 922e7dbf2c
7 changed files with 90 additions and 43 deletions

View File

@@ -14,6 +14,9 @@ import (
"github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/servercfg"
) )
// HttpMiddlewares - middleware functions for REST interactions
var HttpMiddlewares []mux.MiddlewareFunc
// HttpHandlers - handler functions for REST interactions // HttpHandlers - handler functions for REST interactions
var HttpHandlers = []interface{}{ var HttpHandlers = []interface{}{
nodeHandlers, nodeHandlers,
@@ -42,6 +45,10 @@ func HandleRESTRequests(wg *sync.WaitGroup, ctx context.Context) {
originsOk := handlers.AllowedOrigins(strings.Split(servercfg.GetAllowedOrigin(), ",")) originsOk := handlers.AllowedOrigins(strings.Split(servercfg.GetAllowedOrigin(), ","))
methodsOk := handlers.AllowedMethods([]string{http.MethodGet, http.MethodPut, http.MethodPost, http.MethodDelete}) methodsOk := handlers.AllowedMethods([]string{http.MethodGet, http.MethodPut, http.MethodPost, http.MethodDelete})
for _, middleware := range HttpMiddlewares {
r.Use(middleware)
}
for _, handler := range HttpHandlers { for _, handler := range HttpHandlers {
handler.(func(*mux.Router))(r) handler.(func(*mux.Router))(r)
} }

View File

@@ -68,22 +68,21 @@ func getUsage(w http.ResponseWriter, r *http.Request) {
// Responses: // Responses:
// 200: serverConfigResponse // 200: serverConfigResponse
func getStatus(w http.ResponseWriter, r *http.Request) { func getStatus(w http.ResponseWriter, r *http.Request) {
// TODO
// - check health of broker
type status struct { type status struct {
DB bool `json:"db_connected"` DB bool `json:"db_connected"`
Broker bool `json:"broker_connected"` Broker bool `json:"broker_connected"`
Usage struct { LicenseError string `json:"license_error"`
Hosts int `json:"hosts"` }
Clients int `json:"clients"`
Networks int `json:"networks"` licenseErr := ""
Users int `json:"users"` if servercfg.ErrLicenseValidation != nil {
} `json:"usage"` licenseErr = servercfg.ErrLicenseValidation.Error()
} }
currentServerStatus := status{ currentServerStatus := status{
DB: database.IsConnected(), DB: database.IsConnected(),
Broker: mq.IsConnected(), Broker: mq.IsConnected(),
LicenseError: licenseErr,
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")

View File

@@ -0,0 +1,17 @@
package ee_controllers
import (
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/servercfg"
"net/http"
)
func OnlyServerAPIWhenUnlicensedMiddleware(handler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
if servercfg.ErrLicenseValidation != nil && request.URL.Path != "/api/server/status" {
logic.ReturnErrorResponse(writer, request, logic.FormatError(servercfg.ErrLicenseValidation, "forbidden"))
return
}
handler.ServeHTTP(writer, request)
})
}

View File

@@ -7,10 +7,10 @@ import (
controller "github.com/gravitl/netmaker/controllers" controller "github.com/gravitl/netmaker/controllers"
"github.com/gravitl/netmaker/ee/ee_controllers" "github.com/gravitl/netmaker/ee/ee_controllers"
eelogic "github.com/gravitl/netmaker/ee/logic" eelogic "github.com/gravitl/netmaker/ee/logic"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/servercfg"
"golang.org/x/exp/slog"
) )
// InitEE - Initialize EE Logic // InitEE - Initialize EE Logic
@@ -18,6 +18,10 @@ func InitEE() {
setIsEnterprise() setIsEnterprise()
servercfg.Is_EE = true servercfg.Is_EE = true
models.SetLogo(retrieveEELogo()) models.SetLogo(retrieveEELogo())
controller.HttpMiddlewares = append(
controller.HttpMiddlewares,
ee_controllers.OnlyServerAPIWhenUnlicensedMiddleware,
)
controller.HttpHandlers = append( controller.HttpHandlers = append(
controller.HttpHandlers, controller.HttpHandlers,
ee_controllers.MetricHandlers, ee_controllers.MetricHandlers,
@@ -27,8 +31,11 @@ func InitEE() {
) )
logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() { logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() {
// == License Handling == // == License Handling ==
ValidateLicense() if err := ValidateLicense(); err != nil {
logger.Log(0, "proceeding with Paid Tier license") slog.Error(err.Error())
return
}
slog.Info("proceeding with Paid Tier license")
logic.SetFreeTierForTelemetry(false) logic.SetFreeTierForTelemetry(false)
// == End License Handling == // == End License Handling ==
AddLicenseHooks() AddLicenseHooks()
@@ -48,7 +55,7 @@ func resetFailover() {
for _, net := range nets { for _, net := range nets {
err = eelogic.ResetFailover(net.NetID) err = eelogic.ResetFailover(net.NetID)
if err != nil { if err != nil {
logger.Log(0, "failed to reset failover on network", net.NetID, ":", err.Error()) slog.Error("failed to reset failover", "network", net.NetID, "error", err.Error())
} }
} }
} }

View File

@@ -12,7 +12,6 @@ import (
"golang.org/x/exp/slog" "golang.org/x/exp/slog"
"io" "io"
"net/http" "net/http"
"os"
"time" "time"
"github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/database"
@@ -44,29 +43,40 @@ func AddLicenseHooks() {
} }
} }
// ValidateLicense - the initial license check for netmaker server // ValidateLicense - the initial and periodic license check for netmaker server
// checks if a license is valid + limits are not exceeded // checks if a license is valid + limits are not exceeded
// if license is free_tier and limits exceeds, then server should terminate // if license is free_tier and limits exceeds, then function should error
// if license is not valid, server should terminate // if license is not valid, function should error
func ValidateLicense() error { func ValidateLicense() (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("%w: %s", errValidation, err.Error())
servercfg.ErrLicenseValidation = err
}
}()
licenseKeyValue := servercfg.GetLicenseKey() licenseKeyValue := servercfg.GetLicenseKey()
netmakerTenantID := servercfg.GetNetmakerTenantID() netmakerTenantID := servercfg.GetNetmakerTenantID()
slog.Info("proceeding with Netmaker license validation...") slog.Info("proceeding with Netmaker license validation...")
if len(licenseKeyValue) == 0 { if len(licenseKeyValue) == 0 {
failValidation(errors.New("empty license-key (LICENSE_KEY environment variable)")) err = errors.New("empty license-key (LICENSE_KEY environment variable)")
return err
} }
if len(netmakerTenantID) == 0 { if len(netmakerTenantID) == 0 {
failValidation(errors.New("empty tenant-id (NETMAKER_TENANT_ID environment variable)")) err = errors.New("empty tenant-id (NETMAKER_TENANT_ID environment variable)")
return err
} }
apiPublicKey, err := getLicensePublicKey(licenseKeyValue) apiPublicKey, err := getLicensePublicKey(licenseKeyValue)
if err != nil { if err != nil {
failValidation(fmt.Errorf("failed to get license public key: %w", err)) err = fmt.Errorf("failed to get license public key: %w", err)
return err
} }
tempPubKey, tempPrivKey, err := FetchApiServerKeys() tempPubKey, tempPrivKey, err := FetchApiServerKeys()
if err != nil { if err != nil {
failValidation(fmt.Errorf("failed to fetch api server keys: %w", err)) err = fmt.Errorf("failed to fetch api server keys: %w", err)
return err
} }
licenseSecret := LicenseSecret{ licenseSecret := LicenseSecret{
@@ -76,35 +86,42 @@ func ValidateLicense() error {
secretData, err := json.Marshal(&licenseSecret) secretData, err := json.Marshal(&licenseSecret)
if err != nil { if err != nil {
failValidation(fmt.Errorf("failed to marshal license secret: %w", err)) err = fmt.Errorf("failed to marshal license secret: %w", err)
return err
} }
encryptedData, err := ncutils.BoxEncrypt(secretData, apiPublicKey, tempPrivKey) encryptedData, err := ncutils.BoxEncrypt(secretData, apiPublicKey, tempPrivKey)
if err != nil { if err != nil {
failValidation(fmt.Errorf("failed to encrypt license secret data: %w", err)) err = fmt.Errorf("failed to encrypt license secret data: %w", err)
return err
} }
validationResponse, err := validateLicenseKey(encryptedData, tempPubKey) validationResponse, err := validateLicenseKey(encryptedData, tempPubKey)
if err != nil { if err != nil {
failValidation(fmt.Errorf("failed to validate license key: %w", err)) err = fmt.Errorf("failed to validate license key: %w", err)
return err
} }
if len(validationResponse) == 0 { if len(validationResponse) == 0 {
failValidation(errors.New("empty validation response")) err = errors.New("empty validation response")
return err
} }
var licenseResponse ValidatedLicense var licenseResponse ValidatedLicense
if err = json.Unmarshal(validationResponse, &licenseResponse); err != nil { if err = json.Unmarshal(validationResponse, &licenseResponse); err != nil {
failValidation(fmt.Errorf("failed to unmarshal validation response: %w", err)) err = fmt.Errorf("failed to unmarshal validation response: %w", err)
return err
} }
respData, err := ncutils.BoxDecrypt(base64decode(licenseResponse.EncryptedLicense), apiPublicKey, tempPrivKey) respData, err := ncutils.BoxDecrypt(base64decode(licenseResponse.EncryptedLicense), apiPublicKey, tempPrivKey)
if err != nil { if err != nil {
failValidation(fmt.Errorf("failed to decrypt license: %w", err)) err = fmt.Errorf("failed to decrypt license: %w", err)
return err
} }
license := LicenseKey{} license := LicenseKey{}
if err = json.Unmarshal(respData, &license); err != nil { if err = json.Unmarshal(respData, &license); err != nil {
failValidation(fmt.Errorf("failed to unmarshal license key: %w", err)) err = fmt.Errorf("failed to unmarshal license key: %w", err)
return err
} }
slog.Info("License validation succeeded!") slog.Info("License validation succeeded!")
@@ -158,11 +175,6 @@ func FetchApiServerKeys() (pub *[32]byte, priv *[32]byte, err error) {
return pub, priv, nil return pub, priv, nil
} }
func failValidation(err error) {
slog.Error(errValidation.Error(), "error", err)
os.Exit(0)
}
func getLicensePublicKey(licensePubKeyEncoded string) (*[32]byte, error) { func getLicensePublicKey(licensePubKeyEncoded string) (*[32]byte, error) {
decodedPubKey := base64decode(licensePubKeyEncoded) decodedPubKey := base64decode(licensePubKeyEncoded)
return ncutils.ConvertBytesToKey(decodedPubKey) return ncutils.ConvertBytesToKey(decodedPubKey)

View File

@@ -3,10 +3,11 @@ package logic
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/gravitl/netmaker/logger"
"golang.org/x/exp/slog"
"sync" "sync"
"time" "time"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
) )
@@ -52,7 +53,7 @@ func StartHookManager(ctx context.Context, wg *sync.WaitGroup) {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
logger.Log(0, "## Stopping Hook Manager") slog.Error("## Stopping Hook Manager")
return return
case newhook := <-HookManagerCh: case newhook := <-HookManagerCh:
wg.Add(1) wg.Add(1)
@@ -70,7 +71,9 @@ func addHookWithInterval(ctx context.Context, wg *sync.WaitGroup, hook func() er
case <-ctx.Done(): case <-ctx.Done():
return return
case <-ticker.C: case <-ticker.C:
hook() if err := hook(); err != nil {
slog.Error(err.Error())
}
} }
} }
@@ -85,6 +88,7 @@ var timeHooks = []interface{}{
} }
func loggerDump() error { func loggerDump() error {
// TODO use slog?
logger.DumpFile(fmt.Sprintf("data/netmaker.log.%s", time.Now().Format(logger.TimeFormatDay))) logger.DumpFile(fmt.Sprintf("data/netmaker.log.%s", time.Now().Format(logger.TimeFormatDay)))
return nil return nil
} }
@@ -93,7 +97,7 @@ func loggerDump() error {
func runHooks() { func runHooks() {
for _, hook := range timeHooks { for _, hook := range timeHooks {
if err := hook.(func() error)(); err != nil { if err := hook.(func() error)(); err != nil {
logger.Log(1, "error occurred when running timer function:", err.Error()) slog.Error("error occurred when running timer function", "error", err.Error())
} }
} }
} }

View File

@@ -18,8 +18,9 @@ import (
const EmqxBrokerType = "emqx" const EmqxBrokerType = "emqx"
var ( var (
Version = "dev" Version = "dev"
Is_EE = false Is_EE = false
ErrLicenseValidation error
) )
// SetHost - sets the host ip // SetHost - sets the host ip