package connection import ( "net" "sync" "time" "github.com/hdt3213/godis/lib/logger" "github.com/hdt3213/godis/lib/sync/wait" ) const ( // flagSlave means this a connection with slave flagSlave = uint64(1 << iota) // flagSlave means this a connection with master flagMaster // flagMulti means this connection is within a transaction flagMulti ) // Connection represents a connection with a redis-cli type Connection struct { conn net.Conn // wait until finish sending data, used for graceful shutdown sendingData wait.Wait // lock while server sending response mu sync.Mutex flags uint64 // subscribing channels subs map[string]bool // password may be changed by CONFIG command during runtime, so store the password password string // queued commands for `multi` queue [][][]byte watching map[string]uint32 txErrors []error // selected db selectedDB int } var connPool = sync.Pool{ New: func() interface{} { return &Connection{} }, } // RemoteAddr returns the remote network address func (c *Connection) RemoteAddr() string { return c.conn.RemoteAddr().String() } // Close disconnect with the client func (c *Connection) Close() error { c.sendingData.WaitWithTimeout(10 * time.Second) if c.conn != nil { // may be a fake conn for tests _ = c.conn.Close() } c.subs = nil c.password = "" c.queue = nil c.watching = nil c.txErrors = nil c.selectedDB = 0 connPool.Put(c) return nil } // NewConn creates Connection instance func NewConn(conn net.Conn) *Connection { c, ok := connPool.Get().(*Connection) if !ok { logger.Error("connection pool make wrong type") return &Connection{ conn: conn, } } c.conn = conn return c } // Write sends response to client over tcp connection func (c *Connection) Write(b []byte) (int, error) { if len(b) == 0 { return 0, nil } c.sendingData.Add(1) defer func() { c.sendingData.Done() }() return c.conn.Write(b) } func (c *Connection) Name() string { if c.conn != nil { return c.conn.RemoteAddr().String() } return "" } // Subscribe add current connection into subscribers of the given channel func (c *Connection) Subscribe(channel string) { c.mu.Lock() defer c.mu.Unlock() if c.subs == nil { c.subs = make(map[string]bool) } c.subs[channel] = true } // UnSubscribe removes current connection into subscribers of the given channel func (c *Connection) UnSubscribe(channel string) { c.mu.Lock() defer c.mu.Unlock() if len(c.subs) == 0 { return } delete(c.subs, channel) } // SubsCount returns the number of subscribing channels func (c *Connection) SubsCount() int { return len(c.subs) } // GetChannels returns all subscribing channels func (c *Connection) GetChannels() []string { if c.subs == nil { return make([]string, 0) } channels := make([]string, len(c.subs)) i := 0 for channel := range c.subs { channels[i] = channel i++ } return channels } // SetPassword stores password for authentication func (c *Connection) SetPassword(password string) { c.password = password } // GetPassword get password for authentication func (c *Connection) GetPassword() string { return c.password } // InMultiState tells is connection in an uncommitted transaction func (c *Connection) InMultiState() bool { return c.flags&flagMulti > 0 } // SetMultiState sets transaction flag func (c *Connection) SetMultiState(state bool) { if !state { // reset data when cancel multi c.watching = nil c.queue = nil c.flags &= ^flagMulti // clean multi flag return } c.flags |= flagMulti } // GetQueuedCmdLine returns queued commands of current transaction func (c *Connection) GetQueuedCmdLine() [][][]byte { return c.queue } // EnqueueCmd enqueues command of current transaction func (c *Connection) EnqueueCmd(cmdLine [][]byte) { c.queue = append(c.queue, cmdLine) } // AddTxError stores syntax error within transaction func (c *Connection) AddTxError(err error) { c.txErrors = append(c.txErrors, err) } // GetTxErrors returns syntax error within transaction func (c *Connection) GetTxErrors() []error { return c.txErrors } // ClearQueuedCmds clears queued commands of current transaction func (c *Connection) ClearQueuedCmds() { c.queue = nil } // GetWatching returns watching keys and their version code when started watching func (c *Connection) GetWatching() map[string]uint32 { if c.watching == nil { c.watching = make(map[string]uint32) } return c.watching } // GetDBIndex returns selected db func (c *Connection) GetDBIndex() int { return c.selectedDB } // SelectDB selects a database func (c *Connection) SelectDB(dbNum int) { c.selectedDB = dbNum } func (c *Connection) SetSlave() { c.flags |= flagSlave } func (c *Connection) IsSlave() bool { return c.flags&flagSlave > 0 } // SetMaster marks c as a connection with master func (c *Connection) SetMaster() { c.flags |= flagMaster } func (c *Connection) IsMaster() bool { return c.flags&flagMaster > 0 }