From d5e2312aa71878c29b045d1dfc7edf75ae3d0293 Mon Sep 17 00:00:00 2001 From: Mochi Date: Tue, 8 Oct 2019 22:40:55 +0100 Subject: [PATCH] Cleanup --- mqtt.go | 107 ++++++++++++++++++++++++++++++++++++++------------- mqtt_test.go | 1 + 2 files changed, 82 insertions(+), 26 deletions(-) diff --git a/mqtt.go b/mqtt.go index eef520b..c6dc735 100644 --- a/mqtt.go +++ b/mqtt.go @@ -6,6 +6,7 @@ import ( "errors" "log" "net" + "runtime" "sync" "time" @@ -57,8 +58,8 @@ type Server struct { // inbound is a small worker pool which processes incoming packets. inbound chan transitMsg - // outbount is a small worker pool which processes incoming packets. - outbount chan transitMsg + // outbound is a small worker pool which processes incoming packets. + outbound chan transitMsg //inboundPool is a waitgroup for the inbound workers. inboundPool sync.WaitGroup @@ -81,26 +82,61 @@ func New() *Server { clients: newClients(), topics: trie.New(), inbound: make(chan transitMsg), + outbound: make(chan transitMsg), } } func (s *Server) StartProcessing() { - var workers = 8 + var workers = runtime.NumCPU() s.inboundPool = sync.WaitGroup{} - s.inboundPool.Add(workers) + s.inboundPool.Add(workers) for i := 0; i < workers; i++ { log.Println("spawning worker", i) go func(wid int) { defer s.inboundPool.Done() for { select { - case p, ok := <-s.inbound: + case v, ok := <-s.inbound: if !ok { - log.Println("worker closed", wid) + log.Println("worker inbound closed", wid) return } - s.processPacket(p.client, p.packet) + s.processPacket(v.client, v.packet) + case v, ok := <-s.outbound: + if !ok { + log.Println("worker outbound closed", wid) + return + } + + err := s.writeClient(v.client, v.packet) + if err != nil { + log.Println("outbound closing client", v.client.id) + s.closeClient(v.client, true) + } + } + } + }(i) + } + + s.inboundPool.Add(workers) + for i := 0; i < workers; i++ { + log.Println("spawning worker", i) + go func(wid int) { + defer s.inboundPool.Done() + for { + select { + case v, ok := <-s.outbound: + if !ok { + log.Println("worker outbound closed", wid) + return + } + + err := s.writeClient(v.client, v.packet) + if err != nil { + log.Println("outbound closing client", v.client.id) + s.closeClient(v.client, true) + } } } }(i) @@ -449,11 +485,14 @@ func (s *Server) processPublish(cl *client, pk *packets.PublishPacket) error { } // Write the publish packet out to the receiving client. - err := s.writeClient(client, out) - if err != nil { - s.closeClient(client, true) - return err - } + s.outbound <- transitMsg{client: client, packet: out} + /* + err := s.writeClient(client, out) + if err != nil { + s.closeClient(client, true) + return err + } + */ } } @@ -533,27 +572,42 @@ func (s *Server) processSubscribe(cl *client, pk *packets.SubscribePacket) error } } - err := s.writeClient(cl, &packets.SubackPacket{ - FixedHeader: packets.FixedHeader{ - Type: packets.Suback, + s.outbound <- transitMsg{ + client: cl, + packet: &packets.SubackPacket{ + FixedHeader: packets.FixedHeader{ + Type: packets.Suback, + }, + PacketID: pk.PacketID, + ReturnCodes: retCodes, }, - PacketID: pk.PacketID, - ReturnCodes: retCodes, - }) - if err != nil { - s.closeClient(cl, true) - return err } + /* + err := s.writeClient(cl, &packets.SubackPacket{ + FixedHeader: packets.FixedHeader{ + Type: packets.Suback, + }, + PacketID: pk.PacketID, + ReturnCodes: retCodes, + }) + if err != nil { + s.closeClient(cl, true) + return err + } + */ // Publish out any retained messages matching the subscription filter. for i := 0; i < len(pk.Topics); i++ { messages := s.topics.Messages(pk.Topics[i]) for _, pkv := range messages { - err := s.writeClient(cl, pkv) - if err != nil { - s.closeClient(cl, true) - return err - } + s.outbound <- transitMsg{client: cl, packet: pkv} + /* + err := s.writeClient(cl, pkv) + if err != nil { + s.closeClient(cl, true) + return err + } + */ } } @@ -625,6 +679,7 @@ func (s *Server) Close() error { // Close down waitgroups and pools. close(s.inbound) + close(s.outbound) s.inboundPool.Wait() return nil diff --git a/mqtt_test.go b/mqtt_test.go index c0a9736..aced9af 100644 --- a/mqtt_test.go +++ b/mqtt_test.go @@ -78,6 +78,7 @@ func TestNew(t *testing.T) { require.NotNil(t, s.listeners) require.NotNil(t, s.clients) require.NotNil(t, s.inbound) + require.NotNil(t, s.outbound) // log.Println(s) }