diff --git a/README.md b/README.md index 677181a..d4bdbe5 100644 --- a/README.md +++ b/README.md @@ -10,8 +10,6 @@ `Godis` is a simple implementation of Redis Server, which intents to provide an example of writing a high concurrent middleware using golang. -Please be advised, NEVER think about using this in production environment. - Gods implemented most features of redis, including 5 data structures, ttl, publish/subscribe, geo and AOF persistence. Godis can run as a server side cluster which is transparent to client. You can connect to any node in the cluster to access all data in the cluster: diff --git a/cluster/client_pool.go b/cluster/client_pool.go index 5e3423f..f99a538 100644 --- a/cluster/client_pool.go +++ b/cluster/client_pool.go @@ -3,6 +3,8 @@ package cluster import ( "context" "errors" + "github.com/hdt3213/godis/config" + "github.com/hdt3213/godis/lib/utils" "github.com/hdt3213/godis/redis/client" "github.com/jolestar/go-commons-pool/v2" ) @@ -17,6 +19,10 @@ func (f *ConnectionFactory) MakeObject(ctx context.Context) (*pool.PooledObject, return nil, err } c.Start() + // all peers of cluster should use the same password + if config.Properties.RequirePass != "" { + c.Send(utils.ToBytesList("AUTH", config.Properties.RequirePass)) + } return pool.NewPooledObject(c), nil } diff --git a/cluster/cluster.go b/cluster/cluster.go index 31fee2a..b990706 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -79,6 +79,13 @@ func (cluster *Cluster) Close() { var router = MakeRouter() +func isAuthenticated(c redis.Connection) bool { + if config.Properties.RequirePass == "" { + return true + } + return c.GetPassword() == config.Properties.RequirePass +} + func (cluster *Cluster) Exec(c redis.Connection, args [][]byte) (result redis.Reply) { defer func() { if err := recover(); err != nil { @@ -86,8 +93,13 @@ func (cluster *Cluster) Exec(c redis.Connection, args [][]byte) (result redis.Re result = &reply.UnknownErrReply{} } }() - cmd := strings.ToLower(string(args[0])) + if cmd == "auth" { + return db.Auth(cluster.db, c, args[1:]) + } + if !isAuthenticated(c) { + return reply.MakeErrReply("NOAUTH Authentication required") + } cmdFunc, ok := router[cmd] if !ok { return reply.MakeErrReply("ERR unknown command '" + cmd + "', or not supported in cluster mode") diff --git a/cluster/com_test.go b/cluster/com_test.go index 3759c6b..df4daed 100644 --- a/cluster/com_test.go +++ b/cluster/com_test.go @@ -1,6 +1,9 @@ package cluster import ( + "github.com/hdt3213/godis/config" + "github.com/hdt3213/godis/lib/utils" + "github.com/hdt3213/godis/redis/connection" "github.com/hdt3213/godis/redis/reply/asserts" "testing" ) @@ -16,6 +19,21 @@ func TestExec(t *testing.T) { } } +func TestAuth(t *testing.T) { + passwd := utils.RandString(10) + config.Properties.RequirePass = passwd + defer func() { + config.Properties.RequirePass = "" + }() + conn := &connection.FakeConn{} + ret := testCluster.Exec(conn, toArgs("GET", "a")) + asserts.AssertErrReply(t, ret, "NOAUTH Authentication required") + ret = testCluster.Exec(conn, toArgs("AUTH", passwd)) + asserts.AssertStatusReply(t, ret, "OK") + ret = testCluster.Exec(conn, toArgs("GET", "a")) + asserts.AssertNotError(t, ret) +} + func TestRelay(t *testing.T) { testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"}) key := RandString(4) diff --git a/config/config.go b/config/config.go index 64a7e65..03fab01 100644 --- a/config/config.go +++ b/config/config.go @@ -11,13 +11,15 @@ import ( ) type PropertyHolder struct { - Bind string `cfg:"bind"` - Port int `cfg:"port"` - AppendOnly bool `cfg:"appendOnly"` - AppendFilename string `cfg:"appendFilename"` - MaxClients int `cfg:"maxclients"` - Peers []string `cfg:"peers"` - Self string `cfg:"self"` + Bind string `cfg:"bind"` + Port int `cfg:"port"` + AppendOnly bool `cfg:"appendOnly"` + AppendFilename string `cfg:"appendFilename"` + MaxClients int `cfg:"maxclients"` + RequirePass string `cfg:"requirepass"` + + Peers []string `cfg:"peers"` + Self string `cfg:"self"` } var Properties *PropertyHolder diff --git a/db/db.go b/db/db.go index 54de57e..efbd349 100644 --- a/db/db.go +++ b/db/db.go @@ -112,7 +112,12 @@ func (db *DB) Exec(c redis.Connection, args [][]byte) (result redis.Reply) { }() cmd := strings.ToLower(string(args[0])) - + if cmd == "auth" { + return Auth(db, c, args[1:]) + } + if !isAuthenticated(c) { + return reply.MakeErrReply("NOAUTH Authentication required") + } // special commands if cmd == "subscribe" { if len(args) < 2 { diff --git a/db/server.go b/db/server.go index 3a981b2..2630c61 100644 --- a/db/server.go +++ b/db/server.go @@ -1,6 +1,7 @@ package db import ( + "github.com/hdt3213/godis/config" "github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/redis/reply" ) @@ -15,3 +16,26 @@ func Ping(db *DB, args [][]byte) redis.Reply { return reply.MakeErrReply("ERR wrong number of arguments for 'ping' command") } } + +// Auth validate client's password +func Auth(db *DB, c redis.Connection, args [][]byte) redis.Reply { + if len(args) != 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'auth' command") + } + if config.Properties.RequirePass == "" { + return reply.MakeErrReply("ERR Client sent AUTH, but no password is set") + } + passwd := string(args[0]) + c.SetPassword(passwd) + if config.Properties.RequirePass != passwd { + return reply.MakeErrReply("ERR invalid password") + } + return &reply.OkReply{} +} + +func isAuthenticated(c redis.Connection) bool { + if config.Properties.RequirePass == "" { + return true + } + return c.GetPassword() == config.Properties.RequirePass +} \ No newline at end of file diff --git a/db/server_test.go b/db/server_test.go index d06cb57..5bdfcf2 100644 --- a/db/server_test.go +++ b/db/server_test.go @@ -1,7 +1,9 @@ package db import ( + "github.com/hdt3213/godis/config" "github.com/hdt3213/godis/lib/utils" + "github.com/hdt3213/godis/redis/connection" "github.com/hdt3213/godis/redis/reply/asserts" "testing" ) @@ -15,3 +17,22 @@ func TestPing(t *testing.T) { actual = Ping(testDB, utils.ToBytesList(val, val)) asserts.AssertErrReply(t, actual, "ERR wrong number of arguments for 'ping' command") } + +func TestAuth(t *testing.T) { + passwd := utils.RandString(10) + c := &connection.FakeConn{} + ret := Auth(testDB, c, utils.ToBytesList()) + asserts.AssertErrReply(t, ret, "ERR wrong number of arguments for 'auth' command") + ret = Auth(testDB, c, utils.ToBytesList(passwd)) + asserts.AssertErrReply(t, ret, "ERR Client sent AUTH, but no password is set") + + config.Properties.RequirePass = passwd + defer func() { + config.Properties.RequirePass = "" + }() + ret = Auth(testDB, c, utils.ToBytesList(passwd+passwd)) + asserts.AssertErrReply(t, ret, "ERR invalid password") + ret = Auth(testDB, c, utils.ToBytesList(passwd)) + asserts.AssertStatusReply(t, ret, "OK") + +} \ No newline at end of file diff --git a/interface/redis/client.go b/interface/redis/client.go index d810a95..6441af9 100644 --- a/interface/redis/client.go +++ b/interface/redis/client.go @@ -2,7 +2,8 @@ package redis type Connection interface { Write([]byte) error - + SetPassword(string) + GetPassword() string // client should keep its subscribing channels Subscribe(channel string) UnSubscribe(channel string) diff --git a/redis/connection/conn.go b/redis/connection/conn.go index b6d9b8f..77a9951 100644 --- a/redis/connection/conn.go +++ b/redis/connection/conn.go @@ -20,6 +20,9 @@ type Connection struct { // subscribing channels subs map[string]bool + + // password may be changed by CONFIG command during runtime, so store the password + password string } // RemoteAddr returns the remote network address @@ -97,6 +100,14 @@ func (c *Connection) GetChannels() []string { return channels } +func (c *Connection) SetPassword(password string) { + c.password = password +} + +func (c *Connection) GetPassword() string { + return c.password +} + type FakeConn struct { Connection buf bytes.Buffer @@ -113,4 +124,4 @@ func (c *FakeConn) Clean() { func (c *FakeConn) Bytes() []byte { return c.buf.Bytes() -} \ No newline at end of file +}