Implemented multiple server cert/key pairs. Implemented mTLS for verifying client using multiple pem files.

This commit is contained in:
Kelvin Clement Mwinuka
2024-02-03 00:24:59 +08:00
parent c0b3fe36a1
commit f9ae87862c
5 changed files with 93 additions and 38 deletions

View File

@@ -2,10 +2,12 @@ FROM --platform=linux/amd64 alpine:latest
RUN mkdir -p /usr/local/lib/echovault RUN mkdir -p /usr/local/lib/echovault
RUN mkdir -p /opt/echovault/bin RUN mkdir -p /opt/echovault/bin
RUN mkdir -p /etc/ssl/certs/echovault RUN mkdir -p /etc/ssl/certs/echovault/server
RUN mkdir -p /etc/ssl/certs/echovault/client
COPY ./bin/linux/x86_64/server /opt/echovault/bin COPY ./bin/linux/x86_64/server /opt/echovault/bin
COPY ./openssl/server /etc/ssl/certs/echovault COPY ./openssl/server /etc/ssl/certs/echovault/server
COPY ./openssl/client /etc/ssl/certs/echovault/client
WORKDIR /opt/echovault/bin WORKDIR /opt/echovault/bin
@@ -16,8 +18,6 @@ CMD "./server" \
"--raftPort" "${RAFT_PORT}" \ "--raftPort" "${RAFT_PORT}" \
"--serverId" "${SERVER_ID}" \ "--serverId" "${SERVER_ID}" \
"--joinAddr" "${JOIN_ADDR}" \ "--joinAddr" "${JOIN_ADDR}" \
"--key" "${KEY}" \
"--cert" "${CERT}" \
"--pluginDir" "${PLUGIN_DIR}" \ "--pluginDir" "${PLUGIN_DIR}" \
"--dataDir" "${DATA_DIR}" \ "--dataDir" "${DATA_DIR}" \
"--snapshotThreshold" "${SNAPSHOT_THRESHOLD}" \ "--snapshotThreshold" "${SNAPSHOT_THRESHOLD}" \
@@ -31,3 +31,5 @@ CMD "./server" \
"--forwardCommand=${FORWARD_COMMAND}" \ "--forwardCommand=${FORWARD_COMMAND}" \
"--restoreSnapshot=${RESTORE_SNAPSHOT}" \ "--restoreSnapshot=${RESTORE_SNAPSHOT}" \
"--restoreAOF=${RESTORE_AOF}" \ "--restoreAOF=${RESTORE_AOF}" \
"--certKeyPair=${CERT_KEY_PAIR}" \
"--clientCert=${CLIENT_CERT}" \

View File

@@ -39,7 +39,7 @@ func (r *Raft) RaftInit(ctx context.Context) {
raftConfig := raft.DefaultConfig() raftConfig := raft.DefaultConfig()
raftConfig.LocalID = raft.ServerID(conf.ServerID) raftConfig.LocalID = raft.ServerID(conf.ServerID)
raftConfig.SnapshotThreshold = conf.SnapShotThreshold raftConfig.SnapshotThreshold = conf.SnapShotThreshold
raftConfig.SnapshotInterval = time.Duration(conf.SnapshotInterval) * time.Second raftConfig.SnapshotInterval = conf.SnapshotInterval
var logStore raft.LogStore var logStore raft.LogStore
var stableStore raft.StableStore var stableStore raft.StableStore

View File

@@ -3,6 +3,7 @@ package server
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509"
"errors" "errors"
"fmt" "fmt"
"github.com/echovault/echovault/src/memberlist" "github.com/echovault/echovault/src/memberlist"
@@ -67,16 +68,43 @@ func (server *Server) StartTCP(ctx context.Context) {
fmt.Printf("Starting TCP server at Address %s, Port %d...\n", conf.BindAddr, conf.Port) fmt.Printf("Starting TCP server at Address %s, Port %d...\n", conf.BindAddr, conf.Port)
} }
if conf.TLS { if conf.TLS || conf.MTLS {
// TLS // TLS
fmt.Printf("Starting TLS server at Address %s, Port %d...\n", conf.BindAddr, conf.Port) fmt.Printf("Starting TLS server at Address %s, Port %d...\n", conf.BindAddr, conf.Port)
cer, err := tls.LoadX509KeyPair(conf.Cert, conf.Key)
var certificates []tls.Certificate
for _, certKeyPair := range conf.CertKeyPairs {
c, err := tls.LoadX509KeyPair(certKeyPair[0], certKeyPair[1])
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
certificates = append(certificates, c)
}
clientAuth := tls.NoClientCert
clientCerts := x509.NewCertPool()
if conf.MTLS {
clientAuth = tls.RequireAndVerifyClientCert
for _, c := range conf.ClientCerts {
certFile, err := os.Open(c)
if err != nil {
log.Fatal(err)
}
certBytes, err := io.ReadAll(certFile)
if err != nil {
log.Fatal(err)
}
if ok := clientCerts.AppendCertsFromPEM(certBytes); !ok {
log.Fatal(err)
}
}
}
listener = tls.NewListener(listener, &tls.Config{ listener = tls.NewListener(listener, &tls.Config{
Certificates: []tls.Certificate{cer}, Certificates: certificates,
ClientAuth: clientAuth,
ClientCAs: clientCerts,
}) })
} }
@@ -156,8 +184,8 @@ func (server *Server) Start(ctx context.Context) {
server.LoadModules(ctx) server.LoadModules(ctx)
if conf.TLS && (len(conf.Key) <= 0 || len(conf.Cert) <= 0) { if conf.TLS && len(conf.CertKeyPairs) <= 0 {
fmt.Println("Must provide key and certificate file paths for TLS mode.") log.Fatal("must provide certificate and key file paths for TLS mode")
return return
} }

View File

@@ -50,7 +50,7 @@ func (engine *Engine) Start(ctx context.Context) {
if engine.options.Config.SnapshotInterval != 0 { if engine.options.Config.SnapshotInterval != 0 {
go func() { go func() {
for { for {
<-time.After(time.Duration(engine.options.Config.SnapshotInterval) * time.Second) <-time.After(engine.options.Config.SnapshotInterval)
if engine.changeCount == engine.options.Config.SnapShotThreshold { if engine.changeCount == engine.options.Config.SnapShotThreshold {
if err := engine.TakeSnapshot(); err != nil { if err := engine.TakeSnapshot(); err != nil {
log.Println(err) log.Println(err)

View File

@@ -7,14 +7,18 @@ import (
"log" "log"
"os" "os"
"path" "path"
"strings"
"time"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
type Config struct { type Config struct {
TLS bool `json:"tls" yaml:"tls"` TLS bool `json:"tls" yaml:"tls"`
MTLS bool `json:"mtls" yaml:"mtls"`
Key string `json:"key" yaml:"key"` Key string `json:"key" yaml:"key"`
Cert string `json:"cert" yaml:"cert"` CertKeyPairs [][]string `json:"certKeyPairs" yaml:"certKeyPairs"`
ClientCerts []string `json:"clientCerts" yaml:"clientCerts"`
Port uint16 `json:"port" yaml:"port"` Port uint16 `json:"port" yaml:"port"`
PluginDir string `json:"plugins" yaml:"plugins"` PluginDir string `json:"plugins" yaml:"plugins"`
ServerID string `json:"serverId" yaml:"serverId"` ServerID string `json:"serverId" yaml:"serverId"`
@@ -30,15 +34,34 @@ type Config struct {
RequirePass bool `json:"requirePass" yaml:"requirePass"` RequirePass bool `json:"requirePass" yaml:"requirePass"`
Password string `json:"password" yaml:"password"` Password string `json:"password" yaml:"password"`
SnapShotThreshold uint64 `json:"snapshotThreshold" yaml:"snapshotThreshold"` SnapShotThreshold uint64 `json:"snapshotThreshold" yaml:"snapshotThreshold"`
SnapshotInterval uint `json:"snapshotInterval" yaml:"snapshotInterval"` SnapshotInterval time.Duration `json:"snapshotInterval" yaml:"snapshotInterval"`
RestoreSnapshot bool `json:"restoreSnapshot" yaml:"restoreSnapshot"` RestoreSnapshot bool `json:"restoreSnapshot" yaml:"restoreSnapshot"`
RestoreAOF bool `json:"restoreAOF" yaml:"restoreAOF"` RestoreAOF bool `json:"restoreAOF" yaml:"restoreAOF"`
} }
func GetConfig() (Config, error) { func GetConfig() (Config, error) {
var certKeyPairs [][]string
var clientCerts []string
flag.Func("certKeyPair",
"A pair of file paths representing the signed certificate and it's corresponding key separated by a comma.",
func(s string) error {
pair := strings.Split(strings.TrimSpace(s), ",")
if len(pair) != 2 {
return errors.New("certKeyPair must be 2 comma separated strings in the format")
}
certKeyPairs = append(certKeyPairs, pair)
return nil
})
flag.Func("clientCert", "Certificate file used to verify the client. ", func(s string) error {
clientCerts = append(clientCerts, s)
return nil
})
tls := flag.Bool("tls", false, "Start the server in TLS mode. Default is false") tls := flag.Bool("tls", false, "Start the server in TLS mode. Default is false")
mtls := flag.Bool("mtls", false, "Use mTLS to verify the client.")
key := flag.String("key", "", "The private key file path.") key := flag.String("key", "", "The private key file path.")
cert := flag.String("cert", "", "The signed certificate file path.")
port := flag.Int("port", 7480, "Port to use. Default is 7480") port := flag.Int("port", 7480, "Port to use. Default is 7480")
pluginDir := flag.String("pluginDir", "", "Directory where plugins are located.") pluginDir := flag.String("pluginDir", "", "Directory where plugins are located.")
serverId := flag.String("serverId", "1", "Server ID in raft cluster. Leave empty for client.") serverId := flag.String("serverId", "1", "Server ID in raft cluster. Leave empty for client.")
@@ -51,7 +74,7 @@ func GetConfig() (Config, error) {
bootstrapCluster := flag.Bool("bootstrapCluster", false, "Whether this instance should bootstrap a new cluster.") bootstrapCluster := flag.Bool("bootstrapCluster", false, "Whether this instance should bootstrap a new cluster.")
aclConfig := flag.String("aclConfig", "", "ACL config file path.") aclConfig := flag.String("aclConfig", "", "ACL config file path.")
snapshotThreshold := flag.Uint64("snapshotThreshold", 1000, "The number of entries that trigger a snapshot. Default is 1000.") snapshotThreshold := flag.Uint64("snapshotThreshold", 1000, "The number of entries that trigger a snapshot. Default is 1000.")
snapshotInterval := flag.Uint("snapshotInterval", 60*5, "The time interval between snapshots (in seconds). Default is 5 minutes.") snapshotInterval := flag.Duration("snapshotInterval", 5*time.Minute, "The time interval between snapshots (in seconds). Default is 5 minutes.")
restoreSnapshot := flag.Bool("restoreSnapshot", false, "This flag prompts the server to restore state from snapshot when set to true. Only works in standalone mode. Higher priority than restoreAOF.") restoreSnapshot := flag.Bool("restoreSnapshot", false, "This flag prompts the server to restore state from snapshot when set to true. Only works in standalone mode. Higher priority than restoreAOF.")
restoreAOF := flag.Bool("restoreAOF", false, "This flag prompts the server to restore state from append-only logs. Only works in standalone mode. Lower priority than restoreSnapshot.") restoreAOF := flag.Bool("restoreAOF", false, "This flag prompts the server to restore state from append-only logs. Only works in standalone mode. Lower priority than restoreSnapshot.")
forwardCommand := flag.Bool( forwardCommand := flag.Bool(
@@ -79,9 +102,11 @@ It is a plain text value by default but you can provide a SHA256 hash by adding
flag.Parse() flag.Parse()
conf := Config{ conf := Config{
CertKeyPairs: certKeyPairs,
ClientCerts: clientCerts,
TLS: *tls, TLS: *tls,
MTLS: *mtls,
Key: *key, Key: *key,
Cert: *cert,
PluginDir: *pluginDir, PluginDir: *pluginDir,
Port: uint16(*port), Port: uint16(*port),
ServerID: *serverId, ServerID: *serverId,