翻出历史项目,先上传,未测试(MQTT3.1.1已经支持)

This commit is contained in:
Jason
2020-09-11 06:48:42 +08:00
parent ad870c9948
commit 19e8e303ec
29 changed files with 2787 additions and 0 deletions

1
.gitignore vendored
View File

@@ -13,3 +13,4 @@
# Dependency directories (remove the comment below to include it)
# vendor/
/.idea

289
bee.go Normal file
View File

@@ -0,0 +1,289 @@
package beeq
import (
"encoding/binary"
"github.com/zgwit/beeq/packet"
"log"
"net"
"time"
)
type Bee struct {
clientId string
session *Session
conn net.Conn
hive *Hive
quit chan struct{}
events chan *Event
timeout time.Duration
alive bool
}
func NewBee(conn net.Conn, hive *Hive) *Bee {
return &Bee{
conn: conn,
hive: hive,
quit: make(chan struct{}),
events: make(chan *Event, 1024),
timeout: time.Hour * 24,
}
}
func (bee *Bee) Event(event *Event) {
bee.events <- event
}
func (bee *Bee) Active() {
bee.alive = true
go bee.receiver()
go bee.messenger()
}
func (bee *Bee) Shutdown() {
bee.alive = false
// Tell hive delete me
bee.hive.Event(NewEvent(E_LOST_CONN, nil, bee))
// Stop messenger
close(bee.events)
close(bee.quit)
}
func (bee *Bee) recv(b []byte) (int, error) {
err := bee.conn.SetReadDeadline(time.Now().Add(bee.timeout))
if err != nil {
return 0, err
}
return bee.conn.Read(b)
}
func (bee *Bee) send(b []byte) (int, error) {
err := bee.conn.SetWriteDeadline(time.Now().Add(bee.timeout))
if err != nil {
return 0, err
}
return bee.conn.Write(b)
}
func (bee *Bee) receiver() {
//Abort error
defer func() {
if r := recover(); r!=nil {
log.Print("bee receiver panic ", r)
bee.Shutdown()
}
}()
readHead := true
buf := Alloc(6)
offset := 0
total := 0
for {
if l, err := bee.recv(buf[offset:]); err != nil {
log.Print("Receive Failed: ", err)
break
} else {
offset += l
}
//Parse head
if readHead && offset >= 2 {
for i := 1; i <= offset; i++ {
if buf[i] < 0x80 {
rl, rll := binary.Uvarint(buf[1:])
remainLength := int(rl)
total = remainLength + rll + 1
if total > 6 {
buf = ReAlloc(buf, total)
}
readHead = false
break
}
}
}
//Parse Message
if !readHead && offset >= total {
readHead = true
b := Alloc(6)
if msg, l, err := packet.Decode(buf); err != nil {
//TODO log err
log.Print(err)
offset = 0 //clear data
} else {
bee.Event(NewEvent(E_MESSAGE, msg, bee))
//Only message less than 6 bytes
if offset > l {
copy(b, buf[l:])
offset -= l
} else {
offset = 0 //clear data
}
}
buf = b
}
}
//Shutdown bee
if bee.alive {
bee.Shutdown()
}
}
func (bee *Bee) messenger() {
//Abort error
defer func() {
if r := recover(); r!=nil {
log.Print("bee messenger panic ", r)
bee.Shutdown()
}
}()
for {
//Blocking
var event *Event
select {
case <-bee.quit:
bee.conn.Close()
return
case event = <-bee.events:
switch event.event {
case E_MESSAGE:
bee.handleMessage(event.data.(packet.Message))
case E_DISPATCH:
bee.dispatchMessage(event.data.(packet.Message))
case E_CLOSE:
bee.Shutdown()
}
}
}
}
func (bee *Bee) handleMessage(msg packet.Message) {
log.Printf("Received message from %s: %s QOS(%d) DUP(%t) RETAIN(%t)", bee.clientId, msg.Type().Name(), msg.Qos(), msg.Dup(), msg.Retain())
//log.Print("recv Message:", msg.Type().Name())
switch msg.Type() {
case packet.CONNECT:
bee.handleConnect(msg.(*packet.Connect))
case packet.PUBLISH:
bee.handlePublish(msg.(*packet.Publish))
case packet.PUBACK:
bee.handlePubAck(msg.(*packet.PubAck))
case packet.PUBREC:
bee.handlePubRec(msg.(*packet.PubRec))
case packet.PUBREL:
bee.handlePubRel(msg.(*packet.PubRel))
case packet.PUBCOMP:
bee.handlePubComp(msg.(*packet.PubComp))
case packet.SUBSCRIBE:
bee.handleSubscribe(msg.(*packet.Subscribe))
case packet.UNSUBSCRIBE:
bee.handleUnSubscribe(msg.(*packet.UnSubscribe))
case packet.PINGREQ:
bee.handlePingReq(msg.(*packet.PingReq))
case packet.DISCONNECT:
bee.handleDisconnect(msg.(*packet.DisConnect))
}
}
func (bee *Bee) handleConnect(msg *packet.Connect) {
bee.clientId = string(msg.ClientId())
//if msg.WillFlag() {
// bee.session.will = new(packet.Publish)
//}
bee.hive.Event(NewEvent(E_CONNECT, msg, bee))
}
func (bee *Bee) handlePublish(msg *packet.Publish) {
qos := msg.Qos()
if qos == packet.Qos0 {
bee.hive.Event(NewEvent(E_PUBLISH, msg, bee))
} else if qos == packet.Qos1 {
bee.hive.Event(NewEvent(E_PUBLISH, msg, bee))
//Reply PUBACK
puback := packet.PUBACK.NewMessage().(*packet.PubAck)
puback.SetPacketId(msg.PacketId())
bee.Event(NewEvent(E_DISPATCH, puback, bee))
} else if qos == packet.Qos2 {
//Save & Send PUBREC
bee.session.recvPub2[msg.PacketId()] = msg
pubrec := packet.PUBREC.NewMessage().(*packet.PubRec)
pubrec.SetPacketId(msg.PacketId())
bee.Event(NewEvent(E_DISPATCH, pubrec, bee))
} else {
//error
}
}
func (bee *Bee) handlePubAck(msg *packet.PubAck) {
if _, ok := bee.session.pub1[msg.PacketId()]; ok {
delete(bee.session.pub1, msg.PacketId())
}
}
func (bee *Bee) handlePubRec(msg *packet.PubRec) {
msg.SetType(packet.PUBREL)
bee.Event(NewEvent(E_DISPATCH, msg, bee))
}
func (bee *Bee) handlePubRel(msg *packet.PubRel) {
msg.SetType(packet.PUBCOMP)
bee.Event(NewEvent(E_DISPATCH, msg, bee))
}
func (bee *Bee) handlePubComp(msg *packet.PubComp) {
if _, ok := bee.session.pub2[msg.PacketId()]; ok {
delete(bee.session.pub2, msg.PacketId())
}
}
func (bee *Bee) handleSubscribe(msg *packet.Subscribe) {
bee.hive.Event(NewEvent(E_SUBSCRIBE, msg, bee))
}
func (bee *Bee) handleUnSubscribe(msg *packet.UnSubscribe) {
bee.hive.Event(NewEvent(E_UNSUBSCRIBE, msg, bee))
}
func (bee *Bee) handlePingReq(msg *packet.PingReq) {
msg.SetType(packet.PINGRESP)
bee.Event(NewEvent(E_DISPATCH, msg, bee))
}
func (bee *Bee) handleDisconnect(msg *packet.DisConnect) {
bee.hive.Event(NewEvent(E_DISCONNECT, msg, bee))
}
func (bee *Bee) dispatchMessage(msg packet.Message) {
log.Printf("Send message to %s: %s QOS(%d) DUP(%t) RETAIN(%t)", bee.clientId, msg.Type().Name(), msg.Qos(), msg.Dup(), msg.Retain())
if head, payload, err := msg.Encode(); err != nil {
//TODO log
log.Print("Message encode error: ", err)
} else {
bee.send(head)
if payload != nil && len(payload) > 0 {
bee.send(payload)
}
}
if msg.Type() == packet.PUBLISH {
pub := msg.(*packet.Publish)
//Publish Qos1 Qos2 Need store
if msg.Qos() == packet.Qos1 {
bee.session.pub1[pub.PacketId()] = pub
} else if msg.Qos() == packet.Qos2 {
bee.session.pub2[pub.PacketId()] = pub
}
}
}

11
buffer.go Normal file
View File

@@ -0,0 +1,11 @@
package beeq
func Alloc(l int) []byte {
return make([]byte, l)
}
func ReAlloc(buf []byte, l int) []byte {
b := make([]byte, l)
copy(b, buf)
return b
}

0
cmd/main.go Normal file
View File

3
go.mod Normal file
View File

@@ -0,0 +1,3 @@
module github.com/zgwit/beeq
go 1.13

239
hive.go Normal file
View File

@@ -0,0 +1,239 @@
package beeq
import (
"log"
"github.com/zgwit/beeq/packet"
"time"
)
type Hive struct {
//Subscribe tree
subTree *SubTree
//Retain tree
retainTree *RetainTree
//ClientId->Session
sessions map[string]*Session
//ClientId->Bee
//bees map[string]*Bee
//Message received channel. Waiting for handling
events chan *Event
quit chan struct{}
}
func NewHive() *Hive {
return &Hive{
subTree: NewSubTree(),
retainTree: NewRetainTree(),
sessions: make(map[string]*Session),
events: make(chan *Event, 100),
quit: make(chan struct{}),
}
}
func (hive *Hive) messenger() {
//Abort error
defer func() {
if r := recover(); r!=nil {
log.Print("hive messenger panic ", r)
//Recovery main routine
hive.Active()
}
}()
for {
select {
case <-hive.quit:
break
case event := <-hive.events:
switch event.event {
case E_CLOSE:
hive.Shutdown()
case E_LOST_CONN:
hive.handleLostConn(event.from.(*Bee))
case E_CONNECT:
hive.handleConnect(event.data.(*packet.Connect), event.from.(*Bee))
case E_PUBLISH:
hive.handlePublish(event.data.(*packet.Publish), event.from.(*Bee))
case E_SUBSCRIBE:
hive.handleSubscribe(event.data.(*packet.Subscribe), event.from.(*Bee))
case E_UNSUBSCRIBE:
hive.handleUnSubscribe(event.data.(*packet.UnSubscribe), event.from.(*Bee))
case E_DISCONNECT:
hive.handleDisconnect(event.data.(*packet.DisConnect), event.from.(*Bee))
}
}
}
}
func (hive *Hive) Active() {
//Single go Routine, no lock
//Only one processing all message. Performance?
//TODO Benchmark
go hive.messenger()
}
func (hive *Hive) Shutdown() {
close(hive.quit)
}
func (hive *Hive) Event(event *Event) {
//Blocking
hive.events <- event
}
func (hive *Hive) handleLostConn(bee *Bee) {
log.Print("lost ", bee.clientId)
if session, ok := hive.sessions[bee.clientId]; ok {
session.DeActive()
}
}
func (hive *Hive) handleConnect(msg *packet.Connect, bee *Bee) {
connack := packet.CONNACK.NewMessage().(*packet.Connack)
var clientId string
if len(msg.ClientId()) == 0 {
if !msg.CleanSession() {
//error
bee.Event(NewEvent(E_CLOSE, nil, hive))
return
}
// Generate unique clientId (uuid random)
for {
clientId = "xxx"
if _, ok := hive.sessions[clientId]; !ok {
break
}
}
} else {
clientId = string(msg.ClientId())
if session, ok := hive.sessions[clientId]; ok {
// ClientId is already used
if session.Alive() {
//error reject
connack.SetCode(packet.CONNACK_UNAVAILABLE)
bee.Event(NewEvent(E_DISPATCH, connack, hive))
return
} else {
if msg.CleanSession() {
delete(hive.sessions, clientId)
} else {
session.Active(bee)
}
}
}
}
log.Print(clientId, " Connected")
// Generate session
if _, ok := hive.sessions[clientId]; !ok {
session := NewSession()
hive.sessions[clientId] = session
} else {
connack.SetSessionPresent(true)
}
hive.sessions[clientId].bee = bee
hive.sessions[clientId].clientId = clientId
bee.clientId = clientId
bee.session = hive.sessions[clientId]
if msg.KeepAlive() > 0 {
bee.timeout = time.Second * time.Duration(msg.KeepAlive()) * 3 / 2
}
connack.SetCode(packet.CONNACK_ACCEPTED)
bee.Event(NewEvent(E_DISPATCH, connack, hive))
}
func (hive *Hive) handlePublish(msg *packet.Publish, bee *Bee) {
if err := ValidTopic(msg.Topic()); err != nil {
//TODO log
log.Print("Topic invalid ", err)
return
}
if msg.Retain() {
if len(msg.Payload()) == 0 {
hive.retainTree.UnRetain(bee.clientId)
} else {
hive.retainTree.Retain(msg.Topic(), bee.clientId, msg)
}
}
//Fetch subscribers
subs := make(map[string]packet.MsgQos)
hive.subTree.Publish(msg.Topic(), subs)
//Send publish message
for clientId, qos := range subs {
if session, ok := hive.sessions[clientId]; ok && session.alive {
bee := session.bee
//clone new pub
pub := *msg
pub.SetRetain(false)
if msg.Qos() <= qos {
bee.Event(NewEvent(E_DISPATCH, &pub, hive))
} else {
pub.SetQos(qos)
bee.Event(NewEvent(E_DISPATCH, &pub, hive))
}
}
}
}
func (hive *Hive) handleSubscribe(msg *packet.Subscribe, bee *Bee) {
suback := packet.SUBACK.NewMessage().(*packet.SubAck)
suback.SetPacketId(msg.PacketId())
for _, st := range msg.Topics() {
log.Print("Subscribe ", string(st.Topic()))
if err := ValidSubscribe(st.Topic()); err != nil {
log.Print("Invalid topic ", err)
//log error
suback.AddCode(packet.SUB_CODE_ERR)
} else {
hive.subTree.Subscribe(st.Topic(), bee.clientId, st.Qos())
suback.AddCode(packet.SubCode(st.Qos()))
hive.retainTree.Fetch(st.Topic(), func(clientId string, pub *packet.Publish) {
//clone new pub
p := *pub
p.SetRetain(true)
if msg.Qos() <= st.Qos() {
bee.Event(NewEvent(E_DISPATCH, &p, hive))
} else {
p.SetQos(st.Qos())
bee.Event(NewEvent(E_DISPATCH, &p, hive))
}
})
}
}
bee.Event(NewEvent(E_DISPATCH, suback, hive))
}
func (hive *Hive) handleUnSubscribe(msg *packet.UnSubscribe, bee *Bee) {
unsuback := packet.UNSUBACK.NewMessage().(*packet.UnSubAck)
for _, t := range msg.Topics() {
log.Print("UnSubscribe ", string(t))
if err := ValidSubscribe(t); err != nil {
//TODO log
} else {
hive.subTree.UnSubscribe(t, bee.clientId)
}
}
bee.Event(NewEvent(E_DISPATCH, unsuback, hive))
}
func (hive *Hive) handleDisconnect(msg *packet.DisConnect, bee *Bee) {
bee.Event(NewEvent(E_CLOSE, nil, hive))
}

127
packet/connack.go Normal file
View File

@@ -0,0 +1,127 @@
package packet
import (
"fmt"
)
type ConnackCode byte
const (
CONNACK_ACCEPTED ConnackCode = iota
CONNACK_ERROR_VERSION
CONNACK_INVALID_CLIENT_ID
CONNACK_UNAVAILABLE
CONNACK_INVALID_USERNAME_PASSWORD
CONNACK_NOT_AUTHORIZED
)
type Connack struct {
Header
confirm byte
code byte
}
func (msg *Connack) SessionPresent() bool {
return msg.confirm&0x01 == 0x01 //0000 0001
}
func (msg *Connack) SetSessionPresent(b bool) {
msg.dirty = true
if b {
msg.confirm |= 0x01 //0000 0001
} else {
msg.confirm &= 0xFE //1111 1110
}
}
func (msg *Connack) Code() ConnackCode {
return ConnackCode(msg.code)
}
func (msg *Connack) SetCode(c ConnackCode) {
msg.dirty = true
msg.code = byte(c)
}
func (msg *Connack) Decode(buf []byte) (int, error) {
msg.dirty = false
//Tips. remain length is fixed 2 & total is fixed 4
total := len(buf)
if total < 4 {
return 0, fmt.Errorf("Connack expect fixed 4 bytes (%d)", total)
}
offset := 0
//Header
msg.header = buf[0]
offset++
//Remain Length
if l, n, err := ReadRemainLength(buf[offset:]); err != nil {
return offset, err
} else if l != 2 {
return 0, fmt.Errorf("Remain length must be 2, got %d", l)
} else {
msg.remainLength = l
offset += n
}
//1 Confirm
msg.confirm = buf[offset]
offset++
//2 Code
msg.code = buf[offset]
offset++
// FixHead & VarHead
msg.head = buf[0:offset]
return offset, nil
}
func (msg *Connack) Encode() ([]byte, []byte, error) {
if !msg.dirty {
return msg.head, nil, nil
}
//Tips. remain length is fixed 2 & total is fixed 4
//Remain Length
msg.remainLength = 0
//Confirm
msg.remainLength += 1
//Code
msg.remainLength += 1
//FixHead & VarHead
hl := msg.remainLength
hl += 1 + LenLen(msg.remainLength)
//Alloc buffer
msg.head = make([]byte, hl)
//Header
ho := 0
msg.head[ho] = msg.header
ho++
//Remain Length
if n, err := WriteRemainLength(msg.head[ho:], msg.remainLength); err != nil {
return nil, nil, err
} else {
ho += n
}
//1 Confirm
msg.head[ho] = msg.confirm
ho++
//2 Code
msg.head[ho] = msg.code
ho++
return msg.head, nil, nil
}

441
packet/connect.go Normal file
View File

@@ -0,0 +1,441 @@
package packet
import (
"encoding/binary"
"fmt"
"regexp"
)
var clientIdRegex *regexp.Regexp
func init() {
clientIdRegex = regexp.MustCompile("^[0-9a-zA-Z _]*$")
}
var SupportedVersions map[byte]string = map[byte]string{
0x3: "MQIsdp",
0x4: "MQTT",
}
type Connect struct {
Header
protoName []byte
protoLevel byte
flag byte
keepAlive uint16
clientId []byte
willTopic []byte
willMessage []byte
userName []byte
password []byte
}
func (msg *Connect) ProtoName() []byte {
return msg.protoName
}
func (msg *Connect) SetProtoName(b []byte) {
msg.dirty = true
msg.protoName = b
}
func (msg *Connect) ProtoLevel() byte {
return msg.protoLevel
}
func (msg *Connect) SetProtoLevel(b byte) {
msg.dirty = true
msg.protoLevel = b
msg.protoName = []byte(SupportedVersions[b])
}
func (msg *Connect) UserNameFlag() bool {
return msg.flag&0x80 == 0x80 //1000 0000
}
func (msg *Connect) SetUserNameFlag(b bool) {
msg.dirty = true
if b {
msg.flag |= 0x80 //1000 0000
} else {
msg.flag &= 0x7F //0111 1111
}
}
func (msg *Connect) PasswordFlag() bool {
return msg.flag&0x40 == 0x40 //0100 0000
}
func (msg *Connect) SetPasswordFlag(b bool) {
msg.dirty = true
if b {
msg.flag |= 0x40 //0100 0000
} else {
msg.flag &= 0xBF //1011 1111
}
}
func (msg *Connect) WillRetain() bool {
return msg.flag&0x20 == 0x40 //0010 0000
}
func (msg *Connect) SetWillRetain(b bool) {
msg.dirty = true
if b {
msg.flag |= 0x20 //0010 0000
} else {
msg.flag &= 0xDF //1101 1111
}
msg.SetWillFlag(true)
}
func (msg *Connect) WillQos() MsgQos {
return MsgQos((msg.flag & 0x18) >> 2) // 0001 1000
}
func (msg *Connect) SetWillQos(qos MsgQos) {
msg.dirty = true
msg.flag &= 0xE7 // 1110 0111
msg.flag |= byte(qos << 2)
msg.SetWillFlag(true)
}
func (msg *Connect) WillFlag() bool {
return msg.flag&0x04 == 0x04 //0000 0100
}
func (msg *Connect) SetWillFlag(b bool) {
msg.dirty = true
if b {
msg.flag |= 0x04 //0000 0100
} else {
msg.flag &= 0xFB //1111 1011
msg.SetWillQos(Qos0)
msg.SetWillRetain(false)
}
}
func (msg *Connect) KeepAlive() uint16 {
return msg.keepAlive
}
func (msg *Connect) SetKeepAlive(k uint16) {
msg.dirty = true
msg.keepAlive = k
}
func (msg *Connect) ClientId() []byte {
return msg.clientId
}
func (msg *Connect) SetClientId(b []byte) {
msg.dirty = true
msg.clientId = b
//msg.ValidClientId()
}
func (msg *Connect) WillTopic() []byte {
return msg.willTopic
}
func (msg *Connect) SetWillTopic(b []byte) {
msg.dirty = true
msg.willTopic = b
msg.SetWillFlag(true)
}
func (msg *Connect) WillMessage() []byte {
return msg.willMessage
}
func (msg *Connect) SetWillMessage(b []byte) {
msg.dirty = true
msg.willMessage = b
msg.SetWillFlag(true)
}
func (msg *Connect) UserName() []byte {
return msg.userName
}
func (msg *Connect) SetUserName(b []byte) {
msg.dirty = true
msg.userName = b
msg.SetUserNameFlag(true)
}
func (msg *Connect) Password() []byte {
return msg.password
}
func (msg *Connect) SetPassword(b []byte) {
msg.dirty = true
msg.password = b
msg.SetPasswordFlag(true)
}
func (msg *Connect) CleanSession() bool {
return msg.flag&0x02 == 0x02 //0000 0010
}
func (msg *Connect) SetCleanSession(b bool) {
msg.dirty = true
if b {
msg.flag |= 0x02 //0000 0010
} else {
msg.flag &= 0xFD //1111 1101
}
}
func (msg *Connect) ValidClientId() bool {
if msg.ProtoLevel() == 0x3 {
return true
}
return clientIdRegex.Match(msg.clientId)
}
func (msg *Connect) Decode(buf []byte) (int, error) {
msg.dirty = false
//total := len(buf)
offset := 0
//Header
msg.header = buf[0]
offset++
//Remain Length
if l, n, err := ReadRemainLength(buf[offset:]); err != nil {
return offset, err
} else {
msg.remainLength = l
offset += n
}
//1 Protocol Name
if b, n, err := ReadBytes(buf[offset:]); err != nil {
return offset, err
} else {
msg.protoName = b
offset += n
}
//2 Protocol Level
msg.protoLevel = buf[offset]
if version, ok := SupportedVersions[msg.ProtoLevel()]; !ok {
return offset, fmt.Errorf("Protocol level (%d) is not support", msg.ProtoLevel())
} else if ver := string(msg.ProtoName()); ver != version {
return offset, fmt.Errorf("Protocol name (%s) invalid", ver)
}
offset++
//3 Connect flag
msg.flag = buf[offset]
offset++
if msg.flag&0x1 != 0 {
return offset, fmt.Errorf("Connect Flags (%x) reserved bit 0", msg.flag)
}
if msg.WillQos() > Qos2 {
return offset, fmt.Errorf("Invalid WillQoS (%d)", msg.WillQos())
}
if !msg.WillFlag() && (msg.WillRetain() || msg.WillQos() != Qos0) {
return offset, fmt.Errorf("Invalid WillFlag (%x)", msg.flag)
}
if msg.UserNameFlag() != msg.PasswordFlag() {
return offset, fmt.Errorf("UserName Password must be both exists or not (%x)", msg.flag)
}
//4 Keep Alive
msg.keepAlive = binary.BigEndian.Uint16(buf[offset:])
offset += 2
// FixHead & VarHead
msg.head = buf[0:offset]
plo := offset
//5 ClientId
if b, n, err := ReadBytes(buf[offset:]); err != nil {
return offset, err
} else {
msg.clientId = b
offset += n
// None ClientId, Must Clean Session
if n == 2 && !msg.CleanSession() {
return offset, fmt.Errorf("None ClientId, Must Clean Session (%x)", msg.flag)
}
// ClientId at most 23 characters
if n > 128+2 {
return offset, fmt.Errorf("Too long ClientId (%s)", string(msg.ClientId()))
}
// ClientId 0-9, a-z, A-Z
if n > 0 && !msg.ValidClientId() {
return offset, fmt.Errorf("Invalid ClientId (%s)", string(msg.ClientId()))
}
}
//6 Will Topic & Message
if msg.WillFlag() {
if b, n, err := ReadBytes(buf[offset:]); err != nil {
return offset, err
} else {
msg.willTopic = b
offset += n
}
if b, n, err := ReadBytes(buf[offset:]); err != nil {
return offset, err
} else {
msg.willMessage = b
offset += n
}
}
//7 UserName & Password
if msg.UserNameFlag() {
if b, n, err := ReadBytes(buf[offset:]); err != nil {
return offset, err
} else {
msg.userName = b
offset += n
}
if b, n, err := ReadBytes(buf[offset:]); err != nil {
return offset, err
} else {
msg.password = b
offset += n
}
}
//Payload
msg.payload = buf[plo:offset]
return offset, nil
}
func (msg *Connect) Encode() ([]byte, []byte, error) {
if !msg.dirty {
return msg.head, msg.payload, nil
}
//Remain Length
msg.remainLength = 0
//Protocol Name
msg.remainLength += 2 + len(msg.protoName)
//Protocol Level
msg.remainLength += 1
//Connect Flags
msg.remainLength += 1
//Keep Alive
msg.remainLength += 1
//FixHead & VarHead
hl := msg.remainLength
//ClientId
msg.remainLength += 2 + len(msg.clientId)
//Will Topic & Message
if msg.WillFlag() {
msg.remainLength += 2 + len(msg.willTopic)
msg.remainLength += 2 + len(msg.willMessage)
}
//UserName & Password
if msg.UserNameFlag() {
msg.remainLength += 2 + len(msg.userName)
msg.remainLength += 2 + len(msg.password)
}
pl := msg.remainLength - hl
hl += 1 + LenLen(msg.remainLength)
//Alloc buffer
msg.head = make([]byte, hl)
msg.payload = make([]byte, pl)
//Header
ho := 0
msg.head[ho] = msg.header
ho++
//Remain Length
if n, err := WriteRemainLength(msg.head[ho:], msg.remainLength); err != nil {
return nil, nil, err
} else {
ho += n
}
//1 Protocol Name
if n, err := WriteBytes(msg.head[ho:], msg.protoName); err != nil {
return nil, nil, err
} else {
ho += n
}
//2 Protocol Level
msg.head[ho] = msg.protoLevel
ho++
//3 Connect Flags
msg.head[ho] = msg.flag
ho++
//4 Keep Alive
binary.BigEndian.PutUint16(msg.head[ho:], msg.keepAlive)
ho += 2
plo := 0
//5 ClientId
if n, err := WriteBytes(msg.payload[plo:], msg.clientId); err != nil {
return msg.head, nil, err
} else {
plo += n
}
//6 Will Topic & Message
if msg.WillFlag() {
if n, err := WriteBytes(msg.payload[plo:], msg.willTopic); err != nil {
return msg.head, nil, err
} else {
plo += n
}
if n, err := WriteBytes(msg.payload[plo:], msg.willMessage); err != nil {
return msg.head, nil, err
} else {
plo += n
}
}
//7 UserName & Password
if msg.UserNameFlag() {
if n, err := WriteBytes(msg.payload[plo:], msg.userName); err != nil {
return msg.head, nil, err
} else {
plo += n
}
if n, err := WriteBytes(msg.payload[plo:], msg.password); err != nil {
return msg.head, nil, err
} else {
plo += n
}
}
return msg.head, msg.payload, nil
}

71
packet/disconnect.go Normal file
View File

@@ -0,0 +1,71 @@
package packet
import (
"fmt"
)
type DisConnect struct {
Header
}
func (msg *DisConnect) Decode(buf []byte) (int, error) {
msg.dirty = false
//Tips. remain length is fixed 0 & total is fixed 2
total := len(buf)
if total < 2 {
return 0, fmt.Errorf("DisConnect expect fixed 2 bytes, got %d", total)
}
offset := 0
//Header
msg.header = buf[0]
offset++
//Remain Length
if l, n, err := ReadRemainLength(buf[offset:]); err != nil {
return offset, err
} else if l != 0 {
return 0, fmt.Errorf("Remain length must be 0, got %d", l)
} else {
msg.remainLength = l
offset += n
}
// FixHead & VarHead
msg.head = buf[0:offset]
return offset, nil
}
func (msg *DisConnect) Encode() ([]byte, []byte, error) {
if !msg.dirty {
return msg.head, nil, nil
}
//Tips. remain length is fixed 0 & total is fixed 2
//Remain Length
msg.remainLength = 0
//FixHead & VarHead
hl := msg.remainLength
hl += 1 + LenLen(msg.remainLength)
//Alloc buffer
msg.head = make([]byte, hl)
//Header
ho := 0
msg.head[ho] = msg.header
ho++
//Remain Length
if n, err := WriteRemainLength(msg.head[ho:], msg.remainLength); err != nil {
return nil, nil, err
} else {
ho += n
}
return msg.head, nil, nil
}

61
packet/header.go Normal file
View File

@@ -0,0 +1,61 @@
package packet
type Header struct {
header byte
remainLength int
dirty bool
head []byte
payload []byte
}
func (hdr *Header) Type() MsgType {
return MsgType((hdr.header & 0xF0) >> 4)
}
func (hdr *Header) SetType(t MsgType) {
hdr.dirty = true
hdr.header &= 0x0F // 0000 1111
hdr.header |= byte(t << 4)
}
func (hdr *Header) Dup() bool {
return hdr.header&0x08 == 0x08 //0000 1000
}
func (hdr *Header) SetDup(b bool) {
hdr.dirty = true
if b {
hdr.header |= 0x08 //0000 1000
} else {
hdr.header &= 0xF7 //1111 0111
}
}
func (hdr *Header) Qos() MsgQos {
return MsgQos((hdr.header & 0x06) >> 1) //0000 0110
}
func (hdr *Header) SetQos(qos MsgQos) {
hdr.dirty = true
hdr.header &= 0xF9 //1111 1001
hdr.header |= byte(qos << 1) //0000 0110
}
func (hdr *Header) Retain() bool {
return hdr.header&0x01 == 0x01
}
func (hdr *Header) SetRetain(b bool) {
hdr.dirty = true
if b {
hdr.header |= 0x01 //0000 0001
} else {
hdr.header &= 0xFE //1111 1110
}
}
func (hdr *Header) RemainLength() int {
return hdr.remainLength
}

112
packet/message.go Normal file
View File

@@ -0,0 +1,112 @@
package packet
type MsgType byte
const (
RESERVED MsgType = iota
CONNECT
CONNACK
PUBLISH
PUBACK
PUBREC
PUBREL
PUBCOMP
SUBSCRIBE
SUBACK
UNSUBSCRIBE
UNSUBACK
PINGREQ
PINGRESP
DISCONNECT
RESERVED2
)
var msgNames = []string{
"RESERVED", "CONNECT", "CONNACK", "PUBLISH",
"PUBACK", "PUBREC", "PUBREL", "PUBCOMP",
"SUBSCRIBE", "SUBACK", "UNSUBSCRIBE", "UNSUBACK",
"PINGREQ", "PINGRESP", "DISCONNECT", "RESERVED",
}
func (mt MsgType) Name() string {
return msgNames[mt&0x0F]
}
func (mt MsgType) NewMessage() Message {
var msg Message
switch mt {
case CONNECT:
msg = new(Connect)
case CONNACK:
msg = new(Connack)
case PUBLISH:
msg = new(Publish)
case PUBACK:
msg = new(PubAck)
case PUBREC:
msg = new(PubRec)
case PUBREL:
msg = new(PubRel)
case PUBCOMP:
msg = new(PubComp)
case SUBSCRIBE:
msg = new(Subscribe)
case SUBACK:
msg = new(SubAck)
case UNSUBSCRIBE:
msg = new(UnSubscribe)
case UNSUBACK:
msg = new(UnSubAck)
case PINGREQ:
msg = new(PingReq)
case PINGRESP:
msg = new(PingResp)
case DISCONNECT:
msg = new(DisConnect)
default:
//error
return nil
}
msg.SetType(mt)
return msg
}
type MsgQos byte
var qosNames = []string{
"AtMostOnce", "AtLastOnce", "ExactlyOnce", "QosError",
}
func (qos MsgQos) Name() string {
// 0000 0011
return qosNames[qos&0x03]
}
func (qos MsgQos) Level() uint8 {
return uint8(qos & 0x03)
}
const (
//At most once
Qos0 MsgQos = iota
//At least once
Qos1
//Exactly once
Qos2
)
type Message interface {
Type() MsgType
SetType(t MsgType)
Dup() bool
SetDup(b bool)
Qos() MsgQos
SetQos(qos MsgQos)
Retain() bool
SetRetain(b bool)
RemainLength() int
Decode([]byte) (int, error)
Encode() ([]byte, []byte, error)
}

102
packet/packet.go Normal file
View File

@@ -0,0 +1,102 @@
package packet
import (
"encoding/binary"
"fmt"
)
func LenLen(rl int) int {
if rl <= 127 { //0x7F
return 1
} else if rl <= 16383 { //0x7F 7F
return 2
} else if rl <= 2097151 { //0x7F 7F 7F
return 3
} else {
return 4
}
}
func ReadRemainLength(b []byte) (int, int, error) {
length := len(b)
size := 1
for {
if length < size {
return 0, size, fmt.Errorf("[ReadRemainLength] Expect at leat %d bytes", 1)
}
if b[size-1] > 0x80 {
size += 1
} else {
break
}
if size > 4 {
return 0, size, fmt.Errorf("[ReadRemainLength] Expect at most 4 bytes, got %d", size)
}
}
rl, size := binary.Uvarint(b)
return int(rl), size, nil
}
func WriteRemainLength(b []byte, rl int) (int, error) {
length := len(b)
ll := LenLen(rl)
if ll > length {
return 0, fmt.Errorf("[ReadRemainLength] Expect at most %d bytes for remain length", ll)
}
return binary.PutUvarint(b, uint64(rl)), nil
}
func ReadBytes(buf []byte) ([]byte, int, error) {
if len(buf) < 2 {
return nil, 0, fmt.Errorf("[readLPBytes] Expect at least %d bytes for prefix", 2)
}
length := int(binary.BigEndian.Uint16(buf))
total := length + 2
if len(buf) < total {
return nil, 0, fmt.Errorf("[readLPBytes] Expect at least %d bytes", length+2)
}
b := buf[2 : total]
return b, total, nil
}
func WriteBytes(buf []byte, b []byte) (int, error) {
length, size := len(b), len(buf)
if length > 65535 {
return 0, fmt.Errorf("[writeLPBytes] Too much bytes(%d) to write", length)
}
total := length + 2
if size < total {
return 0, fmt.Errorf("[writeLPBytes] Expect at least %d bytes", total)
}
binary.BigEndian.PutUint16(buf, uint16(length))
copy(buf[2:], b)
return total, nil
}
func BytesDup(buf []byte) []byte {
b := make([]byte, len(buf))
copy(b, buf)
return b
}
func Decode(buf []byte) (Message, int, error) {
mt := MsgType(buf[0] >> 4)
msg := mt.NewMessage()
if msg != nil {
l, err := msg.Decode(buf)
return msg, l, err
} else {
return nil, 0, fmt.Errorf("Unknown messege type")
}
}
func Encode(msg Message) ([]byte, []byte, error) {
return msg.Encode()
}

71
packet/pingreq.go Normal file
View File

@@ -0,0 +1,71 @@
package packet
import (
"fmt"
)
type PingReq struct {
Header
}
func (msg *PingReq) Decode(buf []byte) (int, error) {
msg.dirty = false
//Tips. remain length is fixed 0 & total is fixed 2
total := len(buf)
if total < 2 {
return 0, fmt.Errorf("Ping expect fixed 2 bytes, got %d", total)
}
offset := 0
//Header
msg.header = buf[0]
offset++
//Remain Length
if l, n, err := ReadRemainLength(buf[offset:]); err != nil {
return offset, err
} else if l != 0 {
return 0, fmt.Errorf("Remain length must be 0, got %d", l)
} else {
msg.remainLength = l
offset += n
}
// FixHead & VarHead
msg.head = buf[0:offset]
return offset, nil
}
func (msg *PingReq) Encode() ([]byte, []byte, error) {
if !msg.dirty {
return msg.head, nil, nil
}
//Tips. remain length is fixed 0 & total is fixed 2
//Remain Length
msg.remainLength = 0
//FixHead & VarHead
hl := msg.remainLength
hl += 1 + LenLen(msg.remainLength)
//Alloc buffer
msg.head = make([]byte, hl)
//Header
ho := 0
msg.head[ho] = msg.header
ho++
//Remain Length
if n, err := WriteRemainLength(msg.head[ho:], msg.remainLength); err != nil {
return nil, nil, err
} else {
ho += n
}
return msg.head, nil, nil
}

5
packet/pingresp.go Normal file
View File

@@ -0,0 +1,5 @@
package packet
type PingResp struct {
PingReq
}

93
packet/puback.go Normal file
View File

@@ -0,0 +1,93 @@
package packet
import (
"encoding/binary"
"fmt"
)
type PubAck struct {
Header
packetId uint16
}
func (msg *PubAck) PacketId() uint16 {
return msg.packetId
}
func (msg *PubAck) SetPacketId(p uint16) {
msg.dirty = true
msg.packetId = p
}
func (msg *PubAck) Decode(buf []byte) (int, error) {
msg.dirty = false
//Tips. remain length is fixed 2 & total is fixed 4
total := len(buf)
if total < 4 {
return 0, fmt.Errorf("Connack expect fixed 4 bytes (%d)", total)
}
offset := 0
//Header
msg.header = buf[0]
offset++
//Remain Length
if l, n, err := ReadRemainLength(buf[offset:]); err != nil {
return offset, err
} else if l != 2 {
return 0, fmt.Errorf("Remain length must be 2, got %d", l)
} else {
msg.remainLength = l
offset += n
}
// PacketId
msg.packetId = binary.BigEndian.Uint16(buf[offset:])
offset += 2
// FixHead & VarHead
msg.head = buf[0:offset]
return offset, nil
}
func (msg *PubAck) Encode() ([]byte, []byte, error) {
if !msg.dirty {
return msg.head, nil, nil
}
//Tips. remain length is fixed 2 & total is fixed 4
//Remain Length
msg.remainLength = 0
//Packet Id
msg.remainLength += 2
//FixHead & VarHead
hl := msg.remainLength
hl += 1 + LenLen(msg.remainLength)
//Alloc buffer
msg.head = make([]byte, hl)
//Header
ho := 0
msg.head[ho] = msg.header
ho++
//Remain Length
if n, err := WriteRemainLength(msg.head[ho:], msg.remainLength); err != nil {
return nil, nil, err
} else {
ho += n
}
//Packet Id
binary.BigEndian.PutUint16(msg.head[ho:], msg.packetId)
ho += 2
return msg.head, nil, nil
}

5
packet/pubcomp.go Normal file
View File

@@ -0,0 +1,5 @@
package packet
type PubComp struct {
PubAck
}

152
packet/publish.go Normal file
View File

@@ -0,0 +1,152 @@
package packet
import (
"encoding/binary"
"fmt"
)
type Publish struct {
Header
topic []byte
packetId uint16
//payload []byte
}
func (msg *Publish) Topic() []byte {
return msg.topic
}
func (msg *Publish) SetTopic(b []byte) {
msg.dirty = true
msg.topic = b
}
func (msg *Publish) PacketId() uint16 {
return msg.packetId
}
func (msg *Publish) SetPacketId(p uint16) {
msg.dirty = true
msg.packetId = p
}
func (msg *Publish) Payload() []byte {
return msg.payload
}
func (msg *Publish) SetPayload(p []byte) {
msg.dirty = true
msg.payload = p
}
func (msg *Publish) Decode(buf []byte) (int, error) {
msg.dirty = false
total := len(buf)
offset := 0
//Header
msg.header = buf[0]
offset++
//Remain Length
if l, n, err := ReadRemainLength(buf[offset:]); err != nil {
return offset, err
} else {
msg.remainLength = l
offset += n
}
headerLen := offset
if total < msg.remainLength+headerLen {
fmt.Errorf("Payload is not enough expect %d, got %d", msg.remainLength+headerLen, total)
}
//1 Topic
if b, n, err := ReadBytes(buf[offset:]); err != nil {
return offset, err
} else {
msg.topic = b
offset += n
}
//2 PacketId //Only Qos1 Qos2 has packet id
if msg.Qos() > Qos0 {
msg.packetId = binary.BigEndian.Uint16(buf[offset:])
offset += 2
}
// FixHead & VarHead
msg.head = buf[0:offset]
//plo := offset
//3 Payload
l := msg.remainLength + headerLen
b := buf[offset:l]
msg.payload = b
offset += len(b)
return offset, nil
}
func (msg *Publish) Encode() ([]byte, []byte, error) {
if !msg.dirty {
return msg.head, msg.payload, nil
}
//Remain Length
msg.remainLength = 0
//Topic
msg.remainLength += 2 + len(msg.topic)
//PacketId
if msg.Qos() > Qos0 {
msg.remainLength += 2
}
//FixHead & VarHead
hl := msg.remainLength
//Payload
msg.remainLength += len(msg.payload)
//pl := msg.remainLength - hl
hl += 1 + LenLen(msg.remainLength)
//Alloc buffer
msg.head = make([]byte, hl)
//msg.payload = make([]byte, pl)
//Header
ho := 0
msg.head[ho] = msg.header
ho++
//Remain Length
if n, err := WriteRemainLength(msg.head[ho:], msg.remainLength); err != nil {
return nil, nil, err
} else {
ho += n
}
//1 Topic
if n, err := WriteBytes(msg.head[ho:], msg.topic); err != nil {
return nil, nil, err
} else {
ho += n
}
//2 PacketId
if msg.Qos() > Qos0 {
binary.BigEndian.PutUint16(msg.head[ho:], msg.packetId)
ho += 2
}
//3 Payload
//msg.payload = payload
return msg.head, msg.payload, nil
}

5
packet/pubrec.go Normal file
View File

@@ -0,0 +1,5 @@
package packet
type PubRec struct {
PubAck
}

5
packet/pubrel.go Normal file
View File

@@ -0,0 +1,5 @@
package packet
type PubRel struct {
PubAck
}

127
packet/suback.go Normal file
View File

@@ -0,0 +1,127 @@
package packet
import (
"encoding/binary"
)
type SubCode byte
const (
SUB_CODE_QOS0 SubCode = iota
SUB_CODE_QOS1
SUB_CODE_QOS2
SUB_CODE_ERR = 128
)
type SubAck struct {
Header
packetId uint16
codes []byte
}
func (msg *SubAck) PacketId() uint16 {
return msg.packetId
}
func (msg *SubAck) SetPacketId(p uint16) {
msg.dirty = true
msg.packetId = p
}
func (msg *SubAck) Codes() []byte {
return msg.codes
}
func (msg *SubAck) AddCode(c SubCode) {
msg.dirty = true
msg.codes = append(msg.codes, byte(c))
}
func (msg *SubAck) ClearCode() {
msg.dirty = true
msg.codes = msg.codes[0:0]
}
func (msg *SubAck) Decode(buf []byte) (int, error) {
msg.dirty = false
//total := len(buf)
offset := 0
//Header
msg.header = buf[0]
offset++
//Remain Length
if l, n, err := ReadRemainLength(buf[offset:]); err != nil {
return offset, err
} else {
msg.remainLength = l
offset += n
}
headerLen := offset
// PacketId
msg.packetId = binary.BigEndian.Uint16(buf[offset:])
offset += 2
// FixHead & VarHead
msg.head = buf[0:offset]
//plo := offset
// Parse Codes
l := msg.remainLength + headerLen
msg.codes = buf[offset:l]
//Payload
msg.payload = msg.codes
return offset, nil
}
func (msg *SubAck) Encode() ([]byte, []byte, error) {
if !msg.dirty {
return msg.head, msg.payload, nil
}
//Remain Length
msg.remainLength = 0
//Packet Id
msg.remainLength += 2
//FixHead & VarHead
hl := msg.remainLength
//Codes
msg.remainLength += len(msg.codes)
//pl := msg.remainLength - hl
hl += 1 + LenLen(msg.remainLength)
//Alloc buffer
msg.head = make([]byte, hl)
//msg.payload = make([]byte, pl)
//Header
ho := 0
msg.head[ho] = msg.header
ho++
//Remain Length
if n, err := WriteRemainLength(msg.head[ho:], msg.remainLength); err != nil {
return nil, nil, err
} else {
ho += n
}
//Packet Id
binary.BigEndian.PutUint16(msg.head[ho:], msg.packetId)
ho += 2
//Codes
msg.payload = msg.codes
return msg.head, msg.payload, nil
}

174
packet/subscribe.go Normal file
View File

@@ -0,0 +1,174 @@
package packet
import (
"encoding/binary"
"fmt"
)
type SubTopic struct {
topic []byte
flag byte
}
func (msg *SubTopic) Topic() []byte {
return msg.topic
}
func (msg *SubTopic) SetTopic(b []byte) {
msg.topic = b
}
func (msg *SubTopic) Qos() MsgQos {
return MsgQos(msg.flag & 0x03) //0000 0011
}
func (msg *SubTopic) SetQos(qos MsgQos) {
msg.flag &= 0xFC
msg.flag |= byte(qos)
}
type Subscribe struct {
Header
packetId uint16
topics []*SubTopic
}
func (msg *Subscribe) PacketId() uint16 {
return msg.packetId
}
func (msg *Subscribe) SetPacketId(p uint16) {
msg.dirty = true
msg.packetId = p
}
func (msg *Subscribe) Topics() []*SubTopic {
return msg.topics
}
func (msg *Subscribe) AddTopic(topic []byte, qos MsgQos) {
msg.dirty = true
st := &SubTopic{}
st.SetTopic(topic)
st.SetQos(qos)
msg.topics = append(msg.topics, st)
}
func (msg *Subscribe) ClearTopic() {
msg.dirty = true
msg.topics = msg.topics[0:0]
}
func (msg *Subscribe) Decode(buf []byte) (int, error) {
msg.dirty = false
//total := len(buf)
offset := 0
//Header
msg.header = buf[0]
offset++
//Remain Length
if l, n, err := ReadRemainLength(buf[offset:]); err != nil {
return offset, err
} else {
msg.remainLength = l
offset += n
}
headerLen := offset
// PacketId
msg.packetId = binary.BigEndian.Uint16(buf[offset:])
offset += 2
// FixHead & VarHead
msg.head = buf[0:offset]
plo := offset
// Parse Topics
for offset-headerLen < msg.remainLength {
st := &SubTopic{}
//Topic
if b, n, err := ReadBytes(buf[offset:]); err != nil {
return offset, err
} else {
st.SetTopic(b)
offset += n
}
//Qos
qos := buf[offset]
if (qos & 0x03) != qos {
return offset, fmt.Errorf("Topic Qos %x", qos)
}
st.SetQos(MsgQos(qos))
offset++
msg.topics = append(msg.topics, st)
}
//Payload
msg.payload = buf[plo:offset]
return offset, nil
}
func (msg *Subscribe) Encode() ([]byte, []byte, error) {
if !msg.dirty {
return msg.head, msg.payload, nil
}
//Remain Length
msg.remainLength = 0
//Packet Id
msg.remainLength += 2
//FixHead & VarHead
hl := msg.remainLength
//Topics
for _, t := range msg.topics {
msg.remainLength += 2 + len(t.Topic())
msg.remainLength += 1
}
pl := msg.remainLength - hl
hl += 1 + LenLen(msg.remainLength)
//Alloc buffer
msg.head = make([]byte, hl)
msg.payload = make([]byte, pl)
//Header
ho := 0
msg.head[ho] = msg.header
ho++
//Remain Length
if n, err := WriteRemainLength(msg.head[ho:], msg.remainLength); err != nil {
return nil, nil, err
} else {
ho += n
}
//Packet Id
binary.BigEndian.PutUint16(msg.head[ho:], msg.packetId)
ho += 2
plo := 0
//Topics
for _, t := range msg.topics {
//Topic
if n, err := WriteBytes(msg.payload[plo:], t.topic); err != nil {
return msg.head, nil, err
} else {
plo += n
}
//Qos
msg.payload[plo] = t.flag
plo++
}
return msg.head, msg.payload, nil
}

5
packet/unsuback.go Normal file
View File

@@ -0,0 +1,5 @@
package packet
type UnSubAck struct {
PubAck
}

135
packet/unsubscribe.go Normal file
View File

@@ -0,0 +1,135 @@
package packet
import (
"encoding/binary"
)
type UnSubscribe struct {
Header
packetId uint16
topics [][]byte
}
func (msg *UnSubscribe) PacketId() uint16 {
return msg.packetId
}
func (msg *UnSubscribe) SetPacketId(p uint16) {
msg.dirty = true
msg.packetId = p
}
func (msg *UnSubscribe) Topics() [][]byte {
return msg.topics
}
func (msg *UnSubscribe) AddTopic(topic []byte) {
msg.dirty = true
msg.topics = append(msg.topics, topic)
}
func (msg *UnSubscribe) ClearTopic() {
msg.dirty = true
msg.topics = msg.topics[0:0]
}
func (msg *UnSubscribe) Decode(buf []byte) (int, error) {
msg.dirty = false
//total := len(buf)
offset := 0
//Header
msg.header = buf[0]
offset++
//Remain Length
if l, n, err := ReadRemainLength(buf[offset:]); err != nil {
return offset, err
} else {
msg.remainLength = l
offset += n
}
headerLen := offset
// PacketId
msg.packetId = binary.BigEndian.Uint16(buf[offset:])
offset += 2
// FixHead & VarHead
msg.head = buf[0:offset]
plo := offset
// Parse Topics
for offset-headerLen < msg.remainLength {
//Topic
if b, n, err := ReadBytes(buf[offset:]); err != nil {
return offset, err
} else {
msg.AddTopic(b)
offset += n
}
}
//Payload
msg.payload = buf[plo:offset]
return offset, nil
}
func (msg *UnSubscribe) Encode() ([]byte, []byte, error) {
if !msg.dirty {
return msg.head, msg.payload, nil
}
//Remain Length
msg.remainLength = 0
//Packet Id
msg.remainLength += 2
//FixHead & VarHead
hl := msg.remainLength
//Topics
for _, t := range msg.topics {
msg.remainLength += 2 + len(t)
}
pl := msg.remainLength - hl
hl += 1 + LenLen(msg.remainLength)
//Alloc buffer
msg.head = make([]byte, hl)
msg.payload = make([]byte, pl)
//Header
ho := 0
msg.head[ho] = msg.header
ho++
//Remain Length
if n, err := WriteRemainLength(msg.head[ho:], msg.remainLength); err != nil {
return nil, nil, err
} else {
ho += n
}
//Packet Id
binary.BigEndian.PutUint16(msg.head[ho:], msg.packetId)
ho += 2
plo := 0
//Topics
for _, t := range msg.topics {
//Topic
if n, err := WriteBytes(msg.payload[plo:], t); err != nil {
return msg.head, nil, err
} else {
plo += n
}
}
return msg.head, msg.payload, nil
}

116
retaintree.go Normal file
View File

@@ -0,0 +1,116 @@
package beeq
import (
"github.com/zgwit/beeq/packet"
"strings"
)
type RetainNode struct {
//Subscribed retains
//clientId
retains map[string]*packet.Publish
//Sub level
//topic->children
children map[string]*RetainNode
}
func NewRetainNode() *RetainNode {
return &RetainNode{
retains: make(map[string]*packet.Publish),
children: make(map[string]*RetainNode),
}
}
func (rn *RetainNode) Fetch(topics []string, cb func(clientId string, pub *packet.Publish)) {
if len(topics) == 0 {
// Publish all matched retains
for clientId, pub := range rn.retains {
cb(clientId, pub)
}
} else {
name := topics[0]
if name == "#" {
//All retains
for clientId, pub := range rn.retains {
cb(clientId, pub)
}
//And all children
for _, sub := range rn.children {
sub.Fetch(topics, cb)
}
} else if name == "+" {
//Children
for _, sub := range rn.children {
sub.Fetch(topics[1:], cb)
}
} else {
// Sub-Level
if sub, ok := rn.children[name]; ok {
sub.Fetch(topics[1:], cb)
}
}
}
}
func (rn *RetainNode) Retain(topics []string, clientId string, pub *packet.Publish) *RetainNode {
if len(topics) == 0 {
// Publish to specific client
rn.retains[clientId] = pub
return rn
} else {
name := topics[0]
// Sub-Level
if _, ok := rn.children[name]; !ok {
rn.children[name] = NewRetainNode()
}
return rn.children[name].Retain(topics[1:], clientId, pub)
}
}
type RetainTree struct {
//root
root *RetainNode
//tree index
//ClientId -> Node (hold Publish message)
retains map[string]*RetainNode
}
func NewRetainTree() *RetainTree {
return &RetainTree{
root: NewRetainNode(),
}
}
func (rt *RetainTree) Fetch(topic []byte, cb func(clientId string, pub *packet.Publish)) {
topics := strings.Split(string(topic), "/")
if topics[0] == "" {
topics[0] = "/"
}
rt.root.Fetch(topics, cb)
}
func (rt *RetainTree) Retain(topic []byte, clientId string, pub *packet.Publish) {
// Remove last retain publish, firstly
rt.UnRetain(clientId)
topics := strings.Split(string(topic), "/")
if topics[0] == "" {
topics[0] = "/"
}
node := rt.root.Retain(topics, clientId, pub)
//indexed node
rt.retains[clientId] = node
}
func (rt *RetainTree) UnRetain(clientId string) {
if node, ok := rt.retains[clientId]; ok {
delete(node.retains, clientId)
delete(rt.retains, clientId)
}
}

63
session.go Normal file
View File

@@ -0,0 +1,63 @@
package beeq
import (
"github.com/zgwit/beeq/packet"
"time"
)
type Session struct {
//Bee
bee *Bee
//Client ID (from CONNECT)
clientId string
//Keep Alive (from CONNECT)
keepAlive int
//will topic (from CONNECT)
will *packet.Publish
//Qos1 Qos2
pub1 map[uint16]*packet.Publish
pub2 map[uint16]*packet.Publish
//Received Qos2 Publish
recvPub2 map[uint16]*packet.Publish
//Increment 0~65535
packetId uint16
alive bool
create_time time.Time
active_time time.Time
lost_time time.Time
}
func NewSession() *Session {
return &Session{
pub1: make(map[uint16]*packet.Publish),
pub2: make(map[uint16]*packet.Publish),
recvPub2: make(map[uint16]*packet.Publish),
alive: true,
create_time: time.Now(),
active_time: time.Now(),
}
}
func (session *Session) Alive() bool {
return session.alive
}
func (session *Session) Active(bee *Bee) {
session.alive = true
session.bee = bee
session.active_time = time.Now()
}
func (session *Session) DeActive() {
session.alive = false
session.bee = nil
session.lost_time = time.Now()
}

65
socket.go Normal file
View File

@@ -0,0 +1,65 @@
package beeq
import (
"crypto/tls"
"golang.org/x/net/websocket"
"net"
"net/http"
"time"
)
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
tc, err := ln.AcceptTCP()
if err != nil {
return
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute)
return tc, nil
}
func AcceptSocket(ln net.Listener, cb func(conn net.Conn)) {
go (func() {
//defer ln.Close()
for {
if conn, err := ln.Accept(); err != nil {
//TODO log
break
} else {
cb(conn)
}
}
})()
}
func AcceptWebSocket(ln net.Listener, pattern string, cb func(conn net.Conn)) {
h := func(ws *websocket.Conn) {
conn := NewWSConn(ws)
cb(conn)
//block here
conn.Wait()
}
mux := http.NewServeMux()
mux.Handle(pattern, websocket.Handler(h))
svr := http.Server{Handler:mux}
go svr.Serve(tcpKeepAliveListener{ln.(*net.TCPListener)})
}
func ListenSocket(laddr string) (net.Listener, error) {
return net.Listen("tcp", laddr)
}
func ListenSocketTLS(laddr string, cert string, key string) (net.Listener, error) {
if ce, err := tls.LoadX509KeyPair(cert, key); err != nil {
return nil, err
} else {
config := &tls.Config{Certificates: []tls.Certificate{ce}}
return tls.Listen("tcp", laddr, config)
}
}

163
subtree.go Normal file
View File

@@ -0,0 +1,163 @@
package beeq
import (
"github.com/zgwit/beeq/packet"
"strings"
)
type SubNode struct {
//Subscribed clients
//clientId
clients map[string]packet.MsgQos
//Sub level
//topic->children
children map[string]*SubNode
//Multi Wildcard #
mw *SubNode
//Single Wildcard +
sw *SubNode
}
func NewSubNode() *SubNode {
return &SubNode{
clients: make(map[string]packet.MsgQos),
children: make(map[string]*SubNode),
//mw: NewSubNode(),
//sw: NewSubNode(),
}
}
func (sn *SubNode) Publish(topics []string, subs map[string]packet.MsgQos) {
if len(topics) == 0 {
// Publish all matched clients
for clientId, qos := range sn.clients {
if sub, ok := subs[clientId]; ok {
//rewrite by larger Qos
if sub < qos {
subs[clientId] = qos
}
} else {
subs[clientId] = qos
}
}
} else {
name := topics[0]
// Sub-Level
if sub, ok := sn.children[name]; ok {
sub.Publish(topics[1:], subs)
}
// Multi wildcard
if sn.mw != nil {
sn.mw.Publish(topics[1:1], subs)
}
// Single wildcard
if sn.sw != nil {
sn.sw.Publish(topics[1:], subs)
}
}
}
func (sn *SubNode) Subscribe(topics []string, clientId string, qos packet.MsgQos) {
if len(topics) == 0 {
sn.clients[clientId] = qos
return
}
name := topics[0]
if name == "#" {
if sn.mw == nil {
sn.mw = NewSubNode()
}
sn.mw.Subscribe(topics[1:1], clientId, qos)
} else if name == "+" {
if sn.sw == nil {
sn.sw = NewSubNode()
}
sn.sw.Subscribe(topics[1:], clientId, qos)
} else {
if _, ok := sn.children[name]; !ok {
sn.children[name] = NewSubNode()
}
sn.children[name].Subscribe(topics[1:], clientId, qos)
}
}
func (sn *SubNode) UnSubscribe(topics []string, clientId string) {
if len(topics) == 0 {
delete(sn.clients, clientId)
} else {
name := topics[0]
if name == "#" {
if sn.mw != nil {
sn.mw.UnSubscribe(topics[1:1], clientId)
}
} else if name == "+" {
if sn.sw != nil {
sn.sw.UnSubscribe(topics[1:], clientId)
}
} else {
if sub, ok := sn.children[name]; ok {
sub.UnSubscribe(topics[1:], clientId)
}
}
}
}
func (sn *SubNode) ClearClient(clientId string) {
if _, ok := sn.clients[clientId]; ok {
delete(sn.clients, clientId)
}
if sn.mw != nil {
sn.mw.ClearClient(clientId)
}
if sn.sw != nil {
sn.sw.ClearClient(clientId)
}
for _, sub := range sn.children {
sub.ClearClient(clientId)
}
}
type SubTree struct {
//tree root
root *SubNode
}
func NewSubTree() *SubTree {
return &SubTree{
root: NewSubNode(),
}
}
func (st *SubTree) Publish(topic []byte, subs map[string]packet.MsgQos) {
topics := strings.Split(string(topic), "/")
if topics[0] == "" {
topics[0] = "/"
}
st.root.Publish(topics, subs)
}
func (st *SubTree) Subscribe(topic []byte, clientId string, qos packet.MsgQos) {
topics := strings.Split(string(topic), "/")
if topics[0] == "" {
topics[0] = "/"
}
st.root.Subscribe(topics, clientId, qos)
}
func (st *SubTree) UnSubscribe(topic []byte, clientId string) {
topics := strings.Split(string(topic), "/")
if topics[0] == "" {
topics[0] = "/"
}
st.root.UnSubscribe(topics, clientId)
}
func (st *SubTree) ClearClient(clientId string) {
st.root.ClearClient(clientId)
}

70
topic.go Normal file
View File

@@ -0,0 +1,70 @@
package beeq
import (
"bytes"
"errors"
"fmt"
)
func ValidTopic(topic []byte) error {
//no + #
if bytes.ContainsAny(topic, "+#") {
return fmt.Errorf("+ # is not valid (%s)", string(topic))
}
//no //
if bytes.Contains(topic, []byte("//")) {
return fmt.Errorf("// is not valid (%s)", string(topic))
}
return nil
}
func ValidSubscribe(topic []byte) error {
if len(topic) == 0 {
return errors.New("Blank topic")
}
topics := bytes.Split(topic, []byte("/"))
if len(topics[0]) == 0 {
topics[0] = []byte("/")
}
for i, tt := range topics {
l := len(tt)
if l == 0 {
return errors.New("inner blank")
}
if l == 1 && tt[0] == '#' && i < len(topics)-1 {
return errors.New("# must be the last one")
}
if l > 1 && bytes.ContainsAny(tt, "+#") {
return errors.New("+ # is alone")
}
}
return nil
}
func MatchSubscribe(topic []byte, sub []byte) bool {
i, j := 0, 0
for i < len(topic) && j < len(sub) {
t, s := topic[i], sub[j]
if s == '#' {
return true
} else if s == '+' {
for t != '/' {
i++
t = topic[i]
}
j++ // skip /
} else if s != t {
break
}
// else s==t
i++
j++
}
//Just match
if i == len(topic) && j == len(sub) {
return true
}
return false
}

76
wsconn.go Normal file
View File

@@ -0,0 +1,76 @@
package beeq
import (
"golang.org/x/net/websocket"
"net"
"time"
)
//inherit net.Conn
type WSConn struct {
ws *websocket.Conn
buf []byte
quit chan struct{}
}
func NewWSConn(ws *websocket.Conn) *WSConn {
return &WSConn{
ws: ws,
quit: make(chan struct{}),
}
}
func (conn *WSConn) Read(b []byte) (n int, err error) {
if conn.buf == nil {
if err := websocket.Message.Receive(conn.ws, &conn.buf); err != nil {
return 0, err
}
}
//if rn.buf
l := copy(b, conn.buf)
if l < len(conn.buf) {
conn.buf = conn.buf[l:]
} else {
// Receive next packet
conn.buf = nil
}
return l, nil
}
func (conn *WSConn) Write(b []byte) (n int, err error) {
if err := websocket.Message.Send(conn.ws, b); err != nil {
return 0, err
}
return len(b), nil
}
func (conn *WSConn) Close() error {
close(conn.quit)
return conn.ws.Close()
}
func (conn *WSConn) LocalAddr() net.Addr {
return conn.ws.LocalAddr()
}
func (conn *WSConn) RemoteAddr() net.Addr {
return conn.ws.RemoteAddr()
}
func (conn *WSConn) SetDeadline(t time.Time) error {
return conn.ws.SetDeadline(t)
}
func (conn *WSConn) SetReadDeadline(t time.Time) error {
return conn.ws.SetReadDeadline(t)
}
func (conn *WSConn) SetWriteDeadline(t time.Time) error {
return conn.ws.SetWriteDeadline(t)
}
func (conn *WSConn) Wait() {
<- conn.quit
}