From c4903dafef4ce92f53b6ac44c8869cd18c0ce40b Mon Sep 17 00:00:00 2001 From: Asdine El Hrychy Date: Sun, 9 Oct 2016 02:10:15 +0200 Subject: [PATCH] Multi tag support everywhere --- bench_test.go | 120 ++++++++++++++++++++++++------------------------ extract.go | 30 ++++++++---- extract_test.go | 25 ++++++++++ finder.go | 71 +++++++++++----------------- finder_test.go | 12 +++-- metadata.go | 1 + store_test.go | 6 +++ structs_test.go | 13 +++--- 8 files changed, 156 insertions(+), 122 deletions(-) diff --git a/bench_test.go b/bench_test.go index e87a602..e6131d0 100644 --- a/bench_test.go +++ b/bench_test.go @@ -6,6 +6,66 @@ import ( "time" ) +func BenchmarkFindWithIndex(b *testing.B) { + db, cleanup := createDB(b, AutoIncrement()) + defer cleanup() + + var users []User + for i := 0; i < 100; i++ { + var w User + + if i%2 == 0 { + w.Name = "John" + w.Group = "Staff" + } else { + w.Name = "Jack" + w.Group = "Admin" + } + err := db.Save(&w) + if err != nil { + b.Error(err) + } + } + + b.ResetTimer() + for n := 0; n < b.N; n++ { + err := db.Find("Name", "John", &users) + if err != nil { + b.Error(err) + } + } +} + +func BenchmarkFindWithoutIndex(b *testing.B) { + db, cleanup := createDB(b, AutoIncrement()) + defer cleanup() + + var users []User + for i := 0; i < 100; i++ { + var w User + + if i%2 == 0 { + w.Name = "John" + w.Group = "Staff" + } else { + w.Name = "Jack" + w.Group = "Admin" + } + err := db.Save(&w) + if err != nil { + b.Error(err) + } + } + + b.ResetTimer() + for n := 0; n < b.N; n++ { + err := db.Find("Group", "Staff", &users) + if err != nil { + b.Error(err) + } + } +} + func BenchmarkOneWithIndex(b *testing.B) { db, cleanup := createDB(b, AutoIncrement()) defer cleanup() @@ -81,66 +141,6 @@ func BenchmarkOneWithoutIndex(b *testing.B) { } } -func BenchmarkFindWithIndex(b *testing.B) { - db, cleanup := createDB(b, AutoIncrement()) - defer cleanup() - - var users []User - for i := 0; i < 100; i++ { - var w User - - if i%2 == 0 { - w.Name = "John" - w.Group = "Staff" - } else { - w.Name = "Jack" - w.Group = "Admin" - } - err := db.Save(&w) - if err != nil { - b.Error(err) - } - } - - b.ResetTimer() - for n := 0; n < b.N; n++ { - err := db.Find("Name", "John", &users) - if err != nil { - b.Error(err) - } - } -} - -func BenchmarkFindWithoutIndex(b *testing.B) { - db, cleanup := createDB(b, AutoIncrement()) - defer cleanup() - - var users []User - for i := 0; i < 100; i++ { - var w User - - if i%2 == 0 { - w.Name = "John" - w.Group = "Staff" - } else { - w.Name = "Jack" - w.Group = "Admin" - } - err := db.Save(&w) - if err != nil { - b.Error(err) - } - } - - b.ResetTimer() - for n := 0; n < b.N; n++ { - err := db.Find("Group", "Staff", &users) - if err != nil { - b.Error(err) - } - } -} - func BenchmarkSave(b *testing.B) { db, cleanup := createDB(b, AutoIncrement()) defer cleanup() diff --git a/extract.go b/extract.go index 718a964..27a7ae2 100644 --- a/extract.go +++ b/extract.go @@ -1,6 +1,7 @@ package storm import ( + "fmt" "reflect" "strings" @@ -76,15 +77,6 @@ func extract(s *reflect.Value, mi ...*structConfig) (*structConfig, error) { } } - // ID field or tag detected - if m.ID != nil { - zero := reflect.Zero(m.ID.Value.Type()).Interface() - current := m.ID.Value.Interface() - if reflect.DeepEqual(current, zero) { - m.ID.IsZero = true - } - } - if child { return m, nil } @@ -157,8 +149,10 @@ func extractField(value *reflect.Value, field *reflect.StructField, m *structCon Name: field.Name, IsZero: isZero(value), IsInteger: isInteger(value), + IsID: true, Value: value, } + m.Fields[field.Name] = f } m.ID = f } @@ -166,6 +160,24 @@ func extractField(value *reflect.Value, field *reflect.StructField, m *structCon return nil } +func extractSingleField(ref *reflect.Value, fieldName string) (*structConfig, error) { + var cfg structConfig + cfg.Fields = make(map[string]*fieldConfig) + + f, ok := ref.Type().FieldByName(fieldName) + if !ok || f.PkgPath != "" { + return nil, fmt.Errorf("field %s not found", fieldName) + } + + v := ref.FieldByName(fieldName) + err := extractField(&v, &f, &cfg, false) + if err != nil { + return nil, err + } + + return &cfg, nil +} + func getIndex(bucket *bolt.Bucket, idxKind string, fieldName string) (index.Index, error) { var idx index.Index var err error diff --git a/extract_test.go b/extract_test.go index 248a9c5..ddbd39c 100644 --- a/extract_test.go +++ b/extract_test.go @@ -72,3 +72,28 @@ func TestExtractInlineWithIndex(t *testing.T) { assert.Len(t, allByType(infos, "index"), 3) assert.Len(t, allByType(infos, "unique"), 2) } + +func TestExtractMultipleTags(t *testing.T) { + type User struct { + ID uint64 `storm:"id,increment"` + Age uint16 `storm:"index,increment"` + unexportedField int32 `storm:"index,increment"` + Pos string `storm:"unique,increment"` + } + + s := User{} + r := reflect.ValueOf(&s) + infos, err := extract(&r) + assert.NoError(t, err) + assert.NotNil(t, infos) + assert.NotNil(t, infos.ID) + assert.Equal(t, "User", infos.Name) + assert.Len(t, allByType(infos, "index"), 1) + assert.Len(t, allByType(infos, "unique"), 1) + assert.True(t, infos.Fields["Age"].Increment) + assert.Equal(t, "index", infos.Fields["Age"].Index) + assert.False(t, infos.Fields["Age"].IsID) + assert.True(t, infos.Fields["Age"].IsInteger) + assert.True(t, infos.Fields["Age"].IsZero) + assert.NotNil(t, infos.Fields["Age"].Value) +} diff --git a/finder.go b/finder.go index 34faa23..b29abbd 100644 --- a/finder.go +++ b/finder.go @@ -1,9 +1,7 @@ package storm import ( - "fmt" "reflect" - "strings" "github.com/asdine/storm/index" "github.com/asdine/storm/q" @@ -51,15 +49,14 @@ func (n *node) One(fieldName string, value interface{}, to interface{}) error { return ErrNotFound } - typ := reflect.Indirect(sink.ref).Type() - - field, ok := typ.FieldByName(fieldName) - if !ok { - return fmt.Errorf("field %s not found", fieldName) + ref := reflect.Indirect(sink.ref) + cfg, err := extractSingleField(&ref, fieldName) + if err != nil { + return err } - tag := field.Tag.Get("storm") - if tag == "" && fieldName != "ID" { + field, ok := cfg.Fields[fieldName] + if !ok || (!field.IsID && field.Index == "") { query := newQuery(n, q.StrictEq(fieldName, value)) if n.tx != nil { @@ -82,18 +79,12 @@ func (n *node) One(fieldName string, value interface{}, to interface{}) error { return err } - var isID bool - if tag != "" { - tags := strings.Split(tag, ",") - isID = tags[0] == "id" - } - return n.readTx(func(tx *bolt.Tx) error { - return n.one(tx, bucketName, fieldName, tag, to, val, fieldName == "ID" || isID) + return n.one(tx, bucketName, fieldName, cfg, to, val, field.IsID) }) } -func (n *node) one(tx *bolt.Tx, bucketName, fieldName, tag string, to interface{}, val []byte, skipIndex bool) error { +func (n *node) one(tx *bolt.Tx, bucketName, fieldName string, cfg *structConfig, to interface{}, val []byte, skipIndex bool) error { bucket := n.GetBucket(tx, bucketName) if bucket == nil { return ErrNotFound @@ -101,7 +92,7 @@ func (n *node) one(tx *bolt.Tx, bucketName, fieldName, tag string, to interface{ var id []byte if !skipIndex { - idx, err := getIndex(bucket, tag, fieldName) + idx, err := getIndex(bucket, cfg.Fields[fieldName].Index, fieldName) if err != nil { if err == index.ErrNotFound { return ErrNotFound @@ -137,11 +128,10 @@ func (n *node) Find(fieldName string, value interface{}, to interface{}, options return ErrNoName } - typ := reflect.Indirect(sink.ref).Type().Elem() - - field, ok := typ.FieldByName(fieldName) - if !ok { - return fmt.Errorf("field %s not found", fieldName) + ref := reflect.Indirect(reflect.New(sink.elemType)) + cfg, err := extractSingleField(&ref, fieldName) + if err != nil { + return err } opts := index.NewOptions() @@ -149,8 +139,8 @@ func (n *node) Find(fieldName string, value interface{}, to interface{}, options fn(opts) } - tag := field.Tag.Get("storm") - if tag == "" { + field, ok := cfg.Fields[fieldName] + if !ok || (!field.IsID && field.Index == "") { sink.limit = opts.Limit sink.skip = opts.Skip query := newQuery(n, q.StrictEq(fieldName, value)) @@ -176,17 +166,17 @@ func (n *node) Find(fieldName string, value interface{}, to interface{}, options } return n.readTx(func(tx *bolt.Tx) error { - return n.find(tx, bucketName, fieldName, tag, sink, val, opts) + return n.find(tx, bucketName, fieldName, cfg, sink, val, opts) }) } -func (n *node) find(tx *bolt.Tx, bucketName, fieldName, tag string, sink *listSink, val []byte, opts *index.Options) error { +func (n *node) find(tx *bolt.Tx, bucketName, fieldName string, cfg *structConfig, sink *listSink, val []byte, opts *index.Options) error { bucket := n.GetBucket(tx, bucketName) if bucket == nil { return ErrNotFound } - idx, err := getIndex(bucket, tag, fieldName) + idx, err := getIndex(bucket, cfg.Fields[fieldName].Index, fieldName) if err != nil { return err } @@ -341,11 +331,10 @@ func (n *node) Range(fieldName string, min, max, to interface{}, options ...func return ErrNoName } - typ := reflect.Indirect(sink.ref).Type().Elem() - - field, ok := typ.FieldByName(fieldName) - if !ok { - return fmt.Errorf("field %s not found", fieldName) + ref := reflect.Indirect(reflect.New(sink.elemType)) + cfg, err := extractSingleField(&ref, fieldName) + if err != nil { + return err } opts := index.NewOptions() @@ -353,8 +342,8 @@ func (n *node) Range(fieldName string, min, max, to interface{}, options ...func fn(opts) } - tag := field.Tag.Get("storm") - if tag == "" { + field, ok := cfg.Fields[fieldName] + if !ok || (!field.IsID && field.Index == "") { sink.limit = opts.Limit sink.skip = opts.Skip query := newQuery(n, q.And(q.Gte(fieldName, min), q.Lte(fieldName, max))) @@ -384,23 +373,19 @@ func (n *node) Range(fieldName string, min, max, to interface{}, options ...func return err } - if n.tx != nil { - return n.rnge(n.tx, bucketName, fieldName, tag, sink, mn, mx, opts) - } - - return n.s.Bolt.View(func(tx *bolt.Tx) error { - return n.rnge(tx, bucketName, fieldName, tag, sink, mn, mx, opts) + return n.readTx(func(tx *bolt.Tx) error { + return n.rnge(tx, bucketName, fieldName, cfg, sink, mn, mx, opts) }) } -func (n *node) rnge(tx *bolt.Tx, bucketName, fieldName, tag string, sink *listSink, min, max []byte, opts *index.Options) error { +func (n *node) rnge(tx *bolt.Tx, bucketName, fieldName string, cfg *structConfig, sink *listSink, min, max []byte, opts *index.Options) error { bucket := n.GetBucket(tx, bucketName) if bucket == nil { reflect.Indirect(sink.ref).SetLen(0) return nil } - idx, err := getIndex(bucket, tag, fieldName) + idx, err := getIndex(bucket, cfg.Fields[fieldName].Index, fieldName) if err != nil { return err } diff --git a/finder_test.go b/finder_test.go index 75868c9..e5d2a48 100644 --- a/finder_test.go +++ b/finder_test.go @@ -51,9 +51,9 @@ func TestFind(t *testing.T) { users := []User{} - err = db.Find("Age", "John", &users) + err = db.Find("unexportedField", "John", &users) assert.Error(t, err) - assert.EqualError(t, err, "field Age not found") + assert.EqualError(t, err, "field unexportedField not found") err = db.Find("DateOfBirth", "John", &users) assert.Error(t, err) @@ -107,6 +107,10 @@ func TestFind(t *testing.T) { assert.Len(t, users, 10) assert.Equal(t, 21, users[0].ID) assert.Equal(t, 30, users[9].ID) + + // err = db.Find("Age", 10, &users) + // assert.NoError(t, err) + } func TestFindIntIndex(t *testing.T) { @@ -153,7 +157,7 @@ func TestAllByIndex(t *testing.T) { err = db.AllByIndex("Unknown field", &users) assert.Error(t, err) - assert.True(t, ErrNotFound == err) + assert.Equal(t, ErrNotFound, err) err = db.AllByIndex("DateOfBirth", &users) assert.NoError(t, err) @@ -526,7 +530,7 @@ func TestRange(t *testing.T) { err = db.Range("Age", min, max, &users) assert.Error(t, err) - assert.EqualError(t, err, "field Age not found") + assert.EqualError(t, err, "not found") dateMin := time.Now().Add(-time.Duration(50) * time.Hour) dateMax := dateMin.Add(time.Duration(3) * time.Hour) diff --git a/metadata.go b/metadata.go index fae24cb..5da8cb2 100644 --- a/metadata.go +++ b/metadata.go @@ -64,5 +64,6 @@ func (m *meta) increment(field *fieldConfig) error { } field.Value.Set(reflect.ValueOf(counter).Convert(field.Value.Type())) + field.IsZero = false return nil } diff --git a/store_test.go b/store_test.go index c4e7b57..74537c0 100644 --- a/store_test.go +++ b/store_test.go @@ -321,6 +321,12 @@ func TestSaveIncrement(t *testing.T) { err = db.One("Identifier", i, &s2) require.NoError(t, err) require.Equal(t, s1, s2) + + var list []User + err = db.Find("Age", i, &list) + require.NoError(t, err) + require.Len(t, list, 1) + require.Equal(t, s1, list[0]) } } diff --git a/structs_test.go b/structs_test.go index 867cd9a..233dcb6 100644 --- a/structs_test.go +++ b/structs_test.go @@ -62,12 +62,13 @@ type ClassicInline struct { } type User struct { - ID int `storm:"id"` - Name string `storm:"index"` - age int - DateOfBirth time.Time `storm:"index"` - Group string - Slug string `storm:"unique"` + ID int `storm:"id"` + Name string `storm:"index"` + Age int `storm:"index,increment"` + DateOfBirth time.Time `storm:"index"` + Group string + unexportedField int + Slug string `storm:"unique"` } type ToEmbed struct {