diff --git a/examples/persistence/main.go b/examples/persistence/main.go index c5d1eb8..aca2887 100644 --- a/examples/persistence/main.go +++ b/examples/persistence/main.go @@ -42,7 +42,12 @@ func main() { } // Start broker... - go server.Serve() + go func() { + err := server.Serve() + if err != nil { + log.Fatal(err) + } + }() fmt.Println(aurora.BgMagenta(" Started! ")) // Wait for signals... diff --git a/server/internal/clients/clients.go b/server/internal/clients/clients.go index 95264e1..663e971 100644 --- a/server/internal/clients/clients.go +++ b/server/internal/clients/clients.go @@ -20,9 +20,9 @@ import ( ) var ( - defaultKeepalive uint16 = 10 // in seconds. - resendBackoff = []int64{0, 1, 2, 10, 60, 120, 600, 3600, 21600} // <1 second to 6 hours - maxResends = 6 // maximum number of times to retry sending QoS packets. + defaultKeepalive uint16 = 10 // in seconds. + //resendBackoff = []int64{0, 1, 2, 10, 60, 120, 600, 3600, 21600} // <1 second to 6 hours + //maxResends = 6 // maximum number of times to retry sending QoS packets. ErrConnectionClosed = errors.New("Connection not open") ) @@ -109,28 +109,26 @@ type State struct { started *sync.WaitGroup // tracks the goroutines which have been started. endedW *sync.WaitGroup // tracks when the writer has ended. endedR *sync.WaitGroup // tracks when the reader has ended. - endOnce sync.Once // + endOnce sync.Once // only end once. } // NewClient returns a new instance of Client. func NewClient(c net.Conn, r *circ.Reader, w *circ.Writer, s *system.Info) *Client { cl := &Client{ - conn: c, - r: r, - w: w, - + conn: c, + r: r, + w: w, + system: s, keepalive: defaultKeepalive, Inflight: Inflight{ internal: make(map[uint16]InflightMessage), }, Subscriptions: make(map[string]byte), - State: State{ started: new(sync.WaitGroup), endedW: new(sync.WaitGroup), endedR: new(sync.WaitGroup), }, - system: s, } cl.refreshDeadline(cl.keepalive) @@ -146,6 +144,9 @@ func NewClientStub(s *system.Info) *Client { internal: make(map[uint16]InflightMessage), }, Subscriptions: make(map[string]byte), + State: State{ + Done: 1, + }, } } @@ -240,6 +241,10 @@ func (cl *Client) Start() { // Stop instructs the client to shut down all processing goroutines and disconnect. func (cl *Client) Stop() { + if atomic.LoadInt64(&cl.State.Done) == 1 { + return + } + cl.State.endOnce.Do(func() { cl.r.Stop() cl.w.Stop() @@ -324,9 +329,6 @@ func (cl *Client) Read(h func(*Client, packets.Packet) error) error { if err != nil { return err } - - // Attempt periodic resend of inflight messages (where applicable). - //cl.ResendInflight(false) } } @@ -448,43 +450,6 @@ func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) { return } -// ResendInflight will attempt resend send any in-flight messages stored for a client. -func (cl *Client) ResendInflight(force bool) error { - if cl.Inflight.Len() == 0 { - return nil - } - - nt := time.Now().Unix() - for _, tk := range cl.Inflight.GetAll() { - if tk.Resends >= maxResends { // After a reasonable time, drop inflight packets. - cl.Inflight.Delete(tk.Packet.PacketID) - if tk.Packet.FixedHeader.Type == packets.Publish { - atomic.AddInt64(&cl.system.PublishDropped, 1) - } - continue - } - - // Only continue if the resend backoff time has passed and there's a backoff time. - if !force && (nt-tk.Sent < resendBackoff[tk.Resends] || len(resendBackoff) < tk.Resends) { - continue - } - - if tk.Packet.FixedHeader.Type == packets.Publish { - tk.Packet.FixedHeader.Dup = true - } - - tk.Resends++ - tk.Sent = nt - cl.Inflight.Set(tk.Packet.PacketID, tk) - _, err := cl.WritePacket(tk.Packet) - if err != nil { - return err - } - } - - return nil -} - // LWT contains the last will and testament details for a client connection. type LWT struct { Topic string // the topic the will message shall be sent to. diff --git a/server/internal/clients/clients_test.go b/server/internal/clients/clients_test.go index 08c233d..f239b19 100644 --- a/server/internal/clients/clients_test.go +++ b/server/internal/clients/clients_test.go @@ -685,157 +685,6 @@ func TestClientWritePacketInvalidPacket(t *testing.T) { require.Error(t, err) } -func TestClientResendInflight(t *testing.T) { - r, w := net.Pipe() - cl := NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info)) - cl.Start() - - o := make(chan []byte) - go func() { - buf, err := ioutil.ReadAll(w) - require.NoError(t, err) - o <- buf - }() - - pk1 := packets.Packet{ - FixedHeader: packets.FixedHeader{ - Type: packets.Publish, - Qos: 1, - }, - TopicName: "a/b/c", - Payload: []byte("hello"), - PacketID: 11, - } - - cl.Inflight.Set(pk1.PacketID, InflightMessage{ - Packet: pk1, - Sent: time.Now().Unix(), - }) - - err := cl.ResendInflight(true) - require.NoError(t, err) - - time.Sleep(time.Millisecond) - r.Close() - - rcv := <-o - require.Equal(t, []byte{ - byte(packets.Publish<<4 | 1<<1 | 1<<3), 14, - 0, 5, - 'a', '/', 'b', '/', 'c', - 0, 11, - 'h', 'e', 'l', 'l', 'o', - }, rcv) - - m := cl.Inflight.GetAll() - require.Equal(t, 1, m[11].Resends) // index is packet id - -} - -func TestClientResendBackoff(t *testing.T) { - r, w := net.Pipe() - cl := NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info)) - cl.Start() - - o := make(chan []byte) - go func() { - buf, err := ioutil.ReadAll(w) - require.NoError(t, err) - o <- buf - }() - - pk1 := packets.Packet{ - FixedHeader: packets.FixedHeader{ - Type: packets.Publish, - Qos: 1, - }, - TopicName: "a/b/c", - Payload: []byte("hello"), - PacketID: 11, - } - - cl.Inflight.Set(pk1.PacketID, InflightMessage{ - Packet: pk1, - Sent: time.Now().Unix(), - Resends: 0, - }) - - err := cl.ResendInflight(false) - require.NoError(t, err) - - time.Sleep(time.Millisecond) - - // Attempt to send twice, but backoff should kick in stopping second resend. - err = cl.ResendInflight(false) - require.NoError(t, err) - - r.Close() - - rcv := <-o - require.Equal(t, []byte{ - byte(packets.Publish<<4 | 1<<1 | 1<<3), 14, - 0, 5, - 'a', '/', 'b', '/', 'c', - 0, 11, - 'h', 'e', 'l', 'l', 'o', - }, rcv) - - m := cl.Inflight.GetAll() - require.Equal(t, 1, m[11].Resends) // index is packet id -} - -func TestClientResendInflightNoMessages(t *testing.T) { - r, _ := net.Pipe() - cl := NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info)) - out := []packets.Packet{} - err := cl.ResendInflight(true) - require.NoError(t, err) - require.Equal(t, 0, len(out)) - r.Close() -} - -func TestClientResendInflightDropMessage(t *testing.T) { - r, _ := net.Pipe() - cl := NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info)) - - pk1 := packets.Packet{ - FixedHeader: packets.FixedHeader{ - Type: packets.Publish, - Qos: 1, - }, - TopicName: "a/b/c", - Payload: []byte("hello"), - PacketID: 11, - } - - cl.Inflight.Set(pk1.PacketID, InflightMessage{ - Packet: pk1, - Sent: time.Now().Unix(), - Resends: maxResends, - }) - - err := cl.ResendInflight(true) - require.NoError(t, err) - r.Close() - - m := cl.Inflight.GetAll() - require.Equal(t, 0, len(m)) - require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.PublishDropped)) -} - -func TestClientResendInflightError(t *testing.T) { - r, _ := net.Pipe() - cl := NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info)) - - cl.Inflight.Set(1, InflightMessage{ - Packet: packets.Packet{}, - Sent: time.Now().Unix(), - }) - r.Close() - err := cl.ResendInflight(true) - require.Error(t, err) -} - ///// func TestInflightSet(t *testing.T) { @@ -893,6 +742,20 @@ func BenchmarkInflightGetAll(b *testing.B) { } } +func TestInflightLen(t *testing.T) { + cl := genClient() + cl.Inflight.Set(2, InflightMessage{Packet: packets.Packet{}, Sent: 0}) + require.Equal(t, 1, cl.Inflight.Len()) +} + +func BenchmarkInflightLen(b *testing.B) { + cl := genClient() + cl.Inflight.Set(2, InflightMessage{Packet: packets.Packet{}, Sent: 0}) + for n := 0; n < b.N; n++ { + cl.Inflight.Len() + } +} + func TestInflightDelete(t *testing.T) { cl := genClient() cl.Inflight.Set(3, InflightMessage{Packet: packets.Packet{}, Sent: 0}) diff --git a/server/persistence/bolt/bolt.go b/server/persistence/bolt/bolt.go index 0f1843c..07bb2e0 100644 --- a/server/persistence/bolt/bolt.go +++ b/server/persistence/bolt/bolt.go @@ -1,10 +1,11 @@ package bolt import ( - //"encoding/gob" "errors" "time" + "fmt" + sgob "github.com/asdine/storm/codec/gob" "github.com/asdine/storm/v3" "go.etcd.io/bbolt" @@ -17,6 +18,10 @@ const ( defaultTimeout = 250 * time.Millisecond ) +var ( + errNotFound = "not found" +) + // Store is a backend for writing and reading to bolt persistent storage. type Store struct { path string // the path on which to store the db file. @@ -190,74 +195,70 @@ func (s *Store) DeleteRetained(id string) error { // ReadSubscriptions loads all the subscriptions from the boltdb instance. func (s *Store) ReadSubscriptions() (v []persistence.Subscription, err error) { if s.db == nil { - err = errors.New("boltdb not opened") - return + return v, errors.New("boltdb not opened") } err = s.db.Find("T", persistence.KSubscription, &v) - if err != nil { + if err != nil && err.Error() != errNotFound { return } - return + return v, nil } // ReadClients loads all the clients from the boltdb instance. func (s *Store) ReadClients() (v []persistence.Client, err error) { if s.db == nil { - err = errors.New("boltdb not opened") - return + return v, errors.New("boltdb not opened") } err = s.db.Find("T", persistence.KClient, &v) - if err != nil { + if err != nil && err.Error() != errNotFound { return } - return + return v, nil } // ReadInflight loads all the inflight messages from the boltdb instance. func (s *Store) ReadInflight() (v []persistence.Message, err error) { if s.db == nil { - err = errors.New("boltdb not opened") - return + return v, errors.New("boltdb not opened") } err = s.db.Find("T", persistence.KInflight, &v) - if err != nil { + if err != nil && err.Error() != errNotFound { return } - return + return v, nil } // ReadRetained loads all the retained messages from the boltdb instance. func (s *Store) ReadRetained() (v []persistence.Message, err error) { if s.db == nil { - err = errors.New("boltdb not opened") - return + return v, errors.New("boltdb not opened") } err = s.db.Find("T", persistence.KRetained, &v) - if err != nil { + if err != nil && err.Error() != errNotFound { return } - return + return v, nil } //ReadServerInfo loads the server info from the boltdb instance. func (s *Store) ReadServerInfo() (v persistence.ServerInfo, err error) { if s.db == nil { - err = errors.New("boltdb not opened") - return + return v, errors.New("boltdb not opened") } err = s.db.One("ID", persistence.KServerInfo, &v) - if err != nil { + fmt.Println(err) + if err != nil && err.Error() != errNotFound { return } - return + return v, nil } diff --git a/server/persistence/bolt/bolt_test.go b/server/persistence/bolt/bolt_test.go index 6df5026..5a8588a 100644 --- a/server/persistence/bolt/bolt_test.go +++ b/server/persistence/bolt/bolt_test.go @@ -111,10 +111,7 @@ func TestReadServerInfoFail(t *testing.T) { s := New(tmpPath, nil) err := s.Open() require.NoError(t, err) - - err = os.Remove(tmpPath) - require.NoError(t, err) - + s.Close() _, err = s.ReadServerInfo() require.Error(t, err) } @@ -168,10 +165,7 @@ func TestWriteSubscriptionFail(t *testing.T) { s := New(tmpPath, nil) err := s.Open() require.NoError(t, err) - - err = os.Remove(tmpPath) - require.NoError(t, err) - + s.Close() err = s.WriteSubscription(persistence.Subscription{}) require.Error(t, err) } @@ -186,10 +180,7 @@ func TestReadSubscriptionFail(t *testing.T) { s := New(tmpPath, nil) err := s.Open() require.NoError(t, err) - - err = os.Remove(tmpPath) - require.NoError(t, err) - + s.Close() _, err = s.ReadSubscriptions() require.Error(t, err) } @@ -248,10 +239,7 @@ func TestWriteInflightFail(t *testing.T) { s := New(tmpPath, nil) err := s.Open() require.NoError(t, err) - - err = os.Remove(tmpPath) - require.NoError(t, err) - + s.Close() err = s.WriteInflight(persistence.Message{}) require.Error(t, err) } @@ -266,10 +254,7 @@ func TestReadInflightFail(t *testing.T) { s := New(tmpPath, nil) err := s.Open() require.NoError(t, err) - - err = os.Remove(tmpPath) - require.NoError(t, err) - + s.Close() _, err = s.ReadInflight() require.Error(t, err) } @@ -352,10 +337,7 @@ func TestReadRetainedFail(t *testing.T) { s := New(tmpPath, nil) err := s.Open() require.NoError(t, err) - - err = os.Remove(tmpPath) - require.NoError(t, err) - + s.Close() _, err = s.ReadRetained() require.Error(t, err) } @@ -367,15 +349,11 @@ func TestWriteRetrieveDeleteClients(t *testing.T) { defer teardown(s, t) v := persistence.Client{ - ID: "client1", - T: persistence.KClient, - Listener: "tcp1", - Username: []byte{'m', 'o', 'c', 'h', 'i'}, - CleanSession: true, - Subscriptions: map[string]byte{ - "a/b/c": 0, - "d/e/f": 1, - }, + ID: "cl_client1", + ClientID: "client1", + T: persistence.KClient, + Listener: "tcp1", + Username: []byte{'m', 'o', 'c', 'h', 'i'}, LWT: persistence.LWT{ Topic: "a/b/c", Message: []byte{'h', 'e', 'l', 'l', 'o'}, @@ -391,10 +369,10 @@ func TestWriteRetrieveDeleteClients(t *testing.T) { require.Equal(t, []byte{'m', 'o', 'c', 'h', 'i'}, clients[0].Username) require.Equal(t, "a/b/c", clients[0].LWT.Topic) - require.Equal(t, uint8(1), clients[0].Subscriptions["d/e/f"]) v2 := persistence.Client{ - ID: "client2", + ID: "cl_client2", + ClientID: "client2", T: persistence.KClient, Listener: "tcp1", } @@ -406,7 +384,7 @@ func TestWriteRetrieveDeleteClients(t *testing.T) { require.Equal(t, persistence.KClient, clients[0].T) require.Equal(t, 2, len(clients)) - err = s.DeleteClient("client2") + err = s.DeleteClient("cl_client2") require.NoError(t, err) clients, err = s.ReadClients() @@ -424,10 +402,7 @@ func TestWriteClientFail(t *testing.T) { s := New(tmpPath, nil) err := s.Open() require.NoError(t, err) - - err = os.Remove(tmpPath) - require.NoError(t, err) - + s.Close() err = s.WriteClient(persistence.Client{}) require.Error(t, err) } @@ -442,10 +417,7 @@ func TestReadClientFail(t *testing.T) { s := New(tmpPath, nil) err := s.Open() require.NoError(t, err) - - err = os.Remove(tmpPath) - require.NoError(t, err) - + s.Close() _, err = s.ReadClients() require.Error(t, err) } @@ -460,10 +432,7 @@ func TestDeleteSubscriptionFail(t *testing.T) { s := New(tmpPath, nil) err := s.Open() require.NoError(t, err) - - err = os.Remove(tmpPath) - require.NoError(t, err) - + s.Close() err = s.DeleteSubscription("a") require.Error(t, err) } @@ -478,10 +447,7 @@ func TestDeleteClientFail(t *testing.T) { s := New(tmpPath, nil) err := s.Open() require.NoError(t, err) - - err = os.Remove(tmpPath) - require.NoError(t, err) - + s.Close() err = s.DeleteClient("a") require.Error(t, err) } @@ -496,10 +462,7 @@ func TestDeleteInflightFail(t *testing.T) { s := New(tmpPath, nil) err := s.Open() require.NoError(t, err) - - err = os.Remove(tmpPath) - require.NoError(t, err) - + s.Close() err = s.DeleteInflight("a") require.Error(t, err) } @@ -514,10 +477,7 @@ func TestDeleteRetainedFail(t *testing.T) { s := New(tmpPath, nil) err := s.Open() require.NoError(t, err) - - err = os.Remove(tmpPath) - require.NoError(t, err) - + s.Close() err = s.DeleteRetained("a") require.Error(t, err) } diff --git a/server/persistence/persistence.go b/server/persistence/persistence.go index c5e994c..25e274e 100644 --- a/server/persistence/persistence.go +++ b/server/persistence/persistence.go @@ -75,13 +75,12 @@ type FixedHeader struct { // Client contains client data that can be persistently stored. type Client struct { - ID string // the id of the client - T string // the type of the stored data. - Listener string // the last known listener id for the client - Username []byte // the username the client authenticated with. - CleanSession bool // indicates if the client connected expecting a clean-session. - Subscriptions map[string]byte // a list of the subscriptions the user has (qos keyed on filter). - LWT LWT // the last-will-and-testament message for the client. + ID string // the storage key. + ClientID string // the id of the client. + T string // the type of the stored data. + Listener string // the last known listener id for the client + Username []byte // the username the client authenticated with. + LWT LWT // the last-will-and-testament message for the client. } // LWT contains details about a clients LWT payload. @@ -216,14 +215,10 @@ func (s *MockStore) ReadClients() (v []Client, err error) { return []Client{ Client{ - ID: "client1", - T: KClient, - Listener: "tcp1", - CleanSession: true, - Subscriptions: map[string]byte{ - "a/b/c": 0, - "d/e/f": 1, - }, + ID: "cl_client1", + ClientID: "client1", + T: KClient, + Listener: "tcp1", }, }, nil } diff --git a/server/server.go b/server/server.go index baa2aaf..dccc32f 100644 --- a/server/server.go +++ b/server/server.go @@ -22,6 +22,7 @@ const ( Version = "0.1.0" // the server version. maxPacketID = 65535 // the maximum value of a 16-bit packet ID. + ) var ( @@ -31,6 +32,9 @@ var ( ErrInvalidTopic = errors.New("Cannot publish to $ and $SYS topics") SysTopicInterval time.Duration = 30000 // the default number of milliseconds between $SYS topic publishes. + + inflightResendBackoff = []int64{0, 1, 2, 10, 60, 120, 600, 3600, 21600} // <1 second to 6 hours + inflightMaxResends = 6 // maximum number of times to retry sending QoS packets. ) // Server is an MQTT broker server. @@ -113,100 +117,6 @@ func (s *Server) Serve() error { return nil } -// readStore reads in any data from the persistent datastore (if applicable). -func (s *Server) readStore() error { - info, err := s.Store.ReadServerInfo() - if err != nil { - return err - } - s.loadServerInfo(info) - - subs, err := s.Store.ReadSubscriptions() - if err != nil { - return err - } - s.loadSubscriptions(subs) - - clients, err := s.Store.ReadClients() - if err != nil { - return err - } - s.loadClients(clients) - - inflight, err := s.Store.ReadInflight() - if err != nil { - return err - } - s.loadInflight(inflight) - - retained, err := s.Store.ReadRetained() - if err != nil { - return err - } - s.loadRetained(retained) - - return nil -} - -// loadServerInfo restores server info from the datastore. -func (s *Server) loadServerInfo(v persistence.ServerInfo) { - version := s.System.Version - s.System = &v.Info - s.System.Version = version -} - -// loadSubscriptions restores subscriptions from the datastore. -func (s *Server) loadSubscriptions(v []persistence.Subscription) { - for _, sub := range v { - s.Topics.Subscribe(sub.Filter, sub.Client, sub.QoS) - } -} - -// loadClients restores clients from the datastore. -func (s *Server) loadClients(v []persistence.Client) { - for _, cl := range v { - - c := clients.NewClientStub(s.System) - c.ID = cl.ID - c.Listener = cl.Listener - c.Username = cl.Username - c.LWT = clients.LWT(cl.LWT) - c.Subscriptions = cl.Subscriptions - s.Clients.Add(c) - } -} - -// loadInflight restores inflight messages from the datastore. -func (s *Server) loadInflight(v []persistence.Message) { - for _, msg := range v { - if client, ok := s.Clients.Get(msg.Client); ok { - client.Inflight.Set(msg.PacketID, clients.InflightMessage{ - Packet: packets.Packet{ - FixedHeader: packets.FixedHeader(msg.FixedHeader), - PacketID: msg.PacketID, - TopicName: msg.TopicName, - Payload: msg.Payload, - }, - Sent: msg.Sent, - Resends: msg.Resends, - }) - } - } -} - -// loadRetained restores retained messages from the datastore. -func (s *Server) loadRetained(v []persistence.Message) { - for _, msg := range v { - s.Topics.RetainMessage( - packets.Packet{ - FixedHeader: packets.FixedHeader(msg.FixedHeader), - TopicName: msg.TopicName, - Payload: msg.Payload, - }, - ) - } -} - // eventLoop runs server processes at intervals. func (s *Server) eventLoop() { for { @@ -264,6 +174,7 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller) atomic.AddInt64(&s.System.ClientsDisconnected, -1) } existing.Stop() + fmt.Printf("%+v\n", existing.Subscriptions) if pk.CleanSession { for k := range existing.Subscriptions { delete(existing.Subscriptions, k) @@ -298,7 +209,18 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller) return err } - cl.ResendInflight(true) + s.ResendClientInflight(cl, true) + + if s.Store != nil { + s.Store.WriteClient(persistence.Client{ + ID: "cl_" + cl.ID, + ClientID: cl.ID, + T: persistence.KClient, + Listener: cl.Listener, + Username: cl.Username, + LWT: persistence.LWT(cl.LWT), + }) + } err = cl.Read(s.processPacket) if err != nil { @@ -404,8 +326,22 @@ func (s *Server) processPublish(cl *clients.Client, pk packets.Packet) error { } if pk.FixedHeader.Retain { - q := s.Topics.RetainMessage(pk.PublishCopy()) + out := pk.PublishCopy() + q := s.Topics.RetainMessage(out) atomic.AddInt64(&s.System.Retained, q) + if s.Store != nil { + if q == 1 { + s.Store.WriteRetained(persistence.Message{ + ID: "ret_" + out.TopicName, + T: persistence.KRetained, + FixedHeader: persistence.FixedHeader(out.FixedHeader), + TopicName: out.TopicName, + Payload: out.Payload, + }) + } else { + s.Store.DeleteRetained("ret_" + out.TopicName) + } + } } if pk.FixedHeader.Qos > 0 { @@ -450,13 +386,25 @@ func (s *Server) publishToSubscribers(pk packets.Packet) { // the client at some point, one way or another. Store the publish // packet in the client's inflight queue and attempt to redeliver // if an appropriate ack is not received (or if the client is offline). + sent := time.Now().Unix() q := client.Inflight.Set(out.PacketID, clients.InflightMessage{ Packet: out, - Sent: time.Now().Unix(), + Sent: sent, }) if q { atomic.AddInt64(&s.System.Inflight, 1) } + + if s.Store != nil { + s.Store.WriteInflight(persistence.Message{ + ID: "if_" + client.ID + "_" + strconv.Itoa(int(out.PacketID)), + T: persistence.KRetained, + FixedHeader: persistence.FixedHeader(out.FixedHeader), + TopicName: out.TopicName, + Payload: out.Payload, + Sent: sent, + }) + } } s.writeClient(client, out) @@ -470,6 +418,9 @@ func (s *Server) processPuback(cl *clients.Client, pk packets.Packet) error { if q { atomic.AddInt64(&s.System.Inflight, -1) } + if s.Store != nil { + s.Store.DeleteInflight("if_" + cl.ID + "_" + strconv.Itoa(int(pk.PacketID))) + } return nil } @@ -509,6 +460,10 @@ func (s *Server) processPubrel(cl *clients.Client, pk packets.Packet) error { atomic.AddInt64(&s.System.Inflight, -1) } + if s.Store != nil { + s.Store.DeleteInflight("if_" + cl.ID + "_" + strconv.Itoa(int(pk.PacketID))) + } + return nil } @@ -518,6 +473,9 @@ func (s *Server) processPubcomp(cl *clients.Client, pk packets.Packet) error { if q { atomic.AddInt64(&s.System.Inflight, -1) } + if s.Store != nil { + s.Store.DeleteInflight("if_" + cl.ID + "_" + strconv.Itoa(int(pk.PacketID))) + } return nil } @@ -534,6 +492,16 @@ func (s *Server) processSubscribe(cl *clients.Client, pk packets.Packet) error { } cl.NoteSubscription(pk.Topics[i], pk.Qoss[i]) retCodes[i] = pk.Qoss[i] + + if s.Store != nil { + s.Store.WriteSubscription(persistence.Subscription{ + ID: "sub_" + cl.ID + ":" + pk.Topics[i], + T: persistence.KSubscription, + Filter: pk.Topics[i], + Client: cl.ID, + QoS: pk.Qoss[i], + }) + } } } @@ -585,8 +553,6 @@ func (s *Server) processUnsubscribe(cl *clients.Client, pk packets.Packet) error // Due to the int to string conversions this method is not as cheap as // some of the others so the publishing interval should be set appropriately. func (s *Server) publishSysTopics() { - s.System.Uptime = time.Now().Unix() - s.System.Started - pk := packets.Packet{ FixedHeader: packets.FixedHeader{ Type: packets.Publish, @@ -594,6 +560,7 @@ func (s *Server) publishSysTopics() { }, } + s.System.Uptime = time.Now().Unix() - s.System.Started topics := map[string]string{ "$SYS/broker/version": s.System.Version, "$SYS/broker/uptime": strconv.Itoa(int(s.System.Uptime)), @@ -629,7 +596,63 @@ func (s *Server) publishSysTopics() { persistence.KServerInfo, }) } +} +// ResendClientInflight attempts to resend all undelivered inflight messages +// to a client. +func (s *Server) ResendClientInflight(cl *clients.Client, force bool) error { + if cl.Inflight.Len() == 0 { + return nil + } + + nt := time.Now().Unix() + for _, tk := range cl.Inflight.GetAll() { + + // After a reasonable time, drop inflight packets. + if tk.Resends >= inflightMaxResends { + cl.Inflight.Delete(tk.Packet.PacketID) + if tk.Packet.FixedHeader.Type == packets.Publish { + atomic.AddInt64(&s.System.PublishDropped, 1) + } + + if s.Store != nil { + s.Store.DeleteInflight("if_" + cl.ID + "_" + strconv.Itoa(int(tk.Packet.PacketID))) + } + + continue + } + + // Only continue if the resend backoff time has passed and there's a backoff time. + if !force && (nt-tk.Sent < inflightResendBackoff[tk.Resends] || len(inflightResendBackoff) < tk.Resends) { + continue + } + + if tk.Packet.FixedHeader.Type == packets.Publish { + tk.Packet.FixedHeader.Dup = true + } + + tk.Resends++ + tk.Sent = nt + cl.Inflight.Set(tk.Packet.PacketID, tk) + _, err := cl.WritePacket(tk.Packet) + if err != nil { + return err + } + + if s.Store != nil { + s.Store.WriteInflight(persistence.Message{ + ID: "if_" + cl.ID + "_" + strconv.Itoa(int(tk.Packet.PacketID)), + T: persistence.KRetained, + FixedHeader: persistence.FixedHeader(tk.Packet.FixedHeader), + TopicName: tk.Packet.TopicName, + Payload: tk.Packet.Payload, + Sent: tk.Sent, + Resends: tk.Resends, + }) + } + } + + return nil } // Close attempts to gracefully shutdown the server, all listeners, clients, and stores. @@ -665,10 +688,104 @@ func (s *Server) closeClient(cl *clients.Client, sendLWT bool) error { TopicName: cl.LWT.Topic, Payload: cl.LWT.Message, }) - // omit errors, since we're not logging and need to close the client in either case. } cl.Stop() return nil } + +// readStore reads in any data from the persistent datastore (if applicable). +func (s *Server) readStore() error { + info, err := s.Store.ReadServerInfo() + if err != nil { + return fmt.Errorf("load server info; %w", err) + } + s.loadServerInfo(info) + + clients, err := s.Store.ReadClients() + if err != nil { + return fmt.Errorf("load clients; %w", err) + } + fmt.Println("loading clients", clients) + s.loadClients(clients) + + subs, err := s.Store.ReadSubscriptions() + if err != nil { + return fmt.Errorf("load subscriptions; %w", err) + } + s.loadSubscriptions(subs) + + inflight, err := s.Store.ReadInflight() + if err != nil { + return fmt.Errorf("load inflight; %w", err) + } + s.loadInflight(inflight) + + retained, err := s.Store.ReadRetained() + if err != nil { + return fmt.Errorf("load retained; %w", err) + } + s.loadRetained(retained) + + return nil +} + +// loadServerInfo restores server info from the datastore. +func (s *Server) loadServerInfo(v persistence.ServerInfo) { + version := s.System.Version + s.System = &v.Info + s.System.Version = version +} + +// loadSubscriptions restores subscriptions from the datastore. +func (s *Server) loadSubscriptions(v []persistence.Subscription) { + for _, sub := range v { + s.Topics.Subscribe(sub.Filter, sub.Client, sub.QoS) + if cl, ok := s.Clients.Get(sub.Client); ok { + cl.NoteSubscription(sub.Filter, sub.QoS) + } + } + +} + +// loadClients restores clients from the datastore. +func (s *Server) loadClients(v []persistence.Client) { + for _, c := range v { + cl := clients.NewClientStub(s.System) + cl.ID = c.ClientID + cl.Listener = c.Listener + cl.Username = c.Username + cl.LWT = clients.LWT(c.LWT) + s.Clients.Add(cl) + } +} + +// loadInflight restores inflight messages from the datastore. +func (s *Server) loadInflight(v []persistence.Message) { + for _, msg := range v { + if client, ok := s.Clients.Get(msg.Client); ok { + client.Inflight.Set(msg.PacketID, clients.InflightMessage{ + Packet: packets.Packet{ + FixedHeader: packets.FixedHeader(msg.FixedHeader), + PacketID: msg.PacketID, + TopicName: msg.TopicName, + Payload: msg.Payload, + }, + Sent: msg.Sent, + Resends: msg.Resends, + }) + } + } +} + +// loadRetained restores retained messages from the datastore. +func (s *Server) loadRetained(v []persistence.Message) { + for _, msg := range v { + s.Topics.RetainMessage(packets.Packet{ + FixedHeader: packets.FixedHeader(msg.FixedHeader), + TopicName: msg.TopicName, + Payload: msg.Payload, + }) + } +} diff --git a/server/server_test.go b/server/server_test.go index ec96bd7..4d6984e 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1,7 +1,7 @@ package server import ( - "errors" + //"errors" "io/ioutil" "net" "strconv" @@ -25,6 +25,7 @@ const defaultPort = ":18882" func setupClient() (s *Server, cl *clients.Client, r net.Conn, w net.Conn) { s = New() + s.Store = new(persistence.MockStore) r, w = net.Pipe() cl = clients.NewClient(w, circ.NewReader(256, 8), circ.NewWriter(256, 8), s.System) cl.ID = "mochi" @@ -120,234 +121,6 @@ func BenchmarkServerAddListener(b *testing.B) { } } -func TestServerReadStore(t *testing.T) { - s := New() - require.NotNil(t, s) - - s.Store = new(persistence.MockStore) - err := s.readStore() - require.NoError(t, err) - - require.Equal(t, int64(100), s.System.Started) - require.Equal(t, topics.Subscriptions{"test": 1}, s.Topics.Subscribers("a/b/c")) - - cl1, ok := s.Clients.Get("client1") - require.Equal(t, true, ok) - - msg, ok := cl1.Inflight.Get(100) - require.Equal(t, true, ok) - require.Equal(t, []byte{'y', 'e', 's'}, msg.Packet.Payload) - -} - -func TestServerReadStoreFailures(t *testing.T) { - s := New() - require.NotNil(t, s) - - s.Store = new(persistence.MockStore) - s.Store.(*persistence.MockStore).Fail = map[string]bool{ - "read_subs": true, - "read_clients": true, - "read_inflight": true, - "read_retained": true, - "read_info": true, - } - - err := s.readStore() - require.Error(t, err) - require.Equal(t, errors.New("test_info"), err) - delete(s.Store.(*persistence.MockStore).Fail, "read_info") - - err = s.readStore() - require.Error(t, err) - require.Equal(t, errors.New("test_subs"), err) - delete(s.Store.(*persistence.MockStore).Fail, "read_subs") - - err = s.readStore() - require.Error(t, err) - require.Equal(t, errors.New("test_clients"), err) - delete(s.Store.(*persistence.MockStore).Fail, "read_clients") - - err = s.readStore() - require.Error(t, err) - require.Equal(t, errors.New("test_inflight"), err) - delete(s.Store.(*persistence.MockStore).Fail, "read_inflight") - - err = s.readStore() - require.Error(t, err) - require.Equal(t, errors.New("test_retained"), err) - delete(s.Store.(*persistence.MockStore).Fail, "read_retained") -} - -func TestServerLoadServerInfo(t *testing.T) { - s := New() - require.NotNil(t, s) - - s.System.Version = "original" - - s.loadServerInfo(persistence.ServerInfo{ - system.Info{ - Version: "test", - Started: 100, - }, persistence.KServerInfo, - }) - - require.Equal(t, "original", s.System.Version) - require.Equal(t, int64(100), s.System.Started) -} - -func TestServerLoadSubscriptions(t *testing.T) { - s := New() - require.NotNil(t, s) - - subs := []persistence.Subscription{ - persistence.Subscription{ - ID: "test:a/b/c", - Client: "test", - Filter: "a/b/c", - QoS: 1, - T: persistence.KSubscription, - }, - persistence.Subscription{ - ID: "test:d/e/f", - Client: "test", - Filter: "d/e/f", - QoS: 0, - T: persistence.KSubscription, - }, - } - - s.loadSubscriptions(subs) - require.Equal(t, topics.Subscriptions{"test": 1}, s.Topics.Subscribers("a/b/c")) - require.Equal(t, topics.Subscriptions{"test": 0}, s.Topics.Subscribers("d/e/f")) -} - -func TestServerLoadClients(t *testing.T) { - s := New() - require.NotNil(t, s) - - clients := []persistence.Client{ - persistence.Client{ - ID: "client1", - T: persistence.KClient, - Listener: "tcp1", - CleanSession: true, - Subscriptions: map[string]byte{ - "a/b/c": 0, - "d/e/f": 1, - }, - }, - persistence.Client{ - ID: "client2", - T: persistence.KClient, - Listener: "tcp1", - Subscriptions: map[string]byte{ - "q/w/e": 2, - }, - }, - } - - s.loadClients(clients) - - cl1, ok := s.Clients.Get("client1") - require.Equal(t, true, ok) - require.NotNil(t, cl1) - - cl2, ok2 := s.Clients.Get("client2") - require.Equal(t, true, ok2) - require.NotNil(t, cl2) - -} - -func TestServerLoadInflight(t *testing.T) { - s := New() - require.NotNil(t, s) - - msgs := []persistence.Message{ - persistence.Message{ - ID: "client1_if_0", - T: persistence.KInflight, - Client: "client1", - PacketID: 0, - TopicName: "a/b/c", - Payload: []byte{'h', 'e', 'l', 'l', 'o'}, - Sent: 100, - Resends: 0, - }, - persistence.Message{ - ID: "client1_if_100", - T: persistence.KInflight, - Client: "client1", - PacketID: 100, - TopicName: "d/e/f", - Payload: []byte{'y', 'e', 's'}, - Sent: 200, - Resends: 1, - }, - } - - w, _ := net.Pipe() - defer w.Close() - c1 := clients.NewClient(w, nil, nil, nil) - c1.ID = "client1" - s.Clients.Add(c1) - - s.loadInflight(msgs) - - cl1, ok := s.Clients.Get("client1") - require.Equal(t, true, ok) - require.Equal(t, "client1", cl1.ID) - - msg, ok := cl1.Inflight.Get(100) - require.Equal(t, true, ok) - require.Equal(t, []byte{'y', 'e', 's'}, msg.Packet.Payload) - -} - -func TestServerLoadRetained(t *testing.T) { - s := New() - require.NotNil(t, s) - - msgs := []persistence.Message{ - persistence.Message{ - ID: "client1_ret_200", - T: persistence.KRetained, - FixedHeader: persistence.FixedHeader{ - Retain: true, - }, - PacketID: 200, - TopicName: "a/b/c", - Payload: []byte{'h', 'e', 'l', 'l', 'o'}, - Sent: 100, - Resends: 0, - }, - persistence.Message{ - ID: "client1_ret_300", - T: persistence.KRetained, - FixedHeader: persistence.FixedHeader{ - Retain: true, - }, - PacketID: 100, - TopicName: "d/e/f", - Payload: []byte{'y', 'e', 's'}, - Sent: 200, - Resends: 1, - }, - } - - s.loadRetained(msgs) - - require.Equal(t, 1, len(s.Topics.Messages("a/b/c"))) - require.Equal(t, 1, len(s.Topics.Messages("d/e/f"))) - - msg := s.Topics.Messages("a/b/c") - require.Equal(t, []byte{'h', 'e', 'l', 'l', 'o'}, msg[0].Payload) - - msg = s.Topics.Messages("d/e/f") - require.Equal(t, []byte{'y', 'e', 's'}, msg[0].Payload) - -} - func TestServerServe(t *testing.T) { s := New() require.NotNil(t, s) @@ -877,6 +650,35 @@ func TestServerProcessPublishQoS2(t *testing.T) { require.Equal(t, int64(0), atomic.LoadInt64(&s.System.Retained)) } +func TestServerProcessPublishUnretain(t *testing.T) { + s, cl1, r1, w1 := setupClient() + s.Clients.Add(cl1) + + ack1 := make(chan []byte) + go func() { + buf, err := ioutil.ReadAll(r1) + if err != nil { + panic(err) + } + ack1 <- buf + }() + + err := s.processPacket(cl1, packets.Packet{ + FixedHeader: packets.FixedHeader{ + Type: packets.Publish, + Retain: true, + }, + TopicName: "a/b/c", + Payload: []byte{}, + }) + + require.NoError(t, err) + time.Sleep(10 * time.Millisecond) + w1.Close() + + require.Equal(t, int64(0), atomic.LoadInt64(&s.System.Retained)) +} + func TestServerProcessPublishOfflineQueuing(t *testing.T) { s, cl1, r1, w1 := setupClient() cl1.ID = "mochi1" @@ -1451,3 +1253,395 @@ func TestServerCloseClientClosed(t *testing.T) { err := s.closeClient(cl, true) require.NoError(t, err) } + +func TestServerReadStore(t *testing.T) { + s := New() + require.NotNil(t, s) + + s.Store = new(persistence.MockStore) + err := s.readStore() + require.NoError(t, err) + + require.Equal(t, int64(100), s.System.Started) + require.Equal(t, topics.Subscriptions{"test": 1}, s.Topics.Subscribers("a/b/c")) + + cl1, ok := s.Clients.Get("client1") + require.Equal(t, true, ok) + + msg, ok := cl1.Inflight.Get(100) + require.Equal(t, true, ok) + require.Equal(t, []byte{'y', 'e', 's'}, msg.Packet.Payload) + +} + +func TestServerReadStoreFailures(t *testing.T) { + s := New() + require.NotNil(t, s) + + s.Store = new(persistence.MockStore) + s.Store.(*persistence.MockStore).Fail = map[string]bool{ + "read_subs": true, + "read_clients": true, + "read_inflight": true, + "read_retained": true, + "read_info": true, + } + + err := s.readStore() + require.Error(t, err) + delete(s.Store.(*persistence.MockStore).Fail, "read_info") + + err = s.readStore() + require.Error(t, err) + delete(s.Store.(*persistence.MockStore).Fail, "read_subs") + + err = s.readStore() + require.Error(t, err) + delete(s.Store.(*persistence.MockStore).Fail, "read_clients") + + err = s.readStore() + require.Error(t, err) + delete(s.Store.(*persistence.MockStore).Fail, "read_inflight") + + err = s.readStore() + require.Error(t, err) + delete(s.Store.(*persistence.MockStore).Fail, "read_retained") +} + +func TestServerLoadServerInfo(t *testing.T) { + s := New() + require.NotNil(t, s) + + s.System.Version = "original" + + s.loadServerInfo(persistence.ServerInfo{ + system.Info{ + Version: "test", + Started: 100, + }, persistence.KServerInfo, + }) + + require.Equal(t, "original", s.System.Version) + require.Equal(t, int64(100), s.System.Started) +} + +func TestServerLoadSubscriptions(t *testing.T) { + s := New() + require.NotNil(t, s) + + cl := clients.NewClientStub(s.System) + cl.ID = "test" + s.Clients.Add(cl) + + subs := []persistence.Subscription{ + persistence.Subscription{ + ID: "test:a/b/c", + Client: "test", + Filter: "a/b/c", + QoS: 1, + T: persistence.KSubscription, + }, + persistence.Subscription{ + ID: "test:d/e/f", + Client: "test", + Filter: "d/e/f", + QoS: 0, + T: persistence.KSubscription, + }, + } + + s.loadSubscriptions(subs) + require.Equal(t, topics.Subscriptions{"test": 1}, s.Topics.Subscribers("a/b/c")) + require.Equal(t, topics.Subscriptions{"test": 0}, s.Topics.Subscribers("d/e/f")) +} + +func TestServerLoadClients(t *testing.T) { + s := New() + require.NotNil(t, s) + + clients := []persistence.Client{ + persistence.Client{ + ID: "cl_client1", + ClientID: "client1", + T: persistence.KClient, + Listener: "tcp1", + }, + persistence.Client{ + ID: "cl_client2", + ClientID: "client2", + T: persistence.KClient, + Listener: "tcp1", + }, + } + + s.loadClients(clients) + + cl1, ok := s.Clients.Get("client1") + require.Equal(t, true, ok) + require.NotNil(t, cl1) + + cl2, ok2 := s.Clients.Get("client2") + require.Equal(t, true, ok2) + require.NotNil(t, cl2) + +} + +func TestServerLoadInflight(t *testing.T) { + s := New() + require.NotNil(t, s) + + msgs := []persistence.Message{ + persistence.Message{ + ID: "client1_if_0", + T: persistence.KInflight, + Client: "client1", + PacketID: 0, + TopicName: "a/b/c", + Payload: []byte{'h', 'e', 'l', 'l', 'o'}, + Sent: 100, + Resends: 0, + }, + persistence.Message{ + ID: "client1_if_100", + T: persistence.KInflight, + Client: "client1", + PacketID: 100, + TopicName: "d/e/f", + Payload: []byte{'y', 'e', 's'}, + Sent: 200, + Resends: 1, + }, + } + + w, _ := net.Pipe() + defer w.Close() + c1 := clients.NewClient(w, nil, nil, nil) + c1.ID = "client1" + s.Clients.Add(c1) + + s.loadInflight(msgs) + + cl1, ok := s.Clients.Get("client1") + require.Equal(t, true, ok) + require.Equal(t, "client1", cl1.ID) + + msg, ok := cl1.Inflight.Get(100) + require.Equal(t, true, ok) + require.Equal(t, []byte{'y', 'e', 's'}, msg.Packet.Payload) + +} + +func TestServerLoadRetained(t *testing.T) { + s := New() + require.NotNil(t, s) + + msgs := []persistence.Message{ + persistence.Message{ + ID: "client1_ret_200", + T: persistence.KRetained, + FixedHeader: persistence.FixedHeader{ + Retain: true, + }, + PacketID: 200, + TopicName: "a/b/c", + Payload: []byte{'h', 'e', 'l', 'l', 'o'}, + Sent: 100, + Resends: 0, + }, + persistence.Message{ + ID: "client1_ret_300", + T: persistence.KRetained, + FixedHeader: persistence.FixedHeader{ + Retain: true, + }, + PacketID: 100, + TopicName: "d/e/f", + Payload: []byte{'y', 'e', 's'}, + Sent: 200, + Resends: 1, + }, + } + + s.loadRetained(msgs) + + require.Equal(t, 1, len(s.Topics.Messages("a/b/c"))) + require.Equal(t, 1, len(s.Topics.Messages("d/e/f"))) + + msg := s.Topics.Messages("a/b/c") + require.Equal(t, []byte{'h', 'e', 'l', 'l', 'o'}, msg[0].Payload) + + msg = s.Topics.Messages("d/e/f") + require.Equal(t, []byte{'y', 'e', 's'}, msg[0].Payload) +} + +func TestServerResendClientInflight(t *testing.T) { + s := New() + s.Store = new(persistence.MockStore) + require.NotNil(t, s) + + r, w := net.Pipe() + cl := clients.NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info)) + cl.Start() + s.Clients.Add(cl) + + o := make(chan []byte) + go func() { + buf, err := ioutil.ReadAll(w) + require.NoError(t, err) + o <- buf + }() + + pk1 := packets.Packet{ + FixedHeader: packets.FixedHeader{ + Type: packets.Publish, + Qos: 1, + }, + TopicName: "a/b/c", + Payload: []byte("hello"), + PacketID: 11, + } + + cl.Inflight.Set(pk1.PacketID, clients.InflightMessage{ + Packet: pk1, + Sent: time.Now().Unix(), + }) + + err := s.ResendClientInflight(cl, true) + require.NoError(t, err) + + time.Sleep(time.Millisecond) + r.Close() + + rcv := <-o + require.Equal(t, []byte{ + byte(packets.Publish<<4 | 1<<1 | 1<<3), 14, + 0, 5, + 'a', '/', 'b', '/', 'c', + 0, 11, + 'h', 'e', 'l', 'l', 'o', + }, rcv) + + m := cl.Inflight.GetAll() + require.Equal(t, 1, m[11].Resends) // index is packet id + +} + +func TestServerResendClientInflightBackoff(t *testing.T) { + s := New() + s.Store = new(persistence.MockStore) + require.NotNil(t, s) + + r, w := net.Pipe() + cl := clients.NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info)) + cl.Start() + s.Clients.Add(cl) + + o := make(chan []byte) + go func() { + buf, err := ioutil.ReadAll(w) + require.NoError(t, err) + o <- buf + }() + + pk1 := packets.Packet{ + FixedHeader: packets.FixedHeader{ + Type: packets.Publish, + Qos: 1, + }, + TopicName: "a/b/c", + Payload: []byte("hello"), + PacketID: 11, + } + + cl.Inflight.Set(pk1.PacketID, clients.InflightMessage{ + Packet: pk1, + Sent: time.Now().Unix(), + Resends: 0, + }) + + err := s.ResendClientInflight(cl, true) + require.NoError(t, err) + + time.Sleep(time.Millisecond) + + // Attempt to send twice, but backoff should kick in stopping second resend. + err = s.ResendClientInflight(cl, false) + require.NoError(t, err) + + r.Close() + + rcv := <-o + require.Equal(t, []byte{ + byte(packets.Publish<<4 | 1<<1 | 1<<3), 14, + 0, 5, + 'a', '/', 'b', '/', 'c', + 0, 11, + 'h', 'e', 'l', 'l', 'o', + }, rcv) + + m := cl.Inflight.GetAll() + require.Equal(t, 1, m[11].Resends) // index is packet id +} + +func TestServerResendClientInflightNoMessages(t *testing.T) { + s := New() + s.Store = new(persistence.MockStore) + require.NotNil(t, s) + + r, _ := net.Pipe() + cl := clients.NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info)) + out := []packets.Packet{} + err := s.ResendClientInflight(cl, true) + require.NoError(t, err) + require.Equal(t, 0, len(out)) + r.Close() +} + +func TestServerResendClientInflightDropMessage(t *testing.T) { + s := New() + s.Store = new(persistence.MockStore) + require.NotNil(t, s) + + r, _ := net.Pipe() + cl := clients.NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info)) + + pk1 := packets.Packet{ + FixedHeader: packets.FixedHeader{ + Type: packets.Publish, + Qos: 1, + }, + TopicName: "a/b/c", + Payload: []byte("hello"), + PacketID: 11, + } + + cl.Inflight.Set(pk1.PacketID, clients.InflightMessage{ + Packet: pk1, + Sent: time.Now().Unix(), + Resends: inflightMaxResends, + }) + + err := s.ResendClientInflight(cl, true) + require.NoError(t, err) + r.Close() + + m := cl.Inflight.GetAll() + require.Equal(t, 0, len(m)) + require.Equal(t, int64(1), atomic.LoadInt64(&s.System.PublishDropped)) +} + +func TestServerResendClientInflightError(t *testing.T) { + s := New() + require.NotNil(t, s) + + r, _ := net.Pipe() + cl := clients.NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info)) + + cl.Inflight.Set(1, clients.InflightMessage{ + Packet: packets.Packet{}, + Sent: time.Now().Unix(), + }) + r.Close() + err := s.ResendClientInflight(cl, true) + require.Error(t, err) +}