Save and find use Index

This commit is contained in:
Asdine El Hrychy
2016-03-09 19:24:30 +01:00
parent 9535e45031
commit a8cacd81cb
5 changed files with 82 additions and 101 deletions

View File

@@ -9,4 +9,9 @@ var (
ErrBadType = errors.New("provided data must be a struct or a pointer to struct") ErrBadType = errors.New("provided data must be a struct or a pointer to struct")
ErrAlreadyExists = errors.New("already exists") ErrAlreadyExists = errors.New("already exists")
ErrNilParam = errors.New("param must not be nil") ErrNilParam = errors.New("param must not be nil")
ErrBadIndexType = errors.New("bad index type")
ErrBadTarget = errors.New("provided target must be a pointer to a slice")
ErrNoName = errors.New("provided target must have a name")
ErrIndexNotFound = errors.New("index not found")
ErrNotFound = errors.New("not found")
) )

58
find.go
View File

@@ -2,7 +2,6 @@ package storm
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"reflect" "reflect"
@@ -11,11 +10,11 @@ import (
) )
// Find returns one or more records by the specified index // Find returns one or more records by the specified index
func (s *DB) Find(index string, value interface{}, to interface{}) error { func (s *DB) Find(fieldName string, value interface{}, to interface{}) error {
ref := reflect.ValueOf(to) ref := reflect.ValueOf(to)
if ref.Kind() != reflect.Ptr || reflect.Indirect(ref).Kind() != reflect.Slice { if ref.Kind() != reflect.Ptr || reflect.Indirect(ref).Kind() != reflect.Slice {
return errors.New("provided target must be a pointer to a slice") return ErrBadTarget
} }
typ := reflect.Indirect(ref).Type().Elem() typ := reflect.Indirect(ref).Type().Elem()
@@ -24,17 +23,17 @@ func (s *DB) Find(index string, value interface{}, to interface{}) error {
d := structs.New(newElem.Interface()) d := structs.New(newElem.Interface())
bucketName := d.Name() bucketName := d.Name()
if bucketName == "" { if bucketName == "" {
return errors.New("provided target must have a name") return ErrNoName
} }
field, ok := d.FieldOk(index) field, ok := d.FieldOk(fieldName)
if !ok { if !ok {
return fmt.Errorf("field %s not found", index) return fmt.Errorf("field %s not found", fieldName)
} }
tag := field.Tag("storm") tag := field.Tag("storm")
if tag == "" { if tag == "" {
return fmt.Errorf("index %s not found", index) return fmt.Errorf("index %s not found", fieldName)
} }
return s.Bolt.View(func(tx *bolt.Tx) error { return s.Bolt.View(func(tx *bolt.Tx) error {
@@ -43,9 +42,22 @@ func (s *DB) Find(index string, value interface{}, to interface{}) error {
return fmt.Errorf("bucket %s not found", bucketName) return fmt.Errorf("bucket %s not found", bucketName)
} }
idx := bucket.Bucket([]byte(index)) var idx Index
if idx == nil { var err error
return fmt.Errorf("index %s not found", index) switch tag {
case "unique":
idx, err = NewUniqueIndex(bucket, []byte(fieldName))
case "index":
idx, err = NewListIndex(bucket, []byte(fieldName))
default:
err = ErrBadIndexType
}
if err != nil {
if err == ErrIndexNotFound {
return ErrNotFound
}
return err
} }
val, err := toBytes(value) val, err := toBytes(value)
@@ -53,34 +65,20 @@ func (s *DB) Find(index string, value interface{}, to interface{}) error {
return err return err
} }
raw := idx.Get(val) list, err := idx.All(val)
if raw == nil {
return errors.New("not found")
}
var list [][]byte
if tag == "unique" {
list = append(list, raw)
} else if tag == "index" {
err = json.Unmarshal(raw, &list)
if err != nil { if err != nil {
if err == ErrIndexNotFound {
return ErrNotFound
}
return err return err
} }
if list == nil || len(list) == 0 {
return errors.New("not found")
}
} else {
return fmt.Errorf("unsupported struct tag %s", tag)
}
results := reflect.MakeSlice(reflect.Indirect(ref).Type(), len(list), len(list)) results := reflect.MakeSlice(reflect.Indirect(ref).Type(), len(list), len(list))
for i := range list { for i := range list {
raw = bucket.Get(list[i]) raw := bucket.Get(list[i])
if raw == nil { if raw == nil {
return errors.New("not found") return ErrNotFound
} }
err = json.Unmarshal(raw, results.Index(i).Addr().Interface()) err = json.Unmarshal(raw, results.Index(i).Addr().Interface())

View File

@@ -22,10 +22,17 @@ type Index interface {
// NewUniqueIndex loads a UniqueIndex // NewUniqueIndex loads a UniqueIndex
func NewUniqueIndex(parent *bolt.Bucket, indexName []byte) (*UniqueIndex, error) { func NewUniqueIndex(parent *bolt.Bucket, indexName []byte) (*UniqueIndex, error) {
b, err := parent.CreateBucketIfNotExists(indexName) var err error
b := parent.Bucket(indexName)
if b == nil {
if !parent.Writable() {
return nil, ErrIndexNotFound
}
b, err = parent.CreateBucket(indexName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
return &UniqueIndex{ return &UniqueIndex{
IndexBucket: b, IndexBucket: b,
@@ -115,10 +122,17 @@ func (idx *UniqueIndex) first() []byte {
// NewListIndex loads a ListIndex // NewListIndex loads a ListIndex
func NewListIndex(parent *bolt.Bucket, indexName []byte) (*ListIndex, error) { func NewListIndex(parent *bolt.Bucket, indexName []byte) (*ListIndex, error) {
b, err := parent.CreateBucketIfNotExists(indexName) var err error
b := parent.Bucket(indexName)
if b == nil {
if !parent.Writable() {
return nil, ErrIndexNotFound
}
b, err = parent.CreateBucket(indexName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
ids, err := NewUniqueIndex(b, []byte("storm__ids")) ids, err := NewUniqueIndex(b, []byte("storm__ids"))
if err != nil { if err != nil {

57
save.go
View File

@@ -13,69 +13,58 @@ func (s *DB) Save(data interface{}) error {
return ErrBadType return ErrBadType
} }
t, err := extractTags(data) info, err := extract(data)
if err != nil { if err != nil {
return err return err
} }
if t.ZeroID { if info.ID == nil {
return ErrZeroID
}
if t.ID == nil {
return ErrNoID return ErrNoID
} }
id, err := toBytes(t.ID) if info.ID.IsZero() {
return ErrZeroID
}
id, err := toBytes(info.ID.Value())
if err != nil { if err != nil {
return err return err
} }
err = s.Bolt.Update(func(tx *bolt.Tx) error { err = s.Bolt.Update(func(tx *bolt.Tx) error {
bucket, err := tx.CreateBucketIfNotExists([]byte(t.Name)) bucket, err := tx.CreateBucketIfNotExists([]byte(info.Name))
if err != nil { if err != nil {
return err return err
} }
if len(t.Uniques) > 0 { var idx Index
err = s.deleteOldIndexes(bucket, id, t.Uniques, true) for fieldName, idxInfo := range info.Indexes {
if err != nil { switch idxInfo.Type {
return err case "unique":
} idx, err = NewUniqueIndex(bucket, []byte(fieldName))
case "index":
idx, err = NewListIndex(bucket, []byte(fieldName))
default:
err = ErrBadIndexType
} }
if len(t.Indexes) > 0 {
err = s.deleteOldIndexes(bucket, id, t.Indexes, false)
if err != nil {
return err
}
}
if t.Uniques != nil {
for _, field := range t.Uniques {
key, err := toBytes(field.Value())
if err != nil { if err != nil {
return err return err
} }
err = s.addToUniqueIndex([]byte(field.Name()), id, key, bucket) err = idx.RemoveID(id)
if err != nil {
return err
}
}
}
if t.Indexes != nil {
for _, field := range t.Indexes {
key, err := toBytes(field.Value())
if err != nil { if err != nil {
return err return err
} }
err = s.addToListIndex([]byte(field.Name()), id, key, bucket) value, err := toBytes(idxInfo.Field.Value())
if err != nil { if err != nil {
return err return err
} }
err = idx.Add(value, id)
if err != nil {
return err
} }
} }

View File

@@ -17,6 +17,7 @@ func TestSave(t *testing.T) {
db, _ := Open(filepath.Join(dir, "storm.db")) db, _ := Open(filepath.Join(dir, "storm.db"))
err := db.Save(&SimpleUser{ID: 10, Name: "John"}) err := db.Save(&SimpleUser{ID: 10, Name: "John"})
assert.NoError(t, err)
err = db.Save(&SimpleUser{Name: "John"}) err = db.Save(&SimpleUser{Name: "John"})
assert.Error(t, err) assert.Error(t, err)
@@ -117,32 +118,6 @@ func TestSaveIndex(t *testing.T) {
err = db.Save(&u2) err = db.Save(&u2)
assert.NoError(t, err) assert.NoError(t, err)
db.Bolt.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte("IndexedNameUser"))
listBucket := bucket.Bucket([]byte("Name"))
assert.NotNil(t, listBucket)
raw := listBucket.Get([]byte("John"))
assert.NotNil(t, raw)
var list [][]byte
err = json.Unmarshal(raw, &list)
assert.NoError(t, err)
assert.Len(t, list, 2)
id1, err := toBytes(u1.ID)
assert.NoError(t, err)
id2, err := toBytes(u2.ID)
assert.NoError(t, err)
assert.Equal(t, id1, list[0])
assert.Equal(t, id2, list[1])
return nil
})
name1 := "Jake" name1 := "Jake"
name2 := "Jane" name2 := "Jane"
name3 := "James" name3 := "James"
@@ -170,7 +145,7 @@ func TestSaveIndex(t *testing.T) {
err = db.Find("Name", name3, &users) err = db.Find("Name", name3, &users)
assert.Error(t, err) assert.Error(t, err)
assert.EqualError(t, err, "not found") assert.Equal(t, ErrNotFound, err)
err = db.Save(nil) err = db.Save(nil)
assert.Error(t, err) assert.Error(t, err)