mirror of
https://github.com/gravitl/netmaker.git
synced 2025-10-25 01:40:46 +08:00
[NET-404] Run in limited mode when ee checks fail (#2474)
* Add limited http handlers functionality to rest handler * Export ee.errValidation (ee.ErrValidation) * Export a fatal error handled by the hook manager * Export a new status variable for unlicensed server * Mark server as unlicensed when ee checks fail * Handle license validation failures with a (re)boot in a limited state * Revert "Export a fatal error handled by the hook manager" This reverts commit 069c21974a8d36e889c73ad78023448d787d62a5. * Revert "Export ee.errValidation (ee.ErrValidation)" This reverts commit 59dbab8c79773ca5d879f28cbaf53f3dd4297b9b. * Revert "Add limited http handlers functionality to rest handler" This reverts commit e2f1f28facaca54713db76a588839cd2733cf673. * Revert "Handle license validation failures with a (re)boot in a limited state" This reverts commit 58cfbbaf522a1345aac1fa67964ebff0a6d60cd8. * Revert "Mark server as unlicensed when ee checks fail" This reverts commit 77c6dbdd3c9cfa6e7d6becedef6251e8617ae367. * Handle license validation failures with a middleware * Forbid responses if unlicensed ee and not in status api * Remove unused func
This commit is contained in:
committed by
GitHub
parent
a021e2659e
commit
922e7dbf2c
@@ -14,6 +14,9 @@ import (
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
)
|
||||
|
||||
// HttpMiddlewares - middleware functions for REST interactions
|
||||
var HttpMiddlewares []mux.MiddlewareFunc
|
||||
|
||||
// HttpHandlers - handler functions for REST interactions
|
||||
var HttpHandlers = []interface{}{
|
||||
nodeHandlers,
|
||||
@@ -42,6 +45,10 @@ func HandleRESTRequests(wg *sync.WaitGroup, ctx context.Context) {
|
||||
originsOk := handlers.AllowedOrigins(strings.Split(servercfg.GetAllowedOrigin(), ","))
|
||||
methodsOk := handlers.AllowedMethods([]string{http.MethodGet, http.MethodPut, http.MethodPost, http.MethodDelete})
|
||||
|
||||
for _, middleware := range HttpMiddlewares {
|
||||
r.Use(middleware)
|
||||
}
|
||||
|
||||
for _, handler := range HttpHandlers {
|
||||
handler.(func(*mux.Router))(r)
|
||||
}
|
||||
|
||||
@@ -68,22 +68,21 @@ func getUsage(w http.ResponseWriter, r *http.Request) {
|
||||
// Responses:
|
||||
// 200: serverConfigResponse
|
||||
func getStatus(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO
|
||||
// - check health of broker
|
||||
type status struct {
|
||||
DB bool `json:"db_connected"`
|
||||
Broker bool `json:"broker_connected"`
|
||||
Usage struct {
|
||||
Hosts int `json:"hosts"`
|
||||
Clients int `json:"clients"`
|
||||
Networks int `json:"networks"`
|
||||
Users int `json:"users"`
|
||||
} `json:"usage"`
|
||||
LicenseError string `json:"license_error"`
|
||||
}
|
||||
|
||||
licenseErr := ""
|
||||
if servercfg.ErrLicenseValidation != nil {
|
||||
licenseErr = servercfg.ErrLicenseValidation.Error()
|
||||
}
|
||||
|
||||
currentServerStatus := status{
|
||||
DB: database.IsConnected(),
|
||||
Broker: mq.IsConnected(),
|
||||
LicenseError: licenseErr,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
17
ee/ee_controllers/middleware.go
Normal file
17
ee/ee_controllers/middleware.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package ee_controllers
|
||||
|
||||
import (
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func OnlyServerAPIWhenUnlicensedMiddleware(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
if servercfg.ErrLicenseValidation != nil && request.URL.Path != "/api/server/status" {
|
||||
logic.ReturnErrorResponse(writer, request, logic.FormatError(servercfg.ErrLicenseValidation, "forbidden"))
|
||||
return
|
||||
}
|
||||
handler.ServeHTTP(writer, request)
|
||||
})
|
||||
}
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
controller "github.com/gravitl/netmaker/controllers"
|
||||
"github.com/gravitl/netmaker/ee/ee_controllers"
|
||||
eelogic "github.com/gravitl/netmaker/ee/logic"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
// InitEE - Initialize EE Logic
|
||||
@@ -18,6 +18,10 @@ func InitEE() {
|
||||
setIsEnterprise()
|
||||
servercfg.Is_EE = true
|
||||
models.SetLogo(retrieveEELogo())
|
||||
controller.HttpMiddlewares = append(
|
||||
controller.HttpMiddlewares,
|
||||
ee_controllers.OnlyServerAPIWhenUnlicensedMiddleware,
|
||||
)
|
||||
controller.HttpHandlers = append(
|
||||
controller.HttpHandlers,
|
||||
ee_controllers.MetricHandlers,
|
||||
@@ -27,8 +31,11 @@ func InitEE() {
|
||||
)
|
||||
logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() {
|
||||
// == License Handling ==
|
||||
ValidateLicense()
|
||||
logger.Log(0, "proceeding with Paid Tier license")
|
||||
if err := ValidateLicense(); err != nil {
|
||||
slog.Error(err.Error())
|
||||
return
|
||||
}
|
||||
slog.Info("proceeding with Paid Tier license")
|
||||
logic.SetFreeTierForTelemetry(false)
|
||||
// == End License Handling ==
|
||||
AddLicenseHooks()
|
||||
@@ -48,7 +55,7 @@ func resetFailover() {
|
||||
for _, net := range nets {
|
||||
err = eelogic.ResetFailover(net.NetID)
|
||||
if err != nil {
|
||||
logger.Log(0, "failed to reset failover on network", net.NetID, ":", err.Error())
|
||||
slog.Error("failed to reset failover", "network", net.NetID, "error", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"golang.org/x/exp/slog"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
@@ -44,29 +43,40 @@ func AddLicenseHooks() {
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateLicense - the initial license check for netmaker server
|
||||
// ValidateLicense - the initial and periodic license check for netmaker server
|
||||
// checks if a license is valid + limits are not exceeded
|
||||
// if license is free_tier and limits exceeds, then server should terminate
|
||||
// if license is not valid, server should terminate
|
||||
func ValidateLicense() error {
|
||||
// if license is free_tier and limits exceeds, then function should error
|
||||
// if license is not valid, function should error
|
||||
func ValidateLicense() (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("%w: %s", errValidation, err.Error())
|
||||
servercfg.ErrLicenseValidation = err
|
||||
}
|
||||
}()
|
||||
|
||||
licenseKeyValue := servercfg.GetLicenseKey()
|
||||
netmakerTenantID := servercfg.GetNetmakerTenantID()
|
||||
slog.Info("proceeding with Netmaker license validation...")
|
||||
if len(licenseKeyValue) == 0 {
|
||||
failValidation(errors.New("empty license-key (LICENSE_KEY environment variable)"))
|
||||
err = errors.New("empty license-key (LICENSE_KEY environment variable)")
|
||||
return err
|
||||
}
|
||||
if len(netmakerTenantID) == 0 {
|
||||
failValidation(errors.New("empty tenant-id (NETMAKER_TENANT_ID environment variable)"))
|
||||
err = errors.New("empty tenant-id (NETMAKER_TENANT_ID environment variable)")
|
||||
return err
|
||||
}
|
||||
|
||||
apiPublicKey, err := getLicensePublicKey(licenseKeyValue)
|
||||
if err != nil {
|
||||
failValidation(fmt.Errorf("failed to get license public key: %w", err))
|
||||
err = fmt.Errorf("failed to get license public key: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
tempPubKey, tempPrivKey, err := FetchApiServerKeys()
|
||||
if err != nil {
|
||||
failValidation(fmt.Errorf("failed to fetch api server keys: %w", err))
|
||||
err = fmt.Errorf("failed to fetch api server keys: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
licenseSecret := LicenseSecret{
|
||||
@@ -76,35 +86,42 @@ func ValidateLicense() error {
|
||||
|
||||
secretData, err := json.Marshal(&licenseSecret)
|
||||
if err != nil {
|
||||
failValidation(fmt.Errorf("failed to marshal license secret: %w", err))
|
||||
err = fmt.Errorf("failed to marshal license secret: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
encryptedData, err := ncutils.BoxEncrypt(secretData, apiPublicKey, tempPrivKey)
|
||||
if err != nil {
|
||||
failValidation(fmt.Errorf("failed to encrypt license secret data: %w", err))
|
||||
err = fmt.Errorf("failed to encrypt license secret data: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
validationResponse, err := validateLicenseKey(encryptedData, tempPubKey)
|
||||
if err != nil {
|
||||
failValidation(fmt.Errorf("failed to validate license key: %w", err))
|
||||
err = fmt.Errorf("failed to validate license key: %w", err)
|
||||
return err
|
||||
}
|
||||
if len(validationResponse) == 0 {
|
||||
failValidation(errors.New("empty validation response"))
|
||||
err = errors.New("empty validation response")
|
||||
return err
|
||||
}
|
||||
|
||||
var licenseResponse ValidatedLicense
|
||||
if err = json.Unmarshal(validationResponse, &licenseResponse); err != nil {
|
||||
failValidation(fmt.Errorf("failed to unmarshal validation response: %w", err))
|
||||
err = fmt.Errorf("failed to unmarshal validation response: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
respData, err := ncutils.BoxDecrypt(base64decode(licenseResponse.EncryptedLicense), apiPublicKey, tempPrivKey)
|
||||
if err != nil {
|
||||
failValidation(fmt.Errorf("failed to decrypt license: %w", err))
|
||||
err = fmt.Errorf("failed to decrypt license: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
license := LicenseKey{}
|
||||
if err = json.Unmarshal(respData, &license); err != nil {
|
||||
failValidation(fmt.Errorf("failed to unmarshal license key: %w", err))
|
||||
err = fmt.Errorf("failed to unmarshal license key: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
slog.Info("License validation succeeded!")
|
||||
@@ -158,11 +175,6 @@ func FetchApiServerKeys() (pub *[32]byte, priv *[32]byte, err error) {
|
||||
return pub, priv, nil
|
||||
}
|
||||
|
||||
func failValidation(err error) {
|
||||
slog.Error(errValidation.Error(), "error", err)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func getLicensePublicKey(licensePubKeyEncoded string) (*[32]byte, error) {
|
||||
decodedPubKey := base64decode(licensePubKeyEncoded)
|
||||
return ncutils.ConvertBytesToKey(decodedPubKey)
|
||||
|
||||
@@ -3,10 +3,11 @@ package logic
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"golang.org/x/exp/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
)
|
||||
|
||||
@@ -52,7 +53,7 @@ func StartHookManager(ctx context.Context, wg *sync.WaitGroup) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Log(0, "## Stopping Hook Manager")
|
||||
slog.Error("## Stopping Hook Manager")
|
||||
return
|
||||
case newhook := <-HookManagerCh:
|
||||
wg.Add(1)
|
||||
@@ -70,7 +71,9 @@ func addHookWithInterval(ctx context.Context, wg *sync.WaitGroup, hook func() er
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
hook()
|
||||
if err := hook(); err != nil {
|
||||
slog.Error(err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,6 +88,7 @@ var timeHooks = []interface{}{
|
||||
}
|
||||
|
||||
func loggerDump() error {
|
||||
// TODO use slog?
|
||||
logger.DumpFile(fmt.Sprintf("data/netmaker.log.%s", time.Now().Format(logger.TimeFormatDay)))
|
||||
return nil
|
||||
}
|
||||
@@ -93,7 +97,7 @@ func loggerDump() error {
|
||||
func runHooks() {
|
||||
for _, hook := range timeHooks {
|
||||
if err := hook.(func() error)(); err != nil {
|
||||
logger.Log(1, "error occurred when running timer function:", err.Error())
|
||||
slog.Error("error occurred when running timer function", "error", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ const EmqxBrokerType = "emqx"
|
||||
var (
|
||||
Version = "dev"
|
||||
Is_EE = false
|
||||
ErrLicenseValidation error
|
||||
)
|
||||
|
||||
// SetHost - sets the host ip
|
||||
|
||||
Reference in New Issue
Block a user