diff --git a/db.go b/db.go index 62058e8f..03846a5f 100644 --- a/db.go +++ b/db.go @@ -25,6 +25,36 @@ func (db DB) Begin(writable bool) (*Tx, error) { }, nil } +func (db DB) View(fn func(tx *Tx) error) error { + tx, err := db.Begin(false) + if err != nil { + return err + } + defer tx.Rollback() + + err = fn(tx) + if err != nil { + return err + } + + return tx.Rollback() +} + +func (db DB) Update(fn func(tx *Tx) error) error { + tx, err := db.Begin(true) + if err != nil { + return err + } + defer tx.Rollback() + + err = fn(tx) + if err != nil { + return err + } + + return tx.Commit() +} + type Tx struct { engine.Transaction } diff --git a/db_test.go b/db_test.go index 188687da..45aae997 100644 --- a/db_test.go +++ b/db_test.go @@ -11,56 +11,58 @@ import ( ) func TestTable(t *testing.T) { - db := genji.New(memory.NewEngine()) - t.Run("Table/Insert/NoIndex", func(t *testing.T) { - tx, err := db.Begin(true) - require.NoError(t, err) - defer tx.Rollback() + db := genji.New(memory.NewEngine()) - tb, err := tx.CreateTable("test") - require.NoError(t, err) + err := db.Update(func(tx *genji.Tx) error { + tb, err := tx.CreateTable("test") + require.NoError(t, err) - rowid, err := tb.Insert(record.FieldBuffer([]field.Field{ - field.NewString("name", "John"), - field.NewInt64("age", 10), - })) - require.NoError(t, err) - require.NotNil(t, rowid) + rowid, err := tb.Insert(record.FieldBuffer([]field.Field{ + field.NewString("name", "John"), + field.NewInt64("age", 10), + })) + require.NoError(t, err) + require.NotNil(t, rowid) - m, err := tx.Indexes("test") + m, err := tx.Indexes("test") + require.NoError(t, err) + require.Empty(t, m) + + return nil + }) require.NoError(t, err) - require.Empty(t, m) }) t.Run("Table/Insert/WithIndex", func(t *testing.T) { - tx, err := db.Begin(true) - require.NoError(t, err) - defer tx.Rollback() + db := genji.New(memory.NewEngine()) + defer db.Close() - tb, err := tx.CreateTable("test") - require.NoError(t, err) + err := db.Update(func(tx *genji.Tx) error { + tb, err := tx.CreateTable("test") + require.NoError(t, err) - _, err = tx.CreateIndex("test", "name") - require.NoError(t, err) + _, err = tx.CreateIndex("test", "name") + require.NoError(t, err) - rowid, err := tb.Insert(record.FieldBuffer([]field.Field{ - field.NewString("name", "John"), - field.NewInt64("age", 10), - })) - require.NoError(t, err) - require.NotNil(t, rowid) + rowid, err := tb.Insert(record.FieldBuffer([]field.Field{ + field.NewString("name", "John"), + field.NewInt64("age", 10), + })) + require.NoError(t, err) + require.NotNil(t, rowid) - m, err := tx.Indexes("test") - require.NoError(t, err) - require.NotEmpty(t, m) + m, err := tx.Indexes("test") + require.NoError(t, err) + require.NotEmpty(t, m) - c := m["name"].Cursor() - v, rid := c.Seek([]byte("John")) - require.Equal(t, []byte("John"), v) - require.Equal(t, rowid, rid) + c := m["name"].Cursor() + v, rid := c.Seek([]byte("John")) + require.Equal(t, []byte("John"), v) + require.Equal(t, rowid, rid) + + return nil + }) + require.NoError(t, err) }) - - err := db.Close() - require.NoError(t, err) }