mirror of
https://github.com/gravitl/netmaker.git
synced 2025-10-22 00:19:39 +08:00
[NET-404] Run in limited mode when ee checks fail (#2474)
* Add limited http handlers functionality to rest handler * Export ee.errValidation (ee.ErrValidation) * Export a fatal error handled by the hook manager * Export a new status variable for unlicensed server * Mark server as unlicensed when ee checks fail * Handle license validation failures with a (re)boot in a limited state * Revert "Export a fatal error handled by the hook manager" This reverts commit 069c21974a8d36e889c73ad78023448d787d62a5. * Revert "Export ee.errValidation (ee.ErrValidation)" This reverts commit 59dbab8c79773ca5d879f28cbaf53f3dd4297b9b. * Revert "Add limited http handlers functionality to rest handler" This reverts commit e2f1f28facaca54713db76a588839cd2733cf673. * Revert "Handle license validation failures with a (re)boot in a limited state" This reverts commit 58cfbbaf522a1345aac1fa67964ebff0a6d60cd8. * Revert "Mark server as unlicensed when ee checks fail" This reverts commit 77c6dbdd3c9cfa6e7d6becedef6251e8617ae367. * Handle license validation failures with a middleware * Forbid responses if unlicensed ee and not in status api * Remove unused func
This commit is contained in:

committed by
GitHub

parent
a021e2659e
commit
922e7dbf2c
@@ -14,6 +14,9 @@ import (
|
|||||||
"github.com/gravitl/netmaker/servercfg"
|
"github.com/gravitl/netmaker/servercfg"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// HttpMiddlewares - middleware functions for REST interactions
|
||||||
|
var HttpMiddlewares []mux.MiddlewareFunc
|
||||||
|
|
||||||
// HttpHandlers - handler functions for REST interactions
|
// HttpHandlers - handler functions for REST interactions
|
||||||
var HttpHandlers = []interface{}{
|
var HttpHandlers = []interface{}{
|
||||||
nodeHandlers,
|
nodeHandlers,
|
||||||
@@ -42,6 +45,10 @@ func HandleRESTRequests(wg *sync.WaitGroup, ctx context.Context) {
|
|||||||
originsOk := handlers.AllowedOrigins(strings.Split(servercfg.GetAllowedOrigin(), ","))
|
originsOk := handlers.AllowedOrigins(strings.Split(servercfg.GetAllowedOrigin(), ","))
|
||||||
methodsOk := handlers.AllowedMethods([]string{http.MethodGet, http.MethodPut, http.MethodPost, http.MethodDelete})
|
methodsOk := handlers.AllowedMethods([]string{http.MethodGet, http.MethodPut, http.MethodPost, http.MethodDelete})
|
||||||
|
|
||||||
|
for _, middleware := range HttpMiddlewares {
|
||||||
|
r.Use(middleware)
|
||||||
|
}
|
||||||
|
|
||||||
for _, handler := range HttpHandlers {
|
for _, handler := range HttpHandlers {
|
||||||
handler.(func(*mux.Router))(r)
|
handler.(func(*mux.Router))(r)
|
||||||
}
|
}
|
||||||
|
@@ -68,22 +68,21 @@ func getUsage(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Responses:
|
// Responses:
|
||||||
// 200: serverConfigResponse
|
// 200: serverConfigResponse
|
||||||
func getStatus(w http.ResponseWriter, r *http.Request) {
|
func getStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
// TODO
|
|
||||||
// - check health of broker
|
|
||||||
type status struct {
|
type status struct {
|
||||||
DB bool `json:"db_connected"`
|
DB bool `json:"db_connected"`
|
||||||
Broker bool `json:"broker_connected"`
|
Broker bool `json:"broker_connected"`
|
||||||
Usage struct {
|
LicenseError string `json:"license_error"`
|
||||||
Hosts int `json:"hosts"`
|
}
|
||||||
Clients int `json:"clients"`
|
|
||||||
Networks int `json:"networks"`
|
licenseErr := ""
|
||||||
Users int `json:"users"`
|
if servercfg.ErrLicenseValidation != nil {
|
||||||
} `json:"usage"`
|
licenseErr = servercfg.ErrLicenseValidation.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
currentServerStatus := status{
|
currentServerStatus := status{
|
||||||
DB: database.IsConnected(),
|
DB: database.IsConnected(),
|
||||||
Broker: mq.IsConnected(),
|
Broker: mq.IsConnected(),
|
||||||
|
LicenseError: licenseErr,
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
17
ee/ee_controllers/middleware.go
Normal file
17
ee/ee_controllers/middleware.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package ee_controllers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gravitl/netmaker/logic"
|
||||||
|
"github.com/gravitl/netmaker/servercfg"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func OnlyServerAPIWhenUnlicensedMiddleware(handler http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||||
|
if servercfg.ErrLicenseValidation != nil && request.URL.Path != "/api/server/status" {
|
||||||
|
logic.ReturnErrorResponse(writer, request, logic.FormatError(servercfg.ErrLicenseValidation, "forbidden"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
handler.ServeHTTP(writer, request)
|
||||||
|
})
|
||||||
|
}
|
@@ -7,10 +7,10 @@ import (
|
|||||||
controller "github.com/gravitl/netmaker/controllers"
|
controller "github.com/gravitl/netmaker/controllers"
|
||||||
"github.com/gravitl/netmaker/ee/ee_controllers"
|
"github.com/gravitl/netmaker/ee/ee_controllers"
|
||||||
eelogic "github.com/gravitl/netmaker/ee/logic"
|
eelogic "github.com/gravitl/netmaker/ee/logic"
|
||||||
"github.com/gravitl/netmaker/logger"
|
|
||||||
"github.com/gravitl/netmaker/logic"
|
"github.com/gravitl/netmaker/logic"
|
||||||
"github.com/gravitl/netmaker/models"
|
"github.com/gravitl/netmaker/models"
|
||||||
"github.com/gravitl/netmaker/servercfg"
|
"github.com/gravitl/netmaker/servercfg"
|
||||||
|
"golang.org/x/exp/slog"
|
||||||
)
|
)
|
||||||
|
|
||||||
// InitEE - Initialize EE Logic
|
// InitEE - Initialize EE Logic
|
||||||
@@ -18,6 +18,10 @@ func InitEE() {
|
|||||||
setIsEnterprise()
|
setIsEnterprise()
|
||||||
servercfg.Is_EE = true
|
servercfg.Is_EE = true
|
||||||
models.SetLogo(retrieveEELogo())
|
models.SetLogo(retrieveEELogo())
|
||||||
|
controller.HttpMiddlewares = append(
|
||||||
|
controller.HttpMiddlewares,
|
||||||
|
ee_controllers.OnlyServerAPIWhenUnlicensedMiddleware,
|
||||||
|
)
|
||||||
controller.HttpHandlers = append(
|
controller.HttpHandlers = append(
|
||||||
controller.HttpHandlers,
|
controller.HttpHandlers,
|
||||||
ee_controllers.MetricHandlers,
|
ee_controllers.MetricHandlers,
|
||||||
@@ -27,8 +31,11 @@ func InitEE() {
|
|||||||
)
|
)
|
||||||
logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() {
|
logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() {
|
||||||
// == License Handling ==
|
// == License Handling ==
|
||||||
ValidateLicense()
|
if err := ValidateLicense(); err != nil {
|
||||||
logger.Log(0, "proceeding with Paid Tier license")
|
slog.Error(err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
slog.Info("proceeding with Paid Tier license")
|
||||||
logic.SetFreeTierForTelemetry(false)
|
logic.SetFreeTierForTelemetry(false)
|
||||||
// == End License Handling ==
|
// == End License Handling ==
|
||||||
AddLicenseHooks()
|
AddLicenseHooks()
|
||||||
@@ -48,7 +55,7 @@ func resetFailover() {
|
|||||||
for _, net := range nets {
|
for _, net := range nets {
|
||||||
err = eelogic.ResetFailover(net.NetID)
|
err = eelogic.ResetFailover(net.NetID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Log(0, "failed to reset failover on network", net.NetID, ":", err.Error())
|
slog.Error("failed to reset failover", "network", net.NetID, "error", err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -12,7 +12,6 @@ import (
|
|||||||
"golang.org/x/exp/slog"
|
"golang.org/x/exp/slog"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gravitl/netmaker/database"
|
"github.com/gravitl/netmaker/database"
|
||||||
@@ -44,29 +43,40 @@ func AddLicenseHooks() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateLicense - the initial license check for netmaker server
|
// ValidateLicense - the initial and periodic license check for netmaker server
|
||||||
// checks if a license is valid + limits are not exceeded
|
// checks if a license is valid + limits are not exceeded
|
||||||
// if license is free_tier and limits exceeds, then server should terminate
|
// if license is free_tier and limits exceeds, then function should error
|
||||||
// if license is not valid, server should terminate
|
// if license is not valid, function should error
|
||||||
func ValidateLicense() error {
|
func ValidateLicense() (err error) {
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("%w: %s", errValidation, err.Error())
|
||||||
|
servercfg.ErrLicenseValidation = err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
licenseKeyValue := servercfg.GetLicenseKey()
|
licenseKeyValue := servercfg.GetLicenseKey()
|
||||||
netmakerTenantID := servercfg.GetNetmakerTenantID()
|
netmakerTenantID := servercfg.GetNetmakerTenantID()
|
||||||
slog.Info("proceeding with Netmaker license validation...")
|
slog.Info("proceeding with Netmaker license validation...")
|
||||||
if len(licenseKeyValue) == 0 {
|
if len(licenseKeyValue) == 0 {
|
||||||
failValidation(errors.New("empty license-key (LICENSE_KEY environment variable)"))
|
err = errors.New("empty license-key (LICENSE_KEY environment variable)")
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
if len(netmakerTenantID) == 0 {
|
if len(netmakerTenantID) == 0 {
|
||||||
failValidation(errors.New("empty tenant-id (NETMAKER_TENANT_ID environment variable)"))
|
err = errors.New("empty tenant-id (NETMAKER_TENANT_ID environment variable)")
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
apiPublicKey, err := getLicensePublicKey(licenseKeyValue)
|
apiPublicKey, err := getLicensePublicKey(licenseKeyValue)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
failValidation(fmt.Errorf("failed to get license public key: %w", err))
|
err = fmt.Errorf("failed to get license public key: %w", err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
tempPubKey, tempPrivKey, err := FetchApiServerKeys()
|
tempPubKey, tempPrivKey, err := FetchApiServerKeys()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
failValidation(fmt.Errorf("failed to fetch api server keys: %w", err))
|
err = fmt.Errorf("failed to fetch api server keys: %w", err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
licenseSecret := LicenseSecret{
|
licenseSecret := LicenseSecret{
|
||||||
@@ -76,35 +86,42 @@ func ValidateLicense() error {
|
|||||||
|
|
||||||
secretData, err := json.Marshal(&licenseSecret)
|
secretData, err := json.Marshal(&licenseSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
failValidation(fmt.Errorf("failed to marshal license secret: %w", err))
|
err = fmt.Errorf("failed to marshal license secret: %w", err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
encryptedData, err := ncutils.BoxEncrypt(secretData, apiPublicKey, tempPrivKey)
|
encryptedData, err := ncutils.BoxEncrypt(secretData, apiPublicKey, tempPrivKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
failValidation(fmt.Errorf("failed to encrypt license secret data: %w", err))
|
err = fmt.Errorf("failed to encrypt license secret data: %w", err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
validationResponse, err := validateLicenseKey(encryptedData, tempPubKey)
|
validationResponse, err := validateLicenseKey(encryptedData, tempPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
failValidation(fmt.Errorf("failed to validate license key: %w", err))
|
err = fmt.Errorf("failed to validate license key: %w", err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
if len(validationResponse) == 0 {
|
if len(validationResponse) == 0 {
|
||||||
failValidation(errors.New("empty validation response"))
|
err = errors.New("empty validation response")
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var licenseResponse ValidatedLicense
|
var licenseResponse ValidatedLicense
|
||||||
if err = json.Unmarshal(validationResponse, &licenseResponse); err != nil {
|
if err = json.Unmarshal(validationResponse, &licenseResponse); err != nil {
|
||||||
failValidation(fmt.Errorf("failed to unmarshal validation response: %w", err))
|
err = fmt.Errorf("failed to unmarshal validation response: %w", err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
respData, err := ncutils.BoxDecrypt(base64decode(licenseResponse.EncryptedLicense), apiPublicKey, tempPrivKey)
|
respData, err := ncutils.BoxDecrypt(base64decode(licenseResponse.EncryptedLicense), apiPublicKey, tempPrivKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
failValidation(fmt.Errorf("failed to decrypt license: %w", err))
|
err = fmt.Errorf("failed to decrypt license: %w", err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
license := LicenseKey{}
|
license := LicenseKey{}
|
||||||
if err = json.Unmarshal(respData, &license); err != nil {
|
if err = json.Unmarshal(respData, &license); err != nil {
|
||||||
failValidation(fmt.Errorf("failed to unmarshal license key: %w", err))
|
err = fmt.Errorf("failed to unmarshal license key: %w", err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("License validation succeeded!")
|
slog.Info("License validation succeeded!")
|
||||||
@@ -158,11 +175,6 @@ func FetchApiServerKeys() (pub *[32]byte, priv *[32]byte, err error) {
|
|||||||
return pub, priv, nil
|
return pub, priv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func failValidation(err error) {
|
|
||||||
slog.Error(errValidation.Error(), "error", err)
|
|
||||||
os.Exit(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getLicensePublicKey(licensePubKeyEncoded string) (*[32]byte, error) {
|
func getLicensePublicKey(licensePubKeyEncoded string) (*[32]byte, error) {
|
||||||
decodedPubKey := base64decode(licensePubKeyEncoded)
|
decodedPubKey := base64decode(licensePubKeyEncoded)
|
||||||
return ncutils.ConvertBytesToKey(decodedPubKey)
|
return ncutils.ConvertBytesToKey(decodedPubKey)
|
||||||
|
@@ -3,10 +3,11 @@ package logic
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/gravitl/netmaker/logger"
|
||||||
|
"golang.org/x/exp/slog"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gravitl/netmaker/logger"
|
|
||||||
"github.com/gravitl/netmaker/models"
|
"github.com/gravitl/netmaker/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -52,7 +53,7 @@ func StartHookManager(ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
logger.Log(0, "## Stopping Hook Manager")
|
slog.Error("## Stopping Hook Manager")
|
||||||
return
|
return
|
||||||
case newhook := <-HookManagerCh:
|
case newhook := <-HookManagerCh:
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
@@ -70,7 +71,9 @@ func addHookWithInterval(ctx context.Context, wg *sync.WaitGroup, hook func() er
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
hook()
|
if err := hook(); err != nil {
|
||||||
|
slog.Error(err.Error())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -85,6 +88,7 @@ var timeHooks = []interface{}{
|
|||||||
}
|
}
|
||||||
|
|
||||||
func loggerDump() error {
|
func loggerDump() error {
|
||||||
|
// TODO use slog?
|
||||||
logger.DumpFile(fmt.Sprintf("data/netmaker.log.%s", time.Now().Format(logger.TimeFormatDay)))
|
logger.DumpFile(fmt.Sprintf("data/netmaker.log.%s", time.Now().Format(logger.TimeFormatDay)))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -93,7 +97,7 @@ func loggerDump() error {
|
|||||||
func runHooks() {
|
func runHooks() {
|
||||||
for _, hook := range timeHooks {
|
for _, hook := range timeHooks {
|
||||||
if err := hook.(func() error)(); err != nil {
|
if err := hook.(func() error)(); err != nil {
|
||||||
logger.Log(1, "error occurred when running timer function:", err.Error())
|
slog.Error("error occurred when running timer function", "error", err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -18,8 +18,9 @@ import (
|
|||||||
const EmqxBrokerType = "emqx"
|
const EmqxBrokerType = "emqx"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
Version = "dev"
|
Version = "dev"
|
||||||
Is_EE = false
|
Is_EE = false
|
||||||
|
ErrLicenseValidation error
|
||||||
)
|
)
|
||||||
|
|
||||||
// SetHost - sets the host ip
|
// SetHost - sets the host ip
|
||||||
|
Reference in New Issue
Block a user