diff --git a/Dockerfile b/Dockerfile index db11846..70f22a2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,22 +2,22 @@ FROM --platform=linux/amd64 alpine:latest RUN mkdir -p /usr/local/lib/echovault RUN mkdir -p /opt/echovault/bin -RUN mkdir -p /etc/ssl/certs/echovault +RUN mkdir -p /etc/ssl/certs/echovault/server +RUN mkdir -p /etc/ssl/certs/echovault/client COPY ./bin/linux/x86_64/server /opt/echovault/bin -COPY ./openssl/server /etc/ssl/certs/echovault +COPY ./openssl/server /etc/ssl/certs/echovault/server +COPY ./openssl/client /etc/ssl/certs/echovault/client WORKDIR /opt/echovault/bin -CMD "./server" \ +CMD "./server" \ "--bindAddr" "${BIND_ADDR}" \ "--port" "${PORT}" \ "--mlPort" "${ML_PORT}" \ "--raftPort" "${RAFT_PORT}" \ "--serverId" "${SERVER_ID}" \ "--joinAddr" "${JOIN_ADDR}" \ - "--key" "${KEY}" \ - "--cert" "${CERT}" \ "--pluginDir" "${PLUGIN_DIR}" \ "--dataDir" "${DATA_DIR}" \ "--snapshotThreshold" "${SNAPSHOT_THRESHOLD}" \ @@ -31,3 +31,5 @@ CMD "./server" \ "--forwardCommand=${FORWARD_COMMAND}" \ "--restoreSnapshot=${RESTORE_SNAPSHOT}" \ "--restoreAOF=${RESTORE_AOF}" \ + "--certKeyPair=${CERT_KEY_PAIR}" \ + "--clientCert=${CLIENT_CERT}" \ diff --git a/src/raft/raft.go b/src/raft/raft.go index 544e55a..e2c0841 100644 --- a/src/raft/raft.go +++ b/src/raft/raft.go @@ -39,7 +39,7 @@ func (r *Raft) RaftInit(ctx context.Context) { raftConfig := raft.DefaultConfig() raftConfig.LocalID = raft.ServerID(conf.ServerID) raftConfig.SnapshotThreshold = conf.SnapShotThreshold - raftConfig.SnapshotInterval = time.Duration(conf.SnapshotInterval) * time.Second + raftConfig.SnapshotInterval = conf.SnapshotInterval var logStore raft.LogStore var stableStore raft.StableStore diff --git a/src/server/server.go b/src/server/server.go index 7bf7bfe..0202f2d 100644 --- a/src/server/server.go +++ b/src/server/server.go @@ -3,6 +3,7 @@ package server import ( "context" "crypto/tls" + "crypto/x509" "errors" "fmt" "github.com/echovault/echovault/src/memberlist" @@ -67,16 +68,43 @@ func (server *Server) StartTCP(ctx context.Context) { fmt.Printf("Starting TCP server at Address %s, Port %d...\n", conf.BindAddr, conf.Port) } - if conf.TLS { + if conf.TLS || conf.MTLS { // TLS fmt.Printf("Starting TLS server at Address %s, Port %d...\n", conf.BindAddr, conf.Port) - cer, err := tls.LoadX509KeyPair(conf.Cert, conf.Key) - if err != nil { - log.Fatal(err) + + var certificates []tls.Certificate + for _, certKeyPair := range conf.CertKeyPairs { + c, err := tls.LoadX509KeyPair(certKeyPair[0], certKeyPair[1]) + if err != nil { + log.Fatal(err) + } + certificates = append(certificates, c) + } + + clientAuth := tls.NoClientCert + clientCerts := x509.NewCertPool() + + if conf.MTLS { + clientAuth = tls.RequireAndVerifyClientCert + for _, c := range conf.ClientCerts { + certFile, err := os.Open(c) + if err != nil { + log.Fatal(err) + } + certBytes, err := io.ReadAll(certFile) + if err != nil { + log.Fatal(err) + } + if ok := clientCerts.AppendCertsFromPEM(certBytes); !ok { + log.Fatal(err) + } + } } listener = tls.NewListener(listener, &tls.Config{ - Certificates: []tls.Certificate{cer}, + Certificates: certificates, + ClientAuth: clientAuth, + ClientCAs: clientCerts, }) } @@ -156,8 +184,8 @@ func (server *Server) Start(ctx context.Context) { server.LoadModules(ctx) - if conf.TLS && (len(conf.Key) <= 0 || len(conf.Cert) <= 0) { - fmt.Println("Must provide key and certificate file paths for TLS mode.") + if conf.TLS && len(conf.CertKeyPairs) <= 0 { + log.Fatal("must provide certificate and key file paths for TLS mode") return } diff --git a/src/server/snapshot/snapshot.go b/src/server/snapshot/snapshot.go index 798cf74..f1fb680 100644 --- a/src/server/snapshot/snapshot.go +++ b/src/server/snapshot/snapshot.go @@ -50,7 +50,7 @@ func (engine *Engine) Start(ctx context.Context) { if engine.options.Config.SnapshotInterval != 0 { go func() { for { - <-time.After(time.Duration(engine.options.Config.SnapshotInterval) * time.Second) + <-time.After(engine.options.Config.SnapshotInterval) if engine.changeCount == engine.options.Config.SnapShotThreshold { if err := engine.TakeSnapshot(); err != nil { log.Println(err) diff --git a/src/utils/config.go b/src/utils/config.go index f994b66..1207132 100644 --- a/src/utils/config.go +++ b/src/utils/config.go @@ -7,38 +7,61 @@ import ( "log" "os" "path" + "strings" + "time" "gopkg.in/yaml.v3" ) type Config struct { - TLS bool `json:"tls" yaml:"tls"` - Key string `json:"key" yaml:"key"` - Cert string `json:"cert" yaml:"cert"` - Port uint16 `json:"port" yaml:"port"` - PluginDir string `json:"plugins" yaml:"plugins"` - ServerID string `json:"serverId" yaml:"serverId"` - JoinAddr string `json:"joinAddr" yaml:"joinAddr"` - BindAddr string `json:"bindAddr" yaml:"bindAddr"` - RaftBindPort uint16 `json:"raftPort" yaml:"raftPort"` - MemberListBindPort uint16 `json:"mlPort" yaml:"mlPort"` - InMemory bool `json:"inMemory" yaml:"inMemory"` - DataDir string `json:"dataDir" yaml:"dataDir"` - BootstrapCluster bool `json:"BootstrapCluster" yaml:"bootstrapCluster"` - AclConfig string `json:"AclConfig" yaml:"AclConfig"` - ForwardCommand bool `json:"forwardCommand" yaml:"forwardCommand"` - RequirePass bool `json:"requirePass" yaml:"requirePass"` - Password string `json:"password" yaml:"password"` - SnapShotThreshold uint64 `json:"snapshotThreshold" yaml:"snapshotThreshold"` - SnapshotInterval uint `json:"snapshotInterval" yaml:"snapshotInterval"` - RestoreSnapshot bool `json:"restoreSnapshot" yaml:"restoreSnapshot"` - RestoreAOF bool `json:"restoreAOF" yaml:"restoreAOF"` + TLS bool `json:"tls" yaml:"tls"` + MTLS bool `json:"mtls" yaml:"mtls"` + Key string `json:"key" yaml:"key"` + CertKeyPairs [][]string `json:"certKeyPairs" yaml:"certKeyPairs"` + ClientCerts []string `json:"clientCerts" yaml:"clientCerts"` + Port uint16 `json:"port" yaml:"port"` + PluginDir string `json:"plugins" yaml:"plugins"` + ServerID string `json:"serverId" yaml:"serverId"` + JoinAddr string `json:"joinAddr" yaml:"joinAddr"` + BindAddr string `json:"bindAddr" yaml:"bindAddr"` + RaftBindPort uint16 `json:"raftPort" yaml:"raftPort"` + MemberListBindPort uint16 `json:"mlPort" yaml:"mlPort"` + InMemory bool `json:"inMemory" yaml:"inMemory"` + DataDir string `json:"dataDir" yaml:"dataDir"` + BootstrapCluster bool `json:"BootstrapCluster" yaml:"bootstrapCluster"` + AclConfig string `json:"AclConfig" yaml:"AclConfig"` + ForwardCommand bool `json:"forwardCommand" yaml:"forwardCommand"` + RequirePass bool `json:"requirePass" yaml:"requirePass"` + Password string `json:"password" yaml:"password"` + SnapShotThreshold uint64 `json:"snapshotThreshold" yaml:"snapshotThreshold"` + SnapshotInterval time.Duration `json:"snapshotInterval" yaml:"snapshotInterval"` + RestoreSnapshot bool `json:"restoreSnapshot" yaml:"restoreSnapshot"` + RestoreAOF bool `json:"restoreAOF" yaml:"restoreAOF"` } func GetConfig() (Config, error) { + var certKeyPairs [][]string + var clientCerts []string + + flag.Func("certKeyPair", + "A pair of file paths representing the signed certificate and it's corresponding key separated by a comma.", + func(s string) error { + pair := strings.Split(strings.TrimSpace(s), ",") + if len(pair) != 2 { + return errors.New("certKeyPair must be 2 comma separated strings in the format") + } + certKeyPairs = append(certKeyPairs, pair) + return nil + }) + + flag.Func("clientCert", "Certificate file used to verify the client. ", func(s string) error { + clientCerts = append(clientCerts, s) + return nil + }) + tls := flag.Bool("tls", false, "Start the server in TLS mode. Default is false") + mtls := flag.Bool("mtls", false, "Use mTLS to verify the client.") key := flag.String("key", "", "The private key file path.") - cert := flag.String("cert", "", "The signed certificate file path.") port := flag.Int("port", 7480, "Port to use. Default is 7480") pluginDir := flag.String("pluginDir", "", "Directory where plugins are located.") serverId := flag.String("serverId", "1", "Server ID in raft cluster. Leave empty for client.") @@ -51,7 +74,7 @@ func GetConfig() (Config, error) { bootstrapCluster := flag.Bool("bootstrapCluster", false, "Whether this instance should bootstrap a new cluster.") aclConfig := flag.String("aclConfig", "", "ACL config file path.") snapshotThreshold := flag.Uint64("snapshotThreshold", 1000, "The number of entries that trigger a snapshot. Default is 1000.") - snapshotInterval := flag.Uint("snapshotInterval", 60*5, "The time interval between snapshots (in seconds). Default is 5 minutes.") + snapshotInterval := flag.Duration("snapshotInterval", 5*time.Minute, "The time interval between snapshots (in seconds). Default is 5 minutes.") restoreSnapshot := flag.Bool("restoreSnapshot", false, "This flag prompts the server to restore state from snapshot when set to true. Only works in standalone mode. Higher priority than restoreAOF.") restoreAOF := flag.Bool("restoreAOF", false, "This flag prompts the server to restore state from append-only logs. Only works in standalone mode. Lower priority than restoreSnapshot.") forwardCommand := flag.Bool( @@ -79,9 +102,11 @@ It is a plain text value by default but you can provide a SHA256 hash by adding flag.Parse() conf := Config{ + CertKeyPairs: certKeyPairs, + ClientCerts: clientCerts, TLS: *tls, + MTLS: *mtls, Key: *key, - Cert: *cert, PluginDir: *pluginDir, Port: uint16(*port), ServerID: *serverId,