Implement multi-field OrderBy

Results are sorted by the specified fields with descending precedence.

Fixes #152
This commit is contained in:
Peter Fern
2017-05-06 05:45:31 +10:00
parent f09e3cd6c6
commit dbf518c94e
5 changed files with 261 additions and 117 deletions

View File

@@ -378,23 +378,23 @@ err = db.Select(q.Or(
),
)).Find(&users)
query := db.Select(q.Gte("ID", 10), q.Lte("ID", 100)).Limit(10).Skip(5).Reverse().OrderBy("Age")
query := db.Select(q.Gte("ID", 10), q.Lte("ID", 100)).Limit(10).Skip(5).Reverse().OrderBy("Age", "Name")
// Find multiple records
err = query.Find(&users)
// or
err = db.Select(q.Gte("ID", 10), q.Lte("ID", 100)).Limit(10).Skip(5).Reverse().OrderBy("Age").Find(&users)
err = db.Select(q.Gte("ID", 10), q.Lte("ID", 100)).Limit(10).Skip(5).Reverse().OrderBy("Age", "Name").Find(&users)
// Find first record
err = query.First(&user)
// or
err = db.Select(q.Gte("ID", 10), q.Lte("ID", 100)).Limit(10).Skip(5).Reverse().OrderBy("Age").First(&user)
err = db.Select(q.Gte("ID", 10), q.Lte("ID", 100)).Limit(10).Skip(5).Reverse().OrderBy("Age", "Name").First(&user)
// Delete all matching records
err = query.Delete(new(User))
// Fetching records one by one (useful when the bucket contains a lot of records)
query = db.Select(q.Gte("ID", 10),q.Lte("ID", 100)).OrderBy("Age")
query = db.Select(q.Gte("ID", 10),q.Lte("ID", 100)).OrderBy("Age", "Name")
err = query.Each(new(User), func(record interface{}) error) {
u := record.(*User)

View File

@@ -175,9 +175,6 @@ func (n *node) find(tx *bolt.Tx, bucketName, fieldName string, cfg *structConfig
if bucket == nil {
return ErrNotFound
}
sorter := newSorter(n)
idx, err := getIndex(bucket, cfg.Fields[fieldName].Index, fieldName)
if err != nil {
return err
@@ -193,19 +190,19 @@ func (n *node) find(tx *bolt.Tx, bucketName, fieldName string, cfg *structConfig
sink.results = reflect.MakeSlice(reflect.Indirect(sink.ref).Type(), len(list), len(list))
sorter := newSorter(n, sink, nil, nil, false)
for i := range list {
raw := bucket.Get(list[i])
if raw == nil {
return ErrNotFound
}
_, err = sorter.filter(sink, nil, bucket, list[i], raw)
if err != nil {
if _, err := sorter.filter(bucket, list[i], raw); err != nil {
return err
}
}
return sink.flush()
return sorter.flush()
}
// AllByIndex gets all the records of a bucket that are indexed in the specified index
@@ -381,8 +378,6 @@ func (n *node) rnge(tx *bolt.Tx, bucketName, fieldName string, cfg *structConfig
return nil
}
sorter := newSorter(n)
idx, err := getIndex(bucket, cfg.Fields[fieldName].Index, fieldName)
if err != nil {
return err
@@ -394,20 +389,19 @@ func (n *node) rnge(tx *bolt.Tx, bucketName, fieldName string, cfg *structConfig
}
sink.results = reflect.MakeSlice(reflect.Indirect(sink.ref).Type(), len(list), len(list))
sorter := newSorter(n, sink, nil, nil, false)
for i := range list {
raw := bucket.Get(list[i])
if raw == nil {
return ErrNotFound
}
_, err = sorter.filter(sink, nil, bucket, list[i], raw)
if err != nil {
if _, err := sorter.filter(bucket, list[i], raw); err != nil {
return err
}
}
return sink.flush()
return sorter.flush()
}
// Count counts all the records of a bucket

View File

@@ -20,8 +20,8 @@ type Query interface {
// Limit the results by the given number
Limit(int) Query
// Order by the given field.
OrderBy(string) Query
// Order by the given fields, in descending precedence, left-to-right.
OrderBy(...string) Query
// Reverse the order of the results
Reverse() Query
@@ -53,11 +53,10 @@ type Query interface {
func newQuery(n *node, tree q.Matcher) *query {
return &query{
skip: 0,
limit: -1,
node: n,
tree: tree,
sorter: newSorter(n),
skip: 0,
limit: -1,
node: n,
tree: tree,
}
}
@@ -68,7 +67,7 @@ type query struct {
tree q.Matcher
node *node
bucket string
sorter *sorter
orderBy []string
}
func (q *query) Skip(nb int) Query {
@@ -81,14 +80,13 @@ func (q *query) Limit(nb int) Query {
return q
}
func (q *query) OrderBy(field string) Query {
q.sorter.orderBy = field
func (q *query) OrderBy(field ...string) Query {
q.orderBy = field
return q
}
func (q *query) Reverse() Query {
q.reverse = true
q.sorter.reverse = true
return q
}
@@ -206,9 +204,10 @@ func (q *query) query(tx *bolt.Tx, sink sink) error {
bucket := q.node.GetBucket(tx, bucketName)
if q.limit == 0 {
return q.sorter.flush(sink)
return sink.flush()
}
sorter := newSorter(q.node, sink, q.tree, q.orderBy, q.reverse)
if bucket != nil {
c := internal.Cursor{C: bucket.Cursor(), Reverse: q.reverse}
for k, v := c.First(); k != nil; k, v = c.Next() {
@@ -216,7 +215,7 @@ func (q *query) query(tx *bolt.Tx, sink sink) error {
continue
}
stop, err := q.sorter.filter(sink, q.tree, bucket, k, v)
stop, err := sorter.filter(bucket, k, v)
if err != nil {
return err
}
@@ -227,5 +226,5 @@ func (q *query) query(tx *bolt.Tx, sink sink) error {
}
}
return q.sorter.flush(sink)
return sorter.flush()
}

View File

@@ -213,9 +213,9 @@ func TestSelectFindOrderBy(t *testing.T) {
Rnd int
}
strs := []string{"e", "b", "a", "c", "d"}
ints := []int{2, 3, 1, 4, 5}
for i := 0; i < 5; i++ {
strs := []string{"e", "b", "d", "a", "c", "d"}
ints := []int{2, 3, 5, 4, 2, 1}
for i := 0; i < 6; i++ {
record := T{
Str: strs[i],
Int: ints[i],
@@ -231,39 +231,52 @@ func TestSelectFindOrderBy(t *testing.T) {
var list []T
err := db.Select().OrderBy("ID").Find(&list)
assert.NoError(t, err)
assert.Len(t, list, 5)
for i := 0; i < 5; i++ {
assert.Len(t, list, 6)
for i, j := 0, 0; i < 6; i, j = i+1, j+1 {
if i == 2 {
j--
}
assert.Equal(t, i+1, list[i].ID)
}
err = db.Select().OrderBy("Str").Find(&list)
assert.NoError(t, err)
assert.Len(t, list, 5)
for i := 0; i < 5; i++ {
assert.Equal(t, string([]byte{'a' + byte(i)}), list[i].Str)
assert.Len(t, list, 6)
for i, j := 0, 0; i < 6; i, j = i+1, j+1 {
if i == 4 {
j--
}
assert.Equal(t, string([]byte{'a' + byte(j)}), list[i].Str)
}
err = db.Select().OrderBy("Int").Find(&list)
assert.NoError(t, err)
assert.Len(t, list, 5)
for i := 0; i < 5; i++ {
assert.Equal(t, i+1, list[i].Int)
assert.Len(t, list, 6)
for i, j := 0, 0; i < 6; i, j = i+1, j+1 {
if i == 2 {
j--
}
assert.Equal(t, j+1, list[i].Int)
}
err = db.Select().OrderBy("Rnd").Find(&list)
assert.NoError(t, err)
assert.Len(t, list, 5)
assert.Len(t, list, 6)
assert.Equal(t, 1, list[0].ID)
assert.Equal(t, 2, list[1].ID)
assert.Equal(t, 3, list[2].ID)
assert.Equal(t, 5, list[3].ID)
assert.Equal(t, 4, list[4].ID)
assert.Equal(t, 6, list[4].ID)
assert.Equal(t, 4, list[5].ID)
err = db.Select().OrderBy("Int").Reverse().Find(&list)
assert.NoError(t, err)
assert.Len(t, list, 5)
for i := 0; i < 5; i++ {
assert.Equal(t, 5-i, list[i].Int)
assert.Len(t, list, 6)
for i, j := 0, 0; i < 6; i, j = i+1, j+1 {
if i == 4 {
j--
}
assert.Equal(t, 5-j, list[i].Int)
}
err = db.Select().OrderBy("Int").Reverse().Limit(2).Find(&list)
@@ -275,15 +288,34 @@ func TestSelectFindOrderBy(t *testing.T) {
err = db.Select().OrderBy("Int").Reverse().Skip(2).Find(&list)
assert.NoError(t, err)
assert.Len(t, list, 3)
for i := 0; i < 2; i++ {
assert.Equal(t, 3-i, list[i].Int)
assert.Len(t, list, 4)
for i, j := 0, 0; i < 3; i, j = i+1, j+1 {
if i == 2 {
j--
}
assert.Equal(t, 3-j, list[i].Int)
}
err = db.Select().OrderBy("Int").Reverse().Skip(4).Limit(2).Find(&list)
err = db.Select().OrderBy("Int").Reverse().Skip(5).Limit(2).Find(&list)
assert.NoError(t, err)
assert.Len(t, list, 1)
assert.Equal(t, 1, list[0].Int)
err = db.Select().OrderBy("Str", "Int").Find(&list)
assert.NoError(t, err)
assert.Len(t, list, 6)
assert.Equal(t, "a", list[0].Str)
assert.Equal(t, 4, list[0].Int)
assert.Equal(t, "b", list[1].Str)
assert.Equal(t, 3, list[1].Int)
assert.Equal(t, "c", list[2].Str)
assert.Equal(t, 2, list[2].Int)
assert.Equal(t, "d", list[3].Str)
assert.Equal(t, 1, list[3].Int)
assert.Equal(t, "d", list[4].Str)
assert.Equal(t, 5, list[4].Int)
assert.Equal(t, "e", list[5].Str)
assert.Equal(t, 2, list[5].Int)
}
func TestSelectFirst(t *testing.T) {

253
sink.go
View File

@@ -1,14 +1,12 @@
package storm
import (
"encoding/binary"
"reflect"
"sort"
"github.com/asdine/storm/index"
"github.com/asdine/storm/q"
"github.com/boltdb/bolt"
rbt "github.com/emirpasic/gods/trees/redblacktree"
)
type item struct {
@@ -18,97 +16,190 @@ type item struct {
v []byte
}
func newSorter(node Node) *sorter {
func newSorter(n Node, snk sink, tree q.Matcher, orderBy []string, reverse bool) *sorter {
return &sorter{
node: node,
rbTree: rbt.NewWithStringComparator(),
node: n,
sink: snk,
tree: tree,
orderBy: orderBy,
reverse: reverse,
list: make([]*item, 0),
err: make(chan error),
done: make(chan struct{}),
}
}
type sorter struct {
node Node
rbTree *rbt.Tree
orderBy string
sink sink
tree q.Matcher
list []*item
orderBy []string
reverse bool
counter int64
err chan error
done chan struct{}
}
func (s *sorter) filter(snk sink, tree q.Matcher, bucket *bolt.Bucket, k, v []byte) (bool, error) {
s.counter++
rsnk, ok := snk.(reflectSink)
func (s *sorter) filter(bucket *bolt.Bucket, k, v []byte) (bool, error) {
rsink, ok := s.sink.(reflectSink)
if !ok {
return snk.add(&item{
return s.sink.add(&item{
bucket: bucket,
k: k,
v: v,
})
}
newElem := rsnk.elem()
err := s.node.Codec().Unmarshal(v, newElem.Interface())
if err != nil {
newElem := rsink.elem()
if err := s.node.Codec().Unmarshal(v, newElem.Interface()); err != nil {
return false, err
}
ok = tree == nil
if !ok {
ok, err = tree.Match(newElem.Interface())
itm := &item{
bucket: bucket,
value: &newElem,
k: k,
v: v,
}
if s.tree == nil {
if len(s.orderBy) == 0 {
return s.sink.add(itm)
}
} else {
ok, err := s.tree.Match(newElem.Interface())
if err != nil {
return false, err
}
}
if ok {
it := item{
bucket: bucket,
value: &newElem,
k: k,
v: v,
}
if s.orderBy != "" {
elm := reflect.Indirect(newElem).FieldByName(s.orderBy)
if !elm.IsValid() {
return false, ErrNotFound
}
raw, err := toBytes(elm.Interface(), s.node.Codec())
if err != nil {
return false, err
}
key := make([]byte, len(raw)+8)
for i := 0; i < len(raw); i++ {
key[i] = raw[i]
}
binary.PutVarint(key[len(raw):], s.counter)
s.rbTree.Put(string(key), &it)
if !ok {
return false, nil
}
return snk.add(&it)
}
if len(s.orderBy) == 0 {
return s.sink.add(itm)
}
s.list = append(s.list, itm)
return false, nil
}
func (s *sorter) flush(snk sink) error {
if s.orderBy == "" {
return snk.flush()
func (s *sorter) compareValue(left reflect.Value, right reflect.Value) int {
if !left.IsValid() || !right.IsValid() {
if left.IsValid() {
return 1
}
return -1
}
s.orderBy = ""
var err error
var stop bool
it := s.rbTree.Iterator()
if s.reverse {
it.End()
} else {
it.Begin()
switch left.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
l, r := left.Int(), right.Int()
if l < r {
return -1
}
if l > r {
return 1
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
l, r := left.Uint(), right.Uint()
if l < r {
return -1
}
if l > r {
return 1
}
case reflect.Float32, reflect.Float64:
l, r := left.Float(), right.Float()
if l < r {
return -1
}
if l > r {
return 1
}
case reflect.String:
l, r := left.String(), right.String()
if l < r {
return -1
}
if l > r {
return 1
}
default:
rawLeft, err := toBytes(left.Interface(), s.node.Codec())
if err != nil {
return -1
}
rawRight, err := toBytes(right.Interface(), s.node.Codec())
if err != nil {
return 1
}
l, r := string(rawLeft), string(rawRight)
if l < r {
return -1
}
if l > r {
return 1
}
}
for (s.reverse && it.Prev()) || (!s.reverse && it.Next()) {
item := it.Value().(*item)
stop, err = snk.add(item)
return 0
}
func (s *sorter) less(leftElem reflect.Value, rightElem reflect.Value) bool {
for _, orderBy := range s.orderBy {
leftField := reflect.Indirect(leftElem).FieldByName(orderBy)
if !leftField.IsValid() {
s.err <- ErrNotFound
return false
}
rightField := reflect.Indirect(rightElem).FieldByName(orderBy)
if !rightField.IsValid() {
s.err <- ErrNotFound
return false
}
direction := 1
if s.reverse {
direction = -1
}
switch s.compareValue(leftField, rightField) * direction {
case -1:
return true
case 1:
return false
default:
continue
}
}
return false
}
func (s *sorter) flush() error {
if len(s.orderBy) == 0 {
return s.sink.flush()
}
go func() {
sort.Sort(s)
close(s.err)
}()
err := <-s.err
close(s.done)
if err != nil {
return err
}
for _, itm := range s.list {
if itm == nil {
break
}
stop, err := s.sink.add(itm)
if err != nil {
return err
}
@@ -117,7 +208,38 @@ func (s *sorter) flush(snk sink) error {
}
}
return snk.flush()
return s.sink.flush()
}
func (s *sorter) Len() int {
// skip if we encountered an earlier error
select {
case <-s.done:
return 0
default:
return len(s.list)
}
}
func (s *sorter) Swap(i, j int) {
// skip if we encountered an earlier error
select {
case <-s.done:
return
default:
s.list[i], s.list[j] = s.list[j], s.list[i]
}
}
func (s *sorter) Less(i, j int) bool {
// skip if we encountered an earlier error
select {
case <-s.done:
return false
default:
}
return s.less(*s.list[i].value, *s.list[j].value)
}
type sink interface {
@@ -156,6 +278,7 @@ func newListSink(node Node, to interface{}) (*listSink, error) {
elemType: elemType,
name: elemType.Name(),
limit: -1,
results: reflect.MakeSlice(reflect.Indirect(ref).Type(), 0, 0),
}, nil
}
@@ -192,10 +315,6 @@ func (l *listSink) add(i *item) (bool, error) {
return false, nil
}
if !l.results.IsValid() {
l.results = reflect.MakeSlice(reflect.Indirect(l.ref).Type(), 0, 0)
}
if l.limit > 0 {
l.limit--
}