added try to use func and edited tests

This commit is contained in:
0xdcarns
2023-02-15 15:52:58 -05:00
parent db4ea9faa4
commit 0e5e34ef0c
3 changed files with 76 additions and 27 deletions

View File

@@ -2,6 +2,7 @@ package logic
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"time" "time"
@@ -12,15 +13,15 @@ import (
// EnrollmentKeyErrors - struct for holding EnrollmentKey error messages // EnrollmentKeyErrors - struct for holding EnrollmentKey error messages
var EnrollmentKeyErrors = struct { var EnrollmentKeyErrors = struct {
InvalidCreate string InvalidCreate error
NoKeyFound string NoKeyFound error
InvalidKey string InvalidKey error
NoUsesRemaining string NoUsesRemaining error
}{ }{
InvalidCreate: "invalid enrollment key created", InvalidCreate: fmt.Errorf("invalid enrollment key created"),
NoKeyFound: "no enrollmentkey found", NoKeyFound: fmt.Errorf("no enrollmentkey found"),
InvalidKey: "invalid key provided", InvalidKey: fmt.Errorf("invalid key provided"),
NoUsesRemaining: "no uses remaining", NoUsesRemaining: fmt.Errorf("no uses remaining"),
} }
// CreateEnrollmentKey - creates a new enrollment key in db // CreateEnrollmentKey - creates a new enrollment key in db
@@ -50,7 +51,7 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string
k.Tags = tags k.Tags = tags
} }
if ok := k.Validate(); !ok { if ok := k.Validate(); !ok {
return nil, fmt.Errorf(EnrollmentKeyErrors.InvalidCreate) return nil, EnrollmentKeyErrors.InvalidCreate
} }
if err = upsertEnrollmentKey(k); err != nil { if err = upsertEnrollmentKey(k); err != nil {
return nil, err return nil, err
@@ -81,7 +82,7 @@ func GetEnrollmentKey(value string) (*models.EnrollmentKey, error) {
if key, ok := currentKeys[value]; ok { if key, ok := currentKeys[value]; ok {
return key, nil return key, nil
} }
return nil, fmt.Errorf(EnrollmentKeyErrors.NoKeyFound) return nil, EnrollmentKeyErrors.NoKeyFound
} }
// DeleteEnrollmentKey - delete's a given enrollment key by value // DeleteEnrollmentKey - delete's a given enrollment key by value
@@ -93,14 +94,31 @@ func DeleteEnrollmentKey(value string) error {
return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value) return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value)
} }
// DecrementEnrollmentKey - decrements the uses on a key if above 0 remaining // TryToUseEnrollmentKey - checks first if key can be decremented
func DecrementEnrollmentKey(value string) (*models.EnrollmentKey, error) { // returns true if it is decremented or isvalid
func TryToUseEnrollmentKey(k *models.EnrollmentKey) bool {
key, err := decrementEnrollmentKey(k.Value)
if err != nil {
if errors.Is(err, EnrollmentKeyErrors.NoUsesRemaining) {
return k.IsValid()
}
} else {
k.UsesRemaining = key.UsesRemaining
return true
}
return false
}
// == private ==
// decrementEnrollmentKey - decrements the uses on a key if above 0 remaining
func decrementEnrollmentKey(value string) (*models.EnrollmentKey, error) {
k, err := GetEnrollmentKey(value) k, err := GetEnrollmentKey(value)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if k.UsesRemaining == 0 { if k.UsesRemaining == 0 {
return nil, fmt.Errorf(EnrollmentKeyErrors.NoUsesRemaining) return nil, EnrollmentKeyErrors.NoUsesRemaining
} }
k.UsesRemaining = k.UsesRemaining - 1 k.UsesRemaining = k.UsesRemaining - 1
if err = upsertEnrollmentKey(k); err != nil { if err = upsertEnrollmentKey(k); err != nil {
@@ -110,11 +128,9 @@ func DecrementEnrollmentKey(value string) (*models.EnrollmentKey, error) {
return k, nil return k, nil
} }
// == private ==
func upsertEnrollmentKey(k *models.EnrollmentKey) error { func upsertEnrollmentKey(k *models.EnrollmentKey) error {
if k == nil { if k == nil {
return fmt.Errorf(EnrollmentKeyErrors.InvalidKey) return EnrollmentKeyErrors.InvalidKey
} }
data, err := json.Marshal(k) data, err := json.Marshal(k)
if err != nil { if err != nil {

View File

@@ -15,7 +15,7 @@ func TestCreateEnrollmentKey(t *testing.T) {
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false) newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false)
assert.Nil(t, newKey) assert.Nil(t, newKey)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, err.Error(), EnrollmentKeyErrors.InvalidCreate) assert.Equal(t, err, EnrollmentKeyErrors.InvalidCreate)
}) })
t.Run("Can_Create_Key_Uses", func(t *testing.T) { t.Run("Can_Create_Key_Uses", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false) newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
@@ -59,12 +59,12 @@ func TestDelete_EnrollmentKey(t *testing.T) {
oldKey, err := GetEnrollmentKey(newKey.Value) oldKey, err := GetEnrollmentKey(newKey.Value)
assert.Nil(t, oldKey) assert.Nil(t, oldKey)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, err.Error(), EnrollmentKeyErrors.NoKeyFound) assert.Equal(t, err, EnrollmentKeyErrors.NoKeyFound)
}) })
t.Run("Can_Not_Delete_Invalid_Key", func(t *testing.T) { t.Run("Can_Not_Delete_Invalid_Key", func(t *testing.T) {
err := DeleteEnrollmentKey("notakey") err := DeleteEnrollmentKey("notakey")
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, err.Error(), EnrollmentKeyErrors.NoKeyFound) assert.Equal(t, err, EnrollmentKeyErrors.NoKeyFound)
}) })
removeAllEnrollments() removeAllEnrollments()
} }
@@ -72,32 +72,57 @@ func TestDelete_EnrollmentKey(t *testing.T) {
func TestDecrement_EnrollmentKey(t *testing.T) { func TestDecrement_EnrollmentKey(t *testing.T) {
database.InitializeDatabase() database.InitializeDatabase()
defer database.CloseDB() defer database.CloseDB()
newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, true) newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
t.Run("Check_initial_uses", func(t *testing.T) { t.Run("Check_initial_uses", func(t *testing.T) {
assert.True(t, newKey.IsValid()) assert.True(t, newKey.IsValid())
assert.Equal(t, newKey.UsesRemaining, 1) assert.Equal(t, newKey.UsesRemaining, 1)
}) })
t.Run("Check can decrement", func(t *testing.T) { t.Run("Check can decrement", func(t *testing.T) {
assert.Equal(t, newKey.UsesRemaining, 1) assert.Equal(t, newKey.UsesRemaining, 1)
k, err := DecrementEnrollmentKey(newKey.Value) k, err := decrementEnrollmentKey(newKey.Value)
assert.Nil(t, err) assert.Nil(t, err)
newKey = k newKey = k
}) })
t.Run("Check can not decrement", func(t *testing.T) { t.Run("Check can not decrement", func(t *testing.T) {
assert.Equal(t, newKey.UsesRemaining, 0) assert.Equal(t, newKey.UsesRemaining, 0)
_, err := DecrementEnrollmentKey(newKey.Value) _, err := decrementEnrollmentKey(newKey.Value)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, err.Error(), EnrollmentKeyErrors.NoUsesRemaining) assert.Equal(t, err, EnrollmentKeyErrors.NoUsesRemaining)
}) })
removeAllEnrollments() removeAllEnrollments()
} }
// func TestValidity_EnrollmentKey(t *testing.T) { func TestUsability_EnrollmentKey(t *testing.T) {
// database.InitializeDatabase() database.InitializeDatabase()
// defer database.CloseDB() defer database.CloseDB()
key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, false)
key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, true)
t.Run("Check if valid use key can be used", func(t *testing.T) {
assert.Equal(t, key1.UsesRemaining, 1)
ok := TryToUseEnrollmentKey(key1)
assert.True(t, ok)
assert.Equal(t, 0, key1.UsesRemaining)
})
// } t.Run("Check if valid time key can be used", func(t *testing.T) {
assert.True(t, !key2.Expiration.IsZero())
ok := TryToUseEnrollmentKey(key2)
assert.True(t, ok)
})
t.Run("Check if valid unlimited key can be used", func(t *testing.T) {
assert.True(t, key3.Unlimited)
ok := TryToUseEnrollmentKey(key3)
assert.True(t, ok)
})
t.Run("check invalid key can not be used", func(t *testing.T) {
ok := TryToUseEnrollmentKey(key1)
assert.False(t, ok)
})
}
func removeAllEnrollments() { func removeAllEnrollments() {
database.DeleteAllRecords(database.ENROLLMENT_KEYS_TABLE_NAME) database.DeleteAllRecords(database.ENROLLMENT_KEYS_TABLE_NAME)

View File

@@ -1,6 +1,7 @@
package logic package logic
import ( import (
"context"
"net" "net"
"testing" "testing"
@@ -13,6 +14,13 @@ import (
func TestCheckPorts(t *testing.T) { func TestCheckPorts(t *testing.T) {
database.InitializeDatabase() database.InitializeDatabase()
defer database.CloseDB() defer database.CloseDB()
peerUpdate := make(chan *models.Node)
go ManageZombies(context.Background(), peerUpdate)
go func() {
for _ = range peerUpdate {
//do nothing
}
}()
h := models.Host{ h := models.Host{
ID: uuid.New(), ID: uuid.New(),