diff --git a/Dockerfile b/Dockerfile index b5a4395..da8f672 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,4 +23,5 @@ CMD "./server" \ "--dataDir" "${DATADIR}" \ "--http=${HTTP}" \ "--tls=${TLS}" \ - "--inMemory=${INMEMORY}" \ \ No newline at end of file + "--inMemory=${INMEMORY}" \ + "--bootstrapCluster=${BOOTSTRAP_CLUSTER}" \ \ No newline at end of file diff --git a/src/main.go b/src/main.go index 98d233f..b81960b 100644 --- a/src/main.go +++ b/src/main.go @@ -37,8 +37,9 @@ type Server struct { connID atomic.Uint64 - store map[string]interface{} - keyLocks map[string]*sync.RWMutex + store map[string]interface{} + keyLocks map[string]*sync.RWMutex + keyCreationLock *sync.Mutex plugins []Plugin @@ -97,16 +98,25 @@ func (server *Server) KeyExists(key string) bool { return server.keyLocks[key] != nil } -func (server *Server) CreateKey(key string, value interface{}) { - server.keyLocks[key] = &sync.RWMutex{} - server.store[key] = value +func (server *Server) CreateKeyAndLock(ctx context.Context, key string) (bool, error) { + server.keyCreationLock.Lock() + defer server.keyCreationLock.Unlock() + + if !server.KeyExists(key) { + keyLock := &sync.RWMutex{} + keyLock.Lock() + server.keyLocks[key] = keyLock + return true, nil + } + + return server.KeyLock(ctx, key) } func (server *Server) GetValue(key string) interface{} { return server.store[key] } -func (server *Server) SetValue(key string, value interface{}) { +func (server *Server) SetValue(ctx context.Context, key string, value interface{}) { server.store[key] = value } @@ -358,6 +368,7 @@ func (server *Server) Start(ctx context.Context) { server.store = make(map[string]interface{}) server.keyLocks = make(map[string]*sync.RWMutex) + server.keyCreationLock = &sync.Mutex{} server.LoadPlugins(ctx) diff --git a/src/plugins/commands/get/command.go b/src/plugins/commands/get/command.go index be6039e..cd8d71e 100644 --- a/src/plugins/commands/get/command.go +++ b/src/plugins/commands/get/command.go @@ -13,9 +13,9 @@ type Server interface { KeyRLock(ctx context.Context, key string) (bool, error) KeyRUnlock(key string) KeyExists(key string) bool - CreateKey(key string, value interface{}) + CreateKeyAndLock(ctx context.Context, key string) (bool, error) GetValue(key string) interface{} - SetValue(key string, value interface{}) + SetValue(ctx context.Context, key string, value interface{}) } type plugin struct { @@ -54,9 +54,11 @@ func handleGet(ctx context.Context, cmd []string, s Server) ([]byte, error) { return nil, errors.New("wrong number of args for GET command") } - s.KeyRLock(ctx, cmd[1]) - value := s.GetValue(cmd[1]) - s.KeyRUnlock(cmd[1]) + key := cmd[1] + + s.KeyRLock(ctx, key) + value := s.GetValue(key) + s.KeyRUnlock(key) switch value.(type) { default: @@ -74,14 +76,17 @@ func handleMGet(ctx context.Context, cmd []string, s Server) ([]byte, error) { vals := []string{} for _, key := range cmd[1:] { - s.KeyRLock(ctx, key) - switch s.GetValue(key).(type) { - default: - vals = append(vals, fmt.Sprintf("%v", s.GetValue(key))) - case nil: - vals = append(vals, "nil") - } - s.KeyRUnlock(key) + func(key string) { + s.KeyRLock(ctx, key) + switch s.GetValue(key).(type) { + default: + vals = append(vals, fmt.Sprintf("%v", s.GetValue(key))) + case nil: + vals = append(vals, "nil") + } + s.KeyRUnlock(key) + + }(key) } var bytes []byte = []byte(fmt.Sprintf("*%d\r\n", len(vals))) diff --git a/src/plugins/commands/list/command.go b/src/plugins/commands/list/command.go index 4873f70..363d56d 100644 --- a/src/plugins/commands/list/command.go +++ b/src/plugins/commands/list/command.go @@ -21,9 +21,9 @@ type Server interface { KeyRLock(ctx context.Context, key string) (bool, error) KeyRUnlock(key string) KeyExists(key string) bool - CreateKey(key string, value interface{}) + CreateKeyAndLock(ctx context.Context, key string) (bool, error) GetValue(key string) interface{} - SetValue(key string, value interface{}) + SetValue(ctx context.Context, key string, value interface{}) } type plugin struct { @@ -246,7 +246,7 @@ func handleLSet(ctx context.Context, cmd []string, server Server) ([]byte, error } list[index] = utils.AdaptType(cmd[3]) - server.SetValue(cmd[1], list) + server.SetValue(ctx, cmd[1], list) server.KeyUnlock(cmd[1]) return []byte(OK), nil @@ -284,12 +284,12 @@ func handleLTrim(ctx context.Context, cmd []string, server Server) ([]byte, erro } if end == -1 || int(end) > len(list) { - server.SetValue(cmd[1], list[start:]) + server.SetValue(ctx, cmd[1], list[start:]) server.KeyUnlock(cmd[1]) return []byte(OK), nil } - server.SetValue(cmd[1], list[start:end]) + server.SetValue(ctx, cmd[1], list[start:end]) server.KeyUnlock(cmd[1]) return []byte(OK), nil } @@ -350,7 +350,7 @@ func handleLRem(ctx context.Context, cmd []string, server Server) ([]byte, error return elem != nil }) - server.SetValue(cmd[1], list) + server.SetValue(ctx, cmd[1], list) server.KeyUnlock(cmd[1]) return []byte(OK), nil @@ -384,18 +384,18 @@ func handleLMove(ctx context.Context, cmd []string, server Server) ([]byte, erro switch whereFrom { case "left": - server.SetValue(cmd[1], append([]interface{}{}, source[1:]...)) + server.SetValue(ctx, cmd[1], append([]interface{}{}, source[1:]...)) if whereTo == "left" { - server.SetValue(cmd[2], append(source[0:1], destination...)) + server.SetValue(ctx, cmd[2], append(source[0:1], destination...)) } else if whereTo == "right" { - server.SetValue(cmd[2], append(destination, source[0])) + server.SetValue(ctx, cmd[2], append(destination, source[0])) } case "right": - server.SetValue(cmd[1], append([]interface{}{}, source[:len(source)-1]...)) + server.SetValue(ctx, cmd[1], append([]interface{}{}, source[:len(source)-1]...)) if whereTo == "left" { - server.SetValue(cmd[2], append(source[len(source)-1:], destination...)) + server.SetValue(ctx, cmd[2], append(source[len(source)-1:], destination...)) } else if whereTo == "right" { - server.SetValue(cmd[2], append(destination, source[len(source)-1])) + server.SetValue(ctx, cmd[2], append(destination, source[len(source)-1])) } } @@ -416,19 +416,22 @@ func handleLPush(ctx context.Context, cmd []string, server Server) ([]byte, erro newElems = append(newElems, utils.AdaptType(elem)) } - if !server.KeyExists(cmd[1]) { + key := cmd[1] + + if !server.KeyExists(key) { switch strings.ToLower(cmd[0]) { case "lpushx": return nil, fmt.Errorf("%s command on non-list item", cmd[0]) default: - server.CreateKey(cmd[1], []interface{}{}) + // TODO: Retry CreateKeyAndLock until we obtain the key lock + server.CreateKeyAndLock(ctx, key) + server.SetValue(ctx, key, []interface{}{}) } } - server.KeyLock(ctx, cmd[1]) - defer server.KeyUnlock(cmd[1]) + defer server.KeyUnlock(key) - currentList := server.GetValue(cmd[1]) + currentList := server.GetValue(key) l, ok := currentList.([]interface{}) @@ -436,7 +439,7 @@ func handleLPush(ctx context.Context, cmd []string, server Server) ([]byte, erro return nil, fmt.Errorf("%s command on non-list item", cmd[0]) } - server.SetValue(cmd[1], append(newElems, l...)) + server.SetValue(ctx, key, append(newElems, l...)) return []byte(OK), nil } @@ -456,11 +459,12 @@ func handleRPush(ctx context.Context, cmd []string, server Server) ([]byte, erro case "rpushx": return nil, fmt.Errorf("%s command on non-list item", cmd[0]) default: - server.CreateKey(cmd[1], []interface{}{}) + // TODO: Retry CreateKeyAndLock until we managed to obtain the key + server.CreateKeyAndLock(ctx, cmd[1]) + server.SetValue(ctx, cmd[1], []interface{}{}) } } - server.KeyLock(ctx, cmd[1]) defer server.KeyUnlock(cmd[1]) currentList := server.GetValue(cmd[1]) @@ -471,7 +475,7 @@ func handleRPush(ctx context.Context, cmd []string, server Server) ([]byte, erro return nil, errors.New("RPUSH command on non-list item") } - server.SetValue(cmd[1], append(l, newElems...)) + server.SetValue(ctx, cmd[1], append(l, newElems...)) return []byte(OK), nil } @@ -499,10 +503,10 @@ func handlePop(ctx context.Context, cmd []string, server Server) ([]byte, error) switch strings.ToLower(cmd[0]) { default: - server.SetValue(cmd[1], list[1:]) + server.SetValue(ctx, cmd[1], list[1:]) return []byte(fmt.Sprintf("+%v\r\n\n", list[0])), nil case "rpop": - server.SetValue(cmd[1], list[:len(list)-1]) + server.SetValue(ctx, cmd[1], list[:len(list)-1]) return []byte(fmt.Sprintf("+%v\r\n\n", list[len(list)-1])), nil } diff --git a/src/plugins/commands/ping/command.go b/src/plugins/commands/ping/command.go index d083413..c8d61a3 100644 --- a/src/plugins/commands/ping/command.go +++ b/src/plugins/commands/ping/command.go @@ -10,13 +10,14 @@ const ( ) type Server interface { - KeyLock(key string) + KeyLock(ctx context.Context, key string) (bool, error) KeyUnlock(key string) - KeyRLock(key string) + KeyRLock(ctx context.Context, key string) (bool, error) KeyRUnlock(key string) - CreateKey(key string, value interface{}) + KeyExists(key string) bool + CreateKeyAndLock(ctx context.Context, key string) (bool, error) GetValue(key string) interface{} - SetValue(key string, value interface{}) + SetValue(ctx context.Context, key string, value interface{}) } type plugin struct { diff --git a/src/plugins/commands/set/command.go b/src/plugins/commands/set/command.go index 985da36..beadfa8 100644 --- a/src/plugins/commands/set/command.go +++ b/src/plugins/commands/set/command.go @@ -15,9 +15,9 @@ type Server interface { KeyRLock(ctx context.Context, key string) (bool, error) KeyRUnlock(key string) KeyExists(key string) bool - CreateKey(key string, value interface{}) + CreateKeyAndLock(ctx context.Context, key string) (bool, error) GetValue(key string) interface{} - SetValue(key string, value interface{}) + SetValue(ctx context.Context, key string, value interface{}) } type plugin struct { @@ -61,17 +61,22 @@ func handleSet(ctx context.Context, cmd []string, s Server) ([]byte, error) { default: return nil, errors.New("wrong number of args for SET command") case x == 3: - if !s.KeyExists(cmd[1]) { - s.CreateKey(cmd[1], utils.AdaptType(cmd[2])) + key := cmd[1] + + if !s.KeyExists(key) { + // TODO: Retry CreateKeyAndLock until we manage to obtain the key + s.CreateKeyAndLock(ctx, key) + s.SetValue(ctx, key, utils.AdaptType(cmd[2])) + s.KeyUnlock(key) return []byte("+OK\r\n\n"), nil } - if _, err := s.KeyLock(ctx, cmd[1]); err != nil { + if _, err := s.KeyLock(ctx, key); err != nil { return nil, err } - s.SetValue(cmd[1], utils.AdaptType(cmd[2])) - s.KeyUnlock(cmd[1]) + s.SetValue(ctx, key, utils.AdaptType(cmd[2])) + s.KeyUnlock(key) return []byte("+OK\r\n\n"), nil } } @@ -81,10 +86,14 @@ func handleSetNX(ctx context.Context, cmd []string, s Server) ([]byte, error) { default: return nil, errors.New("wrong number of args for SETNX command") case x == 3: - if s.KeyExists(cmd[1]) { + key := cmd[1] + if s.KeyExists(key) { return nil, fmt.Errorf("key %s already exists", cmd[1]) } - s.CreateKey(cmd[1], utils.AdaptType(cmd[2])) + // TODO: Retry CreateKeyAndLock until we manage to obtain the key + s.CreateKeyAndLock(ctx, key) + s.SetValue(ctx, key, utils.AdaptType(cmd[2])) + s.KeyUnlock(key) } return []byte("+OK\r\n\n"), nil } diff --git a/src/raft.go b/src/raft.go index 749b217..ef94b67 100644 --- a/src/raft.go +++ b/src/raft.go @@ -88,8 +88,7 @@ func (server *Server) RaftInit(ctx context.Context) { server.raft = raftServer - // TODO: Only bootstrap cluster if --bootstrapCluster=true config is set - if conf.JoinAddr == "" { + if conf.BootstrapCluster { // Bootstrap raft cluster if err := server.raft.BootstrapCluster(raft.Configuration{ Servers: []raft.Server{ @@ -106,7 +105,7 @@ func (server *Server) RaftInit(ctx context.Context) { } -// Implement raft.FSM interface +// Apply Implements raft.FSM interface func (server *Server) Apply(log *raft.Log) interface{} { switch log.Type { case raft.LogCommand: @@ -191,7 +190,7 @@ func (server *Server) Restore(snapshot io.ReadCloser) error { for k, v := range data { server.keyLocks[k].Lock() - server.SetValue(k, v) + server.SetValue(context.Background(), k, v) server.keyLocks[k].Unlock() } @@ -256,7 +255,7 @@ func (server *Server) addVoter( } for _, s := range raftConfig.Configuration().Servers { - // Check if a server already exists with the current attribtues + // Check if a server already exists with the current attributes if s.ID == id && s.Address == address { return fmt.Errorf("server with id %s and address %s already exists", id, address) } diff --git a/src/utils/config.go b/src/utils/config.go index b0f5ef1..a413acb 100644 --- a/src/utils/config.go +++ b/src/utils/config.go @@ -23,6 +23,7 @@ type Config struct { MemberListBindPort uint16 `json:"mlPort" yaml:"mlPort"` InMemory bool `json:"inMemory" yaml:"inMemory"` DataDir string `json:"dataDir" yaml:"dataDir"` + BootstrapCluster bool `json:"BootstrapCluster" yaml:"bootstrapCluster"` } func GetConfig() Config { @@ -37,8 +38,10 @@ func GetConfig() Config { bindAddr := flag.String("bindAddr", "", "Address to bind the server to.") raftBindPort := flag.Int("raftPort", 7481, "Port to use for intra-cluster communication. Leave on the client.") mlBindPort := flag.Int("mlPort", 7946, "Port to use for memberlist communication.") - inMemory := flag.Bool("inMemory", false, "Wether to use memory or persisten storage for raft logs and snapshots.") + inMemory := flag.Bool("inMemory", false, "Whether to use memory or persisten storage for raft logs and snapshots.") dataDir := flag.String("dataDir", "/var/lib/memstore", "Directory to store raft snapshots and logs.") + bootstrapCluster := flag.Bool("bootstrapCluster", false, "Whether this instance should bootstrap a new cluster.") + config := flag.String( "config", "", @@ -61,6 +64,7 @@ func GetConfig() Config { MemberListBindPort: uint16(*mlBindPort), InMemory: *inMemory, DataDir: *dataDir, + BootstrapCluster: *bootstrapCluster, } if len(*config) > 0 {