diff --git a/engine/engine.go b/engine/engine.go index a0d439f..49b4d8f 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -2,10 +2,10 @@ package engine import ( "context" - "github.com/YueLWish/mqtt-bridge/pkg/xmqtt" mqtt "github.com/eclipse/paho.mqtt.golang" "github.com/panjf2000/ants" "github.com/pkg/errors" + "github.com/yuelwish/mqtt-bridge/pkg/xmqtt" "log" ) @@ -82,9 +82,9 @@ func (e *Engine) handlerMessage(ctx context.Context) { if err = gPool.Submit(func() { err := xmqtt.Send(client, msg.Topic, msg.Payload) if err != nil { - log.Printf("[send message] -- %s ==> %v t:%s failed: %v", msg.FromTag, tTag, msg.Topic, err) + log.Printf("[send message] %s ==> %v t:%s failed: %v", msg.FromTag, tTag, msg.Topic, err) } else { - log.Printf("[send message] -- %s ==> %v t:%s p:%s", msg.FromTag, tTag, msg.Topic, msg.Payload) + log.Printf("[send message] %s ==> %v t:%s p:%s", msg.FromTag, tTag, msg.Topic, msg.Payload) } }); err != nil { log.Printf("[submit message] failed: %v", err) diff --git a/main.go b/main.go index 176e837..7ce025a 100644 --- a/main.go +++ b/main.go @@ -3,9 +3,9 @@ package main import ( "context" "flag" - "github.com/YueLWish/mqtt-bridge/engine" - "github.com/YueLWish/mqtt-bridge/pkg/setting" "github.com/pkg/errors" + "github.com/yuelwish/mqtt-bridge/engine" + "github.com/yuelwish/mqtt-bridge/pkg/setting" "log" "os" "os/signal" diff --git a/pkg/xmqtt/client.go b/pkg/xmqtt/client.go new file mode 100644 index 0000000..7f87210 --- /dev/null +++ b/pkg/xmqtt/client.go @@ -0,0 +1,124 @@ +package xmqtt + +import ( + mqtt "github.com/eclipse/paho.mqtt.golang" + "sync" + "sync/atomic" +) + +type subscribed struct { + topic string + qos byte + callback mqtt.MessageHandler +} + +type Client struct { + mqtt.Client + subMap sync.Map + status int32 // 1 连接成功 2 断开连接 +} + +const ( + stConned = iota + 1 + stLost +) + +func NewClient(o *mqtt.ClientOptions) mqtt.Client { + var ( + onLostFn mqtt.ConnectionLostHandler + onConnFn mqtt.OnConnectHandler + ) + if o.OnConnectionLost != nil { + onLostFn = o.OnConnectionLost + } + if o.OnConnect != nil { + onConnFn = o.OnConnect + } + + mClient := &Client{} + o.SetConnectionLostHandler(func(client mqtt.Client, err error) { + atomic.StoreInt32(&mClient.status, stLost) + if onLostFn != nil { + onLostFn(client, err) + } + }) + + o.SetOnConnectHandler(func(client mqtt.Client) { + defer atomic.StoreInt32(&mClient.status, stConned) + if onConnFn != nil { + onConnFn(client) + } + + if atomic.LoadInt32(&mClient.status) == stLost { + var n int + mClient.subMap.Range(func(_, value interface{}) bool { + sub := value.(*subscribed) + mClient.Client.Subscribe(sub.topic, sub.qos, sub.callback) + n++ + return true + }) + mqtt.DEBUG.Printf("[CONN resubscribe %d topic]", n) + } + }) + mClient.Client = mqtt.NewClient(o) + + return mClient +} + +func (c *Client) SubscribeMultiple(filters map[string]byte, callback mqtt.MessageHandler) mqtt.Token { + token := c.Client.SubscribeMultiple(filters, callback) + + if token.Error() == nil { + for topic, qos := range filters { + sub := subscribed{ + topic: topic, + qos: qos, + callback: callback, + } + c.subMap.Store(sub.topic, &sub) + } + } + return token +} +func (c *Client) Subscribe(topic string, qos byte, callback mqtt.MessageHandler) mqtt.Token { + token := c.Client.Subscribe(topic, qos, callback) + + if token.Error() == nil { + sub := subscribed{ + topic: topic, + qos: qos, + callback: callback, + } + c.subMap.Store(sub.topic, &sub) + } + return token +} + +func (c *Client) Unsubscribe(topics ...string) mqtt.Token { + token := c.Client.Unsubscribe(topics...) + + if token.Error() == nil { + for _, topic := range topics { + c.subMap.Delete(topic) + } + } + return token +} + +func (c *Client) AddRoute(topic string, callback mqtt.MessageHandler) { + c.Client.AddRoute(topic, callback) + + c.subMap.Store(topic, &subscribed{ + topic: topic, + qos: 0, + callback: callback, + }) +} +func (c *Client) Disconnect(quiesce uint) { + c.Client.Disconnect(quiesce) + + c.subMap.Range(func(key, _ interface{}) bool { + c.subMap.Delete(key) + return true + }) +} diff --git a/pkg/xmqtt/mqtt.go b/pkg/xmqtt/mqtt.go index 5d3c336..a8db0a4 100644 --- a/pkg/xmqtt/mqtt.go +++ b/pkg/xmqtt/mqtt.go @@ -17,7 +17,8 @@ func Init(clientIdPrefix, addr string, optFns ...func(opt *mqtt.ClientOptions)) opts := mqtt.NewClientOptions() opts.AddBroker(addr) opts.SetClientID(clientIdPrefix + "-" + strconv.FormatInt(time.Now().UnixNano(), 36)) - opts.SetKeepAlive(time.Second * time.Duration(10)) + opts.SetKeepAlive(10 * time.Second) + opts.SetMaxReconnectInterval(10 * time.Second) opts.SetOnConnectHandler(func(client mqtt.Client) { r := client.OptionsReader() @@ -51,12 +52,12 @@ func Init(clientIdPrefix, addr string, optFns ...func(opt *mqtt.ClientOptions)) fn(opts) } - client := mqtt.NewClient(opts) + client := NewClient(opts) - token := client.Connect() - if token.WaitTimeout(timeoutDuration) && token.Error() != nil { + if token := client.Connect(); token.WaitTimeout(timeoutDuration) && token.Error() != nil { return nil, token.Error() } + return client, nil }