diff --git a/mysql/mysql.go b/mysql/mysql.go index 1c3fd19a..2ecee278 100644 --- a/mysql/mysql.go +++ b/mysql/mysql.go @@ -14,6 +14,7 @@ import ( type Storage struct { db *sql.DB gcInterval time.Duration + done chan struct{} sqlSelect string sqlInsert string @@ -80,6 +81,7 @@ func New(config ...Config) *Storage { store := &Storage{ gcInterval: cfg.GCInterval, db: db, + done: make(chan struct{}), sqlSelect: fmt.Sprintf("SELECT v, e FROM %s WHERE k=?;", cfg.Table), sqlInsert: fmt.Sprintf("INSERT INTO %s (k, v, e) VALUES (?,?,?) ON DUPLICATE KEY UPDATE v = ?, e = ?", cfg.Table), sqlDelete: fmt.Sprintf("DELETE FROM %s WHERE k=?", cfg.Table), @@ -150,16 +152,20 @@ func (s *Storage) Reset() error { // Close the database func (s *Storage) Close() error { + s.done <- struct{}{} return s.db.Close() } // Garbage collector to delete expired keys func (s *Storage) gc() { - tick := time.NewTicker(s.gcInterval) + ticker := time.NewTicker(s.gcInterval) + defer ticker.Stop() for { - <-tick.C - if _, err := s.db.Exec(s.sqlGC); err != nil { - panic(err) + select { + case <-s.done: + return + case t := <-ticker.C: + _, _ = s.db.Exec(s.sqlGC, t.Unix()) } } } diff --git a/mysql/mysql_test.go b/mysql/mysql_test.go index 8b155112..f7c40717 100644 --- a/mysql/mysql_test.go +++ b/mysql/mysql_test.go @@ -121,3 +121,8 @@ func Test_MYSQL_Clear(t *testing.T) { utils.AssertEqual(t, ErrNotExist, err) utils.AssertEqual(t, true, len(result) == 0) } + +func Test_Mysql_Close(t *testing.T) { + err := testStore.Close() + utils.AssertEqual(t, nil, err) +} diff --git a/postgres/postgres.go b/postgres/postgres.go index aacee212..a04ee9bc 100644 --- a/postgres/postgres.go +++ b/postgres/postgres.go @@ -15,6 +15,7 @@ import ( type Storage struct { db *sql.DB gcInterval time.Duration + done chan struct{} sqlSelect string sqlInsert string @@ -97,6 +98,7 @@ func New(config ...Config) *Storage { store := &Storage{ db: db, gcInterval: cfg.GCInterval, + done: make(chan struct{}), sqlSelect: fmt.Sprintf(`SELECT v, e FROM %s WHERE k=$1;`, cfg.Table), sqlInsert: fmt.Sprintf("INSERT INTO %s (k, v, e) VALUES ($1, $2, $3) ON CONFLICT (k) DO UPDATE SET v = $4, e = $5", cfg.Table), sqlDelete: fmt.Sprintf("DELETE FROM %s WHERE k=$1", cfg.Table), @@ -164,16 +166,20 @@ func (s *Storage) Reset() error { // Close the database func (s *Storage) Close() error { + s.done <- struct{}{} return s.db.Close() } // GC deletes all expired entries func (s *Storage) gc() { - tick := time.NewTicker(s.gcInterval) + ticker := time.NewTicker(s.gcInterval) + defer ticker.Stop() for { - <-tick.C - if _, err := s.db.Exec(s.sqlGC); err != nil { - panic(err) + select { + case <-s.done: + return + case t := <-ticker.C: + _, _ = s.db.Exec(s.sqlGC, t.Unix()) } } } diff --git a/postgres/postgres_test.go b/postgres/postgres_test.go index 5c93c2e9..881648bf 100644 --- a/postgres/postgres_test.go +++ b/postgres/postgres_test.go @@ -121,3 +121,8 @@ func Test_Postgres_Clear(t *testing.T) { utils.AssertEqual(t, ErrNotExist, err) utils.AssertEqual(t, true, len(result) == 0) } + +func Test_Postgres_Close(t *testing.T) { + err := testStore.Close() + utils.AssertEqual(t, nil, err) +} diff --git a/sqlite3/sqlite3.go b/sqlite3/sqlite3.go index a61d1bdc..51991ed0 100644 --- a/sqlite3/sqlite3.go +++ b/sqlite3/sqlite3.go @@ -15,6 +15,7 @@ import ( type Storage struct { db *sql.DB gcInterval time.Duration + done chan struct{} sqlSelect string sqlInsert string @@ -80,6 +81,7 @@ func New(config ...Config) *Storage { store := &Storage{ db: db, gcInterval: cfg.GCInterval, + done: make(chan struct{}), sqlSelect: fmt.Sprintf(`SELECT v, e FROM %s WHERE k=?;`, cfg.Table), sqlInsert: fmt.Sprintf("INSERT OR REPLACE INTO %s (k, v, e) VALUES (?,?,?)", cfg.Table), sqlDelete: fmt.Sprintf("DELETE FROM %s WHERE k=?", cfg.Table), @@ -143,16 +145,20 @@ func (s *Storage) Reset() error { // Close the database func (s *Storage) Close() error { + s.done <- struct{}{} return s.db.Close() } // GC deletes all expired entries func (s *Storage) gc() { - tick := time.NewTicker(s.gcInterval) + ticker := time.NewTicker(s.gcInterval) + defer ticker.Stop() for { - <-tick.C - if _, err := s.db.Exec(s.sqlGC); err != nil { - panic(err) + select { + case <-s.done: + return + case t := <-ticker.C: + _, _ = s.db.Exec(s.sqlGC, t.Unix()) } } } diff --git a/sqlite3/sqlite3_test.go b/sqlite3/sqlite3_test.go index e5e49301..390bc6c5 100644 --- a/sqlite3/sqlite3_test.go +++ b/sqlite3/sqlite3_test.go @@ -118,3 +118,8 @@ func Test_SQLite3_Clear(t *testing.T) { utils.AssertEqual(t, ErrNotExist, err) utils.AssertEqual(t, true, len(result) == 0) } + +func Test_SQLite3_Close(t *testing.T) { + err := testStore.Close() + utils.AssertEqual(t, nil, err) +}