Refactored memberlist and raft layers into their own packages

This commit is contained in:
Kelvin Clement Mwinuka
2024-01-10 02:37:48 +03:00
parent 1f91ac1ac9
commit c82560294d
27 changed files with 685 additions and 524 deletions

View File

@@ -6,16 +6,18 @@ import (
"crypto/tls"
"encoding/json"
"fmt"
"github.com/kelvinmwinuka/memstore/src/modules/acl"
"github.com/kelvinmwinuka/memstore/src/modules/etc"
"github.com/kelvinmwinuka/memstore/src/modules/get"
"github.com/kelvinmwinuka/memstore/src/modules/hash"
"github.com/kelvinmwinuka/memstore/src/modules/list"
"github.com/kelvinmwinuka/memstore/src/modules/ping"
"github.com/kelvinmwinuka/memstore/src/modules/pubsub"
"github.com/kelvinmwinuka/memstore/src/modules/set"
"github.com/kelvinmwinuka/memstore/src/modules/sorted_set"
str "github.com/kelvinmwinuka/memstore/src/modules/string"
"github.com/kelvinmwinuka/memstore/src/command_modules/acl"
"github.com/kelvinmwinuka/memstore/src/command_modules/etc"
"github.com/kelvinmwinuka/memstore/src/command_modules/get"
"github.com/kelvinmwinuka/memstore/src/command_modules/hash"
"github.com/kelvinmwinuka/memstore/src/command_modules/list"
"github.com/kelvinmwinuka/memstore/src/command_modules/ping"
"github.com/kelvinmwinuka/memstore/src/command_modules/pubsub"
"github.com/kelvinmwinuka/memstore/src/command_modules/set"
"github.com/kelvinmwinuka/memstore/src/command_modules/sorted_set"
str "github.com/kelvinmwinuka/memstore/src/command_modules/string"
ml "github.com/kelvinmwinuka/memstore/src/memberlist_layer"
rl "github.com/kelvinmwinuka/memstore/src/raft_layer"
"io"
"log"
"net"
@@ -28,8 +30,6 @@ import (
"syscall"
"time"
"github.com/hashicorp/memberlist"
"github.com/hashicorp/raft"
"github.com/kelvinmwinuka/memstore/src/utils"
)
@@ -44,11 +44,8 @@ type Server struct {
commands []utils.Command
raft *raft.Raft
memberList *memberlist.Memberlist
broadcastQueue *memberlist.TransmitLimitedQueue
numOfNodes int
raft *rl.Raft
memberList *ml.MemberList
cancelCh *chan os.Signal
@@ -232,7 +229,7 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
continue
}
if server.isRaftLeader() {
if server.raft.IsRaftLeader() {
applyFuture := server.raft.Apply(b, 500*time.Millisecond)
if err := applyFuture.Error(); err != nil {
@@ -370,8 +367,20 @@ func (server *Server) Start(ctx context.Context) {
}
if server.IsInCluster() {
server.RaftInit(ctx)
server.MemberListInit(ctx)
// Initialise raft and memberlist
server.raft = rl.NewRaft(rl.RaftOpts{
Config: conf,
Server: server,
GetCommand: server.getCommand,
})
server.memberList = ml.NewMemberList(ml.MemberlistOpts{
Config: conf,
HasJoinedCluster: server.raft.HasJoinedCluster,
AddVoter: server.raft.AddVoter,
RemoveRaftServer: server.raft.RemoveServer,
})
server.raft.RaftInit(ctx)
server.memberList.MemberListInit(ctx)
}
if conf.HTTP {
@@ -387,8 +396,8 @@ func (server *Server) IsInCluster() bool {
func (server *Server) ShutDown(ctx context.Context) {
if server.IsInCluster() {
server.RaftShutdown(ctx)
server.MemberListShutdown(ctx)
server.raft.RaftShutdown(ctx)
server.memberList.MemberListShutdown(ctx)
}
}
@@ -418,9 +427,6 @@ func main() {
connID: atomic.Uint64{},
broadcastQueue: new(memberlist.TransmitLimitedQueue),
numOfNodes: 0,
ACL: acl.NewACL(config),
PubSub: pubsub.NewPubSub(),

View File

@@ -1,221 +0,0 @@
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"time"
"github.com/hashicorp/memberlist"
"github.com/hashicorp/raft"
"github.com/kelvinmwinuka/memstore/src/utils"
"github.com/sethvargo/go-retry"
)
type NodeMeta struct {
ServerID raft.ServerID `json:"ServerID"`
MemberlistAddr string `json:"MemberlistAddr"`
RaftAddr raft.ServerAddress `json:"RaftAddr"`
}
type BroadcastMessage struct {
NodeMeta
Action string `json:"Action"`
Content string `json:"Content"`
}
// Invalidates Implements Broadcast interface
func (broadcastMessage *BroadcastMessage) Invalidates(other memberlist.Broadcast) bool {
mb, ok := other.(*BroadcastMessage)
if !ok {
return false
}
if mb.ServerID == broadcastMessage.ServerID && mb.Action == "RaftJoin" {
return true
}
return false
}
// Message Implements Broadcast interface
func (broadcastMessage *BroadcastMessage) Message() []byte {
msg, err := json.Marshal(broadcastMessage)
if err != nil {
fmt.Println(err)
return []byte{}
}
return msg
}
// Finished Implements Broadcast interface
func (broadcastMessage *BroadcastMessage) Finished() {
// No-Op
}
func (server *Server) MemberListInit(ctx context.Context) {
cfg := memberlist.DefaultLocalConfig()
cfg.BindAddr = server.config.BindAddr
cfg.BindPort = int(server.config.MemberListBindPort)
cfg.Events = server
cfg.Delegate = server
server.broadcastQueue.RetransmitMult = 1
server.broadcastQueue.NumNodes = func() int {
return server.numOfNodes
}
list, err := memberlist.Create(cfg)
server.memberList = list
if err != nil {
log.Fatal(err)
}
if server.config.JoinAddr != "" {
backoffPolicy := utils.RetryBackoff(retry.NewFibonacci(1*time.Second), 5, 200*time.Millisecond, 0, 0)
err := retry.Do(ctx, backoffPolicy, func(ctx context.Context) error {
_, err := list.Join([]string{server.config.JoinAddr})
if err != nil {
return retry.RetryableError(err)
}
return nil
})
if err != nil {
log.Fatal(err)
}
go server.broadcastRaftAddress(ctx)
}
}
func (server *Server) broadcastRaftAddress(ctx context.Context) {
ticker := time.NewTicker(5 * time.Second)
for {
msg := BroadcastMessage{
Action: "RaftJoin",
NodeMeta: NodeMeta{
ServerID: raft.ServerID(server.config.ServerID),
RaftAddr: raft.ServerAddress(fmt.Sprintf("%s:%d", server.config.BindAddr, server.config.RaftBindPort)),
},
}
if server.hasJoinedCluster() {
return
}
server.broadcastQueue.QueueBroadcast(&msg)
<-ticker.C
}
}
// NodeMeta implements Delegate interface
func (server *Server) NodeMeta(limit int) []byte {
meta := NodeMeta{
ServerID: raft.ServerID(server.config.ServerID),
RaftAddr: raft.ServerAddress(fmt.Sprintf("%s:%d", server.config.BindAddr, server.config.RaftBindPort)),
MemberlistAddr: fmt.Sprintf("%s:%d", server.config.BindAddr, server.config.MemberListBindPort),
}
b, err := json.Marshal(&meta)
if err != nil {
return []byte("")
}
return b
}
// NotifyMsg implements Delegate interface
func (server *Server) NotifyMsg(msgBytes []byte) {
var msg BroadcastMessage
if err := json.Unmarshal(msgBytes, &msg); err != nil {
fmt.Print(err)
return
}
switch msg.Action {
case "RaftJoin":
if err := server.addVoter(
raft.ServerID(msg.NodeMeta.ServerID),
raft.ServerAddress(msg.NodeMeta.RaftAddr),
0, 0,
); err != nil {
fmt.Println(err)
}
case "MutateData":
// Mutate the value at a given key
}
}
// GetBroadcasts implements Delegate interface
func (server *Server) GetBroadcasts(overhead, limit int) [][]byte {
return server.broadcastQueue.GetBroadcasts(overhead, limit)
}
// LocalState implements Delegate interface
func (server *Server) LocalState(join bool) []byte {
// No-Op
return []byte("")
}
// MergeRemoteState implements Delegate interface
func (server *Server) MergeRemoteState(buf []byte, join bool) {
// No-Op
}
// NotifyJoin implements EventDelegate interface
func (server *Server) NotifyJoin(node *memberlist.Node) {
server.numOfNodes += 1
}
// NotifyLeave implements EventDelegate interface
func (server *Server) NotifyLeave(node *memberlist.Node) {
server.numOfNodes -= 1
var meta NodeMeta
err := json.Unmarshal(node.Meta, &meta)
if err != nil {
fmt.Println("Could not get leaving node's metadata.")
return
}
err = server.removeServer(meta)
if err != nil {
fmt.Println(err)
}
}
// NotifyUpdate implements EventDelegate interface
func (server *Server) NotifyUpdate(node *memberlist.Node) {
// No-Op
}
func (server *Server) MemberListShutdown(ctx context.Context) {
// Gracefully leave memberlist cluster
err := server.memberList.Leave(500 * time.Millisecond)
if err != nil {
log.Fatal("Could not gracefully leave memberlist cluster")
}
err = server.memberList.Shutdown()
if err != nil {
log.Fatal("Could not gracefully shutdown memberlist background maintanance")
}
fmt.Println("Successfully shutdown memberlist")
}

View File

@@ -0,0 +1,45 @@
package memberlist_layer
import (
"encoding/json"
"fmt"
"github.com/hashicorp/memberlist"
)
type BroadcastMessage struct {
NodeMeta
Action string `json:"Action"`
Content string `json:"Content"`
}
// Invalidates Implements Broadcast interface
func (broadcastMessage *BroadcastMessage) Invalidates(other memberlist.Broadcast) bool {
mb, ok := other.(*BroadcastMessage)
if !ok {
return false
}
if mb.ServerID == broadcastMessage.ServerID && mb.Action == "RaftJoin" {
return true
}
return false
}
// Message Implements Broadcast interface
func (broadcastMessage *BroadcastMessage) Message() []byte {
msg, err := json.Marshal(broadcastMessage)
if err != nil {
fmt.Println(err)
return []byte{}
}
return msg
}
// Finished Implements Broadcast interface
func (broadcastMessage *BroadcastMessage) Finished() {
// No-Op
}

View File

@@ -0,0 +1,79 @@
package memberlist_layer
import (
"encoding/json"
"fmt"
"github.com/hashicorp/memberlist"
"github.com/hashicorp/raft"
"github.com/kelvinmwinuka/memstore/src/utils"
"time"
)
type Delegate struct {
options DelegateOpts
}
type DelegateOpts struct {
config utils.Config
broadcastQueue *memberlist.TransmitLimitedQueue
addVoter func(id raft.ServerID, address raft.ServerAddress, prevIndex uint64, timeout time.Duration) error
}
func NewDelegate(opts DelegateOpts) *Delegate {
return &Delegate{
options: opts,
}
}
// NodeMeta implements Delegate interface
func (delegate *Delegate) NodeMeta(limit int) []byte {
meta := NodeMeta{
ServerID: raft.ServerID(delegate.options.config.ServerID),
RaftAddr: raft.ServerAddress(
fmt.Sprintf("%s:%d", delegate.options.config.BindAddr, delegate.options.config.RaftBindPort)),
MemberlistAddr: fmt.Sprintf("%s:%d", delegate.options.config.BindAddr, delegate.options.config.MemberListBindPort),
}
b, err := json.Marshal(&meta)
if err != nil {
return []byte("")
}
return b
}
// NotifyMsg implements Delegate interface
func (delegate *Delegate) NotifyMsg(msgBytes []byte) {
var msg BroadcastMessage
if err := json.Unmarshal(msgBytes, &msg); err != nil {
fmt.Print(err)
return
}
switch msg.Action {
case "RaftJoin":
if err := delegate.options.addVoter(msg.NodeMeta.ServerID, msg.NodeMeta.RaftAddr, 0, 0); err != nil {
fmt.Println(err)
}
case "MutateData":
// Mutate the value at a given key
}
}
// GetBroadcasts implements Delegate interface
func (delegate *Delegate) GetBroadcasts(overhead, limit int) [][]byte {
return delegate.options.broadcastQueue.GetBroadcasts(overhead, limit)
}
// LocalState implements Delegate interface
func (delegate *Delegate) LocalState(join bool) []byte {
// No-Op
return []byte("")
}
// MergeRemoteState implements Delegate interface
func (delegate *Delegate) MergeRemoteState(buf []byte, join bool) {
// No-Op
}

View File

@@ -0,0 +1,53 @@
package memberlist_layer
import (
"encoding/json"
"fmt"
"github.com/hashicorp/memberlist"
)
type EventDelegate struct {
options EventDelegateOpts
}
type EventDelegateOpts struct {
IncrementNodes func()
DecrementNodes func()
RemoveRaftServer func(meta NodeMeta) error
}
func NewEventDelegate(opts EventDelegateOpts) *EventDelegate {
return &EventDelegate{
options: opts,
}
}
// NotifyJoin implements EventDelegate interface
func (eventDelegate *EventDelegate) NotifyJoin(node *memberlist.Node) {
eventDelegate.options.IncrementNodes()
}
// NotifyLeave implements EventDelegate interface
func (eventDelegate *EventDelegate) NotifyLeave(node *memberlist.Node) {
eventDelegate.options.DecrementNodes()
var meta NodeMeta
err := json.Unmarshal(node.Meta, &meta)
if err != nil {
fmt.Println("Could not get leaving node's metadata.")
return
}
err = eventDelegate.options.RemoveRaftServer(meta)
if err != nil {
fmt.Println(err)
}
}
// NotifyUpdate implements EventDelegate interface
func (eventDelegate *EventDelegate) NotifyUpdate(node *memberlist.Node) {
// No-Op
}

View File

@@ -0,0 +1,127 @@
package memberlist_layer
import (
"context"
"fmt"
"log"
"time"
"github.com/hashicorp/memberlist"
"github.com/hashicorp/raft"
"github.com/kelvinmwinuka/memstore/src/utils"
"github.com/sethvargo/go-retry"
)
type NodeMeta struct {
ServerID raft.ServerID `json:"ServerID"`
MemberlistAddr string `json:"MemberlistAddr"`
RaftAddr raft.ServerAddress `json:"RaftAddr"`
}
type MemberlistOpts struct {
Config utils.Config
HasJoinedCluster func() bool
AddVoter func(id raft.ServerID, address raft.ServerAddress, prevIndex uint64, timeout time.Duration) error
RemoveRaftServer func(meta NodeMeta) error
}
type MemberList struct {
options MemberlistOpts
broadcastQueue *memberlist.TransmitLimitedQueue
numOfNodes int
memberList *memberlist.Memberlist
}
func NewMemberList(opts MemberlistOpts) *MemberList {
return &MemberList{
options: opts,
broadcastQueue: new(memberlist.TransmitLimitedQueue),
numOfNodes: 0,
}
}
func (m *MemberList) MemberListInit(ctx context.Context) {
cfg := memberlist.DefaultLocalConfig()
cfg.BindAddr = m.options.Config.BindAddr
cfg.BindPort = int(m.options.Config.MemberListBindPort)
cfg.Delegate = NewDelegate(DelegateOpts{
config: m.options.Config,
broadcastQueue: m.broadcastQueue,
addVoter: m.options.AddVoter,
})
cfg.Events = NewEventDelegate(EventDelegateOpts{
IncrementNodes: func() { m.numOfNodes += 1 },
DecrementNodes: func() { m.numOfNodes -= 1 },
RemoveRaftServer: m.options.RemoveRaftServer,
})
m.broadcastQueue.RetransmitMult = 1
m.broadcastQueue.NumNodes = func() int {
return m.numOfNodes
}
list, err := memberlist.Create(cfg)
m.memberList = list
if err != nil {
log.Fatal(err)
}
if m.options.Config.JoinAddr != "" {
backoffPolicy := utils.RetryBackoff(retry.NewFibonacci(1*time.Second), 5, 200*time.Millisecond, 0, 0)
err := retry.Do(ctx, backoffPolicy, func(ctx context.Context) error {
_, err := list.Join([]string{m.options.Config.JoinAddr})
if err != nil {
return retry.RetryableError(err)
}
return nil
})
if err != nil {
log.Fatal(err)
}
go m.broadcastRaftAddress(ctx)
}
}
func (m *MemberList) broadcastRaftAddress(ctx context.Context) {
ticker := time.NewTicker(5 * time.Second)
for {
msg := BroadcastMessage{
Action: "RaftJoin",
NodeMeta: NodeMeta{
ServerID: raft.ServerID(m.options.Config.ServerID),
RaftAddr: raft.ServerAddress(fmt.Sprintf("%s:%d",
m.options.Config.BindAddr, m.options.Config.RaftBindPort)),
},
}
if m.options.HasJoinedCluster() {
return
}
m.broadcastQueue.QueueBroadcast(&msg)
<-ticker.C
}
}
func (m *MemberList) MemberListShutdown(ctx context.Context) {
// Gracefully leave memberlist cluster
err := m.memberList.Leave(500 * time.Millisecond)
if err != nil {
log.Fatal("Could not gracefully leave memberlist cluster")
}
err = m.memberList.Shutdown()
if err != nil {
log.Fatal("Could not gracefully shutdown memberlist background maintanance")
}
fmt.Println("Successfully shutdown memberlist")
}

View File

@@ -1,278 +0,0 @@
package main
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net"
"os"
"path"
"path/filepath"
"time"
"github.com/hashicorp/raft"
raftboltdb "github.com/hashicorp/raft-boltdb"
"github.com/kelvinmwinuka/memstore/src/utils"
)
func (server *Server) RaftInit(ctx context.Context) {
conf := server.config
raftConfig := raft.DefaultConfig()
raftConfig.LocalID = raft.ServerID(conf.ServerID)
raftConfig.SnapshotThreshold = 5
var logStore raft.LogStore
var stableStore raft.StableStore
var snapshotStore raft.SnapshotStore
if conf.InMemory {
logStore = raft.NewInmemStore()
stableStore = raft.NewInmemStore()
snapshotStore = raft.NewInmemSnapshotStore()
} else {
boltdb, err := raftboltdb.NewBoltStore(filepath.Join(conf.DataDir, "logs.db"))
if err != nil {
log.Fatal(err)
}
logStore, err = raft.NewLogCache(512, boltdb)
if err != nil {
log.Fatal(err)
}
stableStore = raft.StableStore(boltdb)
snapshotStore, err = raft.NewFileSnapshotStore(path.Join(conf.DataDir, "snapshots"), 2, os.Stdout)
if err != nil {
log.Fatal(err)
}
}
addr := fmt.Sprintf("%s:%d", conf.BindAddr, conf.RaftBindPort)
advertiseAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
log.Fatal(err)
}
raftTransport, err := raft.NewTCPTransport(
addr,
advertiseAddr,
10,
500*time.Millisecond,
os.Stdout,
)
if err != nil {
log.Fatal(err)
}
// Start raft server
raftServer, err := raft.NewRaft(
raftConfig,
raft.FSM(server),
logStore,
stableStore,
snapshotStore,
raftTransport,
)
if err != nil {
log.Fatalf("Could not start node with error; %s", err)
}
server.raft = raftServer
if conf.BootstrapCluster {
// Bootstrap raft cluster
if err := server.raft.BootstrapCluster(raft.Configuration{
Servers: []raft.Server{
{
Suffrage: raft.Voter,
ID: raft.ServerID(conf.ServerID),
Address: raft.ServerAddress(addr),
},
},
}).Error(); err != nil {
log.Fatal(err)
}
}
}
// Apply Implements raft.FSM interface
func (server *Server) Apply(log *raft.Log) interface{} {
switch log.Type {
case raft.LogCommand:
var request utils.ApplyRequest
if err := json.Unmarshal(log.Data, &request); err != nil {
return utils.ApplyResponse{
Error: err,
Response: nil,
}
}
ctx := context.WithValue(context.Background(), utils.ContextServerID("ServerID"), request.ServerID)
ctx = context.WithValue(ctx, utils.ContextConnID("ConnectionID"), request.ConnectionID)
// Handle command
command, err := server.getCommand(request.CMD[0])
if err != nil {
return utils.ApplyResponse{
Error: err,
Response: nil,
}
}
handler := command.HandlerFunc
subCommand, ok := utils.GetSubCommand(command, request.CMD).(utils.SubCommand)
if ok {
handler = subCommand.HandlerFunc
}
if res, err := handler(ctx, request.CMD, server, nil); err != nil {
return utils.ApplyResponse{
Error: err,
Response: nil,
}
} else {
return utils.ApplyResponse{
Error: nil,
Response: res,
}
}
}
return nil
}
// Implements raft.FSM interface
func (server *Server) Snapshot() (raft.FSMSnapshot, error) {
return server, nil
}
// Implements raft.FSM interface
func (server *Server) Restore(snapshot io.ReadCloser) error {
b, err := io.ReadAll(snapshot)
if err != nil {
return err
}
data := make(map[string]interface{})
if err := json.Unmarshal(b, &data); err != nil {
return err
}
for k, v := range data {
server.keyLocks[k].Lock()
server.SetValue(context.Background(), k, v)
server.keyLocks[k].Unlock()
}
return nil
}
// Implements FSMSnapshot interface
func (server *Server) Persist(sink raft.SnapshotSink) error {
data := map[string]interface{}{}
// TODO: Copy current store contents
o, err := json.Marshal(data)
if err != nil {
sink.Cancel()
return err
}
if _, err = sink.Write(o); err != nil {
sink.Cancel()
return err
}
// TODO: Store data in separate snapshot file
return nil
}
// Implements FSMSnapshot interface
func (server *Server) Release() {}
func (server *Server) isRaftLeader() bool {
return server.raft.State() == raft.Leader
}
func (server *Server) isRaftFollower() bool {
return server.raft.State() == raft.Follower
}
func (server *Server) hasJoinedCluster() bool {
isFollower := server.isRaftFollower()
leaderAddr, leaderID := server.raft.LeaderWithID()
hasLeader := leaderAddr != "" && leaderID != ""
return isFollower && hasLeader
}
func (server *Server) addVoter(
id raft.ServerID,
address raft.ServerAddress,
prevIndex uint64,
timeout time.Duration,
) error {
if !server.isRaftLeader() {
return errors.New("not leader, cannot add voter")
}
raftConfig := server.raft.GetConfiguration()
if err := raftConfig.Error(); err != nil {
return errors.New("could not retrieve raft config")
}
for _, s := range raftConfig.Configuration().Servers {
// Check if a server already exists with the current attributes
if s.ID == id && s.Address == address {
return fmt.Errorf("server with id %s and address %s already exists", id, address)
}
}
err := server.raft.AddVoter(id, address, prevIndex, timeout).Error()
if err != nil {
return err
}
return nil
}
func (server *Server) removeServer(meta NodeMeta) error {
if !server.isRaftLeader() {
return errors.New("not leader, could not remove server")
}
if err := server.raft.RemoveServer(meta.ServerID, 0, 0).Error(); err != nil {
return err
}
return nil
}
func (server *Server) RaftShutdown(ctx context.Context) {
// Leadership transfer if current node is the leader
if server.isRaftLeader() {
err := server.raft.LeadershipTransfer().Error()
if err != nil {
log.Fatal(err)
}
fmt.Println("Leadership transfer successful.")
}
}

View File

@@ -0,0 +1,46 @@
package raft_layer
import (
"encoding/json"
"github.com/hashicorp/raft"
"github.com/kelvinmwinuka/memstore/src/utils"
)
type SnapshotOpts struct {
Config utils.Config
}
type Snapshot struct {
options SnapshotOpts
}
func NewFSMSnapshot(opts SnapshotOpts) *Snapshot {
return &Snapshot{
options: opts,
}
}
// Persist implements FSMSnapshot interface
func (s *Snapshot) Persist(sink raft.SnapshotSink) error {
data := map[string]interface{}{}
// TODO: Copy current store contents
o, err := json.Marshal(data)
if err != nil {
sink.Cancel()
return err
}
if _, err = sink.Write(o); err != nil {
sink.Cancel()
return err
}
// TODO: Store data in separate snapshot file
return nil
}
// Release implements FSMSnapshot interface
func (s *Snapshot) Release() {}

102
src/raft_layer/fsm.go Normal file
View File

@@ -0,0 +1,102 @@
package raft_layer
import (
"context"
"encoding/json"
"github.com/hashicorp/raft"
"github.com/kelvinmwinuka/memstore/src/utils"
"io"
)
type FSMOpts struct {
Config utils.Config
Server utils.Server
Snapshot raft.FSMSnapshot
GetCommand func(command string) (utils.Command, error)
}
type FSM struct {
options FSMOpts
}
func NewFSM(opts FSMOpts) raft.FSM {
return raft.FSM(&FSM{
options: opts,
})
}
// Apply Implements raft.FSM interface
func (fsm *FSM) Apply(log *raft.Log) interface{} {
switch log.Type {
case raft.LogCommand:
var request utils.ApplyRequest
if err := json.Unmarshal(log.Data, &request); err != nil {
return utils.ApplyResponse{
Error: err,
Response: nil,
}
}
ctx := context.WithValue(context.Background(), utils.ContextServerID("ServerID"), request.ServerID)
ctx = context.WithValue(ctx, utils.ContextConnID("ConnectionID"), request.ConnectionID)
// Handle command
command, err := fsm.options.GetCommand(request.CMD[0])
if err != nil {
return utils.ApplyResponse{
Error: err,
Response: nil,
}
}
handler := command.HandlerFunc
subCommand, ok := utils.GetSubCommand(command, request.CMD).(utils.SubCommand)
if ok {
handler = subCommand.HandlerFunc
}
if res, err := handler(ctx, request.CMD, fsm.options.Server, nil); err != nil {
return utils.ApplyResponse{
Error: err,
Response: nil,
}
} else {
return utils.ApplyResponse{
Error: nil,
Response: res,
}
}
}
return nil
}
// Snapshot implements raft.FSM interface
func (fsm *FSM) Snapshot() (raft.FSMSnapshot, error) {
return fsm.options.Snapshot, nil
}
// Restore implements raft.FSM interface
func (fsm *FSM) Restore(snapshot io.ReadCloser) error {
b, err := io.ReadAll(snapshot)
if err != nil {
return err
}
data := make(map[string]interface{})
if err := json.Unmarshal(b, &data); err != nil {
return err
}
// for k, v := range data {
// server.keyLocks[k].Lock()
// server.SetValue(context.Background(), k, v)
// server.keyLocks[k].Unlock()
// }
return nil
}

202
src/raft_layer/raft.go Normal file
View File

@@ -0,0 +1,202 @@
package raft_layer
import (
"context"
"errors"
"fmt"
"github.com/kelvinmwinuka/memstore/src/memberlist_layer"
"log"
"net"
"os"
"path"
"path/filepath"
"time"
"github.com/hashicorp/raft"
raftboltdb "github.com/hashicorp/raft-boltdb"
"github.com/kelvinmwinuka/memstore/src/utils"
)
type RaftOpts struct {
Config utils.Config
Server utils.Server
GetCommand func(command string) (utils.Command, error)
}
type Raft struct {
options RaftOpts
raft *raft.Raft
}
func NewRaft(opts RaftOpts) *Raft {
return &Raft{
options: opts,
}
}
func (r *Raft) RaftInit(ctx context.Context) {
conf := r.options.Config
raftConfig := raft.DefaultConfig()
raftConfig.LocalID = raft.ServerID(conf.ServerID)
raftConfig.SnapshotThreshold = 5
var logStore raft.LogStore
var stableStore raft.StableStore
var snapshotStore raft.SnapshotStore
if conf.InMemory {
logStore = raft.NewInmemStore()
stableStore = raft.NewInmemStore()
snapshotStore = raft.NewInmemSnapshotStore()
} else {
boltdb, err := raftboltdb.NewBoltStore(filepath.Join(conf.DataDir, "logs.db"))
if err != nil {
log.Fatal(err)
}
logStore, err = raft.NewLogCache(512, boltdb)
if err != nil {
log.Fatal(err)
}
stableStore = raft.StableStore(boltdb)
snapshotStore, err = raft.NewFileSnapshotStore(path.Join(conf.DataDir, "snapshots"), 2, os.Stdout)
if err != nil {
log.Fatal(err)
}
}
addr := fmt.Sprintf("%s:%d", conf.BindAddr, conf.RaftBindPort)
advertiseAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
log.Fatal(err)
}
raftTransport, err := raft.NewTCPTransport(
addr,
advertiseAddr,
10,
500*time.Millisecond,
os.Stdout,
)
if err != nil {
log.Fatal(err)
}
// Start raft server
raftServer, err := raft.NewRaft(
raftConfig,
NewFSM(FSMOpts{
Config: r.options.Config,
Server: r.options.Server,
Snapshot: NewFSMSnapshot(SnapshotOpts{
Config: r.options.Config,
}),
GetCommand: r.options.GetCommand,
}),
logStore,
stableStore,
snapshotStore,
raftTransport,
)
if err != nil {
log.Fatalf("Could not start node with error; %s", err)
}
r.raft = raftServer
if conf.BootstrapCluster {
// Bootstrap raft cluster
if err := r.raft.BootstrapCluster(raft.Configuration{
Servers: []raft.Server{
{
Suffrage: raft.Voter,
ID: raft.ServerID(conf.ServerID),
Address: raft.ServerAddress(addr),
},
},
}).Error(); err != nil {
log.Fatal(err)
}
}
}
func (r *Raft) Apply(cmd []byte, timeout time.Duration) raft.ApplyFuture {
return r.raft.Apply(cmd, timeout)
}
func (r *Raft) IsRaftLeader() bool {
return r.raft.State() == raft.Leader
}
func (r *Raft) isRaftFollower() bool {
return r.raft.State() == raft.Follower
}
func (r *Raft) HasJoinedCluster() bool {
isFollower := r.isRaftFollower()
leaderAddr, leaderID := r.raft.LeaderWithID()
hasLeader := leaderAddr != "" && leaderID != ""
return isFollower && hasLeader
}
func (r *Raft) AddVoter(
id raft.ServerID,
address raft.ServerAddress,
prevIndex uint64,
timeout time.Duration,
) error {
if !r.IsRaftLeader() {
return errors.New("not leader, cannot add voter")
}
raftConfig := r.raft.GetConfiguration()
if err := raftConfig.Error(); err != nil {
return errors.New("could not retrieve raft config")
}
for _, s := range raftConfig.Configuration().Servers {
// Check if a server already exists with the current attributes
if s.ID == id && s.Address == address {
return fmt.Errorf("server with id %s and address %s already exists", id, address)
}
}
err := r.raft.AddVoter(id, address, prevIndex, timeout).Error()
if err != nil {
return err
}
return nil
}
func (r *Raft) RemoveServer(meta memberlist_layer.NodeMeta) error {
if !r.IsRaftLeader() {
return errors.New("not leader, could not remove server")
}
if err := r.raft.RemoveServer(meta.ServerID, 0, 0).Error(); err != nil {
return err
}
return nil
}
func (r *Raft) RaftShutdown(ctx context.Context) {
// Leadership transfer if current node is the leader
if r.IsRaftLeader() {
err := r.raft.LeadershipTransfer().Error()
if err != nil {
log.Fatal(err)
}
fmt.Println("Leadership transfer successful.")
}
}