mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-16 13:00:40 +08:00
Implemented multiple server cert/key pairs. Implemented mTLS for verifying client using multiple pem files.
This commit is contained in:
12
Dockerfile
12
Dockerfile
@@ -2,22 +2,22 @@ FROM --platform=linux/amd64 alpine:latest
|
|||||||
|
|
||||||
RUN mkdir -p /usr/local/lib/echovault
|
RUN mkdir -p /usr/local/lib/echovault
|
||||||
RUN mkdir -p /opt/echovault/bin
|
RUN mkdir -p /opt/echovault/bin
|
||||||
RUN mkdir -p /etc/ssl/certs/echovault
|
RUN mkdir -p /etc/ssl/certs/echovault/server
|
||||||
|
RUN mkdir -p /etc/ssl/certs/echovault/client
|
||||||
|
|
||||||
COPY ./bin/linux/x86_64/server /opt/echovault/bin
|
COPY ./bin/linux/x86_64/server /opt/echovault/bin
|
||||||
COPY ./openssl/server /etc/ssl/certs/echovault
|
COPY ./openssl/server /etc/ssl/certs/echovault/server
|
||||||
|
COPY ./openssl/client /etc/ssl/certs/echovault/client
|
||||||
|
|
||||||
WORKDIR /opt/echovault/bin
|
WORKDIR /opt/echovault/bin
|
||||||
|
|
||||||
CMD "./server" \
|
CMD "./server" \
|
||||||
"--bindAddr" "${BIND_ADDR}" \
|
"--bindAddr" "${BIND_ADDR}" \
|
||||||
"--port" "${PORT}" \
|
"--port" "${PORT}" \
|
||||||
"--mlPort" "${ML_PORT}" \
|
"--mlPort" "${ML_PORT}" \
|
||||||
"--raftPort" "${RAFT_PORT}" \
|
"--raftPort" "${RAFT_PORT}" \
|
||||||
"--serverId" "${SERVER_ID}" \
|
"--serverId" "${SERVER_ID}" \
|
||||||
"--joinAddr" "${JOIN_ADDR}" \
|
"--joinAddr" "${JOIN_ADDR}" \
|
||||||
"--key" "${KEY}" \
|
|
||||||
"--cert" "${CERT}" \
|
|
||||||
"--pluginDir" "${PLUGIN_DIR}" \
|
"--pluginDir" "${PLUGIN_DIR}" \
|
||||||
"--dataDir" "${DATA_DIR}" \
|
"--dataDir" "${DATA_DIR}" \
|
||||||
"--snapshotThreshold" "${SNAPSHOT_THRESHOLD}" \
|
"--snapshotThreshold" "${SNAPSHOT_THRESHOLD}" \
|
||||||
@@ -31,3 +31,5 @@ CMD "./server" \
|
|||||||
"--forwardCommand=${FORWARD_COMMAND}" \
|
"--forwardCommand=${FORWARD_COMMAND}" \
|
||||||
"--restoreSnapshot=${RESTORE_SNAPSHOT}" \
|
"--restoreSnapshot=${RESTORE_SNAPSHOT}" \
|
||||||
"--restoreAOF=${RESTORE_AOF}" \
|
"--restoreAOF=${RESTORE_AOF}" \
|
||||||
|
"--certKeyPair=${CERT_KEY_PAIR}" \
|
||||||
|
"--clientCert=${CLIENT_CERT}" \
|
||||||
|
@@ -39,7 +39,7 @@ func (r *Raft) RaftInit(ctx context.Context) {
|
|||||||
raftConfig := raft.DefaultConfig()
|
raftConfig := raft.DefaultConfig()
|
||||||
raftConfig.LocalID = raft.ServerID(conf.ServerID)
|
raftConfig.LocalID = raft.ServerID(conf.ServerID)
|
||||||
raftConfig.SnapshotThreshold = conf.SnapShotThreshold
|
raftConfig.SnapshotThreshold = conf.SnapShotThreshold
|
||||||
raftConfig.SnapshotInterval = time.Duration(conf.SnapshotInterval) * time.Second
|
raftConfig.SnapshotInterval = conf.SnapshotInterval
|
||||||
|
|
||||||
var logStore raft.LogStore
|
var logStore raft.LogStore
|
||||||
var stableStore raft.StableStore
|
var stableStore raft.StableStore
|
||||||
|
@@ -3,6 +3,7 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/echovault/echovault/src/memberlist"
|
"github.com/echovault/echovault/src/memberlist"
|
||||||
@@ -67,16 +68,43 @@ func (server *Server) StartTCP(ctx context.Context) {
|
|||||||
fmt.Printf("Starting TCP server at Address %s, Port %d...\n", conf.BindAddr, conf.Port)
|
fmt.Printf("Starting TCP server at Address %s, Port %d...\n", conf.BindAddr, conf.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
if conf.TLS {
|
if conf.TLS || conf.MTLS {
|
||||||
// TLS
|
// TLS
|
||||||
fmt.Printf("Starting TLS server at Address %s, Port %d...\n", conf.BindAddr, conf.Port)
|
fmt.Printf("Starting TLS server at Address %s, Port %d...\n", conf.BindAddr, conf.Port)
|
||||||
cer, err := tls.LoadX509KeyPair(conf.Cert, conf.Key)
|
|
||||||
if err != nil {
|
var certificates []tls.Certificate
|
||||||
log.Fatal(err)
|
for _, certKeyPair := range conf.CertKeyPairs {
|
||||||
|
c, err := tls.LoadX509KeyPair(certKeyPair[0], certKeyPair[1])
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
certificates = append(certificates, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientAuth := tls.NoClientCert
|
||||||
|
clientCerts := x509.NewCertPool()
|
||||||
|
|
||||||
|
if conf.MTLS {
|
||||||
|
clientAuth = tls.RequireAndVerifyClientCert
|
||||||
|
for _, c := range conf.ClientCerts {
|
||||||
|
certFile, err := os.Open(c)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
certBytes, err := io.ReadAll(certFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
if ok := clientCerts.AppendCertsFromPEM(certBytes); !ok {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
listener = tls.NewListener(listener, &tls.Config{
|
listener = tls.NewListener(listener, &tls.Config{
|
||||||
Certificates: []tls.Certificate{cer},
|
Certificates: certificates,
|
||||||
|
ClientAuth: clientAuth,
|
||||||
|
ClientCAs: clientCerts,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,8 +184,8 @@ func (server *Server) Start(ctx context.Context) {
|
|||||||
|
|
||||||
server.LoadModules(ctx)
|
server.LoadModules(ctx)
|
||||||
|
|
||||||
if conf.TLS && (len(conf.Key) <= 0 || len(conf.Cert) <= 0) {
|
if conf.TLS && len(conf.CertKeyPairs) <= 0 {
|
||||||
fmt.Println("Must provide key and certificate file paths for TLS mode.")
|
log.Fatal("must provide certificate and key file paths for TLS mode")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -50,7 +50,7 @@ func (engine *Engine) Start(ctx context.Context) {
|
|||||||
if engine.options.Config.SnapshotInterval != 0 {
|
if engine.options.Config.SnapshotInterval != 0 {
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
<-time.After(time.Duration(engine.options.Config.SnapshotInterval) * time.Second)
|
<-time.After(engine.options.Config.SnapshotInterval)
|
||||||
if engine.changeCount == engine.options.Config.SnapShotThreshold {
|
if engine.changeCount == engine.options.Config.SnapShotThreshold {
|
||||||
if err := engine.TakeSnapshot(); err != nil {
|
if err := engine.TakeSnapshot(); err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
|
@@ -7,38 +7,61 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
TLS bool `json:"tls" yaml:"tls"`
|
TLS bool `json:"tls" yaml:"tls"`
|
||||||
Key string `json:"key" yaml:"key"`
|
MTLS bool `json:"mtls" yaml:"mtls"`
|
||||||
Cert string `json:"cert" yaml:"cert"`
|
Key string `json:"key" yaml:"key"`
|
||||||
Port uint16 `json:"port" yaml:"port"`
|
CertKeyPairs [][]string `json:"certKeyPairs" yaml:"certKeyPairs"`
|
||||||
PluginDir string `json:"plugins" yaml:"plugins"`
|
ClientCerts []string `json:"clientCerts" yaml:"clientCerts"`
|
||||||
ServerID string `json:"serverId" yaml:"serverId"`
|
Port uint16 `json:"port" yaml:"port"`
|
||||||
JoinAddr string `json:"joinAddr" yaml:"joinAddr"`
|
PluginDir string `json:"plugins" yaml:"plugins"`
|
||||||
BindAddr string `json:"bindAddr" yaml:"bindAddr"`
|
ServerID string `json:"serverId" yaml:"serverId"`
|
||||||
RaftBindPort uint16 `json:"raftPort" yaml:"raftPort"`
|
JoinAddr string `json:"joinAddr" yaml:"joinAddr"`
|
||||||
MemberListBindPort uint16 `json:"mlPort" yaml:"mlPort"`
|
BindAddr string `json:"bindAddr" yaml:"bindAddr"`
|
||||||
InMemory bool `json:"inMemory" yaml:"inMemory"`
|
RaftBindPort uint16 `json:"raftPort" yaml:"raftPort"`
|
||||||
DataDir string `json:"dataDir" yaml:"dataDir"`
|
MemberListBindPort uint16 `json:"mlPort" yaml:"mlPort"`
|
||||||
BootstrapCluster bool `json:"BootstrapCluster" yaml:"bootstrapCluster"`
|
InMemory bool `json:"inMemory" yaml:"inMemory"`
|
||||||
AclConfig string `json:"AclConfig" yaml:"AclConfig"`
|
DataDir string `json:"dataDir" yaml:"dataDir"`
|
||||||
ForwardCommand bool `json:"forwardCommand" yaml:"forwardCommand"`
|
BootstrapCluster bool `json:"BootstrapCluster" yaml:"bootstrapCluster"`
|
||||||
RequirePass bool `json:"requirePass" yaml:"requirePass"`
|
AclConfig string `json:"AclConfig" yaml:"AclConfig"`
|
||||||
Password string `json:"password" yaml:"password"`
|
ForwardCommand bool `json:"forwardCommand" yaml:"forwardCommand"`
|
||||||
SnapShotThreshold uint64 `json:"snapshotThreshold" yaml:"snapshotThreshold"`
|
RequirePass bool `json:"requirePass" yaml:"requirePass"`
|
||||||
SnapshotInterval uint `json:"snapshotInterval" yaml:"snapshotInterval"`
|
Password string `json:"password" yaml:"password"`
|
||||||
RestoreSnapshot bool `json:"restoreSnapshot" yaml:"restoreSnapshot"`
|
SnapShotThreshold uint64 `json:"snapshotThreshold" yaml:"snapshotThreshold"`
|
||||||
RestoreAOF bool `json:"restoreAOF" yaml:"restoreAOF"`
|
SnapshotInterval time.Duration `json:"snapshotInterval" yaml:"snapshotInterval"`
|
||||||
|
RestoreSnapshot bool `json:"restoreSnapshot" yaml:"restoreSnapshot"`
|
||||||
|
RestoreAOF bool `json:"restoreAOF" yaml:"restoreAOF"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetConfig() (Config, error) {
|
func GetConfig() (Config, error) {
|
||||||
|
var certKeyPairs [][]string
|
||||||
|
var clientCerts []string
|
||||||
|
|
||||||
|
flag.Func("certKeyPair",
|
||||||
|
"A pair of file paths representing the signed certificate and it's corresponding key separated by a comma.",
|
||||||
|
func(s string) error {
|
||||||
|
pair := strings.Split(strings.TrimSpace(s), ",")
|
||||||
|
if len(pair) != 2 {
|
||||||
|
return errors.New("certKeyPair must be 2 comma separated strings in the format")
|
||||||
|
}
|
||||||
|
certKeyPairs = append(certKeyPairs, pair)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
flag.Func("clientCert", "Certificate file used to verify the client. ", func(s string) error {
|
||||||
|
clientCerts = append(clientCerts, s)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
tls := flag.Bool("tls", false, "Start the server in TLS mode. Default is false")
|
tls := flag.Bool("tls", false, "Start the server in TLS mode. Default is false")
|
||||||
|
mtls := flag.Bool("mtls", false, "Use mTLS to verify the client.")
|
||||||
key := flag.String("key", "", "The private key file path.")
|
key := flag.String("key", "", "The private key file path.")
|
||||||
cert := flag.String("cert", "", "The signed certificate file path.")
|
|
||||||
port := flag.Int("port", 7480, "Port to use. Default is 7480")
|
port := flag.Int("port", 7480, "Port to use. Default is 7480")
|
||||||
pluginDir := flag.String("pluginDir", "", "Directory where plugins are located.")
|
pluginDir := flag.String("pluginDir", "", "Directory where plugins are located.")
|
||||||
serverId := flag.String("serverId", "1", "Server ID in raft cluster. Leave empty for client.")
|
serverId := flag.String("serverId", "1", "Server ID in raft cluster. Leave empty for client.")
|
||||||
@@ -51,7 +74,7 @@ func GetConfig() (Config, error) {
|
|||||||
bootstrapCluster := flag.Bool("bootstrapCluster", false, "Whether this instance should bootstrap a new cluster.")
|
bootstrapCluster := flag.Bool("bootstrapCluster", false, "Whether this instance should bootstrap a new cluster.")
|
||||||
aclConfig := flag.String("aclConfig", "", "ACL config file path.")
|
aclConfig := flag.String("aclConfig", "", "ACL config file path.")
|
||||||
snapshotThreshold := flag.Uint64("snapshotThreshold", 1000, "The number of entries that trigger a snapshot. Default is 1000.")
|
snapshotThreshold := flag.Uint64("snapshotThreshold", 1000, "The number of entries that trigger a snapshot. Default is 1000.")
|
||||||
snapshotInterval := flag.Uint("snapshotInterval", 60*5, "The time interval between snapshots (in seconds). Default is 5 minutes.")
|
snapshotInterval := flag.Duration("snapshotInterval", 5*time.Minute, "The time interval between snapshots (in seconds). Default is 5 minutes.")
|
||||||
restoreSnapshot := flag.Bool("restoreSnapshot", false, "This flag prompts the server to restore state from snapshot when set to true. Only works in standalone mode. Higher priority than restoreAOF.")
|
restoreSnapshot := flag.Bool("restoreSnapshot", false, "This flag prompts the server to restore state from snapshot when set to true. Only works in standalone mode. Higher priority than restoreAOF.")
|
||||||
restoreAOF := flag.Bool("restoreAOF", false, "This flag prompts the server to restore state from append-only logs. Only works in standalone mode. Lower priority than restoreSnapshot.")
|
restoreAOF := flag.Bool("restoreAOF", false, "This flag prompts the server to restore state from append-only logs. Only works in standalone mode. Lower priority than restoreSnapshot.")
|
||||||
forwardCommand := flag.Bool(
|
forwardCommand := flag.Bool(
|
||||||
@@ -79,9 +102,11 @@ It is a plain text value by default but you can provide a SHA256 hash by adding
|
|||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
conf := Config{
|
conf := Config{
|
||||||
|
CertKeyPairs: certKeyPairs,
|
||||||
|
ClientCerts: clientCerts,
|
||||||
TLS: *tls,
|
TLS: *tls,
|
||||||
|
MTLS: *mtls,
|
||||||
Key: *key,
|
Key: *key,
|
||||||
Cert: *cert,
|
|
||||||
PluginDir: *pluginDir,
|
PluginDir: *pluginDir,
|
||||||
Port: uint16(*port),
|
Port: uint16(*port),
|
||||||
ServerID: *serverId,
|
ServerID: *serverId,
|
||||||
|
Reference in New Issue
Block a user