diff --git a/.gitignore b/.gitignore index 66fd13c..b71045f 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ # Dependency directories (remove the comment below to include it) # vendor/ +/.idea diff --git a/bee.go b/bee.go new file mode 100644 index 0000000..c1b6278 --- /dev/null +++ b/bee.go @@ -0,0 +1,289 @@ +package beeq + +import ( + "encoding/binary" + "github.com/zgwit/beeq/packet" + "log" + "net" + "time" +) + +type Bee struct { + clientId string + session *Session + + conn net.Conn + + hive *Hive + + quit chan struct{} + events chan *Event + + timeout time.Duration + + alive bool +} + +func NewBee(conn net.Conn, hive *Hive) *Bee { + return &Bee{ + conn: conn, + hive: hive, + quit: make(chan struct{}), + events: make(chan *Event, 1024), + timeout: time.Hour * 24, + } +} + +func (bee *Bee) Event(event *Event) { + bee.events <- event +} + +func (bee *Bee) Active() { + bee.alive = true + + go bee.receiver() + go bee.messenger() +} + +func (bee *Bee) Shutdown() { + bee.alive = false + + // Tell hive delete me + bee.hive.Event(NewEvent(E_LOST_CONN, nil, bee)) + + // Stop messenger + close(bee.events) + close(bee.quit) +} + +func (bee *Bee) recv(b []byte) (int, error) { + err := bee.conn.SetReadDeadline(time.Now().Add(bee.timeout)) + if err != nil { + return 0, err + } + return bee.conn.Read(b) +} + +func (bee *Bee) send(b []byte) (int, error) { + err := bee.conn.SetWriteDeadline(time.Now().Add(bee.timeout)) + if err != nil { + return 0, err + } + return bee.conn.Write(b) +} + +func (bee *Bee) receiver() { + + //Abort error + defer func() { + if r := recover(); r!=nil { + log.Print("bee receiver panic ", r) + bee.Shutdown() + } + }() + + readHead := true + buf := Alloc(6) + offset := 0 + total := 0 + for { + if l, err := bee.recv(buf[offset:]); err != nil { + log.Print("Receive Failed: ", err) + break + } else { + offset += l + } + + //Parse head + if readHead && offset >= 2 { + for i := 1; i <= offset; i++ { + if buf[i] < 0x80 { + rl, rll := binary.Uvarint(buf[1:]) + remainLength := int(rl) + total = remainLength + rll + 1 + if total > 6 { + buf = ReAlloc(buf, total) + } + readHead = false + break + } + } + } + + //Parse Message + if !readHead && offset >= total { + readHead = true + b := Alloc(6) + if msg, l, err := packet.Decode(buf); err != nil { + //TODO log err + log.Print(err) + offset = 0 //clear data + } else { + bee.Event(NewEvent(E_MESSAGE, msg, bee)) + //Only message less than 6 bytes + if offset > l { + copy(b, buf[l:]) + offset -= l + } else { + offset = 0 //clear data + } + } + buf = b + } + } + + //Shutdown bee + if bee.alive { + bee.Shutdown() + } +} + +func (bee *Bee) messenger() { + + //Abort error + defer func() { + if r := recover(); r!=nil { + log.Print("bee messenger panic ", r) + bee.Shutdown() + } + }() + + for { + //Blocking + var event *Event + + select { + case <-bee.quit: + bee.conn.Close() + return + case event = <-bee.events: + switch event.event { + case E_MESSAGE: + bee.handleMessage(event.data.(packet.Message)) + case E_DISPATCH: + bee.dispatchMessage(event.data.(packet.Message)) + case E_CLOSE: + bee.Shutdown() + } + } + } +} + +func (bee *Bee) handleMessage(msg packet.Message) { + log.Printf("Received message from %s: %s QOS(%d) DUP(%t) RETAIN(%t)", bee.clientId, msg.Type().Name(), msg.Qos(), msg.Dup(), msg.Retain()) + //log.Print("recv Message:", msg.Type().Name()) + switch msg.Type() { + case packet.CONNECT: + bee.handleConnect(msg.(*packet.Connect)) + case packet.PUBLISH: + bee.handlePublish(msg.(*packet.Publish)) + case packet.PUBACK: + bee.handlePubAck(msg.(*packet.PubAck)) + case packet.PUBREC: + bee.handlePubRec(msg.(*packet.PubRec)) + case packet.PUBREL: + bee.handlePubRel(msg.(*packet.PubRel)) + case packet.PUBCOMP: + bee.handlePubComp(msg.(*packet.PubComp)) + case packet.SUBSCRIBE: + bee.handleSubscribe(msg.(*packet.Subscribe)) + case packet.UNSUBSCRIBE: + bee.handleUnSubscribe(msg.(*packet.UnSubscribe)) + case packet.PINGREQ: + bee.handlePingReq(msg.(*packet.PingReq)) + case packet.DISCONNECT: + bee.handleDisconnect(msg.(*packet.DisConnect)) + } +} + +func (bee *Bee) handleConnect(msg *packet.Connect) { + bee.clientId = string(msg.ClientId()) + + //if msg.WillFlag() { + // bee.session.will = new(packet.Publish) + //} + bee.hive.Event(NewEvent(E_CONNECT, msg, bee)) +} + +func (bee *Bee) handlePublish(msg *packet.Publish) { + qos := msg.Qos() + if qos == packet.Qos0 { + bee.hive.Event(NewEvent(E_PUBLISH, msg, bee)) + } else if qos == packet.Qos1 { + bee.hive.Event(NewEvent(E_PUBLISH, msg, bee)) + //Reply PUBACK + puback := packet.PUBACK.NewMessage().(*packet.PubAck) + puback.SetPacketId(msg.PacketId()) + bee.Event(NewEvent(E_DISPATCH, puback, bee)) + } else if qos == packet.Qos2 { + //Save & Send PUBREC + bee.session.recvPub2[msg.PacketId()] = msg + pubrec := packet.PUBREC.NewMessage().(*packet.PubRec) + pubrec.SetPacketId(msg.PacketId()) + bee.Event(NewEvent(E_DISPATCH, pubrec, bee)) + } else { + //error + } +} + +func (bee *Bee) handlePubAck(msg *packet.PubAck) { + if _, ok := bee.session.pub1[msg.PacketId()]; ok { + delete(bee.session.pub1, msg.PacketId()) + } +} + +func (bee *Bee) handlePubRec(msg *packet.PubRec) { + msg.SetType(packet.PUBREL) + bee.Event(NewEvent(E_DISPATCH, msg, bee)) +} + +func (bee *Bee) handlePubRel(msg *packet.PubRel) { + msg.SetType(packet.PUBCOMP) + bee.Event(NewEvent(E_DISPATCH, msg, bee)) +} + +func (bee *Bee) handlePubComp(msg *packet.PubComp) { + if _, ok := bee.session.pub2[msg.PacketId()]; ok { + delete(bee.session.pub2, msg.PacketId()) + } +} + +func (bee *Bee) handleSubscribe(msg *packet.Subscribe) { + bee.hive.Event(NewEvent(E_SUBSCRIBE, msg, bee)) +} + +func (bee *Bee) handleUnSubscribe(msg *packet.UnSubscribe) { + bee.hive.Event(NewEvent(E_UNSUBSCRIBE, msg, bee)) +} + +func (bee *Bee) handlePingReq(msg *packet.PingReq) { + msg.SetType(packet.PINGRESP) + bee.Event(NewEvent(E_DISPATCH, msg, bee)) +} + +func (bee *Bee) handleDisconnect(msg *packet.DisConnect) { + bee.hive.Event(NewEvent(E_DISCONNECT, msg, bee)) +} + +func (bee *Bee) dispatchMessage(msg packet.Message) { + log.Printf("Send message to %s: %s QOS(%d) DUP(%t) RETAIN(%t)", bee.clientId, msg.Type().Name(), msg.Qos(), msg.Dup(), msg.Retain()) + if head, payload, err := msg.Encode(); err != nil { + //TODO log + log.Print("Message encode error: ", err) + } else { + bee.send(head) + if payload != nil && len(payload) > 0 { + bee.send(payload) + } + } + + if msg.Type() == packet.PUBLISH { + pub := msg.(*packet.Publish) + //Publish Qos1 Qos2 Need store + if msg.Qos() == packet.Qos1 { + bee.session.pub1[pub.PacketId()] = pub + } else if msg.Qos() == packet.Qos2 { + bee.session.pub2[pub.PacketId()] = pub + } + } +} diff --git a/buffer.go b/buffer.go new file mode 100644 index 0000000..f10b471 --- /dev/null +++ b/buffer.go @@ -0,0 +1,11 @@ +package beeq + +func Alloc(l int) []byte { + return make([]byte, l) +} + +func ReAlloc(buf []byte, l int) []byte { + b := make([]byte, l) + copy(b, buf) + return b +} diff --git a/cmd/main.go b/cmd/main.go new file mode 100644 index 0000000..e69de29 diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..f1c3d88 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/zgwit/beeq + +go 1.13 diff --git a/hive.go b/hive.go new file mode 100644 index 0000000..7da362d --- /dev/null +++ b/hive.go @@ -0,0 +1,239 @@ +package beeq + +import ( + + "log" + "github.com/zgwit/beeq/packet" + "time" +) + +type Hive struct { + //Subscribe tree + subTree *SubTree + + //Retain tree + retainTree *RetainTree + + //ClientId->Session + sessions map[string]*Session + + //ClientId->Bee + //bees map[string]*Bee + + //Message received channel. Waiting for handling + events chan *Event + quit chan struct{} +} + +func NewHive() *Hive { + return &Hive{ + subTree: NewSubTree(), + retainTree: NewRetainTree(), + sessions: make(map[string]*Session), + events: make(chan *Event, 100), + quit: make(chan struct{}), + } +} + +func (hive *Hive) messenger() { + //Abort error + defer func() { + if r := recover(); r!=nil { + log.Print("hive messenger panic ", r) + + //Recovery main routine + hive.Active() + } + }() + + for { + select { + case <-hive.quit: + break + case event := <-hive.events: + switch event.event { + case E_CLOSE: + hive.Shutdown() + case E_LOST_CONN: + hive.handleLostConn(event.from.(*Bee)) + case E_CONNECT: + hive.handleConnect(event.data.(*packet.Connect), event.from.(*Bee)) + case E_PUBLISH: + hive.handlePublish(event.data.(*packet.Publish), event.from.(*Bee)) + case E_SUBSCRIBE: + hive.handleSubscribe(event.data.(*packet.Subscribe), event.from.(*Bee)) + case E_UNSUBSCRIBE: + hive.handleUnSubscribe(event.data.(*packet.UnSubscribe), event.from.(*Bee)) + case E_DISCONNECT: + hive.handleDisconnect(event.data.(*packet.DisConnect), event.from.(*Bee)) + } + } + } +} + +func (hive *Hive) Active() { + //Single go Routine, no lock + //Only one processing all message. Performance? + //TODO Benchmark + go hive.messenger() +} + +func (hive *Hive) Shutdown() { + close(hive.quit) +} + +func (hive *Hive) Event(event *Event) { + //Blocking + hive.events <- event +} + +func (hive *Hive) handleLostConn(bee *Bee) { + log.Print("lost ", bee.clientId) + if session, ok := hive.sessions[bee.clientId]; ok { + session.DeActive() + } +} + +func (hive *Hive) handleConnect(msg *packet.Connect, bee *Bee) { + + connack := packet.CONNACK.NewMessage().(*packet.Connack) + + var clientId string + if len(msg.ClientId()) == 0 { + if !msg.CleanSession() { + //error + bee.Event(NewEvent(E_CLOSE, nil, hive)) + return + } + + // Generate unique clientId (uuid random) + for { + clientId = "xxx" + if _, ok := hive.sessions[clientId]; !ok { + break + } + } + + } else { + clientId = string(msg.ClientId()) + + if session, ok := hive.sessions[clientId]; ok { + // ClientId is already used + if session.Alive() { + //error reject + connack.SetCode(packet.CONNACK_UNAVAILABLE) + bee.Event(NewEvent(E_DISPATCH, connack, hive)) + return + } else { + if msg.CleanSession() { + delete(hive.sessions, clientId) + } else { + session.Active(bee) + } + } + } + } + + log.Print(clientId, " Connected") + + // Generate session + if _, ok := hive.sessions[clientId]; !ok { + session := NewSession() + hive.sessions[clientId] = session + } else { + connack.SetSessionPresent(true) + } + + hive.sessions[clientId].bee = bee + hive.sessions[clientId].clientId = clientId + bee.clientId = clientId + bee.session = hive.sessions[clientId] + if msg.KeepAlive() > 0 { + bee.timeout = time.Second * time.Duration(msg.KeepAlive()) * 3 / 2 + } + + connack.SetCode(packet.CONNACK_ACCEPTED) + bee.Event(NewEvent(E_DISPATCH, connack, hive)) +} + +func (hive *Hive) handlePublish(msg *packet.Publish, bee *Bee) { + if err := ValidTopic(msg.Topic()); err != nil { + //TODO log + log.Print("Topic invalid ", err) + return + } + + if msg.Retain() { + if len(msg.Payload()) == 0 { + hive.retainTree.UnRetain(bee.clientId) + } else { + hive.retainTree.Retain(msg.Topic(), bee.clientId, msg) + } + } + + //Fetch subscribers + subs := make(map[string]packet.MsgQos) + hive.subTree.Publish(msg.Topic(), subs) + + //Send publish message + for clientId, qos := range subs { + if session, ok := hive.sessions[clientId]; ok && session.alive { + bee := session.bee + //clone new pub + pub := *msg + pub.SetRetain(false) + if msg.Qos() <= qos { + bee.Event(NewEvent(E_DISPATCH, &pub, hive)) + } else { + pub.SetQos(qos) + bee.Event(NewEvent(E_DISPATCH, &pub, hive)) + } + } + } +} + +func (hive *Hive) handleSubscribe(msg *packet.Subscribe, bee *Bee) { + suback := packet.SUBACK.NewMessage().(*packet.SubAck) + suback.SetPacketId(msg.PacketId()) + for _, st := range msg.Topics() { + log.Print("Subscribe ", string(st.Topic())) + if err := ValidSubscribe(st.Topic()); err != nil { + log.Print("Invalid topic ", err) + //log error + suback.AddCode(packet.SUB_CODE_ERR) + } else { + hive.subTree.Subscribe(st.Topic(), bee.clientId, st.Qos()) + + suback.AddCode(packet.SubCode(st.Qos())) + hive.retainTree.Fetch(st.Topic(), func(clientId string, pub *packet.Publish) { + //clone new pub + p := *pub + p.SetRetain(true) + if msg.Qos() <= st.Qos() { + bee.Event(NewEvent(E_DISPATCH, &p, hive)) + } else { + p.SetQos(st.Qos()) + bee.Event(NewEvent(E_DISPATCH, &p, hive)) + } + }) + } + } + bee.Event(NewEvent(E_DISPATCH, suback, hive)) +} + +func (hive *Hive) handleUnSubscribe(msg *packet.UnSubscribe, bee *Bee) { + unsuback := packet.UNSUBACK.NewMessage().(*packet.UnSubAck) + for _, t := range msg.Topics() { + log.Print("UnSubscribe ", string(t)) + if err := ValidSubscribe(t); err != nil { + //TODO log + } else { + hive.subTree.UnSubscribe(t, bee.clientId) + } + } + bee.Event(NewEvent(E_DISPATCH, unsuback, hive)) +} + +func (hive *Hive) handleDisconnect(msg *packet.DisConnect, bee *Bee) { + bee.Event(NewEvent(E_CLOSE, nil, hive)) +} diff --git a/packet/connack.go b/packet/connack.go new file mode 100644 index 0000000..bb63134 --- /dev/null +++ b/packet/connack.go @@ -0,0 +1,127 @@ +package packet + +import ( + "fmt" +) + +type ConnackCode byte + +const ( + CONNACK_ACCEPTED ConnackCode = iota + CONNACK_ERROR_VERSION + CONNACK_INVALID_CLIENT_ID + CONNACK_UNAVAILABLE + CONNACK_INVALID_USERNAME_PASSWORD + CONNACK_NOT_AUTHORIZED +) + +type Connack struct { + Header + + confirm byte + code byte +} + +func (msg *Connack) SessionPresent() bool { + return msg.confirm&0x01 == 0x01 //0000 0001 +} + +func (msg *Connack) SetSessionPresent(b bool) { + msg.dirty = true + if b { + msg.confirm |= 0x01 //0000 0001 + } else { + msg.confirm &= 0xFE //1111 1110 + } +} + +func (msg *Connack) Code() ConnackCode { + return ConnackCode(msg.code) +} + +func (msg *Connack) SetCode(c ConnackCode) { + msg.dirty = true + msg.code = byte(c) +} + +func (msg *Connack) Decode(buf []byte) (int, error) { + msg.dirty = false + + //Tips. remain length is fixed 2 & total is fixed 4 + total := len(buf) + if total < 4 { + return 0, fmt.Errorf("Connack expect fixed 4 bytes (%d)", total) + } + + offset := 0 + + //Header + msg.header = buf[0] + offset++ + + //Remain Length + if l, n, err := ReadRemainLength(buf[offset:]); err != nil { + return offset, err + } else if l != 2 { + return 0, fmt.Errorf("Remain length must be 2, got %d", l) + } else { + msg.remainLength = l + offset += n + } + + //1 Confirm + msg.confirm = buf[offset] + offset++ + + //2 Code + msg.code = buf[offset] + offset++ + + // FixHead & VarHead + msg.head = buf[0:offset] + + return offset, nil +} + +func (msg *Connack) Encode() ([]byte, []byte, error) { + if !msg.dirty { + return msg.head, nil, nil + } + + //Tips. remain length is fixed 2 & total is fixed 4 + //Remain Length + msg.remainLength = 0 + //Confirm + msg.remainLength += 1 + //Code + msg.remainLength += 1 + + //FixHead & VarHead + hl := msg.remainLength + + hl += 1 + LenLen(msg.remainLength) + //Alloc buffer + msg.head = make([]byte, hl) + + //Header + ho := 0 + msg.head[ho] = msg.header + ho++ + + //Remain Length + if n, err := WriteRemainLength(msg.head[ho:], msg.remainLength); err != nil { + return nil, nil, err + } else { + ho += n + } + + //1 Confirm + msg.head[ho] = msg.confirm + ho++ + + //2 Code + msg.head[ho] = msg.code + ho++ + + return msg.head, nil, nil +} diff --git a/packet/connect.go b/packet/connect.go new file mode 100644 index 0000000..cf3e14a --- /dev/null +++ b/packet/connect.go @@ -0,0 +1,441 @@ +package packet + +import ( + "encoding/binary" + "fmt" + "regexp" +) + +var clientIdRegex *regexp.Regexp + +func init() { + clientIdRegex = regexp.MustCompile("^[0-9a-zA-Z _]*$") +} + +var SupportedVersions map[byte]string = map[byte]string{ + 0x3: "MQIsdp", + 0x4: "MQTT", +} + +type Connect struct { + Header + + protoName []byte + + protoLevel byte + + flag byte + + keepAlive uint16 + + clientId []byte + willTopic []byte + willMessage []byte + + userName []byte + password []byte +} + +func (msg *Connect) ProtoName() []byte { + return msg.protoName +} + +func (msg *Connect) SetProtoName(b []byte) { + msg.dirty = true + msg.protoName = b +} + +func (msg *Connect) ProtoLevel() byte { + return msg.protoLevel +} + +func (msg *Connect) SetProtoLevel(b byte) { + msg.dirty = true + msg.protoLevel = b + msg.protoName = []byte(SupportedVersions[b]) +} + +func (msg *Connect) UserNameFlag() bool { + return msg.flag&0x80 == 0x80 //1000 0000 +} + +func (msg *Connect) SetUserNameFlag(b bool) { + msg.dirty = true + if b { + msg.flag |= 0x80 //1000 0000 + } else { + msg.flag &= 0x7F //0111 1111 + } +} + +func (msg *Connect) PasswordFlag() bool { + return msg.flag&0x40 == 0x40 //0100 0000 +} + +func (msg *Connect) SetPasswordFlag(b bool) { + msg.dirty = true + if b { + msg.flag |= 0x40 //0100 0000 + } else { + msg.flag &= 0xBF //1011 1111 + } +} + +func (msg *Connect) WillRetain() bool { + return msg.flag&0x20 == 0x40 //0010 0000 +} + +func (msg *Connect) SetWillRetain(b bool) { + msg.dirty = true + if b { + msg.flag |= 0x20 //0010 0000 + } else { + msg.flag &= 0xDF //1101 1111 + } + msg.SetWillFlag(true) +} + +func (msg *Connect) WillQos() MsgQos { + return MsgQos((msg.flag & 0x18) >> 2) // 0001 1000 +} + +func (msg *Connect) SetWillQos(qos MsgQos) { + msg.dirty = true + msg.flag &= 0xE7 // 1110 0111 + msg.flag |= byte(qos << 2) + msg.SetWillFlag(true) +} + +func (msg *Connect) WillFlag() bool { + return msg.flag&0x04 == 0x04 //0000 0100 +} + +func (msg *Connect) SetWillFlag(b bool) { + msg.dirty = true + if b { + msg.flag |= 0x04 //0000 0100 + } else { + msg.flag &= 0xFB //1111 1011 + + msg.SetWillQos(Qos0) + msg.SetWillRetain(false) + } +} + +func (msg *Connect) KeepAlive() uint16 { + return msg.keepAlive +} + +func (msg *Connect) SetKeepAlive(k uint16) { + msg.dirty = true + msg.keepAlive = k +} + +func (msg *Connect) ClientId() []byte { + return msg.clientId +} + +func (msg *Connect) SetClientId(b []byte) { + msg.dirty = true + msg.clientId = b + //msg.ValidClientId() +} + +func (msg *Connect) WillTopic() []byte { + return msg.willTopic +} + +func (msg *Connect) SetWillTopic(b []byte) { + msg.dirty = true + msg.willTopic = b + msg.SetWillFlag(true) +} + +func (msg *Connect) WillMessage() []byte { + return msg.willMessage +} + +func (msg *Connect) SetWillMessage(b []byte) { + msg.dirty = true + msg.willMessage = b + msg.SetWillFlag(true) +} + +func (msg *Connect) UserName() []byte { + return msg.userName +} + +func (msg *Connect) SetUserName(b []byte) { + msg.dirty = true + msg.userName = b + msg.SetUserNameFlag(true) +} + +func (msg *Connect) Password() []byte { + return msg.password +} + +func (msg *Connect) SetPassword(b []byte) { + msg.dirty = true + msg.password = b + msg.SetPasswordFlag(true) +} + +func (msg *Connect) CleanSession() bool { + return msg.flag&0x02 == 0x02 //0000 0010 +} + +func (msg *Connect) SetCleanSession(b bool) { + msg.dirty = true + if b { + msg.flag |= 0x02 //0000 0010 + } else { + msg.flag &= 0xFD //1111 1101 + } +} + +func (msg *Connect) ValidClientId() bool { + + if msg.ProtoLevel() == 0x3 { + return true + } + + return clientIdRegex.Match(msg.clientId) +} + +func (msg *Connect) Decode(buf []byte) (int, error) { + msg.dirty = false + + //total := len(buf) + offset := 0 + + //Header + msg.header = buf[0] + offset++ + + //Remain Length + if l, n, err := ReadRemainLength(buf[offset:]); err != nil { + return offset, err + } else { + msg.remainLength = l + offset += n + } + + //1 Protocol Name + if b, n, err := ReadBytes(buf[offset:]); err != nil { + return offset, err + } else { + msg.protoName = b + offset += n + } + + //2 Protocol Level + msg.protoLevel = buf[offset] + if version, ok := SupportedVersions[msg.ProtoLevel()]; !ok { + return offset, fmt.Errorf("Protocol level (%d) is not support", msg.ProtoLevel()) + } else if ver := string(msg.ProtoName()); ver != version { + return offset, fmt.Errorf("Protocol name (%s) invalid", ver) + } + offset++ + + //3 Connect flag + msg.flag = buf[offset] + offset++ + + if msg.flag&0x1 != 0 { + return offset, fmt.Errorf("Connect Flags (%x) reserved bit 0", msg.flag) + } + + if msg.WillQos() > Qos2 { + return offset, fmt.Errorf("Invalid WillQoS (%d)", msg.WillQos()) + } + + if !msg.WillFlag() && (msg.WillRetain() || msg.WillQos() != Qos0) { + return offset, fmt.Errorf("Invalid WillFlag (%x)", msg.flag) + } + + if msg.UserNameFlag() != msg.PasswordFlag() { + return offset, fmt.Errorf("UserName Password must be both exists or not (%x)", msg.flag) + } + + //4 Keep Alive + msg.keepAlive = binary.BigEndian.Uint16(buf[offset:]) + offset += 2 + + // FixHead & VarHead + msg.head = buf[0:offset] + plo := offset + + //5 ClientId + if b, n, err := ReadBytes(buf[offset:]); err != nil { + return offset, err + } else { + msg.clientId = b + offset += n + + // None ClientId, Must Clean Session + if n == 2 && !msg.CleanSession() { + return offset, fmt.Errorf("None ClientId, Must Clean Session (%x)", msg.flag) + } + + // ClientId at most 23 characters + if n > 128+2 { + return offset, fmt.Errorf("Too long ClientId (%s)", string(msg.ClientId())) + } + + // ClientId 0-9, a-z, A-Z + if n > 0 && !msg.ValidClientId() { + return offset, fmt.Errorf("Invalid ClientId (%s)", string(msg.ClientId())) + } + } + + //6 Will Topic & Message + if msg.WillFlag() { + if b, n, err := ReadBytes(buf[offset:]); err != nil { + return offset, err + } else { + msg.willTopic = b + offset += n + } + + if b, n, err := ReadBytes(buf[offset:]); err != nil { + return offset, err + } else { + msg.willMessage = b + offset += n + } + } + + //7 UserName & Password + if msg.UserNameFlag() { + if b, n, err := ReadBytes(buf[offset:]); err != nil { + return offset, err + } else { + msg.userName = b + offset += n + } + + if b, n, err := ReadBytes(buf[offset:]); err != nil { + return offset, err + } else { + msg.password = b + offset += n + } + } + + //Payload + msg.payload = buf[plo:offset] + + return offset, nil +} + +func (msg *Connect) Encode() ([]byte, []byte, error) { + if !msg.dirty { + return msg.head, msg.payload, nil + } + + //Remain Length + msg.remainLength = 0 + //Protocol Name + msg.remainLength += 2 + len(msg.protoName) + //Protocol Level + msg.remainLength += 1 + //Connect Flags + msg.remainLength += 1 + //Keep Alive + msg.remainLength += 1 + + //FixHead & VarHead + hl := msg.remainLength + + //ClientId + msg.remainLength += 2 + len(msg.clientId) + //Will Topic & Message + if msg.WillFlag() { + msg.remainLength += 2 + len(msg.willTopic) + msg.remainLength += 2 + len(msg.willMessage) + } + //UserName & Password + if msg.UserNameFlag() { + msg.remainLength += 2 + len(msg.userName) + msg.remainLength += 2 + len(msg.password) + } + + pl := msg.remainLength - hl + hl += 1 + LenLen(msg.remainLength) + + //Alloc buffer + msg.head = make([]byte, hl) + msg.payload = make([]byte, pl) + + //Header + ho := 0 + msg.head[ho] = msg.header + ho++ + + //Remain Length + if n, err := WriteRemainLength(msg.head[ho:], msg.remainLength); err != nil { + return nil, nil, err + } else { + ho += n + } + + //1 Protocol Name + if n, err := WriteBytes(msg.head[ho:], msg.protoName); err != nil { + return nil, nil, err + } else { + ho += n + } + + //2 Protocol Level + msg.head[ho] = msg.protoLevel + ho++ + + //3 Connect Flags + msg.head[ho] = msg.flag + ho++ + + //4 Keep Alive + binary.BigEndian.PutUint16(msg.head[ho:], msg.keepAlive) + ho += 2 + + plo := 0 + //5 ClientId + if n, err := WriteBytes(msg.payload[plo:], msg.clientId); err != nil { + return msg.head, nil, err + } else { + plo += n + } + + //6 Will Topic & Message + if msg.WillFlag() { + if n, err := WriteBytes(msg.payload[plo:], msg.willTopic); err != nil { + return msg.head, nil, err + } else { + plo += n + } + + if n, err := WriteBytes(msg.payload[plo:], msg.willMessage); err != nil { + return msg.head, nil, err + } else { + plo += n + } + } + + //7 UserName & Password + if msg.UserNameFlag() { + if n, err := WriteBytes(msg.payload[plo:], msg.userName); err != nil { + return msg.head, nil, err + } else { + plo += n + } + + if n, err := WriteBytes(msg.payload[plo:], msg.password); err != nil { + return msg.head, nil, err + } else { + plo += n + } + } + + return msg.head, msg.payload, nil +} diff --git a/packet/disconnect.go b/packet/disconnect.go new file mode 100644 index 0000000..20253d0 --- /dev/null +++ b/packet/disconnect.go @@ -0,0 +1,71 @@ +package packet + +import ( + "fmt" +) + +type DisConnect struct { + Header +} + +func (msg *DisConnect) Decode(buf []byte) (int, error) { + msg.dirty = false + + //Tips. remain length is fixed 0 & total is fixed 2 + total := len(buf) + if total < 2 { + return 0, fmt.Errorf("DisConnect expect fixed 2 bytes, got %d", total) + } + + offset := 0 + + //Header + msg.header = buf[0] + offset++ + + //Remain Length + if l, n, err := ReadRemainLength(buf[offset:]); err != nil { + return offset, err + } else if l != 0 { + return 0, fmt.Errorf("Remain length must be 0, got %d", l) + } else { + msg.remainLength = l + offset += n + } + + // FixHead & VarHead + msg.head = buf[0:offset] + + return offset, nil +} + +func (msg *DisConnect) Encode() ([]byte, []byte, error) { + if !msg.dirty { + return msg.head, nil, nil + } + + //Tips. remain length is fixed 0 & total is fixed 2 + //Remain Length + msg.remainLength = 0 + + //FixHead & VarHead + hl := msg.remainLength + + hl += 1 + LenLen(msg.remainLength) + //Alloc buffer + msg.head = make([]byte, hl) + + //Header + ho := 0 + msg.head[ho] = msg.header + ho++ + + //Remain Length + if n, err := WriteRemainLength(msg.head[ho:], msg.remainLength); err != nil { + return nil, nil, err + } else { + ho += n + } + + return msg.head, nil, nil +} diff --git a/packet/header.go b/packet/header.go new file mode 100644 index 0000000..62f76cd --- /dev/null +++ b/packet/header.go @@ -0,0 +1,61 @@ +package packet + +type Header struct { + header byte + remainLength int + + dirty bool + + head []byte + payload []byte +} + +func (hdr *Header) Type() MsgType { + return MsgType((hdr.header & 0xF0) >> 4) +} + +func (hdr *Header) SetType(t MsgType) { + hdr.dirty = true + hdr.header &= 0x0F // 0000 1111 + hdr.header |= byte(t << 4) +} + +func (hdr *Header) Dup() bool { + return hdr.header&0x08 == 0x08 //0000 1000 +} + +func (hdr *Header) SetDup(b bool) { + hdr.dirty = true + if b { + hdr.header |= 0x08 //0000 1000 + } else { + hdr.header &= 0xF7 //1111 0111 + } +} + +func (hdr *Header) Qos() MsgQos { + return MsgQos((hdr.header & 0x06) >> 1) //0000 0110 +} + +func (hdr *Header) SetQos(qos MsgQos) { + hdr.dirty = true + hdr.header &= 0xF9 //1111 1001 + hdr.header |= byte(qos << 1) //0000 0110 +} + +func (hdr *Header) Retain() bool { + return hdr.header&0x01 == 0x01 +} + +func (hdr *Header) SetRetain(b bool) { + hdr.dirty = true + if b { + hdr.header |= 0x01 //0000 0001 + } else { + hdr.header &= 0xFE //1111 1110 + } +} + +func (hdr *Header) RemainLength() int { + return hdr.remainLength +} diff --git a/packet/message.go b/packet/message.go new file mode 100644 index 0000000..dc0bd57 --- /dev/null +++ b/packet/message.go @@ -0,0 +1,112 @@ +package packet + +type MsgType byte + +const ( + RESERVED MsgType = iota + CONNECT + CONNACK + PUBLISH + PUBACK + PUBREC + PUBREL + PUBCOMP + SUBSCRIBE + SUBACK + UNSUBSCRIBE + UNSUBACK + PINGREQ + PINGRESP + DISCONNECT + RESERVED2 +) + +var msgNames = []string{ + "RESERVED", "CONNECT", "CONNACK", "PUBLISH", + "PUBACK", "PUBREC", "PUBREL", "PUBCOMP", + "SUBSCRIBE", "SUBACK", "UNSUBSCRIBE", "UNSUBACK", + "PINGREQ", "PINGRESP", "DISCONNECT", "RESERVED", +} + +func (mt MsgType) Name() string { + return msgNames[mt&0x0F] +} + +func (mt MsgType) NewMessage() Message { + var msg Message + switch mt { + case CONNECT: + msg = new(Connect) + case CONNACK: + msg = new(Connack) + case PUBLISH: + msg = new(Publish) + case PUBACK: + msg = new(PubAck) + case PUBREC: + msg = new(PubRec) + case PUBREL: + msg = new(PubRel) + case PUBCOMP: + msg = new(PubComp) + case SUBSCRIBE: + msg = new(Subscribe) + case SUBACK: + msg = new(SubAck) + case UNSUBSCRIBE: + msg = new(UnSubscribe) + case UNSUBACK: + msg = new(UnSubAck) + case PINGREQ: + msg = new(PingReq) + case PINGRESP: + msg = new(PingResp) + case DISCONNECT: + msg = new(DisConnect) + default: + //error + return nil + } + msg.SetType(mt) + return msg +} + +type MsgQos byte + +var qosNames = []string{ + "AtMostOnce", "AtLastOnce", "ExactlyOnce", "QosError", +} + +func (qos MsgQos) Name() string { + // 0000 0011 + return qosNames[qos&0x03] +} + +func (qos MsgQos) Level() uint8 { + return uint8(qos & 0x03) +} + +const ( + //At most once + Qos0 MsgQos = iota + //At least once + Qos1 + //Exactly once + Qos2 +) + +type Message interface { + Type() MsgType + SetType(t MsgType) + Dup() bool + SetDup(b bool) + Qos() MsgQos + SetQos(qos MsgQos) + Retain() bool + SetRetain(b bool) + RemainLength() int + + Decode([]byte) (int, error) + + Encode() ([]byte, []byte, error) +} diff --git a/packet/packet.go b/packet/packet.go new file mode 100644 index 0000000..a6d7c9c --- /dev/null +++ b/packet/packet.go @@ -0,0 +1,102 @@ +package packet + +import ( + "encoding/binary" + "fmt" +) + +func LenLen(rl int) int { + if rl <= 127 { //0x7F + return 1 + } else if rl <= 16383 { //0x7F 7F + return 2 + } else if rl <= 2097151 { //0x7F 7F 7F + return 3 + } else { + return 4 + } +} + +func ReadRemainLength(b []byte) (int, int, error) { + length := len(b) + size := 1 + for { + if length < size { + return 0, size, fmt.Errorf("[ReadRemainLength] Expect at leat %d bytes", 1) + } + + if b[size-1] > 0x80 { + size += 1 + } else { + break + } + + if size > 4 { + return 0, size, fmt.Errorf("[ReadRemainLength] Expect at most 4 bytes, got %d", size) + } + } + rl, size := binary.Uvarint(b) + return int(rl), size, nil +} + +func WriteRemainLength(b []byte, rl int) (int, error) { + length := len(b) + ll := LenLen(rl) + if ll > length { + return 0, fmt.Errorf("[ReadRemainLength] Expect at most %d bytes for remain length", ll) + } + return binary.PutUvarint(b, uint64(rl)), nil +} + +func ReadBytes(buf []byte) ([]byte, int, error) { + if len(buf) < 2 { + return nil, 0, fmt.Errorf("[readLPBytes] Expect at least %d bytes for prefix", 2) + } + length := int(binary.BigEndian.Uint16(buf)) + total := length + 2 + if len(buf) < total { + return nil, 0, fmt.Errorf("[readLPBytes] Expect at least %d bytes", length+2) + } + b := buf[2 : total] + return b, total, nil +} + +func WriteBytes(buf []byte, b []byte) (int, error) { + length, size := len(b), len(buf) + + if length > 65535 { + return 0, fmt.Errorf("[writeLPBytes] Too much bytes(%d) to write", length) + } + + total := length + 2 + if size < total { + return 0, fmt.Errorf("[writeLPBytes] Expect at least %d bytes", total) + } + + binary.BigEndian.PutUint16(buf, uint16(length)) + + copy(buf[2:], b) + + return total, nil +} + +func BytesDup(buf []byte) []byte { + b := make([]byte, len(buf)) + copy(b, buf) + return b +} + +func Decode(buf []byte) (Message, int, error) { + mt := MsgType(buf[0] >> 4) + msg := mt.NewMessage() + if msg != nil { + l, err := msg.Decode(buf) + return msg, l, err + } else { + return nil, 0, fmt.Errorf("Unknown messege type") + } +} + +func Encode(msg Message) ([]byte, []byte, error) { + return msg.Encode() +} diff --git a/packet/pingreq.go b/packet/pingreq.go new file mode 100644 index 0000000..251f25d --- /dev/null +++ b/packet/pingreq.go @@ -0,0 +1,71 @@ +package packet + +import ( + "fmt" +) + +type PingReq struct { + Header +} + +func (msg *PingReq) Decode(buf []byte) (int, error) { + msg.dirty = false + + //Tips. remain length is fixed 0 & total is fixed 2 + total := len(buf) + if total < 2 { + return 0, fmt.Errorf("Ping expect fixed 2 bytes, got %d", total) + } + + offset := 0 + + //Header + msg.header = buf[0] + offset++ + + //Remain Length + if l, n, err := ReadRemainLength(buf[offset:]); err != nil { + return offset, err + } else if l != 0 { + return 0, fmt.Errorf("Remain length must be 0, got %d", l) + } else { + msg.remainLength = l + offset += n + } + + // FixHead & VarHead + msg.head = buf[0:offset] + + return offset, nil +} + +func (msg *PingReq) Encode() ([]byte, []byte, error) { + if !msg.dirty { + return msg.head, nil, nil + } + + //Tips. remain length is fixed 0 & total is fixed 2 + //Remain Length + msg.remainLength = 0 + + //FixHead & VarHead + hl := msg.remainLength + + hl += 1 + LenLen(msg.remainLength) + //Alloc buffer + msg.head = make([]byte, hl) + + //Header + ho := 0 + msg.head[ho] = msg.header + ho++ + + //Remain Length + if n, err := WriteRemainLength(msg.head[ho:], msg.remainLength); err != nil { + return nil, nil, err + } else { + ho += n + } + + return msg.head, nil, nil +} diff --git a/packet/pingresp.go b/packet/pingresp.go new file mode 100644 index 0000000..c1fce1b --- /dev/null +++ b/packet/pingresp.go @@ -0,0 +1,5 @@ +package packet + +type PingResp struct { + PingReq +} diff --git a/packet/puback.go b/packet/puback.go new file mode 100644 index 0000000..0c295a9 --- /dev/null +++ b/packet/puback.go @@ -0,0 +1,93 @@ +package packet + +import ( + "encoding/binary" + "fmt" +) + +type PubAck struct { + Header + + packetId uint16 +} + +func (msg *PubAck) PacketId() uint16 { + return msg.packetId +} + +func (msg *PubAck) SetPacketId(p uint16) { + msg.dirty = true + msg.packetId = p +} + +func (msg *PubAck) Decode(buf []byte) (int, error) { + msg.dirty = false + + //Tips. remain length is fixed 2 & total is fixed 4 + total := len(buf) + if total < 4 { + return 0, fmt.Errorf("Connack expect fixed 4 bytes (%d)", total) + } + + offset := 0 + + //Header + msg.header = buf[0] + offset++ + + //Remain Length + if l, n, err := ReadRemainLength(buf[offset:]); err != nil { + return offset, err + } else if l != 2 { + return 0, fmt.Errorf("Remain length must be 2, got %d", l) + } else { + msg.remainLength = l + offset += n + } + + // PacketId + msg.packetId = binary.BigEndian.Uint16(buf[offset:]) + offset += 2 + + // FixHead & VarHead + msg.head = buf[0:offset] + + return offset, nil +} + +func (msg *PubAck) Encode() ([]byte, []byte, error) { + if !msg.dirty { + return msg.head, nil, nil + } + + //Tips. remain length is fixed 2 & total is fixed 4 + //Remain Length + msg.remainLength = 0 + //Packet Id + msg.remainLength += 2 + + //FixHead & VarHead + hl := msg.remainLength + + hl += 1 + LenLen(msg.remainLength) + //Alloc buffer + msg.head = make([]byte, hl) + + //Header + ho := 0 + msg.head[ho] = msg.header + ho++ + + //Remain Length + if n, err := WriteRemainLength(msg.head[ho:], msg.remainLength); err != nil { + return nil, nil, err + } else { + ho += n + } + + //Packet Id + binary.BigEndian.PutUint16(msg.head[ho:], msg.packetId) + ho += 2 + + return msg.head, nil, nil +} diff --git a/packet/pubcomp.go b/packet/pubcomp.go new file mode 100644 index 0000000..eb07186 --- /dev/null +++ b/packet/pubcomp.go @@ -0,0 +1,5 @@ +package packet + +type PubComp struct { + PubAck +} diff --git a/packet/publish.go b/packet/publish.go new file mode 100644 index 0000000..0938071 --- /dev/null +++ b/packet/publish.go @@ -0,0 +1,152 @@ +package packet + +import ( + "encoding/binary" + "fmt" +) + +type Publish struct { + Header + + topic []byte + + packetId uint16 + + //payload []byte +} + +func (msg *Publish) Topic() []byte { + return msg.topic +} + +func (msg *Publish) SetTopic(b []byte) { + msg.dirty = true + msg.topic = b +} + +func (msg *Publish) PacketId() uint16 { + return msg.packetId +} + +func (msg *Publish) SetPacketId(p uint16) { + msg.dirty = true + msg.packetId = p +} + +func (msg *Publish) Payload() []byte { + return msg.payload +} + +func (msg *Publish) SetPayload(p []byte) { + msg.dirty = true + msg.payload = p +} + +func (msg *Publish) Decode(buf []byte) (int, error) { + msg.dirty = false + + total := len(buf) + offset := 0 + + //Header + msg.header = buf[0] + offset++ + + //Remain Length + if l, n, err := ReadRemainLength(buf[offset:]); err != nil { + return offset, err + } else { + msg.remainLength = l + offset += n + } + headerLen := offset + + if total < msg.remainLength+headerLen { + fmt.Errorf("Payload is not enough expect %d, got %d", msg.remainLength+headerLen, total) + } + + //1 Topic + if b, n, err := ReadBytes(buf[offset:]); err != nil { + return offset, err + } else { + msg.topic = b + offset += n + } + + //2 PacketId //Only Qos1 Qos2 has packet id + if msg.Qos() > Qos0 { + msg.packetId = binary.BigEndian.Uint16(buf[offset:]) + offset += 2 + } + + // FixHead & VarHead + msg.head = buf[0:offset] + //plo := offset + + //3 Payload + l := msg.remainLength + headerLen + b := buf[offset:l] + msg.payload = b + + offset += len(b) + + return offset, nil +} + +func (msg *Publish) Encode() ([]byte, []byte, error) { + if !msg.dirty { + return msg.head, msg.payload, nil + } + + //Remain Length + msg.remainLength = 0 + //Topic + msg.remainLength += 2 + len(msg.topic) + //PacketId + if msg.Qos() > Qos0 { + msg.remainLength += 2 + } + + //FixHead & VarHead + hl := msg.remainLength + + //Payload + msg.remainLength += len(msg.payload) + + //pl := msg.remainLength - hl + hl += 1 + LenLen(msg.remainLength) + + //Alloc buffer + msg.head = make([]byte, hl) + //msg.payload = make([]byte, pl) + + //Header + ho := 0 + msg.head[ho] = msg.header + ho++ + + //Remain Length + if n, err := WriteRemainLength(msg.head[ho:], msg.remainLength); err != nil { + return nil, nil, err + } else { + ho += n + } + + //1 Topic + if n, err := WriteBytes(msg.head[ho:], msg.topic); err != nil { + return nil, nil, err + } else { + ho += n + } + + //2 PacketId + if msg.Qos() > Qos0 { + binary.BigEndian.PutUint16(msg.head[ho:], msg.packetId) + ho += 2 + } + + //3 Payload + //msg.payload = payload + + return msg.head, msg.payload, nil +} diff --git a/packet/pubrec.go b/packet/pubrec.go new file mode 100644 index 0000000..ae572d2 --- /dev/null +++ b/packet/pubrec.go @@ -0,0 +1,5 @@ +package packet + +type PubRec struct { + PubAck +} diff --git a/packet/pubrel.go b/packet/pubrel.go new file mode 100644 index 0000000..98702e4 --- /dev/null +++ b/packet/pubrel.go @@ -0,0 +1,5 @@ +package packet + +type PubRel struct { + PubAck +} diff --git a/packet/suback.go b/packet/suback.go new file mode 100644 index 0000000..d950c19 --- /dev/null +++ b/packet/suback.go @@ -0,0 +1,127 @@ +package packet + +import ( + "encoding/binary" +) + +type SubCode byte + +const ( + SUB_CODE_QOS0 SubCode = iota + SUB_CODE_QOS1 + SUB_CODE_QOS2 + SUB_CODE_ERR = 128 +) + +type SubAck struct { + Header + + packetId uint16 + + codes []byte +} + +func (msg *SubAck) PacketId() uint16 { + return msg.packetId +} + +func (msg *SubAck) SetPacketId(p uint16) { + msg.dirty = true + msg.packetId = p +} + +func (msg *SubAck) Codes() []byte { + return msg.codes +} + +func (msg *SubAck) AddCode(c SubCode) { + msg.dirty = true + msg.codes = append(msg.codes, byte(c)) +} + +func (msg *SubAck) ClearCode() { + msg.dirty = true + msg.codes = msg.codes[0:0] +} + +func (msg *SubAck) Decode(buf []byte) (int, error) { + msg.dirty = false + + //total := len(buf) + offset := 0 + + //Header + msg.header = buf[0] + offset++ + + //Remain Length + if l, n, err := ReadRemainLength(buf[offset:]); err != nil { + return offset, err + } else { + msg.remainLength = l + offset += n + } + headerLen := offset + + // PacketId + msg.packetId = binary.BigEndian.Uint16(buf[offset:]) + offset += 2 + + // FixHead & VarHead + msg.head = buf[0:offset] + //plo := offset + + // Parse Codes + l := msg.remainLength + headerLen + msg.codes = buf[offset:l] + + //Payload + msg.payload = msg.codes + + return offset, nil +} + +func (msg *SubAck) Encode() ([]byte, []byte, error) { + if !msg.dirty { + return msg.head, msg.payload, nil + } + + //Remain Length + msg.remainLength = 0 + //Packet Id + msg.remainLength += 2 + + //FixHead & VarHead + hl := msg.remainLength + + //Codes + msg.remainLength += len(msg.codes) + + //pl := msg.remainLength - hl + hl += 1 + LenLen(msg.remainLength) + + //Alloc buffer + msg.head = make([]byte, hl) + //msg.payload = make([]byte, pl) + + //Header + ho := 0 + msg.head[ho] = msg.header + ho++ + + //Remain Length + if n, err := WriteRemainLength(msg.head[ho:], msg.remainLength); err != nil { + return nil, nil, err + } else { + ho += n + } + + //Packet Id + binary.BigEndian.PutUint16(msg.head[ho:], msg.packetId) + ho += 2 + + //Codes + msg.payload = msg.codes + + return msg.head, msg.payload, nil +} diff --git a/packet/subscribe.go b/packet/subscribe.go new file mode 100644 index 0000000..e886411 --- /dev/null +++ b/packet/subscribe.go @@ -0,0 +1,174 @@ +package packet + +import ( + "encoding/binary" + "fmt" +) + +type SubTopic struct { + topic []byte + flag byte +} + +func (msg *SubTopic) Topic() []byte { + return msg.topic +} + +func (msg *SubTopic) SetTopic(b []byte) { + msg.topic = b +} + +func (msg *SubTopic) Qos() MsgQos { + return MsgQos(msg.flag & 0x03) //0000 0011 +} + +func (msg *SubTopic) SetQos(qos MsgQos) { + msg.flag &= 0xFC + msg.flag |= byte(qos) +} + +type Subscribe struct { + Header + + packetId uint16 + + topics []*SubTopic +} + +func (msg *Subscribe) PacketId() uint16 { + return msg.packetId +} + +func (msg *Subscribe) SetPacketId(p uint16) { + msg.dirty = true + msg.packetId = p +} + +func (msg *Subscribe) Topics() []*SubTopic { + return msg.topics +} + +func (msg *Subscribe) AddTopic(topic []byte, qos MsgQos) { + msg.dirty = true + st := &SubTopic{} + st.SetTopic(topic) + st.SetQos(qos) + msg.topics = append(msg.topics, st) +} + +func (msg *Subscribe) ClearTopic() { + msg.dirty = true + msg.topics = msg.topics[0:0] +} + +func (msg *Subscribe) Decode(buf []byte) (int, error) { + msg.dirty = false + + //total := len(buf) + offset := 0 + + //Header + msg.header = buf[0] + offset++ + + //Remain Length + if l, n, err := ReadRemainLength(buf[offset:]); err != nil { + return offset, err + } else { + msg.remainLength = l + offset += n + } + headerLen := offset + + // PacketId + msg.packetId = binary.BigEndian.Uint16(buf[offset:]) + offset += 2 + + // FixHead & VarHead + msg.head = buf[0:offset] + plo := offset + + // Parse Topics + for offset-headerLen < msg.remainLength { + st := &SubTopic{} + //Topic + if b, n, err := ReadBytes(buf[offset:]); err != nil { + return offset, err + } else { + st.SetTopic(b) + offset += n + } + //Qos + qos := buf[offset] + if (qos & 0x03) != qos { + return offset, fmt.Errorf("Topic Qos %x", qos) + } + st.SetQos(MsgQos(qos)) + offset++ + msg.topics = append(msg.topics, st) + } + + //Payload + msg.payload = buf[plo:offset] + + return offset, nil +} + +func (msg *Subscribe) Encode() ([]byte, []byte, error) { + if !msg.dirty { + return msg.head, msg.payload, nil + } + + //Remain Length + msg.remainLength = 0 + //Packet Id + msg.remainLength += 2 + + //FixHead & VarHead + hl := msg.remainLength + + //Topics + for _, t := range msg.topics { + msg.remainLength += 2 + len(t.Topic()) + msg.remainLength += 1 + } + + pl := msg.remainLength - hl + hl += 1 + LenLen(msg.remainLength) + + //Alloc buffer + msg.head = make([]byte, hl) + msg.payload = make([]byte, pl) + + //Header + ho := 0 + msg.head[ho] = msg.header + ho++ + + //Remain Length + if n, err := WriteRemainLength(msg.head[ho:], msg.remainLength); err != nil { + return nil, nil, err + } else { + ho += n + } + + //Packet Id + binary.BigEndian.PutUint16(msg.head[ho:], msg.packetId) + ho += 2 + + plo := 0 + //Topics + for _, t := range msg.topics { + //Topic + if n, err := WriteBytes(msg.payload[plo:], t.topic); err != nil { + return msg.head, nil, err + } else { + plo += n + } + //Qos + msg.payload[plo] = t.flag + plo++ + } + + return msg.head, msg.payload, nil +} diff --git a/packet/unsuback.go b/packet/unsuback.go new file mode 100644 index 0000000..195d962 --- /dev/null +++ b/packet/unsuback.go @@ -0,0 +1,5 @@ +package packet + +type UnSubAck struct { + PubAck +} diff --git a/packet/unsubscribe.go b/packet/unsubscribe.go new file mode 100644 index 0000000..116e0d1 --- /dev/null +++ b/packet/unsubscribe.go @@ -0,0 +1,135 @@ +package packet + +import ( + "encoding/binary" +) + +type UnSubscribe struct { + Header + + packetId uint16 + + topics [][]byte +} + +func (msg *UnSubscribe) PacketId() uint16 { + return msg.packetId +} + +func (msg *UnSubscribe) SetPacketId(p uint16) { + msg.dirty = true + msg.packetId = p +} + +func (msg *UnSubscribe) Topics() [][]byte { + return msg.topics +} + +func (msg *UnSubscribe) AddTopic(topic []byte) { + msg.dirty = true + msg.topics = append(msg.topics, topic) +} + +func (msg *UnSubscribe) ClearTopic() { + msg.dirty = true + msg.topics = msg.topics[0:0] +} + +func (msg *UnSubscribe) Decode(buf []byte) (int, error) { + msg.dirty = false + + //total := len(buf) + offset := 0 + + //Header + msg.header = buf[0] + offset++ + + //Remain Length + if l, n, err := ReadRemainLength(buf[offset:]); err != nil { + return offset, err + } else { + msg.remainLength = l + offset += n + } + headerLen := offset + + // PacketId + msg.packetId = binary.BigEndian.Uint16(buf[offset:]) + offset += 2 + + // FixHead & VarHead + msg.head = buf[0:offset] + plo := offset + + // Parse Topics + for offset-headerLen < msg.remainLength { + //Topic + if b, n, err := ReadBytes(buf[offset:]); err != nil { + return offset, err + } else { + msg.AddTopic(b) + offset += n + } + } + + //Payload + msg.payload = buf[plo:offset] + + return offset, nil +} + +func (msg *UnSubscribe) Encode() ([]byte, []byte, error) { + if !msg.dirty { + return msg.head, msg.payload, nil + } + + //Remain Length + msg.remainLength = 0 + //Packet Id + msg.remainLength += 2 + + //FixHead & VarHead + hl := msg.remainLength + + //Topics + for _, t := range msg.topics { + msg.remainLength += 2 + len(t) + } + + pl := msg.remainLength - hl + hl += 1 + LenLen(msg.remainLength) + + //Alloc buffer + msg.head = make([]byte, hl) + msg.payload = make([]byte, pl) + + //Header + ho := 0 + msg.head[ho] = msg.header + ho++ + + //Remain Length + if n, err := WriteRemainLength(msg.head[ho:], msg.remainLength); err != nil { + return nil, nil, err + } else { + ho += n + } + + //Packet Id + binary.BigEndian.PutUint16(msg.head[ho:], msg.packetId) + ho += 2 + + plo := 0 + //Topics + for _, t := range msg.topics { + //Topic + if n, err := WriteBytes(msg.payload[plo:], t); err != nil { + return msg.head, nil, err + } else { + plo += n + } + } + + return msg.head, msg.payload, nil +} diff --git a/retaintree.go b/retaintree.go new file mode 100644 index 0000000..d71b117 --- /dev/null +++ b/retaintree.go @@ -0,0 +1,116 @@ +package beeq + +import ( + "github.com/zgwit/beeq/packet" + "strings" +) + +type RetainNode struct { + + //Subscribed retains + //clientId + retains map[string]*packet.Publish + + //Sub level + //topic->children + children map[string]*RetainNode +} + +func NewRetainNode() *RetainNode { + return &RetainNode{ + retains: make(map[string]*packet.Publish), + children: make(map[string]*RetainNode), + } +} + +func (rn *RetainNode) Fetch(topics []string, cb func(clientId string, pub *packet.Publish)) { + if len(topics) == 0 { + // Publish all matched retains + for clientId, pub := range rn.retains { + cb(clientId, pub) + } + } else { + name := topics[0] + + if name == "#" { + //All retains + for clientId, pub := range rn.retains { + cb(clientId, pub) + } + //And all children + for _, sub := range rn.children { + sub.Fetch(topics, cb) + } + } else if name == "+" { + //Children + for _, sub := range rn.children { + sub.Fetch(topics[1:], cb) + } + } else { + // Sub-Level + if sub, ok := rn.children[name]; ok { + sub.Fetch(topics[1:], cb) + } + } + } +} + +func (rn *RetainNode) Retain(topics []string, clientId string, pub *packet.Publish) *RetainNode { + if len(topics) == 0 { + // Publish to specific client + rn.retains[clientId] = pub + return rn + } else { + name := topics[0] + + // Sub-Level + if _, ok := rn.children[name]; !ok { + rn.children[name] = NewRetainNode() + } + return rn.children[name].Retain(topics[1:], clientId, pub) + } +} + +type RetainTree struct { + //root + root *RetainNode + + //tree index + //ClientId -> Node (hold Publish message) + retains map[string]*RetainNode +} + +func NewRetainTree() *RetainTree { + return &RetainTree{ + root: NewRetainNode(), + } +} + +func (rt *RetainTree) Fetch(topic []byte, cb func(clientId string, pub *packet.Publish)) { + topics := strings.Split(string(topic), "/") + if topics[0] == "" { + topics[0] = "/" + } + rt.root.Fetch(topics, cb) +} + +func (rt *RetainTree) Retain(topic []byte, clientId string, pub *packet.Publish) { + // Remove last retain publish, firstly + rt.UnRetain(clientId) + + topics := strings.Split(string(topic), "/") + if topics[0] == "" { + topics[0] = "/" + } + node := rt.root.Retain(topics, clientId, pub) + + //indexed node + rt.retains[clientId] = node +} + +func (rt *RetainTree) UnRetain(clientId string) { + if node, ok := rt.retains[clientId]; ok { + delete(node.retains, clientId) + delete(rt.retains, clientId) + } +} diff --git a/session.go b/session.go new file mode 100644 index 0000000..057cf9c --- /dev/null +++ b/session.go @@ -0,0 +1,63 @@ +package beeq + +import ( + "github.com/zgwit/beeq/packet" + "time" +) + +type Session struct { + //Bee + bee *Bee + + //Client ID (from CONNECT) + clientId string + + //Keep Alive (from CONNECT) + keepAlive int + + //will topic (from CONNECT) + will *packet.Publish + + //Qos1 Qos2 + pub1 map[uint16]*packet.Publish + pub2 map[uint16]*packet.Publish + + //Received Qos2 Publish + recvPub2 map[uint16]*packet.Publish + + //Increment 0~65535 + packetId uint16 + + alive bool + + create_time time.Time + active_time time.Time + lost_time time.Time +} + +func NewSession() *Session { + return &Session{ + pub1: make(map[uint16]*packet.Publish), + pub2: make(map[uint16]*packet.Publish), + recvPub2: make(map[uint16]*packet.Publish), + alive: true, + create_time: time.Now(), + active_time: time.Now(), + } +} + +func (session *Session) Alive() bool { + return session.alive +} + +func (session *Session) Active(bee *Bee) { + session.alive = true + session.bee = bee + session.active_time = time.Now() +} + +func (session *Session) DeActive() { + session.alive = false + session.bee = nil + session.lost_time = time.Now() +} diff --git a/socket.go b/socket.go new file mode 100644 index 0000000..726b400 --- /dev/null +++ b/socket.go @@ -0,0 +1,65 @@ +package beeq + +import ( + "crypto/tls" + "golang.org/x/net/websocket" + "net" + "net/http" + "time" +) + + +type tcpKeepAliveListener struct { + *net.TCPListener +} + +func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() + if err != nil { + return + } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(3 * time.Minute) + return tc, nil +} + + +func AcceptSocket(ln net.Listener, cb func(conn net.Conn)) { + go (func() { + //defer ln.Close() + for { + if conn, err := ln.Accept(); err != nil { + //TODO log + break + } else { + cb(conn) + } + } + })() +} + +func AcceptWebSocket(ln net.Listener, pattern string, cb func(conn net.Conn)) { + h := func(ws *websocket.Conn) { + conn := NewWSConn(ws) + cb(conn) + //block here + conn.Wait() + } + mux := http.NewServeMux() + mux.Handle(pattern, websocket.Handler(h)) + svr := http.Server{Handler:mux} + go svr.Serve(tcpKeepAliveListener{ln.(*net.TCPListener)}) +} + +func ListenSocket(laddr string) (net.Listener, error) { + return net.Listen("tcp", laddr) +} + +func ListenSocketTLS(laddr string, cert string, key string) (net.Listener, error) { + if ce, err := tls.LoadX509KeyPair(cert, key); err != nil { + return nil, err + } else { + config := &tls.Config{Certificates: []tls.Certificate{ce}} + return tls.Listen("tcp", laddr, config) + } +} diff --git a/subtree.go b/subtree.go new file mode 100644 index 0000000..fa637cd --- /dev/null +++ b/subtree.go @@ -0,0 +1,163 @@ +package beeq + +import ( + "github.com/zgwit/beeq/packet" + "strings" +) + +type SubNode struct { + //Subscribed clients + //clientId + clients map[string]packet.MsgQos + + //Sub level + //topic->children + children map[string]*SubNode + + //Multi Wildcard # + mw *SubNode + + //Single Wildcard + + sw *SubNode +} + +func NewSubNode() *SubNode { + return &SubNode{ + clients: make(map[string]packet.MsgQos), + children: make(map[string]*SubNode), + //mw: NewSubNode(), + //sw: NewSubNode(), + } +} + +func (sn *SubNode) Publish(topics []string, subs map[string]packet.MsgQos) { + if len(topics) == 0 { + // Publish all matched clients + for clientId, qos := range sn.clients { + if sub, ok := subs[clientId]; ok { + //rewrite by larger Qos + if sub < qos { + subs[clientId] = qos + } + } else { + subs[clientId] = qos + } + } + } else { + name := topics[0] + // Sub-Level + if sub, ok := sn.children[name]; ok { + sub.Publish(topics[1:], subs) + } + // Multi wildcard + if sn.mw != nil { + sn.mw.Publish(topics[1:1], subs) + } + // Single wildcard + if sn.sw != nil { + sn.sw.Publish(topics[1:], subs) + } + } +} + +func (sn *SubNode) Subscribe(topics []string, clientId string, qos packet.MsgQos) { + if len(topics) == 0 { + sn.clients[clientId] = qos + return + } + + name := topics[0] + if name == "#" { + if sn.mw == nil { + sn.mw = NewSubNode() + } + sn.mw.Subscribe(topics[1:1], clientId, qos) + } else if name == "+" { + if sn.sw == nil { + sn.sw = NewSubNode() + } + sn.sw.Subscribe(topics[1:], clientId, qos) + } else { + if _, ok := sn.children[name]; !ok { + sn.children[name] = NewSubNode() + } + sn.children[name].Subscribe(topics[1:], clientId, qos) + } +} + +func (sn *SubNode) UnSubscribe(topics []string, clientId string) { + if len(topics) == 0 { + delete(sn.clients, clientId) + } else { + name := topics[0] + if name == "#" { + if sn.mw != nil { + sn.mw.UnSubscribe(topics[1:1], clientId) + } + } else if name == "+" { + if sn.sw != nil { + sn.sw.UnSubscribe(topics[1:], clientId) + } + } else { + if sub, ok := sn.children[name]; ok { + sub.UnSubscribe(topics[1:], clientId) + } + } + } +} + +func (sn *SubNode) ClearClient(clientId string) { + if _, ok := sn.clients[clientId]; ok { + delete(sn.clients, clientId) + } + + if sn.mw != nil { + sn.mw.ClearClient(clientId) + } + if sn.sw != nil { + sn.sw.ClearClient(clientId) + } + + for _, sub := range sn.children { + sub.ClearClient(clientId) + } +} + +type SubTree struct { + //tree root + root *SubNode +} + +func NewSubTree() *SubTree { + return &SubTree{ + root: NewSubNode(), + } +} + +func (st *SubTree) Publish(topic []byte, subs map[string]packet.MsgQos) { + topics := strings.Split(string(topic), "/") + if topics[0] == "" { + topics[0] = "/" + } + st.root.Publish(topics, subs) +} + +func (st *SubTree) Subscribe(topic []byte, clientId string, qos packet.MsgQos) { + topics := strings.Split(string(topic), "/") + if topics[0] == "" { + topics[0] = "/" + } + st.root.Subscribe(topics, clientId, qos) +} + +func (st *SubTree) UnSubscribe(topic []byte, clientId string) { + topics := strings.Split(string(topic), "/") + if topics[0] == "" { + topics[0] = "/" + } + st.root.UnSubscribe(topics, clientId) +} + +func (st *SubTree) ClearClient(clientId string) { + st.root.ClearClient(clientId) +} diff --git a/topic.go b/topic.go new file mode 100644 index 0000000..a92bdd4 --- /dev/null +++ b/topic.go @@ -0,0 +1,70 @@ +package beeq + +import ( + "bytes" + "errors" + "fmt" +) + +func ValidTopic(topic []byte) error { + //no + # + if bytes.ContainsAny(topic, "+#") { + return fmt.Errorf("+ # is not valid (%s)", string(topic)) + } + //no // + if bytes.Contains(topic, []byte("//")) { + return fmt.Errorf("// is not valid (%s)", string(topic)) + } + return nil +} + +func ValidSubscribe(topic []byte) error { + if len(topic) == 0 { + return errors.New("Blank topic") + } + topics := bytes.Split(topic, []byte("/")) + if len(topics[0]) == 0 { + topics[0] = []byte("/") + } + for i, tt := range topics { + l := len(tt) + if l == 0 { + return errors.New("inner blank") + } + if l == 1 && tt[0] == '#' && i < len(topics)-1 { + return errors.New("# must be the last one") + } + if l > 1 && bytes.ContainsAny(tt, "+#") { + return errors.New("+ # is alone") + } + } + return nil +} + +func MatchSubscribe(topic []byte, sub []byte) bool { + i, j := 0, 0 + for i < len(topic) && j < len(sub) { + t, s := topic[i], sub[j] + if s == '#' { + return true + } else if s == '+' { + for t != '/' { + i++ + t = topic[i] + } + j++ // skip / + } else if s != t { + break + } + // else s==t + i++ + j++ + } + + //Just match + if i == len(topic) && j == len(sub) { + return true + } + + return false +} diff --git a/wsconn.go b/wsconn.go new file mode 100644 index 0000000..66749ef --- /dev/null +++ b/wsconn.go @@ -0,0 +1,76 @@ +package beeq + +import ( + "golang.org/x/net/websocket" + "net" + "time" +) + +//inherit net.Conn + +type WSConn struct { + ws *websocket.Conn + buf []byte + quit chan struct{} +} + +func NewWSConn(ws *websocket.Conn) *WSConn { + return &WSConn{ + ws: ws, + quit: make(chan struct{}), + } +} + +func (conn *WSConn) Read(b []byte) (n int, err error) { + if conn.buf == nil { + if err := websocket.Message.Receive(conn.ws, &conn.buf); err != nil { + return 0, err + } + } + //if rn.buf + l := copy(b, conn.buf) + if l < len(conn.buf) { + conn.buf = conn.buf[l:] + } else { + // Receive next packet + conn.buf = nil + } + return l, nil +} + +func (conn *WSConn) Write(b []byte) (n int, err error) { + if err := websocket.Message.Send(conn.ws, b); err != nil { + return 0, err + } + return len(b), nil +} + +func (conn *WSConn) Close() error { + close(conn.quit) + return conn.ws.Close() +} + +func (conn *WSConn) LocalAddr() net.Addr { + return conn.ws.LocalAddr() +} + +func (conn *WSConn) RemoteAddr() net.Addr { + return conn.ws.RemoteAddr() +} + +func (conn *WSConn) SetDeadline(t time.Time) error { + return conn.ws.SetDeadline(t) +} + +func (conn *WSConn) SetReadDeadline(t time.Time) error { + return conn.ws.SetReadDeadline(t) +} + +func (conn *WSConn) SetWriteDeadline(t time.Time) error { + return conn.ws.SetWriteDeadline(t) +} + +func (conn *WSConn) Wait() { + <- conn.quit +} +