mirror of
https://github.com/gravitl/netmaker.git
synced 2025-09-26 21:01:32 +08:00
add trial license logic
This commit is contained in:
@@ -124,29 +124,29 @@ func InitializeDatabase() error {
|
||||
}
|
||||
|
||||
func createTables() {
|
||||
createTable(NETWORKS_TABLE_NAME)
|
||||
createTable(NODES_TABLE_NAME)
|
||||
createTable(CERTS_TABLE_NAME)
|
||||
createTable(DELETED_NODES_TABLE_NAME)
|
||||
createTable(USERS_TABLE_NAME)
|
||||
createTable(DNS_TABLE_NAME)
|
||||
createTable(EXT_CLIENT_TABLE_NAME)
|
||||
createTable(PEERS_TABLE_NAME)
|
||||
createTable(SERVERCONF_TABLE_NAME)
|
||||
createTable(SERVER_UUID_TABLE_NAME)
|
||||
createTable(GENERATED_TABLE_NAME)
|
||||
createTable(NODE_ACLS_TABLE_NAME)
|
||||
createTable(SSO_STATE_CACHE)
|
||||
createTable(METRICS_TABLE_NAME)
|
||||
createTable(NETWORK_USER_TABLE_NAME)
|
||||
createTable(USER_GROUPS_TABLE_NAME)
|
||||
createTable(CACHE_TABLE_NAME)
|
||||
createTable(HOSTS_TABLE_NAME)
|
||||
createTable(ENROLLMENT_KEYS_TABLE_NAME)
|
||||
createTable(HOST_ACTIONS_TABLE_NAME)
|
||||
CreateTable(NETWORKS_TABLE_NAME)
|
||||
CreateTable(NODES_TABLE_NAME)
|
||||
CreateTable(CERTS_TABLE_NAME)
|
||||
CreateTable(DELETED_NODES_TABLE_NAME)
|
||||
CreateTable(USERS_TABLE_NAME)
|
||||
CreateTable(DNS_TABLE_NAME)
|
||||
CreateTable(EXT_CLIENT_TABLE_NAME)
|
||||
CreateTable(PEERS_TABLE_NAME)
|
||||
CreateTable(SERVERCONF_TABLE_NAME)
|
||||
CreateTable(SERVER_UUID_TABLE_NAME)
|
||||
CreateTable(GENERATED_TABLE_NAME)
|
||||
CreateTable(NODE_ACLS_TABLE_NAME)
|
||||
CreateTable(SSO_STATE_CACHE)
|
||||
CreateTable(METRICS_TABLE_NAME)
|
||||
CreateTable(NETWORK_USER_TABLE_NAME)
|
||||
CreateTable(USER_GROUPS_TABLE_NAME)
|
||||
CreateTable(CACHE_TABLE_NAME)
|
||||
CreateTable(HOSTS_TABLE_NAME)
|
||||
CreateTable(ENROLLMENT_KEYS_TABLE_NAME)
|
||||
CreateTable(HOST_ACTIONS_TABLE_NAME)
|
||||
}
|
||||
|
||||
func createTable(tableName string) error {
|
||||
func CreateTable(tableName string) error {
|
||||
return getCurrentDB()[CREATE_TABLE].(func(string) error)(tableName)
|
||||
}
|
||||
|
||||
@@ -194,7 +194,7 @@ func DeleteAllRecords(tableName string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = createTable(tableName)
|
||||
err = CreateTable(tableName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -32,12 +32,12 @@ func sendTelemetry() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var telRecord, err = fetchTelemetryRecord()
|
||||
var telRecord, err = FetchTelemetryRecord()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// get telemetry data
|
||||
d, err := fetchTelemetryData()
|
||||
d, err := FetchTelemetryData()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -71,8 +71,8 @@ func sendTelemetry() error {
|
||||
})
|
||||
}
|
||||
|
||||
// fetchTelemetry - fetches telemetry data: count of various object types in DB
|
||||
func fetchTelemetryData() (telemetryData, error) {
|
||||
// FetchTelemetryData - fetches telemetry data: count of various object types in DB
|
||||
func FetchTelemetryData() (telemetryData, error) {
|
||||
var data telemetryData
|
||||
|
||||
data.IsPro = servercfg.IsPro
|
||||
@@ -138,8 +138,8 @@ func getClientCount(nodes []models.Node) clientCount {
|
||||
return count
|
||||
}
|
||||
|
||||
// fetchTelemetryRecord - get the existing UUID and Timestamp from the DB
|
||||
func fetchTelemetryRecord() (models.Telemetry, error) {
|
||||
// FetchTelemetryRecord - get the existing UUID and Timestamp from the DB
|
||||
func FetchTelemetryRecord() (models.Telemetry, error) {
|
||||
var rawData string
|
||||
var telObj models.Telemetry
|
||||
var err error
|
||||
|
@@ -3,11 +3,12 @@ package logic
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"golang.org/x/exp/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"golang.org/x/exp/slog"
|
||||
|
||||
"github.com/gravitl/netmaker/models"
|
||||
)
|
||||
|
||||
@@ -24,7 +25,7 @@ var HookManagerCh = make(chan models.HookDetails, 3)
|
||||
// TimerCheckpoint - Checks if 24 hours has passed since telemetry was last sent. If so, sends telemetry data to posthog
|
||||
func TimerCheckpoint() error {
|
||||
// get the telemetry record in the DB, which contains a timestamp
|
||||
telRecord, err := fetchTelemetryRecord()
|
||||
telRecord, err := FetchTelemetryRecord()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -2,7 +2,7 @@ package logic
|
||||
|
||||
// RetrievePrivateTrafficKey - retrieves private key of server
|
||||
func RetrievePrivateTrafficKey() ([]byte, error) {
|
||||
var telRecord, err = fetchTelemetryRecord()
|
||||
var telRecord, err = FetchTelemetryRecord()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -12,7 +12,7 @@ func RetrievePrivateTrafficKey() ([]byte, error) {
|
||||
|
||||
// RetrievePublicTrafficKey - retrieves public key of server
|
||||
func RetrievePublicTrafficKey() ([]byte, error) {
|
||||
var telRecord, err = fetchTelemetryRecord()
|
||||
var telRecord, err = FetchTelemetryRecord()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@@ -4,6 +4,8 @@
|
||||
package pro
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
controller "github.com/gravitl/netmaker/controllers"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
@@ -17,6 +19,7 @@ import (
|
||||
// InitPro - Initialize Pro Logic
|
||||
func InitPro() {
|
||||
servercfg.IsPro = true
|
||||
proLogic.InitTrial()
|
||||
models.SetLogo(retrieveProLogo())
|
||||
controller.HttpMiddlewares = append(
|
||||
controller.HttpMiddlewares,
|
||||
@@ -31,18 +34,36 @@ func InitPro() {
|
||||
)
|
||||
logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() {
|
||||
// == License Handling ==
|
||||
ClearLicenseCache()
|
||||
if err := ValidateLicense(); err != nil {
|
||||
slog.Error(err.Error())
|
||||
return
|
||||
enableLicenseHook := false
|
||||
trialEndDate, err := getTrialEndDate()
|
||||
if err != nil {
|
||||
slog.Error("failed to get trial end date", "error", err)
|
||||
enableLicenseHook = true
|
||||
}
|
||||
slog.Info("proceeding with Paid Tier license")
|
||||
logic.SetFreeTierForTelemetry(false)
|
||||
// == End License Handling ==
|
||||
AddLicenseHooks()
|
||||
// check if trial ended
|
||||
if time.Now().After(trialEndDate) {
|
||||
// trial ended already
|
||||
enableLicenseHook = true
|
||||
}
|
||||
if enableLicenseHook {
|
||||
slog.Info("starting license checker")
|
||||
ClearLicenseCache()
|
||||
if err := ValidateLicense(); err != nil {
|
||||
slog.Error(err.Error())
|
||||
return
|
||||
}
|
||||
slog.Info("proceeding with Paid Tier license")
|
||||
logic.SetFreeTierForTelemetry(false)
|
||||
// == End License Handling ==
|
||||
AddLicenseHooks()
|
||||
} else {
|
||||
addTrialLicenseHook()
|
||||
}
|
||||
|
||||
if servercfg.GetServerConfig().RacAutoDisable {
|
||||
AddRacHooks()
|
||||
}
|
||||
|
||||
})
|
||||
logic.ResetFailOver = proLogic.ResetFailOver
|
||||
logic.ResetFailedOverPeer = proLogic.ResetFailedOverPeer
|
||||
|
146
pro/trial.go
Normal file
146
pro/trial.go
Normal file
@@ -0,0 +1,146 @@
|
||||
//go:build ee
|
||||
// +build ee
|
||||
|
||||
package pro
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/netclient/ncutils"
|
||||
"golang.org/x/crypto/nacl/box"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
type TrialInfo struct {
|
||||
PrivKey []byte `json:"priv_key"`
|
||||
PubKey []byte `json:"pub_key"`
|
||||
Secret string `json:"secret"`
|
||||
}
|
||||
|
||||
func addTrialLicenseHook() {
|
||||
logic.HookManagerCh <- models.HookDetails{
|
||||
Hook: TrialLicenseHook,
|
||||
Interval: time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
type TrialDates struct {
|
||||
TrialStartedAt time.Time `json:"trial_started_at"`
|
||||
TrialEndsAt time.Time `json:"trial_ends_at"`
|
||||
}
|
||||
|
||||
const trial_table_name = "trial"
|
||||
|
||||
const trial_data_key = "trialdata"
|
||||
|
||||
// store trial date
|
||||
func InitTrial() error {
|
||||
telData, err := logic.FetchTelemetryData()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if telData.Hosts > 0 || telData.Networks > 0 || telData.Users > 0 {
|
||||
return nil
|
||||
}
|
||||
err = database.CreateTable(trial_table_name)
|
||||
if err != nil {
|
||||
slog.Error("failed to create table", "table name", trial_table_name, "err", err.Error())
|
||||
return err
|
||||
}
|
||||
// setup encryption keys
|
||||
trafficPubKey, trafficPrivKey, err := box.GenerateKey(rand.Reader) // generate traffic keys
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tPriv, err := ncutils.ConvertKeyToBytes(trafficPrivKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tPub, err := ncutils.ConvertKeyToBytes(trafficPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
trialDates := TrialDates{
|
||||
TrialStartedAt: time.Now(),
|
||||
TrialEndsAt: time.Now().Add(time.Hour * 24 * 30),
|
||||
}
|
||||
t := TrialInfo{
|
||||
PrivKey: tPriv,
|
||||
PubKey: tPub,
|
||||
}
|
||||
tel, err := logic.FetchTelemetryRecord()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
trialDatesData, err := json.Marshal(trialDates)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
trialDatesSecret, err := ncutils.BoxEncrypt(trialDatesData, (*[32]byte)(tel.TrafficKeyPub), (*[32]byte)(t.PrivKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.Secret = string(trialDatesSecret)
|
||||
trialData, err := json.Marshal(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = database.Insert(trial_data_key, string(trialData), trial_table_name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TrialLicenseHook() error {
|
||||
endDate, err := getTrialEndDate()
|
||||
if err != nil {
|
||||
logger.FatalLog0("failed to trial end date", err.Error())
|
||||
}
|
||||
if time.Now().After(endDate) {
|
||||
logger.FatalLog0("***IMPORTANT: Your Trial Has Ended, to continue using pro version, please visit https://app.netmaker.io/ and create on-prem tenant to obtain a license***\nIf you wish to downgrade to community version, please run this command `/root/nm-quick.sh -d`")
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// get trial date
|
||||
func getTrialEndDate() (time.Time, error) {
|
||||
record, err := database.FetchRecord(trial_table_name, trial_data_key)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
var trialInfo TrialInfo
|
||||
err = json.Unmarshal([]byte(record), &trialInfo)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
tel, err := logic.FetchTelemetryRecord()
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
// decrypt secret
|
||||
secretDecrypt, err := ncutils.BoxDecrypt([]byte(trialInfo.Secret), (*[32]byte)(trialInfo.PubKey), (*[32]byte)(tel.TrafficKeyPriv))
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
trialDates := TrialDates{}
|
||||
err = json.Unmarshal(secretDecrypt, &trialDates)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
if trialDates.TrialEndsAt.IsZero() {
|
||||
return time.Time{}, errors.New("invalid date")
|
||||
}
|
||||
return trialDates.TrialEndsAt, nil
|
||||
|
||||
}
|
Reference in New Issue
Block a user