Files
mochi-mqtt/packets/connect.go
2019-09-22 16:17:42 +01:00

178 lines
4.6 KiB
Go

package packets
import (
"bytes"
"errors"
)
// ConnectPacket contains the values of an MQTT CONNECT packet.
type ConnectPacket struct {
FixedHeader
ProtocolName string
ProtocolVersion byte
CleanSession bool
WillFlag bool
WillQos byte
WillRetain bool
UsernameFlag bool
PasswordFlag bool
ReservedBit byte
Keepalive uint16
ClientIdentifier string
WillTopic string
WillMessage []byte // WillMessage is a payload, so store as byte array.
Username string
Password string
}
// Encode encodes and writes the packet data values to the buffer.
func (pk *ConnectPacket) Encode(buf *bytes.Buffer) error {
var body bytes.Buffer
// Write flags to packet body.
body.Write(encodeString(pk.ProtocolName))
body.WriteByte(pk.ProtocolVersion)
body.WriteByte(encodeBool(pk.CleanSession)<<1 | encodeBool(pk.WillFlag)<<2 | pk.WillQos<<3 | encodeBool(pk.WillRetain)<<5 | encodeBool(pk.PasswordFlag)<<6 | encodeBool(pk.UsernameFlag)<<7)
body.Write(encodeUint16(pk.Keepalive))
body.Write(encodeString(pk.ClientIdentifier))
// If will flag is set, add topic and message.
if pk.WillFlag {
body.Write(encodeString(pk.WillTopic))
body.Write(encodeBytes(pk.WillMessage))
}
// If username flag is set, add username.
if pk.UsernameFlag {
body.Write(encodeString(pk.Username))
}
// If password flag is set, add password.
if pk.PasswordFlag {
body.Write(encodeString(pk.Password))
}
pk.FixedHeader.Remaining = body.Len()
pk.FixedHeader.encode(buf)
buf.Write(body.Bytes())
return nil
}
// Decode extracts the data values from the packet.
func (pk *ConnectPacket) Decode(buf []byte) error {
var offset int
var err error
// Unpack protocol name and version.
pk.ProtocolName, offset, err = decodeString(buf, 0)
if err != nil {
return errors.New(ErrMalformedProtocolName)
}
pk.ProtocolVersion, offset, err = decodeByte(buf, offset)
if err != nil {
return errors.New(ErrMalformedProtocolVersion)
}
// Unpack flags byte.
flags, offset, err := decodeByte(buf, offset)
if err != nil {
return errors.New(ErrMalformedFlags)
}
pk.ReservedBit = 1 & flags
pk.CleanSession = 1&(flags>>1) > 0
pk.WillFlag = 1&(flags>>2) > 0
pk.WillQos = 3 & (flags >> 3) // this one is not a bool
pk.WillRetain = 1&(flags>>5) > 0
pk.PasswordFlag = 1&(flags>>6) > 0
pk.UsernameFlag = 1&(flags>>7) > 0
// Get keepalive interval.
pk.Keepalive, offset, err = decodeUint16(buf, offset)
if err != nil {
return errors.New(ErrMalformedKeepalive)
}
// Get client ID.
pk.ClientIdentifier, offset, err = decodeString(buf, offset)
if err != nil {
return errors.New(ErrMalformedClientID)
}
// Get Last Will and Testament topic and message if applicable.
if pk.WillFlag {
pk.WillTopic, offset, err = decodeString(buf, offset)
if err != nil {
return errors.New(ErrMalformedWillTopic)
}
pk.WillMessage, offset, err = decodeBytes(buf, offset)
if err != nil {
return errors.New(ErrMalformedWillMessage)
}
}
// Get username and password if applicable.
if pk.UsernameFlag {
pk.Username, offset, err = decodeString(buf, offset)
if err != nil {
return errors.New(ErrMalformedUsername)
}
}
if pk.PasswordFlag {
pk.Password, offset, err = decodeString(buf, offset)
if err != nil {
return errors.New(ErrMalformedPassword)
}
}
return nil
}
// Validate ensures the packet is compliant.
func (pk *ConnectPacket) Validate() (b byte, err error) {
// End if protocol name is bad.
if pk.ProtocolName != "MQIsdp" && pk.ProtocolName != "MQTT" {
return ErrConnectProtocolViolation, errors.New(ErrProtocolViolation)
}
// End if protocol version is bad.
if (pk.ProtocolName == "MQIsdp" && pk.ProtocolVersion != 3) ||
(pk.ProtocolName == "MQTT" && pk.ProtocolVersion != 4) {
return ErrConnectBadProtocolVersion, errors.New(ErrProtocolViolation)
}
// End if reserved bit is not 0.
if pk.ReservedBit != 0 {
return ErrConnectProtocolViolation, errors.New(ErrProtocolViolation)
}
// End if ClientID is too long.
if len(pk.ClientIdentifier) > 65535 {
return ErrConnectProtocolViolation, errors.New(ErrProtocolViolation)
}
// End if password flag is set without a username.
if pk.PasswordFlag && !pk.UsernameFlag {
return ErrConnectProtocolViolation, errors.New(ErrProtocolViolation)
}
// End if Username or Password is too long.
if len(pk.Username) > 65535 || len(pk.Password) > 65535 {
return ErrConnectProtocolViolation, errors.New(ErrProtocolViolation)
}
// End if client id isn't set and clean session is false.
if !pk.CleanSession && len(pk.ClientIdentifier) == 0 {
return ErrConnectBadClientID, errors.New(ErrProtocolViolation)
}
return Accepted, nil
}