Persistence V1

This commit is contained in:
Mochi
2020-02-04 21:18:29 +00:00
parent 85f269af5d
commit d1daa843df
8 changed files with 727 additions and 627 deletions

View File

@@ -42,7 +42,12 @@ func main() {
} }
// Start broker... // Start broker...
go server.Serve() go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
fmt.Println(aurora.BgMagenta(" Started! ")) fmt.Println(aurora.BgMagenta(" Started! "))
// Wait for signals... // Wait for signals...

View File

@@ -20,9 +20,9 @@ import (
) )
var ( var (
defaultKeepalive uint16 = 10 // in seconds. defaultKeepalive uint16 = 10 // in seconds.
resendBackoff = []int64{0, 1, 2, 10, 60, 120, 600, 3600, 21600} // <1 second to 6 hours //resendBackoff = []int64{0, 1, 2, 10, 60, 120, 600, 3600, 21600} // <1 second to 6 hours
maxResends = 6 // maximum number of times to retry sending QoS packets. //maxResends = 6 // maximum number of times to retry sending QoS packets.
ErrConnectionClosed = errors.New("Connection not open") ErrConnectionClosed = errors.New("Connection not open")
) )
@@ -109,28 +109,26 @@ type State struct {
started *sync.WaitGroup // tracks the goroutines which have been started. started *sync.WaitGroup // tracks the goroutines which have been started.
endedW *sync.WaitGroup // tracks when the writer has ended. endedW *sync.WaitGroup // tracks when the writer has ended.
endedR *sync.WaitGroup // tracks when the reader has ended. endedR *sync.WaitGroup // tracks when the reader has ended.
endOnce sync.Once // endOnce sync.Once // only end once.
} }
// NewClient returns a new instance of Client. // NewClient returns a new instance of Client.
func NewClient(c net.Conn, r *circ.Reader, w *circ.Writer, s *system.Info) *Client { func NewClient(c net.Conn, r *circ.Reader, w *circ.Writer, s *system.Info) *Client {
cl := &Client{ cl := &Client{
conn: c, conn: c,
r: r, r: r,
w: w, w: w,
system: s,
keepalive: defaultKeepalive, keepalive: defaultKeepalive,
Inflight: Inflight{ Inflight: Inflight{
internal: make(map[uint16]InflightMessage), internal: make(map[uint16]InflightMessage),
}, },
Subscriptions: make(map[string]byte), Subscriptions: make(map[string]byte),
State: State{ State: State{
started: new(sync.WaitGroup), started: new(sync.WaitGroup),
endedW: new(sync.WaitGroup), endedW: new(sync.WaitGroup),
endedR: new(sync.WaitGroup), endedR: new(sync.WaitGroup),
}, },
system: s,
} }
cl.refreshDeadline(cl.keepalive) cl.refreshDeadline(cl.keepalive)
@@ -146,6 +144,9 @@ func NewClientStub(s *system.Info) *Client {
internal: make(map[uint16]InflightMessage), internal: make(map[uint16]InflightMessage),
}, },
Subscriptions: make(map[string]byte), Subscriptions: make(map[string]byte),
State: State{
Done: 1,
},
} }
} }
@@ -240,6 +241,10 @@ func (cl *Client) Start() {
// Stop instructs the client to shut down all processing goroutines and disconnect. // Stop instructs the client to shut down all processing goroutines and disconnect.
func (cl *Client) Stop() { func (cl *Client) Stop() {
if atomic.LoadInt64(&cl.State.Done) == 1 {
return
}
cl.State.endOnce.Do(func() { cl.State.endOnce.Do(func() {
cl.r.Stop() cl.r.Stop()
cl.w.Stop() cl.w.Stop()
@@ -324,9 +329,6 @@ func (cl *Client) Read(h func(*Client, packets.Packet) error) error {
if err != nil { if err != nil {
return err return err
} }
// Attempt periodic resend of inflight messages (where applicable).
//cl.ResendInflight(false)
} }
} }
@@ -448,43 +450,6 @@ func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
return return
} }
// ResendInflight will attempt resend send any in-flight messages stored for a client.
func (cl *Client) ResendInflight(force bool) error {
if cl.Inflight.Len() == 0 {
return nil
}
nt := time.Now().Unix()
for _, tk := range cl.Inflight.GetAll() {
if tk.Resends >= maxResends { // After a reasonable time, drop inflight packets.
cl.Inflight.Delete(tk.Packet.PacketID)
if tk.Packet.FixedHeader.Type == packets.Publish {
atomic.AddInt64(&cl.system.PublishDropped, 1)
}
continue
}
// Only continue if the resend backoff time has passed and there's a backoff time.
if !force && (nt-tk.Sent < resendBackoff[tk.Resends] || len(resendBackoff) < tk.Resends) {
continue
}
if tk.Packet.FixedHeader.Type == packets.Publish {
tk.Packet.FixedHeader.Dup = true
}
tk.Resends++
tk.Sent = nt
cl.Inflight.Set(tk.Packet.PacketID, tk)
_, err := cl.WritePacket(tk.Packet)
if err != nil {
return err
}
}
return nil
}
// LWT contains the last will and testament details for a client connection. // LWT contains the last will and testament details for a client connection.
type LWT struct { type LWT struct {
Topic string // the topic the will message shall be sent to. Topic string // the topic the will message shall be sent to.

View File

@@ -685,157 +685,6 @@ func TestClientWritePacketInvalidPacket(t *testing.T) {
require.Error(t, err) require.Error(t, err)
} }
func TestClientResendInflight(t *testing.T) {
r, w := net.Pipe()
cl := NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
cl.Start()
o := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(w)
require.NoError(t, err)
o <- buf
}()
pk1 := packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Qos: 1,
},
TopicName: "a/b/c",
Payload: []byte("hello"),
PacketID: 11,
}
cl.Inflight.Set(pk1.PacketID, InflightMessage{
Packet: pk1,
Sent: time.Now().Unix(),
})
err := cl.ResendInflight(true)
require.NoError(t, err)
time.Sleep(time.Millisecond)
r.Close()
rcv := <-o
require.Equal(t, []byte{
byte(packets.Publish<<4 | 1<<1 | 1<<3), 14,
0, 5,
'a', '/', 'b', '/', 'c',
0, 11,
'h', 'e', 'l', 'l', 'o',
}, rcv)
m := cl.Inflight.GetAll()
require.Equal(t, 1, m[11].Resends) // index is packet id
}
func TestClientResendBackoff(t *testing.T) {
r, w := net.Pipe()
cl := NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
cl.Start()
o := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(w)
require.NoError(t, err)
o <- buf
}()
pk1 := packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Qos: 1,
},
TopicName: "a/b/c",
Payload: []byte("hello"),
PacketID: 11,
}
cl.Inflight.Set(pk1.PacketID, InflightMessage{
Packet: pk1,
Sent: time.Now().Unix(),
Resends: 0,
})
err := cl.ResendInflight(false)
require.NoError(t, err)
time.Sleep(time.Millisecond)
// Attempt to send twice, but backoff should kick in stopping second resend.
err = cl.ResendInflight(false)
require.NoError(t, err)
r.Close()
rcv := <-o
require.Equal(t, []byte{
byte(packets.Publish<<4 | 1<<1 | 1<<3), 14,
0, 5,
'a', '/', 'b', '/', 'c',
0, 11,
'h', 'e', 'l', 'l', 'o',
}, rcv)
m := cl.Inflight.GetAll()
require.Equal(t, 1, m[11].Resends) // index is packet id
}
func TestClientResendInflightNoMessages(t *testing.T) {
r, _ := net.Pipe()
cl := NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
out := []packets.Packet{}
err := cl.ResendInflight(true)
require.NoError(t, err)
require.Equal(t, 0, len(out))
r.Close()
}
func TestClientResendInflightDropMessage(t *testing.T) {
r, _ := net.Pipe()
cl := NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
pk1 := packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Qos: 1,
},
TopicName: "a/b/c",
Payload: []byte("hello"),
PacketID: 11,
}
cl.Inflight.Set(pk1.PacketID, InflightMessage{
Packet: pk1,
Sent: time.Now().Unix(),
Resends: maxResends,
})
err := cl.ResendInflight(true)
require.NoError(t, err)
r.Close()
m := cl.Inflight.GetAll()
require.Equal(t, 0, len(m))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.PublishDropped))
}
func TestClientResendInflightError(t *testing.T) {
r, _ := net.Pipe()
cl := NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
cl.Inflight.Set(1, InflightMessage{
Packet: packets.Packet{},
Sent: time.Now().Unix(),
})
r.Close()
err := cl.ResendInflight(true)
require.Error(t, err)
}
///// /////
func TestInflightSet(t *testing.T) { func TestInflightSet(t *testing.T) {
@@ -893,6 +742,20 @@ func BenchmarkInflightGetAll(b *testing.B) {
} }
} }
func TestInflightLen(t *testing.T) {
cl := genClient()
cl.Inflight.Set(2, InflightMessage{Packet: packets.Packet{}, Sent: 0})
require.Equal(t, 1, cl.Inflight.Len())
}
func BenchmarkInflightLen(b *testing.B) {
cl := genClient()
cl.Inflight.Set(2, InflightMessage{Packet: packets.Packet{}, Sent: 0})
for n := 0; n < b.N; n++ {
cl.Inflight.Len()
}
}
func TestInflightDelete(t *testing.T) { func TestInflightDelete(t *testing.T) {
cl := genClient() cl := genClient()
cl.Inflight.Set(3, InflightMessage{Packet: packets.Packet{}, Sent: 0}) cl.Inflight.Set(3, InflightMessage{Packet: packets.Packet{}, Sent: 0})

View File

@@ -1,10 +1,11 @@
package bolt package bolt
import ( import (
//"encoding/gob"
"errors" "errors"
"time" "time"
"fmt"
sgob "github.com/asdine/storm/codec/gob" sgob "github.com/asdine/storm/codec/gob"
"github.com/asdine/storm/v3" "github.com/asdine/storm/v3"
"go.etcd.io/bbolt" "go.etcd.io/bbolt"
@@ -17,6 +18,10 @@ const (
defaultTimeout = 250 * time.Millisecond defaultTimeout = 250 * time.Millisecond
) )
var (
errNotFound = "not found"
)
// Store is a backend for writing and reading to bolt persistent storage. // Store is a backend for writing and reading to bolt persistent storage.
type Store struct { type Store struct {
path string // the path on which to store the db file. path string // the path on which to store the db file.
@@ -190,74 +195,70 @@ func (s *Store) DeleteRetained(id string) error {
// ReadSubscriptions loads all the subscriptions from the boltdb instance. // ReadSubscriptions loads all the subscriptions from the boltdb instance.
func (s *Store) ReadSubscriptions() (v []persistence.Subscription, err error) { func (s *Store) ReadSubscriptions() (v []persistence.Subscription, err error) {
if s.db == nil { if s.db == nil {
err = errors.New("boltdb not opened") return v, errors.New("boltdb not opened")
return
} }
err = s.db.Find("T", persistence.KSubscription, &v) err = s.db.Find("T", persistence.KSubscription, &v)
if err != nil { if err != nil && err.Error() != errNotFound {
return return
} }
return return v, nil
} }
// ReadClients loads all the clients from the boltdb instance. // ReadClients loads all the clients from the boltdb instance.
func (s *Store) ReadClients() (v []persistence.Client, err error) { func (s *Store) ReadClients() (v []persistence.Client, err error) {
if s.db == nil { if s.db == nil {
err = errors.New("boltdb not opened") return v, errors.New("boltdb not opened")
return
} }
err = s.db.Find("T", persistence.KClient, &v) err = s.db.Find("T", persistence.KClient, &v)
if err != nil { if err != nil && err.Error() != errNotFound {
return return
} }
return return v, nil
} }
// ReadInflight loads all the inflight messages from the boltdb instance. // ReadInflight loads all the inflight messages from the boltdb instance.
func (s *Store) ReadInflight() (v []persistence.Message, err error) { func (s *Store) ReadInflight() (v []persistence.Message, err error) {
if s.db == nil { if s.db == nil {
err = errors.New("boltdb not opened") return v, errors.New("boltdb not opened")
return
} }
err = s.db.Find("T", persistence.KInflight, &v) err = s.db.Find("T", persistence.KInflight, &v)
if err != nil { if err != nil && err.Error() != errNotFound {
return return
} }
return return v, nil
} }
// ReadRetained loads all the retained messages from the boltdb instance. // ReadRetained loads all the retained messages from the boltdb instance.
func (s *Store) ReadRetained() (v []persistence.Message, err error) { func (s *Store) ReadRetained() (v []persistence.Message, err error) {
if s.db == nil { if s.db == nil {
err = errors.New("boltdb not opened") return v, errors.New("boltdb not opened")
return
} }
err = s.db.Find("T", persistence.KRetained, &v) err = s.db.Find("T", persistence.KRetained, &v)
if err != nil { if err != nil && err.Error() != errNotFound {
return return
} }
return return v, nil
} }
//ReadServerInfo loads the server info from the boltdb instance. //ReadServerInfo loads the server info from the boltdb instance.
func (s *Store) ReadServerInfo() (v persistence.ServerInfo, err error) { func (s *Store) ReadServerInfo() (v persistence.ServerInfo, err error) {
if s.db == nil { if s.db == nil {
err = errors.New("boltdb not opened") return v, errors.New("boltdb not opened")
return
} }
err = s.db.One("ID", persistence.KServerInfo, &v) err = s.db.One("ID", persistence.KServerInfo, &v)
if err != nil { fmt.Println(err)
if err != nil && err.Error() != errNotFound {
return return
} }
return return v, nil
} }

View File

@@ -111,10 +111,7 @@ func TestReadServerInfoFail(t *testing.T) {
s := New(tmpPath, nil) s := New(tmpPath, nil)
err := s.Open() err := s.Open()
require.NoError(t, err) require.NoError(t, err)
s.Close()
err = os.Remove(tmpPath)
require.NoError(t, err)
_, err = s.ReadServerInfo() _, err = s.ReadServerInfo()
require.Error(t, err) require.Error(t, err)
} }
@@ -168,10 +165,7 @@ func TestWriteSubscriptionFail(t *testing.T) {
s := New(tmpPath, nil) s := New(tmpPath, nil)
err := s.Open() err := s.Open()
require.NoError(t, err) require.NoError(t, err)
s.Close()
err = os.Remove(tmpPath)
require.NoError(t, err)
err = s.WriteSubscription(persistence.Subscription{}) err = s.WriteSubscription(persistence.Subscription{})
require.Error(t, err) require.Error(t, err)
} }
@@ -186,10 +180,7 @@ func TestReadSubscriptionFail(t *testing.T) {
s := New(tmpPath, nil) s := New(tmpPath, nil)
err := s.Open() err := s.Open()
require.NoError(t, err) require.NoError(t, err)
s.Close()
err = os.Remove(tmpPath)
require.NoError(t, err)
_, err = s.ReadSubscriptions() _, err = s.ReadSubscriptions()
require.Error(t, err) require.Error(t, err)
} }
@@ -248,10 +239,7 @@ func TestWriteInflightFail(t *testing.T) {
s := New(tmpPath, nil) s := New(tmpPath, nil)
err := s.Open() err := s.Open()
require.NoError(t, err) require.NoError(t, err)
s.Close()
err = os.Remove(tmpPath)
require.NoError(t, err)
err = s.WriteInflight(persistence.Message{}) err = s.WriteInflight(persistence.Message{})
require.Error(t, err) require.Error(t, err)
} }
@@ -266,10 +254,7 @@ func TestReadInflightFail(t *testing.T) {
s := New(tmpPath, nil) s := New(tmpPath, nil)
err := s.Open() err := s.Open()
require.NoError(t, err) require.NoError(t, err)
s.Close()
err = os.Remove(tmpPath)
require.NoError(t, err)
_, err = s.ReadInflight() _, err = s.ReadInflight()
require.Error(t, err) require.Error(t, err)
} }
@@ -352,10 +337,7 @@ func TestReadRetainedFail(t *testing.T) {
s := New(tmpPath, nil) s := New(tmpPath, nil)
err := s.Open() err := s.Open()
require.NoError(t, err) require.NoError(t, err)
s.Close()
err = os.Remove(tmpPath)
require.NoError(t, err)
_, err = s.ReadRetained() _, err = s.ReadRetained()
require.Error(t, err) require.Error(t, err)
} }
@@ -367,15 +349,11 @@ func TestWriteRetrieveDeleteClients(t *testing.T) {
defer teardown(s, t) defer teardown(s, t)
v := persistence.Client{ v := persistence.Client{
ID: "client1", ID: "cl_client1",
T: persistence.KClient, ClientID: "client1",
Listener: "tcp1", T: persistence.KClient,
Username: []byte{'m', 'o', 'c', 'h', 'i'}, Listener: "tcp1",
CleanSession: true, Username: []byte{'m', 'o', 'c', 'h', 'i'},
Subscriptions: map[string]byte{
"a/b/c": 0,
"d/e/f": 1,
},
LWT: persistence.LWT{ LWT: persistence.LWT{
Topic: "a/b/c", Topic: "a/b/c",
Message: []byte{'h', 'e', 'l', 'l', 'o'}, Message: []byte{'h', 'e', 'l', 'l', 'o'},
@@ -391,10 +369,10 @@ func TestWriteRetrieveDeleteClients(t *testing.T) {
require.Equal(t, []byte{'m', 'o', 'c', 'h', 'i'}, clients[0].Username) require.Equal(t, []byte{'m', 'o', 'c', 'h', 'i'}, clients[0].Username)
require.Equal(t, "a/b/c", clients[0].LWT.Topic) require.Equal(t, "a/b/c", clients[0].LWT.Topic)
require.Equal(t, uint8(1), clients[0].Subscriptions["d/e/f"])
v2 := persistence.Client{ v2 := persistence.Client{
ID: "client2", ID: "cl_client2",
ClientID: "client2",
T: persistence.KClient, T: persistence.KClient,
Listener: "tcp1", Listener: "tcp1",
} }
@@ -406,7 +384,7 @@ func TestWriteRetrieveDeleteClients(t *testing.T) {
require.Equal(t, persistence.KClient, clients[0].T) require.Equal(t, persistence.KClient, clients[0].T)
require.Equal(t, 2, len(clients)) require.Equal(t, 2, len(clients))
err = s.DeleteClient("client2") err = s.DeleteClient("cl_client2")
require.NoError(t, err) require.NoError(t, err)
clients, err = s.ReadClients() clients, err = s.ReadClients()
@@ -424,10 +402,7 @@ func TestWriteClientFail(t *testing.T) {
s := New(tmpPath, nil) s := New(tmpPath, nil)
err := s.Open() err := s.Open()
require.NoError(t, err) require.NoError(t, err)
s.Close()
err = os.Remove(tmpPath)
require.NoError(t, err)
err = s.WriteClient(persistence.Client{}) err = s.WriteClient(persistence.Client{})
require.Error(t, err) require.Error(t, err)
} }
@@ -442,10 +417,7 @@ func TestReadClientFail(t *testing.T) {
s := New(tmpPath, nil) s := New(tmpPath, nil)
err := s.Open() err := s.Open()
require.NoError(t, err) require.NoError(t, err)
s.Close()
err = os.Remove(tmpPath)
require.NoError(t, err)
_, err = s.ReadClients() _, err = s.ReadClients()
require.Error(t, err) require.Error(t, err)
} }
@@ -460,10 +432,7 @@ func TestDeleteSubscriptionFail(t *testing.T) {
s := New(tmpPath, nil) s := New(tmpPath, nil)
err := s.Open() err := s.Open()
require.NoError(t, err) require.NoError(t, err)
s.Close()
err = os.Remove(tmpPath)
require.NoError(t, err)
err = s.DeleteSubscription("a") err = s.DeleteSubscription("a")
require.Error(t, err) require.Error(t, err)
} }
@@ -478,10 +447,7 @@ func TestDeleteClientFail(t *testing.T) {
s := New(tmpPath, nil) s := New(tmpPath, nil)
err := s.Open() err := s.Open()
require.NoError(t, err) require.NoError(t, err)
s.Close()
err = os.Remove(tmpPath)
require.NoError(t, err)
err = s.DeleteClient("a") err = s.DeleteClient("a")
require.Error(t, err) require.Error(t, err)
} }
@@ -496,10 +462,7 @@ func TestDeleteInflightFail(t *testing.T) {
s := New(tmpPath, nil) s := New(tmpPath, nil)
err := s.Open() err := s.Open()
require.NoError(t, err) require.NoError(t, err)
s.Close()
err = os.Remove(tmpPath)
require.NoError(t, err)
err = s.DeleteInflight("a") err = s.DeleteInflight("a")
require.Error(t, err) require.Error(t, err)
} }
@@ -514,10 +477,7 @@ func TestDeleteRetainedFail(t *testing.T) {
s := New(tmpPath, nil) s := New(tmpPath, nil)
err := s.Open() err := s.Open()
require.NoError(t, err) require.NoError(t, err)
s.Close()
err = os.Remove(tmpPath)
require.NoError(t, err)
err = s.DeleteRetained("a") err = s.DeleteRetained("a")
require.Error(t, err) require.Error(t, err)
} }

View File

@@ -75,13 +75,12 @@ type FixedHeader struct {
// Client contains client data that can be persistently stored. // Client contains client data that can be persistently stored.
type Client struct { type Client struct {
ID string // the id of the client ID string // the storage key.
T string // the type of the stored data. ClientID string // the id of the client.
Listener string // the last known listener id for the client T string // the type of the stored data.
Username []byte // the username the client authenticated with. Listener string // the last known listener id for the client
CleanSession bool // indicates if the client connected expecting a clean-session. Username []byte // the username the client authenticated with.
Subscriptions map[string]byte // a list of the subscriptions the user has (qos keyed on filter). LWT LWT // the last-will-and-testament message for the client.
LWT LWT // the last-will-and-testament message for the client.
} }
// LWT contains details about a clients LWT payload. // LWT contains details about a clients LWT payload.
@@ -216,14 +215,10 @@ func (s *MockStore) ReadClients() (v []Client, err error) {
return []Client{ return []Client{
Client{ Client{
ID: "client1", ID: "cl_client1",
T: KClient, ClientID: "client1",
Listener: "tcp1", T: KClient,
CleanSession: true, Listener: "tcp1",
Subscriptions: map[string]byte{
"a/b/c": 0,
"d/e/f": 1,
},
}, },
}, nil }, nil
} }

View File

@@ -22,6 +22,7 @@ const (
Version = "0.1.0" // the server version. Version = "0.1.0" // the server version.
maxPacketID = 65535 // the maximum value of a 16-bit packet ID. maxPacketID = 65535 // the maximum value of a 16-bit packet ID.
) )
var ( var (
@@ -31,6 +32,9 @@ var (
ErrInvalidTopic = errors.New("Cannot publish to $ and $SYS topics") ErrInvalidTopic = errors.New("Cannot publish to $ and $SYS topics")
SysTopicInterval time.Duration = 30000 // the default number of milliseconds between $SYS topic publishes. SysTopicInterval time.Duration = 30000 // the default number of milliseconds between $SYS topic publishes.
inflightResendBackoff = []int64{0, 1, 2, 10, 60, 120, 600, 3600, 21600} // <1 second to 6 hours
inflightMaxResends = 6 // maximum number of times to retry sending QoS packets.
) )
// Server is an MQTT broker server. // Server is an MQTT broker server.
@@ -113,100 +117,6 @@ func (s *Server) Serve() error {
return nil return nil
} }
// readStore reads in any data from the persistent datastore (if applicable).
func (s *Server) readStore() error {
info, err := s.Store.ReadServerInfo()
if err != nil {
return err
}
s.loadServerInfo(info)
subs, err := s.Store.ReadSubscriptions()
if err != nil {
return err
}
s.loadSubscriptions(subs)
clients, err := s.Store.ReadClients()
if err != nil {
return err
}
s.loadClients(clients)
inflight, err := s.Store.ReadInflight()
if err != nil {
return err
}
s.loadInflight(inflight)
retained, err := s.Store.ReadRetained()
if err != nil {
return err
}
s.loadRetained(retained)
return nil
}
// loadServerInfo restores server info from the datastore.
func (s *Server) loadServerInfo(v persistence.ServerInfo) {
version := s.System.Version
s.System = &v.Info
s.System.Version = version
}
// loadSubscriptions restores subscriptions from the datastore.
func (s *Server) loadSubscriptions(v []persistence.Subscription) {
for _, sub := range v {
s.Topics.Subscribe(sub.Filter, sub.Client, sub.QoS)
}
}
// loadClients restores clients from the datastore.
func (s *Server) loadClients(v []persistence.Client) {
for _, cl := range v {
c := clients.NewClientStub(s.System)
c.ID = cl.ID
c.Listener = cl.Listener
c.Username = cl.Username
c.LWT = clients.LWT(cl.LWT)
c.Subscriptions = cl.Subscriptions
s.Clients.Add(c)
}
}
// loadInflight restores inflight messages from the datastore.
func (s *Server) loadInflight(v []persistence.Message) {
for _, msg := range v {
if client, ok := s.Clients.Get(msg.Client); ok {
client.Inflight.Set(msg.PacketID, clients.InflightMessage{
Packet: packets.Packet{
FixedHeader: packets.FixedHeader(msg.FixedHeader),
PacketID: msg.PacketID,
TopicName: msg.TopicName,
Payload: msg.Payload,
},
Sent: msg.Sent,
Resends: msg.Resends,
})
}
}
}
// loadRetained restores retained messages from the datastore.
func (s *Server) loadRetained(v []persistence.Message) {
for _, msg := range v {
s.Topics.RetainMessage(
packets.Packet{
FixedHeader: packets.FixedHeader(msg.FixedHeader),
TopicName: msg.TopicName,
Payload: msg.Payload,
},
)
}
}
// eventLoop runs server processes at intervals. // eventLoop runs server processes at intervals.
func (s *Server) eventLoop() { func (s *Server) eventLoop() {
for { for {
@@ -264,6 +174,7 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
atomic.AddInt64(&s.System.ClientsDisconnected, -1) atomic.AddInt64(&s.System.ClientsDisconnected, -1)
} }
existing.Stop() existing.Stop()
fmt.Printf("%+v\n", existing.Subscriptions)
if pk.CleanSession { if pk.CleanSession {
for k := range existing.Subscriptions { for k := range existing.Subscriptions {
delete(existing.Subscriptions, k) delete(existing.Subscriptions, k)
@@ -298,7 +209,18 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
return err return err
} }
cl.ResendInflight(true) s.ResendClientInflight(cl, true)
if s.Store != nil {
s.Store.WriteClient(persistence.Client{
ID: "cl_" + cl.ID,
ClientID: cl.ID,
T: persistence.KClient,
Listener: cl.Listener,
Username: cl.Username,
LWT: persistence.LWT(cl.LWT),
})
}
err = cl.Read(s.processPacket) err = cl.Read(s.processPacket)
if err != nil { if err != nil {
@@ -404,8 +326,22 @@ func (s *Server) processPublish(cl *clients.Client, pk packets.Packet) error {
} }
if pk.FixedHeader.Retain { if pk.FixedHeader.Retain {
q := s.Topics.RetainMessage(pk.PublishCopy()) out := pk.PublishCopy()
q := s.Topics.RetainMessage(out)
atomic.AddInt64(&s.System.Retained, q) atomic.AddInt64(&s.System.Retained, q)
if s.Store != nil {
if q == 1 {
s.Store.WriteRetained(persistence.Message{
ID: "ret_" + out.TopicName,
T: persistence.KRetained,
FixedHeader: persistence.FixedHeader(out.FixedHeader),
TopicName: out.TopicName,
Payload: out.Payload,
})
} else {
s.Store.DeleteRetained("ret_" + out.TopicName)
}
}
} }
if pk.FixedHeader.Qos > 0 { if pk.FixedHeader.Qos > 0 {
@@ -450,13 +386,25 @@ func (s *Server) publishToSubscribers(pk packets.Packet) {
// the client at some point, one way or another. Store the publish // the client at some point, one way or another. Store the publish
// packet in the client's inflight queue and attempt to redeliver // packet in the client's inflight queue and attempt to redeliver
// if an appropriate ack is not received (or if the client is offline). // if an appropriate ack is not received (or if the client is offline).
sent := time.Now().Unix()
q := client.Inflight.Set(out.PacketID, clients.InflightMessage{ q := client.Inflight.Set(out.PacketID, clients.InflightMessage{
Packet: out, Packet: out,
Sent: time.Now().Unix(), Sent: sent,
}) })
if q { if q {
atomic.AddInt64(&s.System.Inflight, 1) atomic.AddInt64(&s.System.Inflight, 1)
} }
if s.Store != nil {
s.Store.WriteInflight(persistence.Message{
ID: "if_" + client.ID + "_" + strconv.Itoa(int(out.PacketID)),
T: persistence.KRetained,
FixedHeader: persistence.FixedHeader(out.FixedHeader),
TopicName: out.TopicName,
Payload: out.Payload,
Sent: sent,
})
}
} }
s.writeClient(client, out) s.writeClient(client, out)
@@ -470,6 +418,9 @@ func (s *Server) processPuback(cl *clients.Client, pk packets.Packet) error {
if q { if q {
atomic.AddInt64(&s.System.Inflight, -1) atomic.AddInt64(&s.System.Inflight, -1)
} }
if s.Store != nil {
s.Store.DeleteInflight("if_" + cl.ID + "_" + strconv.Itoa(int(pk.PacketID)))
}
return nil return nil
} }
@@ -509,6 +460,10 @@ func (s *Server) processPubrel(cl *clients.Client, pk packets.Packet) error {
atomic.AddInt64(&s.System.Inflight, -1) atomic.AddInt64(&s.System.Inflight, -1)
} }
if s.Store != nil {
s.Store.DeleteInflight("if_" + cl.ID + "_" + strconv.Itoa(int(pk.PacketID)))
}
return nil return nil
} }
@@ -518,6 +473,9 @@ func (s *Server) processPubcomp(cl *clients.Client, pk packets.Packet) error {
if q { if q {
atomic.AddInt64(&s.System.Inflight, -1) atomic.AddInt64(&s.System.Inflight, -1)
} }
if s.Store != nil {
s.Store.DeleteInflight("if_" + cl.ID + "_" + strconv.Itoa(int(pk.PacketID)))
}
return nil return nil
} }
@@ -534,6 +492,16 @@ func (s *Server) processSubscribe(cl *clients.Client, pk packets.Packet) error {
} }
cl.NoteSubscription(pk.Topics[i], pk.Qoss[i]) cl.NoteSubscription(pk.Topics[i], pk.Qoss[i])
retCodes[i] = pk.Qoss[i] retCodes[i] = pk.Qoss[i]
if s.Store != nil {
s.Store.WriteSubscription(persistence.Subscription{
ID: "sub_" + cl.ID + ":" + pk.Topics[i],
T: persistence.KSubscription,
Filter: pk.Topics[i],
Client: cl.ID,
QoS: pk.Qoss[i],
})
}
} }
} }
@@ -585,8 +553,6 @@ func (s *Server) processUnsubscribe(cl *clients.Client, pk packets.Packet) error
// Due to the int to string conversions this method is not as cheap as // Due to the int to string conversions this method is not as cheap as
// some of the others so the publishing interval should be set appropriately. // some of the others so the publishing interval should be set appropriately.
func (s *Server) publishSysTopics() { func (s *Server) publishSysTopics() {
s.System.Uptime = time.Now().Unix() - s.System.Started
pk := packets.Packet{ pk := packets.Packet{
FixedHeader: packets.FixedHeader{ FixedHeader: packets.FixedHeader{
Type: packets.Publish, Type: packets.Publish,
@@ -594,6 +560,7 @@ func (s *Server) publishSysTopics() {
}, },
} }
s.System.Uptime = time.Now().Unix() - s.System.Started
topics := map[string]string{ topics := map[string]string{
"$SYS/broker/version": s.System.Version, "$SYS/broker/version": s.System.Version,
"$SYS/broker/uptime": strconv.Itoa(int(s.System.Uptime)), "$SYS/broker/uptime": strconv.Itoa(int(s.System.Uptime)),
@@ -629,7 +596,63 @@ func (s *Server) publishSysTopics() {
persistence.KServerInfo, persistence.KServerInfo,
}) })
} }
}
// ResendClientInflight attempts to resend all undelivered inflight messages
// to a client.
func (s *Server) ResendClientInflight(cl *clients.Client, force bool) error {
if cl.Inflight.Len() == 0 {
return nil
}
nt := time.Now().Unix()
for _, tk := range cl.Inflight.GetAll() {
// After a reasonable time, drop inflight packets.
if tk.Resends >= inflightMaxResends {
cl.Inflight.Delete(tk.Packet.PacketID)
if tk.Packet.FixedHeader.Type == packets.Publish {
atomic.AddInt64(&s.System.PublishDropped, 1)
}
if s.Store != nil {
s.Store.DeleteInflight("if_" + cl.ID + "_" + strconv.Itoa(int(tk.Packet.PacketID)))
}
continue
}
// Only continue if the resend backoff time has passed and there's a backoff time.
if !force && (nt-tk.Sent < inflightResendBackoff[tk.Resends] || len(inflightResendBackoff) < tk.Resends) {
continue
}
if tk.Packet.FixedHeader.Type == packets.Publish {
tk.Packet.FixedHeader.Dup = true
}
tk.Resends++
tk.Sent = nt
cl.Inflight.Set(tk.Packet.PacketID, tk)
_, err := cl.WritePacket(tk.Packet)
if err != nil {
return err
}
if s.Store != nil {
s.Store.WriteInflight(persistence.Message{
ID: "if_" + cl.ID + "_" + strconv.Itoa(int(tk.Packet.PacketID)),
T: persistence.KRetained,
FixedHeader: persistence.FixedHeader(tk.Packet.FixedHeader),
TopicName: tk.Packet.TopicName,
Payload: tk.Packet.Payload,
Sent: tk.Sent,
Resends: tk.Resends,
})
}
}
return nil
} }
// Close attempts to gracefully shutdown the server, all listeners, clients, and stores. // Close attempts to gracefully shutdown the server, all listeners, clients, and stores.
@@ -665,10 +688,104 @@ func (s *Server) closeClient(cl *clients.Client, sendLWT bool) error {
TopicName: cl.LWT.Topic, TopicName: cl.LWT.Topic,
Payload: cl.LWT.Message, Payload: cl.LWT.Message,
}) })
// omit errors, since we're not logging and need to close the client in either case.
} }
cl.Stop() cl.Stop()
return nil return nil
} }
// readStore reads in any data from the persistent datastore (if applicable).
func (s *Server) readStore() error {
info, err := s.Store.ReadServerInfo()
if err != nil {
return fmt.Errorf("load server info; %w", err)
}
s.loadServerInfo(info)
clients, err := s.Store.ReadClients()
if err != nil {
return fmt.Errorf("load clients; %w", err)
}
fmt.Println("loading clients", clients)
s.loadClients(clients)
subs, err := s.Store.ReadSubscriptions()
if err != nil {
return fmt.Errorf("load subscriptions; %w", err)
}
s.loadSubscriptions(subs)
inflight, err := s.Store.ReadInflight()
if err != nil {
return fmt.Errorf("load inflight; %w", err)
}
s.loadInflight(inflight)
retained, err := s.Store.ReadRetained()
if err != nil {
return fmt.Errorf("load retained; %w", err)
}
s.loadRetained(retained)
return nil
}
// loadServerInfo restores server info from the datastore.
func (s *Server) loadServerInfo(v persistence.ServerInfo) {
version := s.System.Version
s.System = &v.Info
s.System.Version = version
}
// loadSubscriptions restores subscriptions from the datastore.
func (s *Server) loadSubscriptions(v []persistence.Subscription) {
for _, sub := range v {
s.Topics.Subscribe(sub.Filter, sub.Client, sub.QoS)
if cl, ok := s.Clients.Get(sub.Client); ok {
cl.NoteSubscription(sub.Filter, sub.QoS)
}
}
}
// loadClients restores clients from the datastore.
func (s *Server) loadClients(v []persistence.Client) {
for _, c := range v {
cl := clients.NewClientStub(s.System)
cl.ID = c.ClientID
cl.Listener = c.Listener
cl.Username = c.Username
cl.LWT = clients.LWT(c.LWT)
s.Clients.Add(cl)
}
}
// loadInflight restores inflight messages from the datastore.
func (s *Server) loadInflight(v []persistence.Message) {
for _, msg := range v {
if client, ok := s.Clients.Get(msg.Client); ok {
client.Inflight.Set(msg.PacketID, clients.InflightMessage{
Packet: packets.Packet{
FixedHeader: packets.FixedHeader(msg.FixedHeader),
PacketID: msg.PacketID,
TopicName: msg.TopicName,
Payload: msg.Payload,
},
Sent: msg.Sent,
Resends: msg.Resends,
})
}
}
}
// loadRetained restores retained messages from the datastore.
func (s *Server) loadRetained(v []persistence.Message) {
for _, msg := range v {
s.Topics.RetainMessage(packets.Packet{
FixedHeader: packets.FixedHeader(msg.FixedHeader),
TopicName: msg.TopicName,
Payload: msg.Payload,
})
}
}

View File

@@ -1,7 +1,7 @@
package server package server
import ( import (
"errors" //"errors"
"io/ioutil" "io/ioutil"
"net" "net"
"strconv" "strconv"
@@ -25,6 +25,7 @@ const defaultPort = ":18882"
func setupClient() (s *Server, cl *clients.Client, r net.Conn, w net.Conn) { func setupClient() (s *Server, cl *clients.Client, r net.Conn, w net.Conn) {
s = New() s = New()
s.Store = new(persistence.MockStore)
r, w = net.Pipe() r, w = net.Pipe()
cl = clients.NewClient(w, circ.NewReader(256, 8), circ.NewWriter(256, 8), s.System) cl = clients.NewClient(w, circ.NewReader(256, 8), circ.NewWriter(256, 8), s.System)
cl.ID = "mochi" cl.ID = "mochi"
@@ -120,234 +121,6 @@ func BenchmarkServerAddListener(b *testing.B) {
} }
} }
func TestServerReadStore(t *testing.T) {
s := New()
require.NotNil(t, s)
s.Store = new(persistence.MockStore)
err := s.readStore()
require.NoError(t, err)
require.Equal(t, int64(100), s.System.Started)
require.Equal(t, topics.Subscriptions{"test": 1}, s.Topics.Subscribers("a/b/c"))
cl1, ok := s.Clients.Get("client1")
require.Equal(t, true, ok)
msg, ok := cl1.Inflight.Get(100)
require.Equal(t, true, ok)
require.Equal(t, []byte{'y', 'e', 's'}, msg.Packet.Payload)
}
func TestServerReadStoreFailures(t *testing.T) {
s := New()
require.NotNil(t, s)
s.Store = new(persistence.MockStore)
s.Store.(*persistence.MockStore).Fail = map[string]bool{
"read_subs": true,
"read_clients": true,
"read_inflight": true,
"read_retained": true,
"read_info": true,
}
err := s.readStore()
require.Error(t, err)
require.Equal(t, errors.New("test_info"), err)
delete(s.Store.(*persistence.MockStore).Fail, "read_info")
err = s.readStore()
require.Error(t, err)
require.Equal(t, errors.New("test_subs"), err)
delete(s.Store.(*persistence.MockStore).Fail, "read_subs")
err = s.readStore()
require.Error(t, err)
require.Equal(t, errors.New("test_clients"), err)
delete(s.Store.(*persistence.MockStore).Fail, "read_clients")
err = s.readStore()
require.Error(t, err)
require.Equal(t, errors.New("test_inflight"), err)
delete(s.Store.(*persistence.MockStore).Fail, "read_inflight")
err = s.readStore()
require.Error(t, err)
require.Equal(t, errors.New("test_retained"), err)
delete(s.Store.(*persistence.MockStore).Fail, "read_retained")
}
func TestServerLoadServerInfo(t *testing.T) {
s := New()
require.NotNil(t, s)
s.System.Version = "original"
s.loadServerInfo(persistence.ServerInfo{
system.Info{
Version: "test",
Started: 100,
}, persistence.KServerInfo,
})
require.Equal(t, "original", s.System.Version)
require.Equal(t, int64(100), s.System.Started)
}
func TestServerLoadSubscriptions(t *testing.T) {
s := New()
require.NotNil(t, s)
subs := []persistence.Subscription{
persistence.Subscription{
ID: "test:a/b/c",
Client: "test",
Filter: "a/b/c",
QoS: 1,
T: persistence.KSubscription,
},
persistence.Subscription{
ID: "test:d/e/f",
Client: "test",
Filter: "d/e/f",
QoS: 0,
T: persistence.KSubscription,
},
}
s.loadSubscriptions(subs)
require.Equal(t, topics.Subscriptions{"test": 1}, s.Topics.Subscribers("a/b/c"))
require.Equal(t, topics.Subscriptions{"test": 0}, s.Topics.Subscribers("d/e/f"))
}
func TestServerLoadClients(t *testing.T) {
s := New()
require.NotNil(t, s)
clients := []persistence.Client{
persistence.Client{
ID: "client1",
T: persistence.KClient,
Listener: "tcp1",
CleanSession: true,
Subscriptions: map[string]byte{
"a/b/c": 0,
"d/e/f": 1,
},
},
persistence.Client{
ID: "client2",
T: persistence.KClient,
Listener: "tcp1",
Subscriptions: map[string]byte{
"q/w/e": 2,
},
},
}
s.loadClients(clients)
cl1, ok := s.Clients.Get("client1")
require.Equal(t, true, ok)
require.NotNil(t, cl1)
cl2, ok2 := s.Clients.Get("client2")
require.Equal(t, true, ok2)
require.NotNil(t, cl2)
}
func TestServerLoadInflight(t *testing.T) {
s := New()
require.NotNil(t, s)
msgs := []persistence.Message{
persistence.Message{
ID: "client1_if_0",
T: persistence.KInflight,
Client: "client1",
PacketID: 0,
TopicName: "a/b/c",
Payload: []byte{'h', 'e', 'l', 'l', 'o'},
Sent: 100,
Resends: 0,
},
persistence.Message{
ID: "client1_if_100",
T: persistence.KInflight,
Client: "client1",
PacketID: 100,
TopicName: "d/e/f",
Payload: []byte{'y', 'e', 's'},
Sent: 200,
Resends: 1,
},
}
w, _ := net.Pipe()
defer w.Close()
c1 := clients.NewClient(w, nil, nil, nil)
c1.ID = "client1"
s.Clients.Add(c1)
s.loadInflight(msgs)
cl1, ok := s.Clients.Get("client1")
require.Equal(t, true, ok)
require.Equal(t, "client1", cl1.ID)
msg, ok := cl1.Inflight.Get(100)
require.Equal(t, true, ok)
require.Equal(t, []byte{'y', 'e', 's'}, msg.Packet.Payload)
}
func TestServerLoadRetained(t *testing.T) {
s := New()
require.NotNil(t, s)
msgs := []persistence.Message{
persistence.Message{
ID: "client1_ret_200",
T: persistence.KRetained,
FixedHeader: persistence.FixedHeader{
Retain: true,
},
PacketID: 200,
TopicName: "a/b/c",
Payload: []byte{'h', 'e', 'l', 'l', 'o'},
Sent: 100,
Resends: 0,
},
persistence.Message{
ID: "client1_ret_300",
T: persistence.KRetained,
FixedHeader: persistence.FixedHeader{
Retain: true,
},
PacketID: 100,
TopicName: "d/e/f",
Payload: []byte{'y', 'e', 's'},
Sent: 200,
Resends: 1,
},
}
s.loadRetained(msgs)
require.Equal(t, 1, len(s.Topics.Messages("a/b/c")))
require.Equal(t, 1, len(s.Topics.Messages("d/e/f")))
msg := s.Topics.Messages("a/b/c")
require.Equal(t, []byte{'h', 'e', 'l', 'l', 'o'}, msg[0].Payload)
msg = s.Topics.Messages("d/e/f")
require.Equal(t, []byte{'y', 'e', 's'}, msg[0].Payload)
}
func TestServerServe(t *testing.T) { func TestServerServe(t *testing.T) {
s := New() s := New()
require.NotNil(t, s) require.NotNil(t, s)
@@ -877,6 +650,35 @@ func TestServerProcessPublishQoS2(t *testing.T) {
require.Equal(t, int64(0), atomic.LoadInt64(&s.System.Retained)) require.Equal(t, int64(0), atomic.LoadInt64(&s.System.Retained))
} }
func TestServerProcessPublishUnretain(t *testing.T) {
s, cl1, r1, w1 := setupClient()
s.Clients.Add(cl1)
ack1 := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(r1)
if err != nil {
panic(err)
}
ack1 <- buf
}()
err := s.processPacket(cl1, packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Retain: true,
},
TopicName: "a/b/c",
Payload: []byte{},
})
require.NoError(t, err)
time.Sleep(10 * time.Millisecond)
w1.Close()
require.Equal(t, int64(0), atomic.LoadInt64(&s.System.Retained))
}
func TestServerProcessPublishOfflineQueuing(t *testing.T) { func TestServerProcessPublishOfflineQueuing(t *testing.T) {
s, cl1, r1, w1 := setupClient() s, cl1, r1, w1 := setupClient()
cl1.ID = "mochi1" cl1.ID = "mochi1"
@@ -1451,3 +1253,395 @@ func TestServerCloseClientClosed(t *testing.T) {
err := s.closeClient(cl, true) err := s.closeClient(cl, true)
require.NoError(t, err) require.NoError(t, err)
} }
func TestServerReadStore(t *testing.T) {
s := New()
require.NotNil(t, s)
s.Store = new(persistence.MockStore)
err := s.readStore()
require.NoError(t, err)
require.Equal(t, int64(100), s.System.Started)
require.Equal(t, topics.Subscriptions{"test": 1}, s.Topics.Subscribers("a/b/c"))
cl1, ok := s.Clients.Get("client1")
require.Equal(t, true, ok)
msg, ok := cl1.Inflight.Get(100)
require.Equal(t, true, ok)
require.Equal(t, []byte{'y', 'e', 's'}, msg.Packet.Payload)
}
func TestServerReadStoreFailures(t *testing.T) {
s := New()
require.NotNil(t, s)
s.Store = new(persistence.MockStore)
s.Store.(*persistence.MockStore).Fail = map[string]bool{
"read_subs": true,
"read_clients": true,
"read_inflight": true,
"read_retained": true,
"read_info": true,
}
err := s.readStore()
require.Error(t, err)
delete(s.Store.(*persistence.MockStore).Fail, "read_info")
err = s.readStore()
require.Error(t, err)
delete(s.Store.(*persistence.MockStore).Fail, "read_subs")
err = s.readStore()
require.Error(t, err)
delete(s.Store.(*persistence.MockStore).Fail, "read_clients")
err = s.readStore()
require.Error(t, err)
delete(s.Store.(*persistence.MockStore).Fail, "read_inflight")
err = s.readStore()
require.Error(t, err)
delete(s.Store.(*persistence.MockStore).Fail, "read_retained")
}
func TestServerLoadServerInfo(t *testing.T) {
s := New()
require.NotNil(t, s)
s.System.Version = "original"
s.loadServerInfo(persistence.ServerInfo{
system.Info{
Version: "test",
Started: 100,
}, persistence.KServerInfo,
})
require.Equal(t, "original", s.System.Version)
require.Equal(t, int64(100), s.System.Started)
}
func TestServerLoadSubscriptions(t *testing.T) {
s := New()
require.NotNil(t, s)
cl := clients.NewClientStub(s.System)
cl.ID = "test"
s.Clients.Add(cl)
subs := []persistence.Subscription{
persistence.Subscription{
ID: "test:a/b/c",
Client: "test",
Filter: "a/b/c",
QoS: 1,
T: persistence.KSubscription,
},
persistence.Subscription{
ID: "test:d/e/f",
Client: "test",
Filter: "d/e/f",
QoS: 0,
T: persistence.KSubscription,
},
}
s.loadSubscriptions(subs)
require.Equal(t, topics.Subscriptions{"test": 1}, s.Topics.Subscribers("a/b/c"))
require.Equal(t, topics.Subscriptions{"test": 0}, s.Topics.Subscribers("d/e/f"))
}
func TestServerLoadClients(t *testing.T) {
s := New()
require.NotNil(t, s)
clients := []persistence.Client{
persistence.Client{
ID: "cl_client1",
ClientID: "client1",
T: persistence.KClient,
Listener: "tcp1",
},
persistence.Client{
ID: "cl_client2",
ClientID: "client2",
T: persistence.KClient,
Listener: "tcp1",
},
}
s.loadClients(clients)
cl1, ok := s.Clients.Get("client1")
require.Equal(t, true, ok)
require.NotNil(t, cl1)
cl2, ok2 := s.Clients.Get("client2")
require.Equal(t, true, ok2)
require.NotNil(t, cl2)
}
func TestServerLoadInflight(t *testing.T) {
s := New()
require.NotNil(t, s)
msgs := []persistence.Message{
persistence.Message{
ID: "client1_if_0",
T: persistence.KInflight,
Client: "client1",
PacketID: 0,
TopicName: "a/b/c",
Payload: []byte{'h', 'e', 'l', 'l', 'o'},
Sent: 100,
Resends: 0,
},
persistence.Message{
ID: "client1_if_100",
T: persistence.KInflight,
Client: "client1",
PacketID: 100,
TopicName: "d/e/f",
Payload: []byte{'y', 'e', 's'},
Sent: 200,
Resends: 1,
},
}
w, _ := net.Pipe()
defer w.Close()
c1 := clients.NewClient(w, nil, nil, nil)
c1.ID = "client1"
s.Clients.Add(c1)
s.loadInflight(msgs)
cl1, ok := s.Clients.Get("client1")
require.Equal(t, true, ok)
require.Equal(t, "client1", cl1.ID)
msg, ok := cl1.Inflight.Get(100)
require.Equal(t, true, ok)
require.Equal(t, []byte{'y', 'e', 's'}, msg.Packet.Payload)
}
func TestServerLoadRetained(t *testing.T) {
s := New()
require.NotNil(t, s)
msgs := []persistence.Message{
persistence.Message{
ID: "client1_ret_200",
T: persistence.KRetained,
FixedHeader: persistence.FixedHeader{
Retain: true,
},
PacketID: 200,
TopicName: "a/b/c",
Payload: []byte{'h', 'e', 'l', 'l', 'o'},
Sent: 100,
Resends: 0,
},
persistence.Message{
ID: "client1_ret_300",
T: persistence.KRetained,
FixedHeader: persistence.FixedHeader{
Retain: true,
},
PacketID: 100,
TopicName: "d/e/f",
Payload: []byte{'y', 'e', 's'},
Sent: 200,
Resends: 1,
},
}
s.loadRetained(msgs)
require.Equal(t, 1, len(s.Topics.Messages("a/b/c")))
require.Equal(t, 1, len(s.Topics.Messages("d/e/f")))
msg := s.Topics.Messages("a/b/c")
require.Equal(t, []byte{'h', 'e', 'l', 'l', 'o'}, msg[0].Payload)
msg = s.Topics.Messages("d/e/f")
require.Equal(t, []byte{'y', 'e', 's'}, msg[0].Payload)
}
func TestServerResendClientInflight(t *testing.T) {
s := New()
s.Store = new(persistence.MockStore)
require.NotNil(t, s)
r, w := net.Pipe()
cl := clients.NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
cl.Start()
s.Clients.Add(cl)
o := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(w)
require.NoError(t, err)
o <- buf
}()
pk1 := packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Qos: 1,
},
TopicName: "a/b/c",
Payload: []byte("hello"),
PacketID: 11,
}
cl.Inflight.Set(pk1.PacketID, clients.InflightMessage{
Packet: pk1,
Sent: time.Now().Unix(),
})
err := s.ResendClientInflight(cl, true)
require.NoError(t, err)
time.Sleep(time.Millisecond)
r.Close()
rcv := <-o
require.Equal(t, []byte{
byte(packets.Publish<<4 | 1<<1 | 1<<3), 14,
0, 5,
'a', '/', 'b', '/', 'c',
0, 11,
'h', 'e', 'l', 'l', 'o',
}, rcv)
m := cl.Inflight.GetAll()
require.Equal(t, 1, m[11].Resends) // index is packet id
}
func TestServerResendClientInflightBackoff(t *testing.T) {
s := New()
s.Store = new(persistence.MockStore)
require.NotNil(t, s)
r, w := net.Pipe()
cl := clients.NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
cl.Start()
s.Clients.Add(cl)
o := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(w)
require.NoError(t, err)
o <- buf
}()
pk1 := packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Qos: 1,
},
TopicName: "a/b/c",
Payload: []byte("hello"),
PacketID: 11,
}
cl.Inflight.Set(pk1.PacketID, clients.InflightMessage{
Packet: pk1,
Sent: time.Now().Unix(),
Resends: 0,
})
err := s.ResendClientInflight(cl, true)
require.NoError(t, err)
time.Sleep(time.Millisecond)
// Attempt to send twice, but backoff should kick in stopping second resend.
err = s.ResendClientInflight(cl, false)
require.NoError(t, err)
r.Close()
rcv := <-o
require.Equal(t, []byte{
byte(packets.Publish<<4 | 1<<1 | 1<<3), 14,
0, 5,
'a', '/', 'b', '/', 'c',
0, 11,
'h', 'e', 'l', 'l', 'o',
}, rcv)
m := cl.Inflight.GetAll()
require.Equal(t, 1, m[11].Resends) // index is packet id
}
func TestServerResendClientInflightNoMessages(t *testing.T) {
s := New()
s.Store = new(persistence.MockStore)
require.NotNil(t, s)
r, _ := net.Pipe()
cl := clients.NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
out := []packets.Packet{}
err := s.ResendClientInflight(cl, true)
require.NoError(t, err)
require.Equal(t, 0, len(out))
r.Close()
}
func TestServerResendClientInflightDropMessage(t *testing.T) {
s := New()
s.Store = new(persistence.MockStore)
require.NotNil(t, s)
r, _ := net.Pipe()
cl := clients.NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
pk1 := packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Qos: 1,
},
TopicName: "a/b/c",
Payload: []byte("hello"),
PacketID: 11,
}
cl.Inflight.Set(pk1.PacketID, clients.InflightMessage{
Packet: pk1,
Sent: time.Now().Unix(),
Resends: inflightMaxResends,
})
err := s.ResendClientInflight(cl, true)
require.NoError(t, err)
r.Close()
m := cl.Inflight.GetAll()
require.Equal(t, 0, len(m))
require.Equal(t, int64(1), atomic.LoadInt64(&s.System.PublishDropped))
}
func TestServerResendClientInflightError(t *testing.T) {
s := New()
require.NotNil(t, s)
r, _ := net.Pipe()
cl := clients.NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
cl.Inflight.Set(1, clients.InflightMessage{
Packet: packets.Packet{},
Sent: time.Now().Unix(),
})
r.Close()
err := s.ResendClientInflight(cl, true)
require.Error(t, err)
}