mirror of
				https://github.com/mochi-mqtt/server.git
				synced 2025-10-31 03:26:28 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			184 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			184 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package packets
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"errors"
 | |
| 	"io"
 | |
| )
 | |
| 
 | |
| // ConnectPacket contains the values of an MQTT CONNECT packet.
 | |
| type ConnectPacket struct {
 | |
| 	FixedHeader
 | |
| 
 | |
| 	ProtocolName     string
 | |
| 	ProtocolVersion  byte
 | |
| 	CleanSession     bool
 | |
| 	WillFlag         bool
 | |
| 	WillQos          byte
 | |
| 	WillRetain       bool
 | |
| 	UsernameFlag     bool
 | |
| 	PasswordFlag     bool
 | |
| 	ReservedBit      byte
 | |
| 	Keepalive        uint16
 | |
| 	ClientIdentifier string
 | |
| 	WillTopic        string
 | |
| 	WillMessage      []byte // WillMessage is a payload, so store as byte array.
 | |
| 	Username         string
 | |
| 	Password         string
 | |
| }
 | |
| 
 | |
| // Encode encodes and writes the packet data values to the buffer.
 | |
| func (pk *ConnectPacket) Encode(w io.Writer) error {
 | |
| 
 | |
| 	var body bytes.Buffer
 | |
| 
 | |
| 	// Write flags to packet body.
 | |
| 	body.Write(encodeString(pk.ProtocolName))
 | |
| 	body.WriteByte(pk.ProtocolVersion)
 | |
| 	body.WriteByte(encodeBool(pk.CleanSession)<<1 | encodeBool(pk.WillFlag)<<2 | pk.WillQos<<3 | encodeBool(pk.WillRetain)<<5 | encodeBool(pk.PasswordFlag)<<6 | encodeBool(pk.UsernameFlag)<<7)
 | |
| 	body.Write(encodeUint16(pk.Keepalive))
 | |
| 	body.Write(encodeString(pk.ClientIdentifier))
 | |
| 
 | |
| 	// If will flag is set, add topic and message.
 | |
| 	if pk.WillFlag {
 | |
| 		body.Write(encodeString(pk.WillTopic))
 | |
| 		body.Write(encodeBytes(pk.WillMessage))
 | |
| 	}
 | |
| 
 | |
| 	// If username flag is set, add username.
 | |
| 	if pk.UsernameFlag {
 | |
| 		body.Write(encodeString(pk.Username))
 | |
| 	}
 | |
| 
 | |
| 	// If password flag is set, add password.
 | |
| 	if pk.PasswordFlag {
 | |
| 		body.Write(encodeString(pk.Password))
 | |
| 	}
 | |
| 
 | |
| 	// Set remaining length.
 | |
| 	pk.FixedHeader.Remaining = body.Len()
 | |
| 
 | |
| 	// Write header and packet to output.
 | |
| 	out := pk.FixedHeader.encode()
 | |
| 	out.Write(body.Bytes())
 | |
| 	_, err := out.WriteTo(w)
 | |
| 
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| // Decode extracts the data values from the packet.
 | |
| func (pk *ConnectPacket) Decode(buf []byte) error {
 | |
| 
 | |
| 	var offset int
 | |
| 	var err error
 | |
| 
 | |
| 	// Unpack protocol name and version.
 | |
| 	pk.ProtocolName, offset, err = decodeString(buf, 0)
 | |
| 	if err != nil {
 | |
| 		return errors.New(ErrMalformedProtocolName)
 | |
| 	}
 | |
| 
 | |
| 	pk.ProtocolVersion, offset, err = decodeByte(buf, offset)
 | |
| 	if err != nil {
 | |
| 		return errors.New(ErrMalformedProtocolVersion)
 | |
| 	}
 | |
| 	// Unpack flags byte.
 | |
| 	flags, offset, err := decodeByte(buf, offset)
 | |
| 	if err != nil {
 | |
| 		return errors.New(ErrMalformedFlags)
 | |
| 	}
 | |
| 	pk.ReservedBit = 1 & flags
 | |
| 	pk.CleanSession = 1&(flags>>1) > 0
 | |
| 	pk.WillFlag = 1&(flags>>2) > 0
 | |
| 	pk.WillQos = 3 & (flags >> 3) // this one is not a bool
 | |
| 	pk.WillRetain = 1&(flags>>5) > 0
 | |
| 	pk.PasswordFlag = 1&(flags>>6) > 0
 | |
| 	pk.UsernameFlag = 1&(flags>>7) > 0
 | |
| 
 | |
| 	// Get keepalive interval.
 | |
| 	pk.Keepalive, offset, err = decodeUint16(buf, offset)
 | |
| 	if err != nil {
 | |
| 		return errors.New(ErrMalformedKeepalive)
 | |
| 	}
 | |
| 
 | |
| 	// Get client ID.
 | |
| 	pk.ClientIdentifier, offset, err = decodeString(buf, offset)
 | |
| 	if err != nil {
 | |
| 		return errors.New(ErrMalformedClientID)
 | |
| 	}
 | |
| 
 | |
| 	// Get Last Will and Testament topic and message if applicable.
 | |
| 	if pk.WillFlag {
 | |
| 
 | |
| 		pk.WillTopic, offset, err = decodeString(buf, offset)
 | |
| 		if err != nil {
 | |
| 			return errors.New(ErrMalformedWillTopic)
 | |
| 		}
 | |
| 
 | |
| 		pk.WillMessage, offset, err = decodeBytes(buf, offset)
 | |
| 		if err != nil {
 | |
| 			return errors.New(ErrMalformedWillMessage)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Get username and password if applicable.
 | |
| 	if pk.UsernameFlag {
 | |
| 		pk.Username, offset, err = decodeString(buf, offset)
 | |
| 		if err != nil {
 | |
| 			return errors.New(ErrMalformedUsername)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if pk.PasswordFlag {
 | |
| 		pk.Password, offset, err = decodeString(buf, offset)
 | |
| 		if err != nil {
 | |
| 			return errors.New(ErrMalformedPassword)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| 
 | |
| }
 | |
| 
 | |
| // Validate ensures the packet is compliant.
 | |
| func (pk *ConnectPacket) Validate() (b byte, err error) {
 | |
| 
 | |
| 	// End if protocol name is bad.
 | |
| 	if pk.ProtocolName != "MQIsdp" && pk.ProtocolName != "MQTT" {
 | |
| 		return ErrConnectProtocolViolation, errors.New(ErrProtocolViolation)
 | |
| 	}
 | |
| 
 | |
| 	// End if protocol version is bad.
 | |
| 	if (pk.ProtocolName == "MQIsdp" && pk.ProtocolVersion != 3) ||
 | |
| 		(pk.ProtocolName == "MQTT" && pk.ProtocolVersion != 4) {
 | |
| 		return ErrConnectBadProtocolVersion, errors.New(ErrProtocolViolation)
 | |
| 	}
 | |
| 
 | |
| 	// End if reserved bit is not 0.
 | |
| 	if pk.ReservedBit != 0 {
 | |
| 		return ErrConnectProtocolViolation, errors.New(ErrProtocolViolation)
 | |
| 	}
 | |
| 
 | |
| 	// End if ClientID is too long.
 | |
| 	if len(pk.ClientIdentifier) > 65535 {
 | |
| 		return ErrConnectProtocolViolation, errors.New(ErrProtocolViolation)
 | |
| 	}
 | |
| 
 | |
| 	// End if password flag is set without a username.
 | |
| 	if pk.PasswordFlag && !pk.UsernameFlag {
 | |
| 		return ErrConnectProtocolViolation, errors.New(ErrProtocolViolation)
 | |
| 	}
 | |
| 
 | |
| 	// End if Username or Password is too long.
 | |
| 	if len(pk.Username) > 65535 || len(pk.Password) > 65535 {
 | |
| 		return ErrConnectProtocolViolation, errors.New(ErrProtocolViolation)
 | |
| 	}
 | |
| 
 | |
| 	// End if client id isn't set and clean session is false.
 | |
| 	if !pk.CleanSession && len(pk.ClientIdentifier) == 0 {
 | |
| 		return ErrConnectBadClientID, errors.New(ErrProtocolViolation)
 | |
| 	}
 | |
| 
 | |
| 	return Accepted, nil
 | |
| }
 | 
