Files
beeq/hive.go

405 lines
8.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package beeq
import (
"encoding/binary"
uuid "github.com/google/uuid"
"git.zgwit.com/iot/beeq/packet"
"log"
"net"
"sync"
"time"
)
func reAlloc(buf []byte, l int) []byte {
b := make([]byte, l)
copy(b, buf)
return b
}
type Hive struct {
//Subscribe tree
subTree SubTree
//Retain tree
retainTree RetainTree
//ClientId->Bee
bees sync.Map // map[string]*Bee
onConnect func(*packet.Connect, *Bee) bool
onPublish func(*packet.Publish, *Bee) bool
onSubscribe func(*packet.Subscribe, *Bee) bool
onUnSubscribe func(*packet.UnSubscribe, *Bee)
onDisconnect func(*packet.DisConnect, *Bee)
}
//TODO 添加参数
func NewHive() *Hive {
return &Hive{}
}
func (h *Hive) ListenAndServe(addr string) error {
ln, err := net.Listen("tcp", addr)
if err != nil {
return err
}
go h.Serve(ln)
return nil
}
func (h *Hive) Serve(ln net.Listener) {
for {
conn, err := ln.Accept()
if err != nil {
log.Println(err)
break
}
//process
go h.Receive(conn)
}
}
func (h *Hive) Receive(conn net.Conn) {
//TODO 先解析第一个包而且必须是Connect
bee := NewBee(conn)
var parser packet.Parser
buf := make([]byte, 1024)
for {
n, err := conn.Read(buf)
if err != nil {
log.Println(err)
break
}
ms := parser.Parse(buf[:n])
//处理消息
//TODO 可以放入队列
for _, msg := range ms {
h.handle(msg, bee)
}
}
_ = bee.Close()
}
func (h *Hive) Receive2(conn net.Conn) {
//TODO 先解析第一个包而且必须是Connect
bee := NewBee(conn)
bufSize := 6
buf := make([]byte, bufSize)
of := 0
for {
//TODO 先解析
n, err := conn.Read(buf[of:])
if err != nil {
log.Println(err)
break
}
ln := of + n
if ln < 2 {
of = ln
continue
}
//解析包头,包体
//读取Remain Length
rl, rll := binary.Uvarint(buf[1:])
remainLength := int(rl)
packLen := remainLength + rll + 1
//读取未读完的包体
if packLen > bufSize {
buf = reAlloc(buf, packLen)
//直至将全部包体读完
o := ln
for o < packLen {
n, err = conn.Read(buf[o:])
if err != nil {
log.Println(err)
//return
break
}
o += n
}
//一般不会发生
if o < packLen {
break
}
}
//解析消息
msg, err := packet.Decode(buf[:packLen])
if err != nil {
log.Println(err)
break
}
//处理消息
//TODO 可以放入队列
h.handle(msg, bee)
//解析 剩余内容
if packLen < bufSize {
buf = reAlloc(buf[packLen:], bufSize-packLen)
of = bufSize - packLen
//TODO 剩余内容可能已经包含了消息不用再Read直接解析
} else {
buf = make([]byte, bufSize)
of = 0
}
}
_ = bee.Close()
}
func (h *Hive) handle(msg packet.Message, bee *Bee) {
switch msg.Type() {
case packet.CONNECT:
h.handleConnect(msg.(*packet.Connect), bee)
case packet.PUBLISH:
h.handlePublish(msg.(*packet.Publish), bee)
case packet.PUBACK:
bee.pub1.Delete(msg.(*packet.PubAck).PacketId())
case packet.PUBREC:
msg.SetType(packet.PUBREL)
bee.dispatch(msg)
case packet.PUBREL:
msg.SetType(packet.PUBCOMP)
bee.dispatch(msg)
case packet.PUBCOMP:
bee.pub2.Delete(msg.(*packet.PubComp).PacketId())
case packet.SUBSCRIBE:
h.handleSubscribe(msg.(*packet.Subscribe), bee)
case packet.UNSUBSCRIBE:
h.handleUnSubscribe(msg.(*packet.UnSubscribe), bee)
case packet.PINGREQ:
msg.SetType(packet.PINGRESP)
bee.dispatch(msg)
case packet.DISCONNECT:
h.handleDisconnect(msg.(*packet.DisConnect), bee)
}
}
func (h *Hive) handleConnect(msg *packet.Connect, bee *Bee) {
ack := packet.CONNACK.NewMessage().(*packet.Connack)
//验证用户名密码
if h.onConnect != nil {
if !h.onConnect(msg, bee) {
ack.SetCode(packet.CONNACK_INVALID_USERNAME_PASSWORD)
bee.dispatch(ack)
// 断开
_ = bee.Close()
return
}
}
var clientId string
if len(msg.ClientId()) == 0 {
if !msg.CleanSession() {
//TODO 无ID必须是清空会话 error
//return
_ = bee.Close()
return
}
// Generate unique clientId (uuid random)
clientId = uuid.New().String()
//UUID不用验重了
//for { if _, ok := h.bees.Load(clientId); !ok { break } }
} else {
clientId = string(msg.ClientId())
if v, ok := h.bees.Load(clientId); ok {
b := v.(*Bee)
// ClientId is already used
if !b.closed {
//error reject
ack.SetCode(packet.CONNACK_UNAVAILABLE)
bee.dispatch(ack)
_ = bee.Close()
return
} else {
if !msg.CleanSession() {
//TODO 复制内容
bee.keepAlive = b.keepAlive
bee.will = b.will
bee.pub1 = b.pub1 //sync.Map不能直接复制。。。。
bee.pub2 = b.pub2
bee.recvPub2 = b.recvPub2
bee.packetId = b.packetId
//ack.SetSessionPresent(true)
}
}
}
h.bees.Store(clientId, bee)
}
bee.clientId = clientId
if msg.KeepAlive() > 0 {
bee.timeout = time.Second * time.Duration(msg.KeepAlive()) * 3 / 2
}
//TODO 如果发生错误,与客户端断开连接
ack.SetCode(packet.CONNACK_ACCEPTED)
bee.dispatch(ack)
}
func (h *Hive) handlePublish(msg *packet.Publish, bee *Bee) {
//外部验证
if h.onPublish != nil {
if !h.onPublish(msg, bee) {
return
}
}
qos := msg.Qos()
if qos == packet.Qos0 {
//不需要回复puback
} else if qos == packet.Qos1 {
//Reply PUBACK
ack := packet.PUBACK.NewMessage().(*packet.PubAck)
ack.SetPacketId(msg.PacketId())
bee.dispatch(ack)
} else if qos == packet.Qos2 {
//Save & Send PUBREC
bee.recvPub2.Store(msg.PacketId(), msg)
ack := packet.PUBREC.NewMessage().(*packet.PubRec)
ack.SetPacketId(msg.PacketId())
bee.dispatch(ack)
} else {
//TODO error
}
if err := ValidTopic(msg.Topic()); err != nil {
//TODO log
log.Println("Topic invalid ", err)
return
}
if msg.Retain() {
if len(msg.Payload()) == 0 {
h.retainTree.UnRetain(bee.clientId)
} else {
h.retainTree.Retain(msg.Topic(), bee.clientId, msg)
}
}
//Fetch subscribers
subs := make(map[string]packet.MsgQos)
h.subTree.Publish(msg.Topic(), subs)
//Send publish message
for clientId, qos := range subs {
if b, ok := h.bees.Load(clientId); ok {
bb := b.(*Bee)
if bb.closed {
continue
}
//clone new pub
pub := *msg
pub.SetRetain(false)
if msg.Qos() > qos {
pub.SetQos(qos)
}
bb.dispatch(&pub)
}
}
}
func (h *Hive) handleSubscribe(msg *packet.Subscribe, bee *Bee) {
ack := packet.SUBACK.NewMessage().(*packet.SubAck)
ack.SetPacketId(msg.PacketId())
//外部验证
if h.onSubscribe != nil {
if !h.onSubscribe(msg, bee) {
//回复失败
ack.AddCode(packet.SUB_CODE_ERR)
bee.dispatch(ack)
return
}
}
for _, st := range msg.Topics() {
//log.Print("Subscribe ", string(st.Topic()))
if err := ValidSubscribe(st.Topic()); err != nil {
log.Println("Invalid topic ", err)
//log error
ack.AddCode(packet.SUB_CODE_ERR)
} else {
h.subTree.Subscribe(st.Topic(), bee.clientId, st.Qos())
ack.AddCode(packet.SubCode(st.Qos()))
h.retainTree.Fetch(st.Topic(), func(clientId string, pub *packet.Publish) {
//clone new pub
p := *pub
p.SetRetain(true)
if msg.Qos() > st.Qos() {
p.SetQos(st.Qos())
}
bee.dispatch(&p)
})
}
}
bee.dispatch(ack)
}
func (h *Hive) handleUnSubscribe(msg *packet.UnSubscribe, bee *Bee) {
//外部验证
if h.onUnSubscribe != nil {
h.onUnSubscribe(msg, bee)
}
ack := packet.UNSUBACK.NewMessage().(*packet.UnSubAck)
for _, t := range msg.Topics() {
//log.Print("UnSubscribe ", string(t))
if err := ValidSubscribe(t); err != nil {
//TODO log
log.Println(err)
} else {
h.subTree.UnSubscribe(t, bee.clientId)
}
}
bee.dispatch(ack)
}
func (h *Hive) handleDisconnect(msg *packet.DisConnect, bee *Bee) {
if h.onDisconnect != nil {
h.onDisconnect(msg, bee)
}
h.bees.Delete(bee.clientId)
_ = bee.Close()
}
func (h *Hive) OnConnect(fn func(*packet.Connect, *Bee) bool) {
h.onConnect = fn
}
func (h *Hive) OnPublish(fn func(*packet.Publish, *Bee) bool) {
h.onPublish = fn
}
func (h *Hive) OnSubscribe(fn func(*packet.Subscribe, *Bee) bool) {
h.onSubscribe = fn
}
func (h *Hive) OnUnSubscribe(fn func(*packet.UnSubscribe, *Bee)) {
h.onUnSubscribe = fn
}
func (h *Hive) OnDisconnect(fn func(*packet.DisConnect, *Bee)) {
h.onDisconnect = fn
}