diff --git a/redis/client/client.go b/redis/client/client.go index 250de19..55452fa 100644 --- a/redis/client/client.go +++ b/redis/client/client.go @@ -1,6 +1,7 @@ package client import ( + "errors" "github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/lib/logger" "github.com/hdt3213/godis/lib/sync/wait" @@ -8,6 +9,7 @@ import ( "github.com/hdt3213/godis/redis/protocol" "net" "runtime/debug" + "strings" "sync" "time" ) @@ -57,12 +59,7 @@ func MakeClient(addr string) (*Client, error) { func (client *Client) Start() { client.ticker = time.NewTicker(10 * time.Second) go client.handleWrite() - go func() { - err := client.handleRead() - if err != nil { - logger.Error(err) - } - }() + go client.handleRead() go client.heartbeat() } @@ -80,27 +77,35 @@ func (client *Client) Close() { close(client.waitingReqs) } -func (client *Client) handleConnectionError(err error) error { - err1 := client.conn.Close() - if err1 != nil { - if opErr, ok := err1.(*net.OpError); ok { - if opErr.Err.Error() != "use of closed network connection" { - return err1 - } +func (client *Client) reconnect() { + _ = client.conn.Close() // ignore possible errors from repeated closes + + var conn net.Conn + for i := 0; i < 3; i++ { + var err error + conn, err = net.Dial("tcp", client.addr) + if err != nil { + logger.Error("reconnect error: " + err.Error()) + time.Sleep(time.Second) + continue } else { - return err1 + break } } - conn, err1 := net.Dial("tcp", client.addr) - if err1 != nil { - logger.Error(err1) - return err1 + if conn == nil { // reach max retry, abort + client.Close() + return } client.conn = conn - go func() { - _ = client.handleRead() - }() - return nil + // + close(client.waitingReqs) + for req := range client.waitingReqs { + req.err = errors.New("connection closed") + req.waiting.Done() + } + client.waitingReqs = make(chan *request, chanSize) + // restart handle read + go client.handleRead() } func (client *Client) heartbeat() { @@ -155,14 +160,14 @@ func (client *Client) doRequest(req *request) { } re := protocol.MakeMultiBulkReply(req.args) bytes := re.ToBytes() - _, err := client.conn.Write(bytes) - i := 0 - for err != nil && i < 3 { - err = client.handleConnectionError(err) - if err == nil { - _, err = client.conn.Write(bytes) + var err error + for i := 0; i < 3; i++ { // only retry, waiting for handleRead + _, err = client.conn.Write(bytes) + if err == nil || + (!strings.Contains(err.Error(), "timeout") && // only retry timeout + !strings.Contains(err.Error(), "deadline exceeded")) { + break } - i++ } if err == nil { client.waitingReqs <- req @@ -189,14 +194,13 @@ func (client *Client) finishRequest(reply redis.Reply) { } } -func (client *Client) handleRead() error { +func (client *Client) handleRead() { ch := parser.ParseStream(client.conn) for payload := range ch { if payload.Err != nil { - client.finishRequest(protocol.MakeErrReply(payload.Err.Error())) - continue + client.reconnect() + return } client.finishRequest(payload.Data) } - return nil } diff --git a/redis/client/client_test.go b/redis/client/client_test.go index da414f3..ff7458f 100644 --- a/redis/client/client_test.go +++ b/redis/client/client_test.go @@ -1,10 +1,12 @@ package client import ( + "bytes" "github.com/hdt3213/godis/lib/logger" "github.com/hdt3213/godis/redis/protocol" "strconv" "testing" + "time" ) func TestClient(t *testing.T) { @@ -104,3 +106,36 @@ func TestClient(t *testing.T) { client.Close() } + +func TestReconnect(t *testing.T) { + logger.Setup(&logger.Settings{ + Path: "logs", + Name: "godis", + Ext: ".log", + TimeFormat: "2006-01-02", + }) + client, err := MakeClient("localhost:6379") + if err != nil { + t.Error(err) + } + client.Start() + + _ = client.conn.Close() + time.Sleep(time.Second) // wait for reconnecting + success := false + for i := 0; i < 3; i++ { + result := client.Send([][]byte{ + []byte("PING"), + }) + if bytes.Equal(result.ToBytes(), []byte("+PONG\r\n")) { + success = true + break + } + } + if !success { + t.Error("reconnect error") + } + //var wg sync.WaitGroup + //wg.Add(1) + //wg.Wait() +}