diff --git a/server/persistence/bolt/bolt.go b/server/persistence/bolt/bolt.go index 024f351..57a26e8 100644 --- a/server/persistence/bolt/bolt.go +++ b/server/persistence/bolt/bolt.go @@ -2,6 +2,7 @@ package bolt import ( //"encoding/gob" + "errors" "time" sgob "github.com/asdine/storm/codec/gob" @@ -64,6 +65,10 @@ func (s *Store) Close() { // WriteServerInfo writes the server info to the boltdb instance. func (s *Store) WriteServerInfo(v persistence.ServerInfo) error { + if s.db == nil { + return errors.New("boltdb not opened") + } + err := s.db.Save(&v) if err != nil { return err @@ -73,6 +78,10 @@ func (s *Store) WriteServerInfo(v persistence.ServerInfo) error { // WriteSubscription writes a single subscription to the boltdb instance. func (s *Store) WriteSubscription(v persistence.Subscription) error { + if s.db == nil { + return errors.New("boltdb not opened") + } + err := s.db.Save(&v) if err != nil { return err @@ -82,6 +91,10 @@ func (s *Store) WriteSubscription(v persistence.Subscription) error { // WriteInflight writes a single inflight message to the boltdb instance. func (s *Store) WriteInflight(v persistence.Message) error { + if s.db == nil { + return errors.New("boltdb not opened") + } + err := s.db.Save(&v) if err != nil { return err @@ -91,6 +104,10 @@ func (s *Store) WriteInflight(v persistence.Message) error { // WriteRetained writes a single retained message to the boltdb instance. func (s *Store) WriteRetained(v persistence.Message) error { + if s.db == nil { + return errors.New("boltdb not opened") + } + err := s.db.Save(&v) if err != nil { return err @@ -99,12 +116,25 @@ func (s *Store) WriteRetained(v persistence.Message) error { } // WriteClient writes a single client to the boltdb instance. -func (s *Store) WriteClient() { +func (s *Store) WriteClient(v persistence.Client) error { + if s.db == nil { + return errors.New("boltdb not opened") + } + err := s.db.Save(&v) + if err != nil { + return err + } + return nil } // ReadSubscriptions loads all the subscriptions from the boltdb instance. func (s *Store) ReadSubscriptions() (v []persistence.Subscription, err error) { + if s.db == nil { + err = errors.New("boltdb not opened") + return + } + err = s.db.Find("T", persistence.KSubscription, &v) if err != nil { return @@ -113,12 +143,26 @@ func (s *Store) ReadSubscriptions() (v []persistence.Subscription, err error) { } // ReadClients loads all the clients from the boltdb instance. -func (s *Store) ReadClients() { +func (s *Store) ReadClients() (v []persistence.Client, err error) { + if s.db == nil { + err = errors.New("boltdb not opened") + return + } + err = s.db.Find("T", persistence.KClient, &v) + if err != nil { + return + } + return } // ReadInflight loads all the inflight messages from the boltdb instance. func (s *Store) ReadInflight() (v []persistence.Message, err error) { + if s.db == nil { + err = errors.New("boltdb not opened") + return + } + err = s.db.Find("T", persistence.KInflight, &v) if err != nil { return @@ -128,6 +172,11 @@ func (s *Store) ReadInflight() (v []persistence.Message, err error) { // ReadRetained loads all the retained messages from the boltdb instance. func (s *Store) ReadRetained() (v []persistence.Message, err error) { + if s.db == nil { + err = errors.New("boltdb not opened") + return + } + err = s.db.Find("T", persistence.KRetained, &v) if err != nil { return @@ -137,6 +186,11 @@ func (s *Store) ReadRetained() (v []persistence.Message, err error) { //ReadServerInfo loads the server info from the boltdb instance. func (s *Store) ReadServerInfo() (v persistence.ServerInfo, err error) { + if s.db == nil { + err = errors.New("boltdb not opened") + return + } + err = s.db.One("ID", persistence.KServerInfo, &v) if err != nil { return diff --git a/server/persistence/bolt/bolt_test.go b/server/persistence/bolt/bolt_test.go index 8683edc..696eddc 100644 --- a/server/persistence/bolt/bolt_test.go +++ b/server/persistence/bolt/bolt_test.go @@ -14,11 +14,20 @@ import ( const tmpPath = "testbolt.db" -func teardown(t *testing.T) { +func teardown(s *Store, t *testing.T) { + s.Close() err := os.Remove(tmpPath) require.NoError(t, err) } +func TestSatsifies(t *testing.T) { + var x persistence.Store + x = New(tmpPath, &bbolt.Options{ + Timeout: 500 * time.Millisecond, + }) + require.NotNil(t, x) +} + func TestNew(t *testing.T) { s := New(tmpPath, &bbolt.Options{ Timeout: 500 * time.Millisecond, @@ -44,15 +53,21 @@ func TestOpen(t *testing.T) { s := New(tmpPath, nil) err := s.Open() require.NoError(t, err) - defer teardown(t) + defer teardown(s, t) require.NotNil(t, s.db) } -func TestStoreAndRetrieveServerInfo(t *testing.T) { +func TestOpenFailure(t *testing.T) { + s := New("..", nil) + err := s.Open() + require.Error(t, err) +} + +func TestWriteAndRetrieveServerInfo(t *testing.T) { s := New(tmpPath, nil) err := s.Open() require.NoError(t, err) - defer teardown(t) + defer teardown(s, t) v := system.Info{ Version: "test", @@ -68,11 +83,47 @@ func TestStoreAndRetrieveServerInfo(t *testing.T) { require.Equal(t, v.Started, r.Started) } +func TestWriteServerInfoNoDB(t *testing.T) { + s := New(tmpPath, nil) + err := s.WriteServerInfo(persistence.ServerInfo{}) + require.Error(t, err) +} + +func TestWriteServerInfoFail(t *testing.T) { + s := New(tmpPath, nil) + err := s.Open() + require.NoError(t, err) + + err = os.Remove(tmpPath) + require.NoError(t, err) + + err = s.WriteServerInfo(persistence.ServerInfo{}) + require.Error(t, err) +} + +func TestReadServerInfoNoDB(t *testing.T) { + s := New(tmpPath, nil) + _, err := s.ReadServerInfo() + require.Error(t, err) +} + +func TestReadServerInfoFail(t *testing.T) { + s := New(tmpPath, nil) + err := s.Open() + require.NoError(t, err) + + err = os.Remove(tmpPath) + require.NoError(t, err) + + _, err = s.ReadServerInfo() + require.Error(t, err) +} + func TestWriteAndRetrieveSubscription(t *testing.T) { s := New(tmpPath, nil) err := s.Open() require.NoError(t, err) - defer teardown(t) + defer teardown(s, t) v := persistence.Subscription{ ID: "test:a/b/c", @@ -100,11 +151,47 @@ func TestWriteAndRetrieveSubscription(t *testing.T) { require.Equal(t, 2, len(subs)) } +func TestWriteSubscriptionNoDB(t *testing.T) { + s := New(tmpPath, nil) + err := s.WriteSubscription(persistence.Subscription{}) + require.Error(t, err) +} + +func TestWriteSubscriptionFail(t *testing.T) { + s := New(tmpPath, nil) + err := s.Open() + require.NoError(t, err) + + err = os.Remove(tmpPath) + require.NoError(t, err) + + err = s.WriteSubscription(persistence.Subscription{}) + require.Error(t, err) +} + +func TestReadSubscriptionNoDB(t *testing.T) { + s := New(tmpPath, nil) + _, err := s.ReadSubscriptions() + require.Error(t, err) +} + +func TestReadSubscriptionFail(t *testing.T) { + s := New(tmpPath, nil) + err := s.Open() + require.NoError(t, err) + + err = os.Remove(tmpPath) + require.NoError(t, err) + + _, err = s.ReadSubscriptions() + require.Error(t, err) +} + func TestWriteAndRetrieveInflight(t *testing.T) { s := New(tmpPath, nil) err := s.Open() require.NoError(t, err) - defer teardown(t) + defer teardown(s, t) v := persistence.Message{ ID: "client1_if_0", @@ -136,11 +223,47 @@ func TestWriteAndRetrieveInflight(t *testing.T) { require.Equal(t, 2, len(msgs)) } -func TestWriteAndRetrievePersistent(t *testing.T) { +func TestWriteInflightNoDB(t *testing.T) { + s := New(tmpPath, nil) + err := s.WriteInflight(persistence.Message{}) + require.Error(t, err) +} + +func TestWriteInflightFail(t *testing.T) { s := New(tmpPath, nil) err := s.Open() require.NoError(t, err) - defer teardown(t) + + err = os.Remove(tmpPath) + require.NoError(t, err) + + err = s.WriteInflight(persistence.Message{}) + require.Error(t, err) +} + +func TestReadInflightNoDB(t *testing.T) { + s := New(tmpPath, nil) + _, err := s.ReadInflight() + require.Error(t, err) +} + +func TestReadInflightFail(t *testing.T) { + s := New(tmpPath, nil) + err := s.Open() + require.NoError(t, err) + + err = os.Remove(tmpPath) + require.NoError(t, err) + + _, err = s.ReadInflight() + require.Error(t, err) +} + +func TestWriteAndRetrieveRetained(t *testing.T) { + s := New(tmpPath, nil) + err := s.Open() + require.NoError(t, err) + defer teardown(s, t) v := persistence.Message{ ID: "client1_ret_200", @@ -154,7 +277,7 @@ func TestWriteAndRetrievePersistent(t *testing.T) { Sent: 100, Resends: 0, } - err = s.WriteInflight(v) + err = s.WriteRetained(v) require.NoError(t, err) v2 := persistence.Message{ @@ -169,7 +292,7 @@ func TestWriteAndRetrievePersistent(t *testing.T) { Sent: 200, Resends: 1, } - err = s.WriteInflight(v2) + err = s.WriteRetained(v2) require.NoError(t, err) msgs, err := s.ReadRetained() @@ -178,3 +301,124 @@ func TestWriteAndRetrievePersistent(t *testing.T) { require.Equal(t, true, msgs[0].FixedHeader.Retain) require.Equal(t, 2, len(msgs)) } + +func TestWriteRetainedNoDB(t *testing.T) { + s := New(tmpPath, nil) + err := s.WriteRetained(persistence.Message{}) + require.Error(t, err) +} + +func TestWriteRetainedFail(t *testing.T) { + s := New(tmpPath, nil) + err := s.Open() + require.NoError(t, err) + + err = os.Remove(tmpPath) + require.NoError(t, err) + + err = s.WriteRetained(persistence.Message{}) + require.Error(t, err) +} + +func TestReadRetainedNoDB(t *testing.T) { + s := New(tmpPath, nil) + _, err := s.ReadRetained() + require.Error(t, err) +} + +func TestReadRetainedFail(t *testing.T) { + s := New(tmpPath, nil) + err := s.Open() + require.NoError(t, err) + + err = os.Remove(tmpPath) + require.NoError(t, err) + + _, err = s.ReadRetained() + require.Error(t, err) +} + +func TestWriteAndRetrieveClients(t *testing.T) { + s := New(tmpPath, nil) + err := s.Open() + require.NoError(t, err) + defer teardown(s, t) + + v := persistence.Client{ + ID: "client1", + T: persistence.KClient, + Listener: "tcp1", + Username: []byte{'m', 'o', 'c', 'h', 'i'}, + CleanSession: true, + Subscriptions: []persistence.Subscription{ + persistence.Subscription{ + Filter: "a/b/c", + QoS: 1, + }, + }, + LWT: persistence.LWT{ + Topic: "a/b/c", + Message: []byte{'h', 'e', 'l', 'l', 'o'}, + Qos: 1, + Retain: true, + }, + } + err = s.WriteClient(v) + require.NoError(t, err) + + clients, err := s.ReadClients() + require.NoError(t, err) + + require.Equal(t, []byte{'m', 'o', 'c', 'h', 'i'}, clients[0].Username) + require.Equal(t, "a/b/c", clients[0].LWT.Topic) + require.Equal(t, "a/b/c", clients[0].Subscriptions[0].Filter) + + v2 := persistence.Client{ + ID: "client2", + T: persistence.KClient, + Listener: "tcp1", + } + err = s.WriteClient(v2) + require.NoError(t, err) + + clients, err = s.ReadClients() + require.NoError(t, err) + require.Equal(t, persistence.KClient, clients[0].T) + require.Equal(t, 2, len(clients)) +} + +func TestWriteClientNoDB(t *testing.T) { + s := New(tmpPath, nil) + err := s.WriteClient(persistence.Client{}) + require.Error(t, err) +} + +func TestWriteClientFail(t *testing.T) { + s := New(tmpPath, nil) + err := s.Open() + require.NoError(t, err) + + err = os.Remove(tmpPath) + require.NoError(t, err) + + err = s.WriteClient(persistence.Client{}) + require.Error(t, err) +} + +func TestReadClientNoDB(t *testing.T) { + s := New(tmpPath, nil) + _, err := s.ReadClients() + require.Error(t, err) +} + +func TestReadClientFail(t *testing.T) { + s := New(tmpPath, nil) + err := s.Open() + require.NoError(t, err) + + err = os.Remove(tmpPath) + require.NoError(t, err) + + _, err = s.ReadClients() + require.Error(t, err) +} diff --git a/server/persistence/persistence.go b/server/persistence/persistence.go index e99279e..b59673f 100644 --- a/server/persistence/persistence.go +++ b/server/persistence/persistence.go @@ -19,18 +19,18 @@ type Store interface { Open() error Close() - WriteSubscription() // including retained - WriteClient() - WriteInflight() - WriteServerInfo() - WriteRetained() + WriteSubscription(v Subscription) error + WriteClient(v Client) error + WriteInflight(v Message) error + WriteServerInfo(v ServerInfo) error + WriteRetained(v Message) error ReadSubscriptions() (v []Subscription, err error) - ReadInflight() - ReadRetained() - ReadClients() + ReadInflight() (v []Message, err error) + ReadRetained() (v []Message, err error) + ReadClients() (v []Client, err error) - ReadServerInfo() + ReadServerInfo() (v ServerInfo, err error) } // ServerInfo contains information and statistics about the server. @@ -69,10 +69,10 @@ type FixedHeader struct { Remaining int // the number of remaining bytes in the payload. } -/* // Client contains client data that can be persistently stored. type Client struct { ID string // the id of the client + T string // the type of the stored data. Listener string // the last known listener id for the client Username []byte // the username the client authenticated with. CleanSession bool // indicates if the client connected expecting a clean-session. @@ -88,8 +88,6 @@ type LWT struct { Retain bool // indicates whether the will message should be retained } -*/ - /* * * @@ -119,16 +117,29 @@ func (s *MockStore) Close() { } // WriteSubscription writes a single subscription to the storage instance. -func (s *MockStore) WriteSubscription() {} +func (s *MockStore) WriteSubscription(v Subscription) error { + return nil +} // WriteClient writes a single client to the storage instance. -func (s *MockStore) WriteClient() {} +func (s *MockStore) WriteClient(v Client) error { + return nil +} // WriteInFlight writes a single InFlight message to the storage instance. -func (s *MockStore) WriteInflight() {} +func (s *MockStore) WriteInflight(v Message) error { + return nil +} // WriteRetained writes a single retained message to the storage instance. -func (s *MockStore) WriteRetained() {} +func (s *MockStore) WriteRetained(v Message) error { + return nil +} + +// WriteServerInfo writes server info to the storage instance. +func (s *MockStore) WriteServerInfo(v ServerInfo) error { + return nil +} // ReadSubscriptions loads the subscriptions from the storage instance. func (s *MockStore) ReadSubscriptions() (v []Subscription, err error) { @@ -136,10 +147,19 @@ func (s *MockStore) ReadSubscriptions() (v []Subscription, err error) { } // ReadClients loads the clients from the storage instance. -func (s *MockStore) ReadClients() {} +func (s *MockStore) ReadClients() (v []Client, err error) { + return +} // ReadInflight loads the inflight messages from the storage instance. -func (s *MockStore) ReadInflight() {} +func (s *MockStore) ReadInflight() (v []Message, err error) { + return +} + +// ReadRetained loads the retained messages from the storage instance. +func (s *MockStore) ReadRetained() (v []Message, err error) { + return +} //ReadServerInfo loads the server info from the storage instance. func (s *MockStore) ReadServerInfo() (v ServerInfo, err error) { diff --git a/server/persistence/persistence_test.go b/server/persistence/persistence_test.go index 6e277c1..7ac937a 100644 --- a/server/persistence/persistence_test.go +++ b/server/persistence/persistence_test.go @@ -25,3 +25,63 @@ func TestMockStoreClose(t *testing.T) { s.Close() require.Equal(t, true, s.Closed) } + +func TestMockStoreWriteSubscription(t *testing.T) { + s := new(MockStore) + err := s.WriteSubscription(Subscription{}) + require.NoError(t, err) +} + +func TestMockStoreWriteClient(t *testing.T) { + s := new(MockStore) + err := s.WriteClient(Client{}) + require.NoError(t, err) +} + +func TestMockStoreWriteInflight(t *testing.T) { + s := new(MockStore) + err := s.WriteInflight(Message{}) + require.NoError(t, err) +} + +func TestMockStoreWriteRetained(t *testing.T) { + s := new(MockStore) + err := s.WriteRetained(Message{}) + require.NoError(t, err) +} + +func TestMockStoreWriteServerInfo(t *testing.T) { + s := new(MockStore) + err := s.WriteServerInfo(ServerInfo{}) + require.NoError(t, err) +} + +func TestMockStorReadServerInfo(t *testing.T) { + s := new(MockStore) + _, err := s.ReadServerInfo() + require.NoError(t, err) +} + +func TestMockStoreReadSubscriptions(t *testing.T) { + s := new(MockStore) + _, err := s.ReadSubscriptions() + require.NoError(t, err) +} + +func TestMockStoreReadClients(t *testing.T) { + s := new(MockStore) + _, err := s.ReadClients() + require.NoError(t, err) +} + +func TestMockStoreReadInflight(t *testing.T) { + s := new(MockStore) + _, err := s.ReadInflight() + require.NoError(t, err) +} + +func TestMockStoreReadRetained(t *testing.T) { + s := new(MockStore) + _, err := s.ReadRetained() + require.NoError(t, err) +}