add trial license logic

This commit is contained in:
abhishek9686
2024-01-19 14:51:51 +05:30
parent abe7f4cf52
commit 6749fb4516
6 changed files with 209 additions and 41 deletions

View File

@@ -124,29 +124,29 @@ func InitializeDatabase() error {
}
func createTables() {
createTable(NETWORKS_TABLE_NAME)
createTable(NODES_TABLE_NAME)
createTable(CERTS_TABLE_NAME)
createTable(DELETED_NODES_TABLE_NAME)
createTable(USERS_TABLE_NAME)
createTable(DNS_TABLE_NAME)
createTable(EXT_CLIENT_TABLE_NAME)
createTable(PEERS_TABLE_NAME)
createTable(SERVERCONF_TABLE_NAME)
createTable(SERVER_UUID_TABLE_NAME)
createTable(GENERATED_TABLE_NAME)
createTable(NODE_ACLS_TABLE_NAME)
createTable(SSO_STATE_CACHE)
createTable(METRICS_TABLE_NAME)
createTable(NETWORK_USER_TABLE_NAME)
createTable(USER_GROUPS_TABLE_NAME)
createTable(CACHE_TABLE_NAME)
createTable(HOSTS_TABLE_NAME)
createTable(ENROLLMENT_KEYS_TABLE_NAME)
createTable(HOST_ACTIONS_TABLE_NAME)
CreateTable(NETWORKS_TABLE_NAME)
CreateTable(NODES_TABLE_NAME)
CreateTable(CERTS_TABLE_NAME)
CreateTable(DELETED_NODES_TABLE_NAME)
CreateTable(USERS_TABLE_NAME)
CreateTable(DNS_TABLE_NAME)
CreateTable(EXT_CLIENT_TABLE_NAME)
CreateTable(PEERS_TABLE_NAME)
CreateTable(SERVERCONF_TABLE_NAME)
CreateTable(SERVER_UUID_TABLE_NAME)
CreateTable(GENERATED_TABLE_NAME)
CreateTable(NODE_ACLS_TABLE_NAME)
CreateTable(SSO_STATE_CACHE)
CreateTable(METRICS_TABLE_NAME)
CreateTable(NETWORK_USER_TABLE_NAME)
CreateTable(USER_GROUPS_TABLE_NAME)
CreateTable(CACHE_TABLE_NAME)
CreateTable(HOSTS_TABLE_NAME)
CreateTable(ENROLLMENT_KEYS_TABLE_NAME)
CreateTable(HOST_ACTIONS_TABLE_NAME)
}
func createTable(tableName string) error {
func CreateTable(tableName string) error {
return getCurrentDB()[CREATE_TABLE].(func(string) error)(tableName)
}
@@ -194,7 +194,7 @@ func DeleteAllRecords(tableName string) error {
if err != nil {
return err
}
err = createTable(tableName)
err = CreateTable(tableName)
if err != nil {
return err
}

View File

@@ -32,12 +32,12 @@ func sendTelemetry() error {
return nil
}
var telRecord, err = fetchTelemetryRecord()
var telRecord, err = FetchTelemetryRecord()
if err != nil {
return err
}
// get telemetry data
d, err := fetchTelemetryData()
d, err := FetchTelemetryData()
if err != nil {
return err
}
@@ -71,8 +71,8 @@ func sendTelemetry() error {
})
}
// fetchTelemetry - fetches telemetry data: count of various object types in DB
func fetchTelemetryData() (telemetryData, error) {
// FetchTelemetryData - fetches telemetry data: count of various object types in DB
func FetchTelemetryData() (telemetryData, error) {
var data telemetryData
data.IsPro = servercfg.IsPro
@@ -138,8 +138,8 @@ func getClientCount(nodes []models.Node) clientCount {
return count
}
// fetchTelemetryRecord - get the existing UUID and Timestamp from the DB
func fetchTelemetryRecord() (models.Telemetry, error) {
// FetchTelemetryRecord - get the existing UUID and Timestamp from the DB
func FetchTelemetryRecord() (models.Telemetry, error) {
var rawData string
var telObj models.Telemetry
var err error

View File

@@ -3,11 +3,12 @@ package logic
import (
"context"
"fmt"
"github.com/gravitl/netmaker/logger"
"golang.org/x/exp/slog"
"sync"
"time"
"github.com/gravitl/netmaker/logger"
"golang.org/x/exp/slog"
"github.com/gravitl/netmaker/models"
)
@@ -24,7 +25,7 @@ var HookManagerCh = make(chan models.HookDetails, 3)
// TimerCheckpoint - Checks if 24 hours has passed since telemetry was last sent. If so, sends telemetry data to posthog
func TimerCheckpoint() error {
// get the telemetry record in the DB, which contains a timestamp
telRecord, err := fetchTelemetryRecord()
telRecord, err := FetchTelemetryRecord()
if err != nil {
return err
}

View File

@@ -2,7 +2,7 @@ package logic
// RetrievePrivateTrafficKey - retrieves private key of server
func RetrievePrivateTrafficKey() ([]byte, error) {
var telRecord, err = fetchTelemetryRecord()
var telRecord, err = FetchTelemetryRecord()
if err != nil {
return nil, err
}
@@ -12,7 +12,7 @@ func RetrievePrivateTrafficKey() ([]byte, error) {
// RetrievePublicTrafficKey - retrieves public key of server
func RetrievePublicTrafficKey() ([]byte, error) {
var telRecord, err = fetchTelemetryRecord()
var telRecord, err = FetchTelemetryRecord()
if err != nil {
return nil, err
}

View File

@@ -4,6 +4,8 @@
package pro
import (
"time"
controller "github.com/gravitl/netmaker/controllers"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
@@ -17,6 +19,7 @@ import (
// InitPro - Initialize Pro Logic
func InitPro() {
servercfg.IsPro = true
proLogic.InitTrial()
models.SetLogo(retrieveProLogo())
controller.HttpMiddlewares = append(
controller.HttpMiddlewares,
@@ -31,18 +34,36 @@ func InitPro() {
)
logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() {
// == License Handling ==
ClearLicenseCache()
if err := ValidateLicense(); err != nil {
slog.Error(err.Error())
return
enableLicenseHook := false
trialEndDate, err := getTrialEndDate()
if err != nil {
slog.Error("failed to get trial end date", "error", err)
enableLicenseHook = true
}
slog.Info("proceeding with Paid Tier license")
logic.SetFreeTierForTelemetry(false)
// == End License Handling ==
AddLicenseHooks()
// check if trial ended
if time.Now().After(trialEndDate) {
// trial ended already
enableLicenseHook = true
}
if enableLicenseHook {
slog.Info("starting license checker")
ClearLicenseCache()
if err := ValidateLicense(); err != nil {
slog.Error(err.Error())
return
}
slog.Info("proceeding with Paid Tier license")
logic.SetFreeTierForTelemetry(false)
// == End License Handling ==
AddLicenseHooks()
} else {
addTrialLicenseHook()
}
if servercfg.GetServerConfig().RacAutoDisable {
AddRacHooks()
}
})
logic.ResetFailOver = proLogic.ResetFailOver
logic.ResetFailedOverPeer = proLogic.ResetFailedOverPeer

146
pro/trial.go Normal file
View File

@@ -0,0 +1,146 @@
//go:build ee
// +build ee
package pro
import (
"crypto/rand"
"encoding/json"
"errors"
"time"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/netclient/ncutils"
"golang.org/x/crypto/nacl/box"
"golang.org/x/exp/slog"
)
type TrialInfo struct {
PrivKey []byte `json:"priv_key"`
PubKey []byte `json:"pub_key"`
Secret string `json:"secret"`
}
func addTrialLicenseHook() {
logic.HookManagerCh <- models.HookDetails{
Hook: TrialLicenseHook,
Interval: time.Hour,
}
}
type TrialDates struct {
TrialStartedAt time.Time `json:"trial_started_at"`
TrialEndsAt time.Time `json:"trial_ends_at"`
}
const trial_table_name = "trial"
const trial_data_key = "trialdata"
// store trial date
func InitTrial() error {
telData, err := logic.FetchTelemetryData()
if err != nil {
return err
}
if telData.Hosts > 0 || telData.Networks > 0 || telData.Users > 0 {
return nil
}
err = database.CreateTable(trial_table_name)
if err != nil {
slog.Error("failed to create table", "table name", trial_table_name, "err", err.Error())
return err
}
// setup encryption keys
trafficPubKey, trafficPrivKey, err := box.GenerateKey(rand.Reader) // generate traffic keys
if err != nil {
return err
}
tPriv, err := ncutils.ConvertKeyToBytes(trafficPrivKey)
if err != nil {
return err
}
tPub, err := ncutils.ConvertKeyToBytes(trafficPubKey)
if err != nil {
return err
}
trialDates := TrialDates{
TrialStartedAt: time.Now(),
TrialEndsAt: time.Now().Add(time.Hour * 24 * 30),
}
t := TrialInfo{
PrivKey: tPriv,
PubKey: tPub,
}
tel, err := logic.FetchTelemetryRecord()
if err != nil {
return err
}
trialDatesData, err := json.Marshal(trialDates)
if err != nil {
return err
}
trialDatesSecret, err := ncutils.BoxEncrypt(trialDatesData, (*[32]byte)(tel.TrafficKeyPub), (*[32]byte)(t.PrivKey))
if err != nil {
return err
}
t.Secret = string(trialDatesSecret)
trialData, err := json.Marshal(t)
if err != nil {
return err
}
err = database.Insert(trial_data_key, string(trialData), trial_table_name)
if err != nil {
return err
}
return nil
}
func TrialLicenseHook() error {
endDate, err := getTrialEndDate()
if err != nil {
logger.FatalLog0("failed to trial end date", err.Error())
}
if time.Now().After(endDate) {
logger.FatalLog0("***IMPORTANT: Your Trial Has Ended, to continue using pro version, please visit https://app.netmaker.io/ and create on-prem tenant to obtain a license***\nIf you wish to downgrade to community version, please run this command `/root/nm-quick.sh -d`")
}
return nil
}
// get trial date
func getTrialEndDate() (time.Time, error) {
record, err := database.FetchRecord(trial_table_name, trial_data_key)
if err != nil {
return time.Time{}, err
}
var trialInfo TrialInfo
err = json.Unmarshal([]byte(record), &trialInfo)
if err != nil {
return time.Time{}, err
}
tel, err := logic.FetchTelemetryRecord()
if err != nil {
return time.Time{}, err
}
// decrypt secret
secretDecrypt, err := ncutils.BoxDecrypt([]byte(trialInfo.Secret), (*[32]byte)(trialInfo.PubKey), (*[32]byte)(tel.TrafficKeyPriv))
if err != nil {
return time.Time{}, err
}
trialDates := TrialDates{}
err = json.Unmarshal(secretDecrypt, &trialDates)
if err != nil {
return time.Time{}, err
}
if trialDates.TrialEndsAt.IsZero() {
return time.Time{}, errors.New("invalid date")
}
return trialDates.TrialEndsAt, nil
}