Added ability to set the connection's database in SetConnectionInfo function. Implemented SELECT command to allow TCP connections to select a different database.

This commit is contained in:
Kelvin Mwinuka
2024-06-24 09:06:07 +08:00
parent 21aabda04d
commit dc9b33bc15
3 changed files with 66 additions and 4 deletions

View File

@@ -21,6 +21,7 @@ import (
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/internal/clock"
"github.com/echovault/echovault/internal/constants"
"github.com/echovault/echovault/internal/eviction"
"io"
"net"
"strings"
@@ -67,14 +68,43 @@ func (server *EchoVault) getHandlerFuncParams(ctx context.Context, cmd []string,
defer server.connInfo.mut.RUnlock()
return server.connInfo.tcpClients[conn]
},
SetConnectionInfo: func(conn *net.Conn, protocol int, clientname string) {
SetConnectionInfo: func(conn *net.Conn, clientname string, protocol int, database int) {
server.connInfo.mut.Lock()
defer server.connInfo.mut.Unlock()
info := server.connInfo.tcpClients[conn]
// Set protocol.
info.Protocol = protocol
// Set connection name.
if clientname != "" {
info.Name = clientname
}
// If the database index does not exist, create the new database.
server.storeLock.Lock()
if server.store[database] == nil {
// Database does not exist.
server.store[database] = make(map[string]internal.KeyData)
// Create volatile key tracker for the database.
server.keysWithExpiry.rwMutex.Lock()
server.keysWithExpiry.keys[database] = make([]string, 0)
server.keysWithExpiry.rwMutex.Unlock()
// Create LFU cache for the database.
server.lfuCache.mutex.Lock()
server.lfuCache.cache[database] = eviction.NewCacheLFU()
server.lfuCache.mutex.Unlock()
// Create LRU cache for the database.
server.lruCache.mutex.Lock()
server.lruCache.cache[database] = eviction.NewCacheLRU()
server.lruCache.mutex.Unlock()
}
server.storeLock.Unlock()
// Set database index for the current connection.
info.Database = database
server.connInfo.tcpClients[conn] = info
},
GetServerInfo: func() internal.ServerInfo {

View File

@@ -119,14 +119,31 @@ func handleHello(params internal.HandlerFuncParams) ([]byte, error) {
}
// Set the connection details.
params.SetConnectionInfo(params.Connection, options.protocol, options.clientname)
connectionInfo := params.GetConnectionInfo(params.Connection)
params.SetConnectionInfo(params.Connection, options.clientname, options.protocol, connectionInfo.Database)
// Get the new connection details and server info to return to the client.
serverInfo := params.GetServerInfo()
connectionInfo := params.GetConnectionInfo(params.Connection)
connectionInfo = params.GetConnectionInfo(params.Connection)
return buildHelloResponse(serverInfo, connectionInfo), nil
}
func handleSelect(params internal.HandlerFuncParams) ([]byte, error) {
if len(params.Command) != 2 {
return nil, errors.New(constants.WrongArgsResponse)
}
database, err := strconv.Atoi(params.Command[1])
if err != nil {
return nil, err
}
connectionInfo := params.GetConnectionInfo(params.Connection)
params.SetConnectionInfo(params.Connection, connectionInfo.Name, connectionInfo.Protocol, database)
return []byte(constants.OkResponse), nil
}
func Commands() []internal.Command {
return []internal.Command{
{
@@ -193,5 +210,20 @@ Otherwise, the server will return "PONG".`,
},
HandlerFunc: handleHello,
},
{
Command: "select",
Module: constants.ConnectionModule,
Categories: []string{constants.FastCategory, constants.ConnectionCategory},
Description: `(SELECT index) Change the logical database that the current connection is operating from.`,
Sync: false,
KeyExtractionFunc: func(cmd []string) (internal.KeyExtractionFuncResult, error) {
return internal.KeyExtractionFuncResult{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleSelect,
},
}
}

View File

@@ -133,7 +133,7 @@ type HandlerFuncParams struct {
// ListModules returns the list of modules loaded in the EchoVault instance.
ListModules func() []string
// SetConnectionInfo sets the connection's protocol and clientname.
SetConnectionInfo func(conn *net.Conn, protocol int, clientname string)
SetConnectionInfo func(conn *net.Conn, clientname string, protocol int, database int)
// GetConnectionInfo returns information about the current connection.
GetConnectionInfo func(conn *net.Conn) ConnectionInfo
// GetServerInfo returns information about the server when requested by commands such as HELLO.