Files
SugarDB/internal/modules/acl/acl.go
Kelvin Clement Mwinuka 0108444d69 Replaced fmt.Println statements with log.Println.
Return "empty command" error from handleCommand method if an empty command is passed to the server.
Wait until connection is no longer nil in acl package tests.
2024-05-27 11:45:48 +08:00

482 lines
13 KiB
Go

// Copyright 2024 Kelvin Clement Mwinuka
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package acl
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/internal/constants"
"github.com/gobwas/glob"
"gopkg.in/yaml.v3"
"log"
"net"
"os"
"path"
"reflect"
"slices"
"strings"
"sync"
"time"
)
type Connection struct {
Authenticated bool // Whether the connection has been authenticated
User *User // The user the connection is associated with
}
type ACL struct {
Users []*User // List of ACL user profiles
UsersMutex sync.RWMutex // RWMutex for concurrency control when accessing ACL profile list
Connections map[*net.Conn]Connection // Connections to the echovault that are currently registered with the ACL module
Config config.Config // EchoVault configuration that contains the relevant ACL config options
GlobPatterns map[string]glob.Glob
}
func NewACL(config config.Config) *ACL {
var users []*User
// 1. Initialise default ACL user
defaultUser := CreateUser("default")
if config.RequirePass {
defaultUser.NoPassword = false
defaultUser.Passwords = []Password{
{
PasswordType: GetPasswordType(config.Password),
PasswordValue: config.Password,
},
}
}
// 2. Read and parse the ACL config file
if config.AclConfig != "" {
// Override acl configurations from file
if f, err := os.Open(config.AclConfig); err != nil {
panic(err)
} else {
defer func() {
if err := f.Close(); err != nil {
log.Println("acl config file close error: ", err)
}
}()
ext := path.Ext(f.Name())
if ext == ".json" {
if err := json.NewDecoder(f).Decode(&users); err != nil {
log.Fatal("could not load JSON ACL config: ", err)
}
}
if ext == ".yaml" || ext == ".yml" {
if err := yaml.NewDecoder(f).Decode(&users); err != nil {
log.Fatal("could not load YAML ACL config: ", err)
}
}
}
}
// 3. If default user was not loaded from file, add the created one
defaultLoaded := false
for _, user := range users {
if user.Username == "default" {
defaultLoaded = true
break
}
}
if !defaultLoaded {
users = append([]*User{defaultUser}, users...)
}
// 4. Normalise all users
for _, user := range users {
user.Normalise()
}
acl := ACL{
Users: users,
UsersMutex: sync.RWMutex{},
Connections: make(map[*net.Conn]Connection),
Config: config,
GlobPatterns: make(map[string]glob.Glob),
}
acl.CompileGlobs()
return &acl
}
func (acl *ACL) RegisterConnection(conn *net.Conn) {
acl.LockUsers()
defer acl.UnlockUsers()
// This is called only when a connection is established.
defaultUserIdx := slices.IndexFunc(acl.Users, func(user *User) bool {
return user.Username == "default"
})
defaultUser := acl.Users[defaultUserIdx]
acl.Connections[conn] = Connection{
Authenticated: defaultUser.NoPassword,
User: defaultUser,
}
}
func (acl *ACL) SetUser(cmd []string) error {
acl.LockUsers()
defer acl.UnlockUsers()
// Check if user with the given username already exists
// If it does, replace user variable with this user
for _, user := range acl.Users {
if user.Username == cmd[0] {
if err := user.UpdateUser(cmd); err != nil {
return err
} else {
acl.CompileGlobs()
return nil
}
}
}
user := CreateUser(cmd[0])
if err := user.UpdateUser(cmd); err != nil {
return err
}
user.Normalise()
// Add user to ACL
acl.Users = append(acl.Users, user)
acl.CompileGlobs()
return nil
}
func (acl *ACL) AddUsers(users []*User) {
acl.LockUsers()
defer acl.UnlockUsers()
acl.Users = append(acl.Users, users...)
}
func (acl *ACL) DeleteUser(_ context.Context, usernames []string) error {
acl.LockUsers()
defer acl.UnlockUsers()
var user *User
for _, username := range usernames {
if username == "default" {
// Skip default user
continue
}
// Extract the user
for _, u := range acl.Users {
if username == u.Username {
user = u
}
}
// Skip if the current username was not found in the ACL
if user == nil {
continue
}
// Terminate every connection attached to this user
for connRef, connection := range acl.Connections {
if connection.User.Username == user.Username {
_ = (*connRef).SetReadDeadline(time.Now().Add(-1 * time.Second))
}
}
// Delete the user from the ACL
acl.Users = slices.DeleteFunc(acl.Users, func(u *User) bool {
return u.Username == user.Username
})
}
return nil
}
func (acl *ACL) AuthenticateConnection(_ context.Context, conn *net.Conn, cmd []string) error {
acl.RLockUsers()
defer acl.RUnlockUsers()
var passwords []Password
var user *User
h := sha256.New()
if len(cmd) == 2 {
// Process AUTH <password>
h.Write([]byte(cmd[1]))
passwords = []Password{
{PasswordType: "plaintext", PasswordValue: cmd[1]},
{PasswordType: "SHA256", PasswordValue: string(h.Sum(nil))},
}
// Authenticate with default user
idx := slices.IndexFunc(acl.Users, func(user *User) bool {
return user.Username == "default"
})
user = acl.Users[idx]
}
if len(cmd) == 3 {
// Process AUTH <username> <password>
h.Write([]byte(cmd[2]))
passwords = []Password{
{PasswordType: "plaintext", PasswordValue: cmd[2]},
{PasswordType: "SHA256", PasswordValue: string(h.Sum(nil))},
}
// Find user with the specified username
userFound := false
for _, u := range acl.Users {
if u.Username == cmd[1] {
user = u
userFound = true
break
}
}
if !userFound {
return fmt.Errorf("no user with username %s", cmd[1])
}
}
// If user is not enabled, return error
if !user.Enabled {
return fmt.Errorf("user %s is disabled", user.Username)
}
// If user is set to NoPassword, then immediately authenticate connection without considering the password
if user.NoPassword {
acl.Connections[conn] = Connection{
Authenticated: true,
User: user,
}
return nil
}
for _, userPassword := range user.Passwords {
for _, password := range passwords {
if strings.EqualFold(userPassword.PasswordType, password.PasswordType) &&
userPassword.PasswordValue == password.PasswordValue &&
user.Enabled {
// Set the current connection to the selected user and set them as authenticated
acl.Connections[conn] = Connection{
Authenticated: true,
User: user,
}
return nil
}
}
}
return errors.New("could not authenticate user")
}
func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command internal.Command, subCommand internal.SubCommand) error {
acl.RLockUsers()
defer acl.RUnlockUsers()
// Extract command, categories, and keys
comm := command.Command
categories := command.Categories
keys, err := command.KeyExtractionFunc(cmd)
if err != nil {
return err
}
channels := keys.Channels
readKeys := keys.ReadKeys
writeKeys := keys.WriteKeys
if !reflect.DeepEqual(subCommand, internal.SubCommand{}) {
comm = fmt.Sprintf("%s|%s", comm, subCommand.Command)
categories = append(categories, subCommand.Categories...)
keys, err = subCommand.KeyExtractionFunc(cmd)
if err != nil {
return err
}
}
// Skip ack
if strings.EqualFold(comm, "ack") {
return nil
}
// Skip connection
if strings.EqualFold(comm, "connection") {
return nil
}
// If the command is 'auth', then return early and allow it
if strings.EqualFold(comm, "auth") {
return nil
}
// Get current connection ACL details
connection := acl.Connections[conn]
// If password is not required, allow the connection
if !acl.Config.RequirePass {
return nil
}
// 1. Check if password is required and if the user is authenticated
if acl.Config.RequirePass && !connection.Authenticated {
return errors.New("user must be authenticated")
}
// 2. Check if all categories are in IncludedCategories
var notAllowed []string
if !slices.ContainsFunc(categories, func(category string) bool {
return slices.ContainsFunc(connection.User.IncludedCategories, func(includedCategory string) bool {
if includedCategory == "*" || includedCategory == category {
return true
}
notAllowed = append(notAllowed, fmt.Sprintf("@%s", category))
return false
})
}) {
if len(notAllowed) == 0 {
notAllowed = []string{"@all"}
}
return fmt.Errorf("unauthorized access to the following categories: %+v", notAllowed)
}
// 3. Check if commands category is in ExcludedCategories
if slices.ContainsFunc(categories, func(category string) bool {
return slices.ContainsFunc(connection.User.ExcludedCategories, func(excludedCategory string) bool {
if excludedCategory == "*" || excludedCategory == category {
notAllowed = []string{fmt.Sprintf("@%s", category)}
return true
}
return false
})
}) {
return fmt.Errorf("unauthorized access to the following categories: %+v", notAllowed)
}
// 4. Check if commands are in IncludedCommands
if !slices.ContainsFunc(connection.User.IncludedCommands, func(includedCommand string) bool {
return includedCommand == "*" || includedCommand == comm
}) {
return fmt.Errorf("not authorised to run %s command", comm)
}
// 5. Check if command are in ExcludedCommands
if slices.ContainsFunc(connection.User.ExcludedCommands, func(excludedCommand string) bool {
return excludedCommand == "*" || excludedCommand == comm
}) {
return fmt.Errorf("not authorised to run %s command", comm)
}
// 6. PUBSUB authorisation.
if slices.Contains(categories, constants.PubSubCategory) {
// Loop through each of the channels accessed by this command
for _, channel := range channels {
// 2.1) Check if the channel is in IncludedPubSubChannels
if !slices.ContainsFunc(connection.User.IncludedPubSubChannels, func(includedChannelGlob string) bool {
return acl.GlobPatterns[includedChannelGlob].Match(channel)
}) {
return fmt.Errorf("not authorised to access channel &%s", channel)
}
// 2.2) Check if the channel is in ExcludedPubSubChannels
if slices.ContainsFunc(connection.User.ExcludedPubSubChannels, func(excludedChannelGlob string) bool {
return acl.GlobPatterns[excludedChannelGlob].Match(channel)
}) {
return fmt.Errorf("not authorised to access channel &%s", channel)
}
}
return nil
}
if len(append(readKeys, writeKeys...)) > 0 {
// 7. Check if nokeys is true
if connection.User.NoKeys {
return errors.New("not authorised to access any keys")
}
// 8. Check if readKeys are in IncludedReadKeys
if !slices.ContainsFunc(readKeys, func(key string) bool {
return slices.ContainsFunc(connection.User.IncludedReadKeys, func(readKeyGlob string) bool {
if acl.GlobPatterns[readKeyGlob].Match(key) {
return true
}
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%R", key))
return false
})
}) {
return fmt.Errorf("not authorised to access the following keys %+v", notAllowed)
}
// 9. Check if keys are in IncludedWriteKeys
if !slices.ContainsFunc(writeKeys, func(key string) bool {
return slices.ContainsFunc(connection.User.IncludedWriteKeys, func(writeKeyGlob string) bool {
if acl.GlobPatterns[writeKeyGlob].Match(key) {
return true
}
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%W", key))
return false
})
}) {
return fmt.Errorf("not authorised to access the following keys %+v", notAllowed)
}
}
return nil
}
func (acl *ACL) CompileGlobs() {
// Extract all the relevant globs from all the users
var allGlobs []string
var userGlobs []string
for _, user := range acl.Users {
userGlobs = append(userGlobs, user.IncludedPubSubChannels...)
userGlobs = append(userGlobs, user.ExcludedPubSubChannels...)
userGlobs = append(userGlobs, user.IncludedReadKeys...)
userGlobs = append(userGlobs, user.IncludedWriteKeys...)
for _, g := range userGlobs {
if !slices.Contains(allGlobs, g) {
allGlobs = append(allGlobs, g)
}
}
userGlobs = []string{}
}
// Compile the globs that have not been compiled yet
for _, g := range allGlobs {
if acl.GlobPatterns[g] == nil {
acl.GlobPatterns[g] = glob.MustCompile(g)
}
}
}
func (acl *ACL) LockUsers() {
acl.UsersMutex.Lock()
}
func (acl *ACL) UnlockUsers() {
acl.UsersMutex.Unlock()
}
func (acl *ACL) RLockUsers() {
acl.UsersMutex.RLock()
}
func (acl *ACL) RUnlockUsers() {
acl.UsersMutex.RUnlock()
}