diff --git a/echovault/modules.go b/echovault/modules.go index 0596909..7e9f080 100644 --- a/echovault/modules.go +++ b/echovault/modules.go @@ -21,6 +21,7 @@ import ( "github.com/echovault/echovault/internal" "github.com/echovault/echovault/internal/clock" "github.com/echovault/echovault/internal/constants" + "github.com/echovault/echovault/internal/eviction" "io" "net" "strings" @@ -67,14 +68,43 @@ func (server *EchoVault) getHandlerFuncParams(ctx context.Context, cmd []string, defer server.connInfo.mut.RUnlock() return server.connInfo.tcpClients[conn] }, - SetConnectionInfo: func(conn *net.Conn, protocol int, clientname string) { + SetConnectionInfo: func(conn *net.Conn, clientname string, protocol int, database int) { server.connInfo.mut.Lock() defer server.connInfo.mut.Unlock() + info := server.connInfo.tcpClients[conn] + + // Set protocol. info.Protocol = protocol + + // Set connection name. if clientname != "" { info.Name = clientname } + + // If the database index does not exist, create the new database. + server.storeLock.Lock() + if server.store[database] == nil { + // Database does not exist. + server.store[database] = make(map[string]internal.KeyData) + // Create volatile key tracker for the database. + server.keysWithExpiry.rwMutex.Lock() + server.keysWithExpiry.keys[database] = make([]string, 0) + server.keysWithExpiry.rwMutex.Unlock() + // Create LFU cache for the database. + server.lfuCache.mutex.Lock() + server.lfuCache.cache[database] = eviction.NewCacheLFU() + server.lfuCache.mutex.Unlock() + // Create LRU cache for the database. + server.lruCache.mutex.Lock() + server.lruCache.cache[database] = eviction.NewCacheLRU() + server.lruCache.mutex.Unlock() + } + server.storeLock.Unlock() + + // Set database index for the current connection. + info.Database = database + server.connInfo.tcpClients[conn] = info }, GetServerInfo: func() internal.ServerInfo { diff --git a/internal/modules/connection/commands.go b/internal/modules/connection/commands.go index b3079a5..f370b28 100644 --- a/internal/modules/connection/commands.go +++ b/internal/modules/connection/commands.go @@ -119,14 +119,31 @@ func handleHello(params internal.HandlerFuncParams) ([]byte, error) { } // Set the connection details. - params.SetConnectionInfo(params.Connection, options.protocol, options.clientname) + connectionInfo := params.GetConnectionInfo(params.Connection) + params.SetConnectionInfo(params.Connection, options.clientname, options.protocol, connectionInfo.Database) // Get the new connection details and server info to return to the client. serverInfo := params.GetServerInfo() - connectionInfo := params.GetConnectionInfo(params.Connection) + connectionInfo = params.GetConnectionInfo(params.Connection) return buildHelloResponse(serverInfo, connectionInfo), nil } +func handleSelect(params internal.HandlerFuncParams) ([]byte, error) { + if len(params.Command) != 2 { + return nil, errors.New(constants.WrongArgsResponse) + } + + database, err := strconv.Atoi(params.Command[1]) + if err != nil { + return nil, err + } + + connectionInfo := params.GetConnectionInfo(params.Connection) + params.SetConnectionInfo(params.Connection, connectionInfo.Name, connectionInfo.Protocol, database) + + return []byte(constants.OkResponse), nil +} + func Commands() []internal.Command { return []internal.Command{ { @@ -193,5 +210,20 @@ Otherwise, the server will return "PONG".`, }, HandlerFunc: handleHello, }, + { + Command: "select", + Module: constants.ConnectionModule, + Categories: []string{constants.FastCategory, constants.ConnectionCategory}, + Description: `(SELECT index) Change the logical database that the current connection is operating from.`, + Sync: false, + KeyExtractionFunc: func(cmd []string) (internal.KeyExtractionFuncResult, error) { + return internal.KeyExtractionFuncResult{ + Channels: make([]string, 0), + ReadKeys: make([]string, 0), + WriteKeys: make([]string, 0), + }, nil + }, + HandlerFunc: handleSelect, + }, } } diff --git a/internal/types.go b/internal/types.go index 4b8b2de..403f5e1 100644 --- a/internal/types.go +++ b/internal/types.go @@ -133,7 +133,7 @@ type HandlerFuncParams struct { // ListModules returns the list of modules loaded in the EchoVault instance. ListModules func() []string // SetConnectionInfo sets the connection's protocol and clientname. - SetConnectionInfo func(conn *net.Conn, protocol int, clientname string) + SetConnectionInfo func(conn *net.Conn, clientname string, protocol int, database int) // GetConnectionInfo returns information about the current connection. GetConnectionInfo func(conn *net.Conn) ConnectionInfo // GetServerInfo returns information about the server when requested by commands such as HELLO.