diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 49462d0..0000000 Binary files a/.DS_Store and /dev/null differ diff --git a/.gitignore b/.gitignore index 4bc60d7..d09d700 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,5 @@ # Go workspace file go.work -.idea \ No newline at end of file +.idea +.DS_Store diff --git a/broker/.DS_Store b/broker/.DS_Store deleted file mode 100644 index 5008ddf..0000000 Binary files a/broker/.DS_Store and /dev/null differ diff --git a/codec/codec.go b/codec/codec.go index 8aa5953..aa13a87 100644 --- a/codec/codec.go +++ b/codec/codec.go @@ -12,7 +12,7 @@ import ( "encoding/json" "fmt" "io" - + "github.com/oarkflow/mq/consts" ) @@ -63,29 +63,19 @@ func VerifyHMAC(key []byte, data []byte, receivedHMAC string) bool { return hmac.Equal([]byte(expectedHMAC), []byte(receivedHMAC)) } -func (m *Message) Serialize(aesKey []byte, hmacKey []byte) ([]byte, string, error) { +func (m *Message) Serialize(aesKey, hmacKey []byte) ([]byte, string, error) { var buf bytes.Buffer - headersBytes, err := json.Marshal(m.Headers) - if err != nil { - return nil, "", fmt.Errorf("error marshalling headers: %v", err) - } - headersLen := uint32(len(headersBytes)) - if err := binary.Write(&buf, binary.LittleEndian, headersLen); err != nil { + + if err := writeHeaders(&buf, m.Headers); err != nil { return nil, "", err } - buf.Write(headersBytes) - topicBytes := []byte(m.Topic) - topicLen := uint8(len(topicBytes)) - if err := binary.Write(&buf, binary.LittleEndian, topicLen); err != nil { + if err := writeString(&buf, m.Topic); err != nil { return nil, "", err } - buf.Write(topicBytes) - if !m.Command.IsValid() { - return nil, "", fmt.Errorf("invalid command: %s", m.Command) - } - if err := binary.Write(&buf, binary.LittleEndian, m.Command); err != nil { + if err := writeCommand(&buf, m.Command); err != nil { return nil, "", err } + payloadBytes, err := json.Marshal(m.Payload) if err != nil { return nil, "", fmt.Errorf("error marshalling payload: %v", err) @@ -94,24 +84,96 @@ func (m *Message) Serialize(aesKey []byte, hmacKey []byte) ([]byte, string, erro if err != nil { return nil, "", err } - payloadLen := uint32(len(encryptedPayload)) - if err := binary.Write(&buf, binary.LittleEndian, payloadLen); err != nil { + if err := writePayload(&buf, encryptedPayload); err != nil { return nil, "", err } - buf.Write(encryptedPayload) buf.Write(nonce) + messageBytes := buf.Bytes() hmacSignature := CalculateHMAC(hmacKey, messageBytes) return messageBytes, hmacSignature, nil } -func Deserialize(data []byte, aesKey []byte, hmacKey []byte, receivedHMAC string) (*Message, error) { +func Deserialize(data []byte, aesKey, hmacKey []byte, receivedHMAC string) (*Message, error) { if !VerifyHMAC(hmacKey, data, receivedHMAC) { return nil, fmt.Errorf("HMAC verification failed") } + buf := bytes.NewReader(data) - var topicLen uint8 - var payloadLen uint32 + + headers, err := readHeaders(buf) + if err != nil { + return nil, err + } + topic, err := readString(buf) + if err != nil { + return nil, err + } + command, err := readCommand(buf) + if err != nil { + return nil, err + } + encryptedPayload, nonce, err := readPayload(buf) + if err != nil { + return nil, err + } + + payloadBytes, err := DecryptPayload(aesKey, encryptedPayload, nonce) + if err != nil { + return nil, err + } + var payload json.RawMessage + if err := json.Unmarshal(payloadBytes, &payload); err != nil { + return nil, fmt.Errorf("error unmarshalling payload: %v", err) + } + + return &Message{ + Headers: headers, + Topic: topic, + Command: command, + Payload: payload, + }, nil +} + +func writeHeaders(buf *bytes.Buffer, headers map[string]string) error { + headersBytes, err := json.Marshal(headers) + if err != nil { + return fmt.Errorf("error marshalling headers: %v", err) + } + headersLen := uint32(len(headersBytes)) + if err := binary.Write(buf, binary.LittleEndian, headersLen); err != nil { + return err + } + buf.Write(headersBytes) + return nil +} + +func writeString(buf *bytes.Buffer, str string) error { + strBytes := []byte(str) + if err := binary.Write(buf, binary.LittleEndian, uint8(len(strBytes))); err != nil { + return err + } + buf.Write(strBytes) + return nil +} + +func writeCommand(buf *bytes.Buffer, command consts.CMD) error { + if !command.IsValid() { + return fmt.Errorf("invalid command: %s", command) + } + return binary.Write(buf, binary.LittleEndian, command) +} + +func writePayload(buf *bytes.Buffer, encryptedPayload []byte) error { + payloadLen := uint32(len(encryptedPayload)) + if err := binary.Write(buf, binary.LittleEndian, payloadLen); err != nil { + return err + } + buf.Write(encryptedPayload) + return nil +} + +func readHeaders(buf *bytes.Reader) (map[string]string, error) { var headersLen uint32 if err := binary.Read(buf, binary.LittleEndian, &headersLen); err != nil { return nil, err @@ -124,46 +186,46 @@ func Deserialize(data []byte, aesKey []byte, hmacKey []byte, receivedHMAC string if err := json.Unmarshal(headersBytes, &headers); err != nil { return nil, err } + return headers, nil +} + +func readString(buf *bytes.Reader) (string, error) { + var topicLen uint8 if err := binary.Read(buf, binary.LittleEndian, &topicLen); err != nil { - return nil, err + return "", err } topicBytes := make([]byte, topicLen) if _, err := buf.Read(topicBytes); err != nil { - return nil, err + return "", err } - topic := string(topicBytes) + return string(topicBytes), nil +} + +func readCommand(buf *bytes.Reader) (consts.CMD, error) { var command consts.CMD if err := binary.Read(buf, binary.LittleEndian, &command); err != nil { - return nil, err + return command, err } if !command.IsValid() { - return nil, fmt.Errorf("invalid command: %s", command) + return command, fmt.Errorf("invalid command: %s", command) } + return command, nil +} + +func readPayload(buf *bytes.Reader) ([]byte, []byte, error) { + var payloadLen uint32 if err := binary.Read(buf, binary.LittleEndian, &payloadLen); err != nil { - return nil, err + return nil, nil, err } encryptedPayload := make([]byte, payloadLen) if _, err := buf.Read(encryptedPayload); err != nil { - return nil, err + return nil, nil, err } - nonce := make([]byte, 12) + nonce := make([]byte, 12) // Adjust size as needed if _, err := buf.Read(nonce); err != nil { - return nil, err + return nil, nil, err } - payloadBytes, err := DecryptPayload(aesKey, encryptedPayload, nonce) - if err != nil { - return nil, err - } - var payload json.RawMessage - if err := json.Unmarshal(payloadBytes, &payload); err != nil { - return nil, fmt.Errorf("error unmarshalling payload: %v", err) - } - return &Message{ - Headers: headers, - Topic: topic, - Command: command, - Payload: payload, - }, nil + return encryptedPayload, nonce, nil } func SendMessage(conn io.Writer, msg *Message, aesKey []byte, hmacKey []byte) error { @@ -174,7 +236,7 @@ func SendMessage(conn io.Writer, msg *Message, aesKey []byte, hmacKey []byte) er if err := binary.Write(conn, binary.LittleEndian, uint32(len(sentData))); err != nil { return err } - + if _, err := conn.Write(sentData); err != nil { return err } @@ -193,7 +255,7 @@ func ReadMessage(conn io.Reader, aesKey []byte, hmacKey []byte) (*Message, error if _, err := io.ReadFull(conn, data); err != nil { return nil, err } - + hmacBytes := make([]byte, 64) if _, err := io.ReadFull(conn, hmacBytes); err != nil { return nil, err