diff --git a/conn.go b/conn.go index f73abd7a..97c58fac 100644 --- a/conn.go +++ b/conn.go @@ -5,47 +5,59 @@ import ( "fmt" "io" "net" + "time" +) + +const ( + _READ_DEADLINE = 10 * time.Second + _WRITE_DEADLINE = 10 * time.Second ) type Conn struct { - c net.Conn + nconn net.Conn writeBuf []byte } -func NewConn(c net.Conn) *Conn { +func NewConn(nconn net.Conn) *Conn { return &Conn{ - c: c, + nconn: nconn, writeBuf: make([]byte, 2048), } } func (c *Conn) Close() error { - return c.c.Close() + return c.nconn.Close() } func (c *Conn) RemoteAddr() net.Addr { - return c.c.RemoteAddr() + return c.nconn.RemoteAddr() } func (c *Conn) ReadRequest() (*Request, error) { - return requestDecode(c.c) + c.nconn.SetReadDeadline(time.Now().Add(_READ_DEADLINE)) + return requestDecode(c.nconn) } func (c *Conn) WriteRequest(req *Request) error { - return requestEncode(c.c, req) + c.nconn.SetWriteDeadline(time.Now().Add(_WRITE_DEADLINE)) + return requestEncode(c.nconn, req) } func (c *Conn) ReadResponse() (*Response, error) { - return responseDecode(c.c) + c.nconn.SetReadDeadline(time.Now().Add(_READ_DEADLINE)) + return responseDecode(c.nconn) } func (c *Conn) WriteResponse(res *Response) error { - return responseEncode(c.c, res) + c.nconn.SetWriteDeadline(time.Now().Add(_WRITE_DEADLINE)) + return responseEncode(c.nconn, res) } func (c *Conn) ReadInterleavedFrame(frame []byte) (int, int, error) { + c.nconn.SetReadDeadline(time.Now().Add(_READ_DEADLINE)) + var header [4]byte - _, err := io.ReadFull(c.c, header[:]) + _, err := io.ReadFull(c.nconn, header[:]) if err != nil { return 0, 0, err } @@ -64,7 +76,7 @@ func (c *Conn) ReadInterleavedFrame(frame []byte) (int, int, error) { return 0, 0, fmt.Errorf("frame length greater than 2048") } - _, err = io.ReadFull(c.c, frame[:framelen]) + _, err = io.ReadFull(c.nconn, frame[:framelen]) if err != nil { return 0, 0, err } @@ -73,12 +85,14 @@ func (c *Conn) ReadInterleavedFrame(frame []byte) (int, int, error) { } func (c *Conn) WriteInterleavedFrame(channel int, frame []byte) error { + c.nconn.SetWriteDeadline(time.Now().Add(_WRITE_DEADLINE)) + c.writeBuf[0] = 0x24 c.writeBuf[1] = byte(channel) binary.BigEndian.PutUint16(c.writeBuf[2:], uint16(len(frame))) n := copy(c.writeBuf[4:], frame) - _, err := c.c.Write(c.writeBuf[:4+n]) + _, err := c.nconn.Write(c.writeBuf[:4+n]) if err != nil { return err } @@ -86,5 +100,5 @@ func (c *Conn) WriteInterleavedFrame(channel int, frame []byte) error { } func (c *Conn) Read(buf []byte) (int, error) { - return c.c.Read(buf) + return c.nconn.Read(buf) }