Close client connection on quit command

This commit is contained in:
Kelvin Clement Mwinuka
2024-05-30 20:01:24 +08:00
parent 502e804459
commit 6f8511632e
3 changed files with 15 additions and 8 deletions

View File

@@ -399,6 +399,13 @@ func (server *EchoVault) handleConnection(conn net.Conn) {
ctx := context.WithValue(server.context, internal.ContextConnID("ConnectionID"), ctx := context.WithValue(server.context, internal.ContextConnID("ConnectionID"),
fmt.Sprintf("%s-%d", server.context.Value(internal.ContextServerID("ServerID")), cid)) fmt.Sprintf("%s-%d", server.context.Value(internal.ContextServerID("ServerID")), cid))
defer func() {
log.Printf("closing connection %d...", cid)
if err := conn.Close(); err != nil {
log.Println(err)
}
}()
for { for {
message, err := internal.ReadMessage(r) message, err := internal.ReadMessage(r)
@@ -414,11 +421,9 @@ func (server *EchoVault) handleConnection(conn net.Conn) {
} }
res, err := server.handleCommand(ctx, message, &conn, false, false) res, err := server.handleCommand(ctx, message, &conn, false, false)
if err != nil && errors.Is(err, io.EOF) { if err != nil && errors.Is(err, io.EOF) {
break break
} }
if err != nil { if err != nil {
if _, err = w.Write([]byte(fmt.Sprintf("-Error %s\r\n", err.Error()))); err != nil { if _, err = w.Write([]byte(fmt.Sprintf("-Error %s\r\n", err.Error()))); err != nil {
log.Println(err) log.Println(err)
@@ -428,7 +433,7 @@ func (server *EchoVault) handleConnection(conn net.Conn) {
chunkSize := 1024 chunkSize := 1024
// If the length of the response is 0, return nothing to the client // If the length of the response is 0, return nothing to the client.
if len(res) == 0 { if len(res) == 0 {
continue continue
} }
@@ -456,10 +461,6 @@ func (server *EchoVault) handleConnection(conn net.Conn) {
startIndex += chunkSize startIndex += chunkSize
} }
} }
if err := conn.Close(); err != nil {
log.Println(err)
}
} }
// Start starts the EchoVault instance's TCP listener. // Start starts the EchoVault instance's TCP listener.

View File

@@ -21,6 +21,7 @@ import (
"github.com/echovault/echovault/internal" "github.com/echovault/echovault/internal"
"github.com/echovault/echovault/internal/clock" "github.com/echovault/echovault/internal/clock"
"github.com/echovault/echovault/internal/constants" "github.com/echovault/echovault/internal/constants"
"io"
"net" "net"
"strings" "strings"
) )
@@ -74,6 +75,11 @@ func (server *EchoVault) handleCommand(ctx context.Context, message []byte, conn
return nil, errors.New("empty command") return nil, errors.New("empty command")
} }
// If quit command is passed, EOF error.
if strings.EqualFold(cmd[0], "quit") {
return nil, io.EOF
}
command, err := server.getCommand(cmd[0]) command, err := server.getCommand(cmd[0])
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -62,7 +62,7 @@ func setUpServer(bindAddr string, port uint16, requirePass bool, aclConfig strin
echovault.WithConfig(conf), echovault.WithConfig(conf),
) )
// Add the initial test users to the ACL module // Add the initial test users to the ACL module.
// a.AddUsers(generateInitialTestUsers()) // a.AddUsers(generateInitialTestUsers())
return mockServer return mockServer