package client import ( "errors" "github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/lib/logger" "github.com/hdt3213/godis/lib/sync/wait" "github.com/hdt3213/godis/redis/parser" "github.com/hdt3213/godis/redis/protocol" "net" "runtime/debug" "strings" "sync" "sync/atomic" "time" ) const ( created = iota running closed ) // Client is a pipeline mode redis client type Client struct { conn net.Conn pendingReqs chan *request // wait to send waitingReqs chan *request // waiting response ticker *time.Ticker addr string status int32 working *sync.WaitGroup // its counter presents unfinished requests(pending and waiting) } // request is a message sends to redis server type request struct { id uint64 args [][]byte reply redis.Reply heartbeat bool waiting *wait.Wait err error } const ( chanSize = 256 maxWait = 3 * time.Second ) // MakeClient creates a new client func MakeClient(addr string) (*Client, error) { conn, err := net.Dial("tcp", addr) if err != nil { return nil, err } return &Client{ addr: addr, conn: conn, pendingReqs: make(chan *request, chanSize), waitingReqs: make(chan *request, chanSize), working: &sync.WaitGroup{}, }, nil } func (client *Client) RemoteAddress() string { return client.addr } // Start starts asynchronous goroutines func (client *Client) Start() { client.ticker = time.NewTicker(10 * time.Second) go client.handleWrite() go client.handleRead() go client.heartbeat() atomic.StoreInt32(&client.status, running) } // Close stops asynchronous goroutines and close connection func (client *Client) Close() { atomic.StoreInt32(&client.status, closed) client.ticker.Stop() // stop new request close(client.pendingReqs) // wait stop process client.working.Wait() // clean _ = client.conn.Close() close(client.waitingReqs) } func (client *Client) reconnect() { logger.Info("reconnect with: " + client.addr) _ = client.conn.Close() // ignore possible errors from repeated closes var conn net.Conn for i := 0; i < 3; i++ { var err error conn, err = net.Dial("tcp", client.addr) if err != nil { logger.Error("reconnect error: " + err.Error()) time.Sleep(time.Second) continue } else { break } } if conn == nil { // reach max retry, abort client.Close() return } client.conn = conn close(client.waitingReqs) for req := range client.waitingReqs { req.err = errors.New("connection closed") req.waiting.Done() } client.waitingReqs = make(chan *request, chanSize) // restart handle read go client.handleRead() } func (client *Client) heartbeat() { for range client.ticker.C { client.doHeartbeat() } } func (client *Client) handleWrite() { for req := range client.pendingReqs { client.doRequest(req) } } // Send sends a request to redis server func (client *Client) Send(args [][]byte) redis.Reply { if atomic.LoadInt32(&client.status) != running { return protocol.MakeErrReply("client closed") } req := &request{ args: args, heartbeat: false, waiting: &wait.Wait{}, } req.waiting.Add(1) client.working.Add(1) defer client.working.Done() client.pendingReqs <- req timeout := req.waiting.WaitWithTimeout(maxWait) if timeout { return protocol.MakeErrReply("server time out") } if req.err != nil { return protocol.MakeErrReply("request failed " + req.err.Error()) } return req.reply } func (client *Client) doHeartbeat() { request := &request{ args: [][]byte{[]byte("PING")}, heartbeat: true, waiting: &wait.Wait{}, } request.waiting.Add(1) client.working.Add(1) defer client.working.Done() client.pendingReqs <- request request.waiting.WaitWithTimeout(maxWait) } func (client *Client) doRequest(req *request) { if req == nil || len(req.args) == 0 { return } re := protocol.MakeMultiBulkReply(req.args) bytes := re.ToBytes() var err error for i := 0; i < 3; i++ { // only retry, waiting for handleRead _, err = client.conn.Write(bytes) if err == nil || (!strings.Contains(err.Error(), "timeout") && // only retry timeout !strings.Contains(err.Error(), "deadline exceeded")) { break } } if err == nil { client.waitingReqs <- req } else { req.err = err req.waiting.Done() } } func (client *Client) finishRequest(reply redis.Reply) { defer func() { if err := recover(); err != nil { debug.PrintStack() logger.Error(err) } }() request := <-client.waitingReqs if request == nil { return } request.reply = reply if request.waiting != nil { request.waiting.Done() } } func (client *Client) handleRead() { ch := parser.ParseStream(client.conn) for payload := range ch { if payload.Err != nil { status := atomic.LoadInt32(&client.status) if status == closed { return } client.reconnect() return } client.finishRequest(payload.Data) } }