diff --git a/codec/codec.go b/codec/codec.go index a61f0b5..92e7d36 100644 --- a/codec/codec.go +++ b/codec/codec.go @@ -4,7 +4,9 @@ import ( "bufio" "context" "encoding/binary" + "io" // added for full reads "net" + "time" // added for handling deadlines "github.com/oarkflow/mq/consts" "github.com/oarkflow/mq/internal/bpool" @@ -62,20 +64,29 @@ func SendMessage(ctx context.Context, conn net.Conn, msg *Message) error { binary.BigEndian.PutUint32(buffer.B[:4], uint32(len(data))) copy(buffer.B[4:], data) writer := bufio.NewWriter(conn) - select { - case <-ctx.Done(): - return ctx.Err() - default: - if _, err := writer.Write(buffer.B[:totalLength]); err != nil { - return err - } + + // Set write deadline if context has one + if deadline, ok := ctx.Deadline(); ok { + conn.SetWriteDeadline(deadline) + defer conn.SetWriteDeadline(time.Time{}) + } + + // Write full data + if _, err := writer.Write(buffer.B[:totalLength]); err != nil { + return err } return writer.Flush() } func ReadMessage(ctx context.Context, conn net.Conn) (*Message, error) { lengthBytes := make([]byte, 4) - if _, err := conn.Read(lengthBytes); err != nil { + // Set read deadline if context has one + if deadline, ok := ctx.Deadline(); ok { + conn.SetReadDeadline(deadline) + defer conn.SetReadDeadline(time.Time{}) + } + // Use io.ReadFull to ensure all header bytes are read + if _, err := io.ReadFull(conn, lengthBytes); err != nil { return nil, err } length := binary.BigEndian.Uint32(lengthBytes) @@ -87,18 +98,9 @@ func ReadMessage(ctx context.Context, conn net.Conn) (*Message, error) { } else { buffer.B = buffer.B[:length] } - totalRead := 0 - for totalRead < int(length) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - n, err := conn.Read(buffer.B[totalRead:]) - if err != nil { - return nil, err - } - totalRead += n - } + // Read the entire message payload + if _, err := io.ReadFull(conn, buffer.B[:length]); err != nil { + return nil, err } return Deserialize(buffer.B[:length]) }