diff --git a/all.go b/all.go index 99579ca..699a072 100644 --- a/all.go +++ b/all.go @@ -53,12 +53,12 @@ func (n *node) allByIndex(tx *bolt.Tx, fieldName string, cfg *structConfig, ref return ErrNotFound } - idxInfo, ok := cfg.Fields[fieldName] + fieldCfg, ok := cfg.Fields[fieldName] if !ok { return ErrNotFound } - idx, err := getIndex(bucket, idxInfo.Type, fieldName) + idx, err := getIndex(bucket, fieldCfg.Index, fieldName) if err != nil { return err } diff --git a/delete_struct.go b/delete_struct.go index cfec84b..556d4ca 100644 --- a/delete_struct.go +++ b/delete_struct.go @@ -36,8 +36,12 @@ func (n *node) deleteStruct(tx *bolt.Tx, cfg *structConfig, id []byte) error { return ErrNotFound } - for fieldName, idxInfo := range cfg.Fields { - idx, err := getIndex(bucket, idxInfo.Type, fieldName) + for fieldName, fieldCfg := range cfg.Fields { + if fieldCfg.Index == "" { + continue + } + + idx, err := getIndex(bucket, fieldCfg.Index, fieldName) if err != nil { return err } diff --git a/extract.go b/extract.go index b4940d7..718a964 100644 --- a/extract.go +++ b/extract.go @@ -14,12 +14,13 @@ const ( tagIdx = "index" tagUniqueIdx = "unique" tagInline = "inline" + tagIncrement = "increment" indexPrefix = "__storm_index_" ) type fieldConfig struct { Name string - Type string + Index string IsZero bool IsID bool Increment bool @@ -34,18 +35,6 @@ type structConfig struct { ID *fieldConfig } -// helper -func (m *structConfig) AllByType(indexType string) []*fieldConfig { - var idx []*fieldConfig - for k := range m.Fields { - if m.Fields[k].Type == indexType { - idx = append(idx, m.Fields[k]) - } - } - - return idx -} - func extract(s *reflect.Value, mi ...*structConfig) (*structConfig, error) { if s.Kind() == reflect.Ptr { e := s.Elem() @@ -130,7 +119,9 @@ func extractField(value *reflect.Value, field *reflect.StructField, m *structCon case "id": f.IsID = true case tagUniqueIdx, tagIdx: - f.Type = tag + f.Index = tag + case tagIncrement: + f.Increment = true case tagInline: if value.Kind() == reflect.Ptr { e := value.Elem() @@ -150,10 +141,8 @@ func extractField(value *reflect.Value, field *reflect.StructField, m *structCon } } - if f.Type != "" { - if _, ok := m.Fields[f.Name]; !ok || !isChild { - m.Fields[f.Name] = f - } + if _, ok := m.Fields[f.Name]; !ok || !isChild { + m.Fields[f.Name] = f } } diff --git a/extract_test.go b/extract_test.go index 0b12909..248a9c5 100644 --- a/extract_test.go +++ b/extract_test.go @@ -7,6 +7,17 @@ import ( "github.com/stretchr/testify/assert" ) +func allByType(m *structConfig, indexType string) []*fieldConfig { + var idx []*fieldConfig + for k := range m.Fields { + if m.Fields[k].Index == indexType { + idx = append(idx, m.Fields[k]) + } + } + + return idx +} + func TestExtractNoTags(t *testing.T) { s := ClassicNoTags{} r := reflect.ValueOf(&s) @@ -33,8 +44,8 @@ func TestExtractUniqueTags(t *testing.T) { assert.NotNil(t, infos.ID) assert.False(t, infos.ID.IsZero) assert.Equal(t, "ClassicUnique", infos.Name) - assert.Len(t, infos.AllByType("index"), 0) - assert.Len(t, infos.AllByType("unique"), 4) + assert.Len(t, allByType(infos, "index"), 0) + assert.Len(t, allByType(infos, "unique"), 4) } func TestExtractIndexTags(t *testing.T) { @@ -46,8 +57,8 @@ func TestExtractIndexTags(t *testing.T) { assert.NotNil(t, infos.ID) assert.False(t, infos.ID.IsZero) assert.Equal(t, "ClassicIndex", infos.Name) - assert.Len(t, infos.AllByType("index"), 5) - assert.Len(t, infos.AllByType("unique"), 0) + assert.Len(t, allByType(infos, "index"), 5) + assert.Len(t, allByType(infos, "unique"), 0) } func TestExtractInlineWithIndex(t *testing.T) { @@ -58,6 +69,6 @@ func TestExtractInlineWithIndex(t *testing.T) { assert.NotNil(t, infos) assert.NotNil(t, infos.ID) assert.Equal(t, "ClassicInline", infos.Name) - assert.Len(t, infos.AllByType("index"), 3) - assert.Len(t, infos.AllByType("unique"), 2) + assert.Len(t, allByType(infos, "index"), 3) + assert.Len(t, allByType(infos, "unique"), 2) } diff --git a/init.go b/init.go index 343162a..6efe006 100644 --- a/init.go +++ b/init.go @@ -32,8 +32,11 @@ func (n *node) init(tx *bolt.Tx, cfg *structConfig) error { return err } - for fieldName, idxInfo := range cfg.Fields { - switch idxInfo.Type { + for fieldName, fieldCfg := range cfg.Fields { + if fieldCfg.Index == "" { + continue + } + switch fieldCfg.Index { case tagUniqueIdx: _, err = index.NewUniqueIndex(bucket, []byte(indexPrefix+fieldName)) case tagIdx: diff --git a/one.go b/one.go index 7c2ec7e..3a382d5 100644 --- a/one.go +++ b/one.go @@ -3,6 +3,7 @@ package storm import ( "fmt" "reflect" + "strings" "github.com/asdine/storm/index" "github.com/asdine/storm/q" @@ -56,8 +57,14 @@ func (n *node) One(fieldName string, value interface{}, to interface{}) error { return err } + var isID bool + if tag != "" { + tags := strings.Split(tag, ",") + isID = tags[0] == "" + } + return n.readTx(func(tx *bolt.Tx) error { - return n.one(tx, bucketName, fieldName, tag, to, val, fieldName == "ID" || tag == "id") + return n.one(tx, bucketName, fieldName, tag, to, val, fieldName == "ID" || isID) }) } diff --git a/save.go b/save.go index d29dbbe..9759901 100644 --- a/save.go +++ b/save.go @@ -81,13 +81,17 @@ func (n *node) save(tx *bolt.Tx, cfg *structConfig, id []byte, raw []byte, data } } - for fieldName, idxInfo := range cfg.Fields { - idx, err := getIndex(bucket, idxInfo.Type, fieldName) + for fieldName, fieldCfg := range cfg.Fields { + if fieldCfg.Index == "" { + continue + } + + idx, err := getIndex(bucket, fieldCfg.Index, fieldName) if err != nil { return err } - if idxInfo.IsZero { + if fieldCfg.IsZero { err = idx.RemoveID(id) if err != nil { return err @@ -95,7 +99,7 @@ func (n *node) save(tx *bolt.Tx, cfg *structConfig, id []byte, raw []byte, data continue } - value, err := toBytes(idxInfo.Value.Interface(), n.s.codec) + value, err := toBytes(fieldCfg.Value.Interface(), n.s.codec) if err != nil { return err } diff --git a/sink.go b/sink.go index ef20629..f6cd4fd 100644 --- a/sink.go +++ b/sink.go @@ -235,8 +235,11 @@ func (d *deleteSink) add(bucket *bolt.Bucket, k []byte, v []byte, elem reflect.V return false, err } - for fieldName, idxInfo := range info.Fields { - idx, err := getIndex(bucket, idxInfo.Type, fieldName) + for fieldName, fieldCfg := range info.Fields { + if fieldCfg.Index == "" { + continue + } + idx, err := getIndex(bucket, fieldCfg.Index, fieldName) if err != nil { return false, err }