diff --git a/README.md b/README.md index cf7a7fe..4c192e0 100644 --- a/README.md +++ b/README.md @@ -378,23 +378,23 @@ err = db.Select(q.Or( ), )).Find(&users) -query := db.Select(q.Gte("ID", 10), q.Lte("ID", 100)).Limit(10).Skip(5).Reverse().OrderBy("Age") +query := db.Select(q.Gte("ID", 10), q.Lte("ID", 100)).Limit(10).Skip(5).Reverse().OrderBy("Age", "Name") // Find multiple records err = query.Find(&users) // or -err = db.Select(q.Gte("ID", 10), q.Lte("ID", 100)).Limit(10).Skip(5).Reverse().OrderBy("Age").Find(&users) +err = db.Select(q.Gte("ID", 10), q.Lte("ID", 100)).Limit(10).Skip(5).Reverse().OrderBy("Age", "Name").Find(&users) // Find first record err = query.First(&user) // or -err = db.Select(q.Gte("ID", 10), q.Lte("ID", 100)).Limit(10).Skip(5).Reverse().OrderBy("Age").First(&user) +err = db.Select(q.Gte("ID", 10), q.Lte("ID", 100)).Limit(10).Skip(5).Reverse().OrderBy("Age", "Name").First(&user) // Delete all matching records err = query.Delete(new(User)) // Fetching records one by one (useful when the bucket contains a lot of records) -query = db.Select(q.Gte("ID", 10),q.Lte("ID", 100)).OrderBy("Age") +query = db.Select(q.Gte("ID", 10),q.Lte("ID", 100)).OrderBy("Age", "Name") err = query.Each(new(User), func(record interface{}) error) { u := record.(*User) diff --git a/finder.go b/finder.go index 557c516..933b568 100644 --- a/finder.go +++ b/finder.go @@ -175,9 +175,6 @@ func (n *node) find(tx *bolt.Tx, bucketName, fieldName string, cfg *structConfig if bucket == nil { return ErrNotFound } - - sorter := newSorter(n) - idx, err := getIndex(bucket, cfg.Fields[fieldName].Index, fieldName) if err != nil { return err @@ -193,19 +190,19 @@ func (n *node) find(tx *bolt.Tx, bucketName, fieldName string, cfg *structConfig sink.results = reflect.MakeSlice(reflect.Indirect(sink.ref).Type(), len(list), len(list)) + sorter := newSorter(n, sink, nil, nil, false) for i := range list { raw := bucket.Get(list[i]) if raw == nil { return ErrNotFound } - _, err = sorter.filter(sink, nil, bucket, list[i], raw) - if err != nil { + if _, err := sorter.filter(bucket, list[i], raw); err != nil { return err } } - return sink.flush() + return sorter.flush() } // AllByIndex gets all the records of a bucket that are indexed in the specified index @@ -381,8 +378,6 @@ func (n *node) rnge(tx *bolt.Tx, bucketName, fieldName string, cfg *structConfig return nil } - sorter := newSorter(n) - idx, err := getIndex(bucket, cfg.Fields[fieldName].Index, fieldName) if err != nil { return err @@ -394,20 +389,19 @@ func (n *node) rnge(tx *bolt.Tx, bucketName, fieldName string, cfg *structConfig } sink.results = reflect.MakeSlice(reflect.Indirect(sink.ref).Type(), len(list), len(list)) - + sorter := newSorter(n, sink, nil, nil, false) for i := range list { raw := bucket.Get(list[i]) if raw == nil { return ErrNotFound } - _, err = sorter.filter(sink, nil, bucket, list[i], raw) - if err != nil { + if _, err := sorter.filter(bucket, list[i], raw); err != nil { return err } } - return sink.flush() + return sorter.flush() } // Count counts all the records of a bucket diff --git a/query.go b/query.go index e09f574..b52db37 100644 --- a/query.go +++ b/query.go @@ -20,8 +20,8 @@ type Query interface { // Limit the results by the given number Limit(int) Query - // Order by the given field. - OrderBy(string) Query + // Order by the given fields, in descending precedence, left-to-right. + OrderBy(...string) Query // Reverse the order of the results Reverse() Query @@ -53,11 +53,10 @@ type Query interface { func newQuery(n *node, tree q.Matcher) *query { return &query{ - skip: 0, - limit: -1, - node: n, - tree: tree, - sorter: newSorter(n), + skip: 0, + limit: -1, + node: n, + tree: tree, } } @@ -68,7 +67,7 @@ type query struct { tree q.Matcher node *node bucket string - sorter *sorter + orderBy []string } func (q *query) Skip(nb int) Query { @@ -81,14 +80,13 @@ func (q *query) Limit(nb int) Query { return q } -func (q *query) OrderBy(field string) Query { - q.sorter.orderBy = field +func (q *query) OrderBy(field ...string) Query { + q.orderBy = field return q } func (q *query) Reverse() Query { q.reverse = true - q.sorter.reverse = true return q } @@ -206,9 +204,10 @@ func (q *query) query(tx *bolt.Tx, sink sink) error { bucket := q.node.GetBucket(tx, bucketName) if q.limit == 0 { - return q.sorter.flush(sink) + return sink.flush() } + sorter := newSorter(q.node, sink, q.tree, q.orderBy, q.reverse) if bucket != nil { c := internal.Cursor{C: bucket.Cursor(), Reverse: q.reverse} for k, v := c.First(); k != nil; k, v = c.Next() { @@ -216,7 +215,7 @@ func (q *query) query(tx *bolt.Tx, sink sink) error { continue } - stop, err := q.sorter.filter(sink, q.tree, bucket, k, v) + stop, err := sorter.filter(bucket, k, v) if err != nil { return err } @@ -227,5 +226,5 @@ func (q *query) query(tx *bolt.Tx, sink sink) error { } } - return q.sorter.flush(sink) + return sorter.flush() } diff --git a/query_test.go b/query_test.go index ba0a806..2557eb0 100644 --- a/query_test.go +++ b/query_test.go @@ -213,9 +213,9 @@ func TestSelectFindOrderBy(t *testing.T) { Rnd int } - strs := []string{"e", "b", "a", "c", "d"} - ints := []int{2, 3, 1, 4, 5} - for i := 0; i < 5; i++ { + strs := []string{"e", "b", "d", "a", "c", "d"} + ints := []int{2, 3, 5, 4, 2, 1} + for i := 0; i < 6; i++ { record := T{ Str: strs[i], Int: ints[i], @@ -231,39 +231,52 @@ func TestSelectFindOrderBy(t *testing.T) { var list []T err := db.Select().OrderBy("ID").Find(&list) assert.NoError(t, err) - assert.Len(t, list, 5) - for i := 0; i < 5; i++ { + assert.Len(t, list, 6) + for i, j := 0, 0; i < 6; i, j = i+1, j+1 { + if i == 2 { + j-- + } assert.Equal(t, i+1, list[i].ID) } err = db.Select().OrderBy("Str").Find(&list) assert.NoError(t, err) - assert.Len(t, list, 5) - for i := 0; i < 5; i++ { - assert.Equal(t, string([]byte{'a' + byte(i)}), list[i].Str) + assert.Len(t, list, 6) + for i, j := 0, 0; i < 6; i, j = i+1, j+1 { + if i == 4 { + j-- + } + assert.Equal(t, string([]byte{'a' + byte(j)}), list[i].Str) } err = db.Select().OrderBy("Int").Find(&list) assert.NoError(t, err) - assert.Len(t, list, 5) - for i := 0; i < 5; i++ { - assert.Equal(t, i+1, list[i].Int) + assert.Len(t, list, 6) + for i, j := 0, 0; i < 6; i, j = i+1, j+1 { + if i == 2 { + j-- + } + assert.Equal(t, j+1, list[i].Int) } err = db.Select().OrderBy("Rnd").Find(&list) assert.NoError(t, err) - assert.Len(t, list, 5) + assert.Len(t, list, 6) assert.Equal(t, 1, list[0].ID) assert.Equal(t, 2, list[1].ID) assert.Equal(t, 3, list[2].ID) assert.Equal(t, 5, list[3].ID) - assert.Equal(t, 4, list[4].ID) + assert.Equal(t, 6, list[4].ID) + assert.Equal(t, 4, list[5].ID) err = db.Select().OrderBy("Int").Reverse().Find(&list) assert.NoError(t, err) - assert.Len(t, list, 5) - for i := 0; i < 5; i++ { - assert.Equal(t, 5-i, list[i].Int) + assert.Len(t, list, 6) + for i, j := 0, 0; i < 6; i, j = i+1, j+1 { + if i == 4 { + j-- + } + assert.Equal(t, 5-j, list[i].Int) } err = db.Select().OrderBy("Int").Reverse().Limit(2).Find(&list) @@ -275,15 +288,34 @@ func TestSelectFindOrderBy(t *testing.T) { err = db.Select().OrderBy("Int").Reverse().Skip(2).Find(&list) assert.NoError(t, err) - assert.Len(t, list, 3) - for i := 0; i < 2; i++ { - assert.Equal(t, 3-i, list[i].Int) + assert.Len(t, list, 4) + for i, j := 0, 0; i < 3; i, j = i+1, j+1 { + if i == 2 { + j-- + } + assert.Equal(t, 3-j, list[i].Int) } - err = db.Select().OrderBy("Int").Reverse().Skip(4).Limit(2).Find(&list) + err = db.Select().OrderBy("Int").Reverse().Skip(5).Limit(2).Find(&list) assert.NoError(t, err) assert.Len(t, list, 1) assert.Equal(t, 1, list[0].Int) + + err = db.Select().OrderBy("Str", "Int").Find(&list) + assert.NoError(t, err) + assert.Len(t, list, 6) + assert.Equal(t, "a", list[0].Str) + assert.Equal(t, 4, list[0].Int) + assert.Equal(t, "b", list[1].Str) + assert.Equal(t, 3, list[1].Int) + assert.Equal(t, "c", list[2].Str) + assert.Equal(t, 2, list[2].Int) + assert.Equal(t, "d", list[3].Str) + assert.Equal(t, 1, list[3].Int) + assert.Equal(t, "d", list[4].Str) + assert.Equal(t, 5, list[4].Int) + assert.Equal(t, "e", list[5].Str) + assert.Equal(t, 2, list[5].Int) } func TestSelectFirst(t *testing.T) { diff --git a/sink.go b/sink.go index 2f08a84..b0dbdca 100644 --- a/sink.go +++ b/sink.go @@ -1,14 +1,12 @@ package storm import ( - "encoding/binary" "reflect" + "sort" "github.com/asdine/storm/index" "github.com/asdine/storm/q" "github.com/boltdb/bolt" - - rbt "github.com/emirpasic/gods/trees/redblacktree" ) type item struct { @@ -18,97 +16,190 @@ type item struct { v []byte } -func newSorter(node Node) *sorter { +func newSorter(n Node, snk sink, tree q.Matcher, orderBy []string, reverse bool) *sorter { return &sorter{ - node: node, - rbTree: rbt.NewWithStringComparator(), + node: n, + sink: snk, + tree: tree, + orderBy: orderBy, + reverse: reverse, + list: make([]*item, 0), + err: make(chan error), + done: make(chan struct{}), } } type sorter struct { node Node - rbTree *rbt.Tree - orderBy string + sink sink + tree q.Matcher + list []*item + orderBy []string reverse bool - counter int64 + err chan error + done chan struct{} } -func (s *sorter) filter(snk sink, tree q.Matcher, bucket *bolt.Bucket, k, v []byte) (bool, error) { - s.counter++ - - rsnk, ok := snk.(reflectSink) +func (s *sorter) filter(bucket *bolt.Bucket, k, v []byte) (bool, error) { + rsink, ok := s.sink.(reflectSink) if !ok { - return snk.add(&item{ + return s.sink.add(&item{ bucket: bucket, k: k, v: v, }) } - newElem := rsnk.elem() - err := s.node.Codec().Unmarshal(v, newElem.Interface()) - if err != nil { + newElem := rsink.elem() + if err := s.node.Codec().Unmarshal(v, newElem.Interface()); err != nil { return false, err } - ok = tree == nil - if !ok { - ok, err = tree.Match(newElem.Interface()) + itm := &item{ + bucket: bucket, + value: &newElem, + k: k, + v: v, + } + + if s.tree == nil { + if len(s.orderBy) == 0 { + return s.sink.add(itm) + } + } else { + ok, err := s.tree.Match(newElem.Interface()) if err != nil { return false, err } - } - - if ok { - it := item{ - bucket: bucket, - value: &newElem, - k: k, - v: v, - } - - if s.orderBy != "" { - elm := reflect.Indirect(newElem).FieldByName(s.orderBy) - if !elm.IsValid() { - return false, ErrNotFound - } - raw, err := toBytes(elm.Interface(), s.node.Codec()) - if err != nil { - return false, err - } - - key := make([]byte, len(raw)+8) - for i := 0; i < len(raw); i++ { - key[i] = raw[i] - } - binary.PutVarint(key[len(raw):], s.counter) - s.rbTree.Put(string(key), &it) + if !ok { return false, nil } - - return snk.add(&it) } + if len(s.orderBy) == 0 { + return s.sink.add(itm) + } + + s.list = append(s.list, itm) + return false, nil } -func (s *sorter) flush(snk sink) error { - if s.orderBy == "" { - return snk.flush() +func (s *sorter) compareValue(left reflect.Value, right reflect.Value) int { + if !left.IsValid() || !right.IsValid() { + if left.IsValid() { + return 1 + } + return -1 } - s.orderBy = "" - var err error - var stop bool - it := s.rbTree.Iterator() - if s.reverse { - it.End() - } else { - it.Begin() + switch left.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + l, r := left.Int(), right.Int() + if l < r { + return -1 + } + if l > r { + return 1 + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + l, r := left.Uint(), right.Uint() + if l < r { + return -1 + } + if l > r { + return 1 + } + case reflect.Float32, reflect.Float64: + l, r := left.Float(), right.Float() + if l < r { + return -1 + } + if l > r { + return 1 + } + case reflect.String: + l, r := left.String(), right.String() + if l < r { + return -1 + } + if l > r { + return 1 + } + default: + rawLeft, err := toBytes(left.Interface(), s.node.Codec()) + if err != nil { + return -1 + } + rawRight, err := toBytes(right.Interface(), s.node.Codec()) + if err != nil { + return 1 + } + + l, r := string(rawLeft), string(rawRight) + if l < r { + return -1 + } + if l > r { + return 1 + } } - for (s.reverse && it.Prev()) || (!s.reverse && it.Next()) { - item := it.Value().(*item) - stop, err = snk.add(item) + + return 0 +} + +func (s *sorter) less(leftElem reflect.Value, rightElem reflect.Value) bool { + for _, orderBy := range s.orderBy { + leftField := reflect.Indirect(leftElem).FieldByName(orderBy) + if !leftField.IsValid() { + s.err <- ErrNotFound + return false + } + rightField := reflect.Indirect(rightElem).FieldByName(orderBy) + if !rightField.IsValid() { + s.err <- ErrNotFound + return false + } + + direction := 1 + if s.reverse { + direction = -1 + } + + switch s.compareValue(leftField, rightField) * direction { + case -1: + return true + case 1: + return false + default: + continue + } + } + + return false +} + +func (s *sorter) flush() error { + if len(s.orderBy) == 0 { + return s.sink.flush() + } + + go func() { + sort.Sort(s) + close(s.err) + }() + err := <-s.err + close(s.done) + + if err != nil { + return err + } + + for _, itm := range s.list { + if itm == nil { + break + } + stop, err := s.sink.add(itm) if err != nil { return err } @@ -117,7 +208,38 @@ func (s *sorter) flush(snk sink) error { } } - return snk.flush() + return s.sink.flush() +} + +func (s *sorter) Len() int { + // skip if we encountered an earlier error + select { + case <-s.done: + return 0 + default: + return len(s.list) + } +} + +func (s *sorter) Swap(i, j int) { + // skip if we encountered an earlier error + select { + case <-s.done: + return + default: + s.list[i], s.list[j] = s.list[j], s.list[i] + } +} + +func (s *sorter) Less(i, j int) bool { + // skip if we encountered an earlier error + select { + case <-s.done: + return false + default: + } + + return s.less(*s.list[i].value, *s.list[j].value) } type sink interface { @@ -156,6 +278,7 @@ func newListSink(node Node, to interface{}) (*listSink, error) { elemType: elemType, name: elemType.Name(), limit: -1, + results: reflect.MakeSlice(reflect.Indirect(ref).Type(), 0, 0), }, nil } @@ -192,10 +315,6 @@ func (l *listSink) add(i *item) (bool, error) { return false, nil } - if !l.results.IsValid() { - l.results = reflect.MakeSlice(reflect.Indirect(l.ref).Type(), 0, 0) - } - if l.limit > 0 { l.limit-- }