mirror of
https://github.com/asdine/storm.git
synced 2025-09-26 19:01:14 +08:00
Implement multi-field OrderBy
Results are sorted by the specified fields with descending precedence. Fixes #152
This commit is contained in:
@@ -378,23 +378,23 @@ err = db.Select(q.Or(
|
||||
),
|
||||
)).Find(&users)
|
||||
|
||||
query := db.Select(q.Gte("ID", 10), q.Lte("ID", 100)).Limit(10).Skip(5).Reverse().OrderBy("Age")
|
||||
query := db.Select(q.Gte("ID", 10), q.Lte("ID", 100)).Limit(10).Skip(5).Reverse().OrderBy("Age", "Name")
|
||||
|
||||
// Find multiple records
|
||||
err = query.Find(&users)
|
||||
// or
|
||||
err = db.Select(q.Gte("ID", 10), q.Lte("ID", 100)).Limit(10).Skip(5).Reverse().OrderBy("Age").Find(&users)
|
||||
err = db.Select(q.Gte("ID", 10), q.Lte("ID", 100)).Limit(10).Skip(5).Reverse().OrderBy("Age", "Name").Find(&users)
|
||||
|
||||
// Find first record
|
||||
err = query.First(&user)
|
||||
// or
|
||||
err = db.Select(q.Gte("ID", 10), q.Lte("ID", 100)).Limit(10).Skip(5).Reverse().OrderBy("Age").First(&user)
|
||||
err = db.Select(q.Gte("ID", 10), q.Lte("ID", 100)).Limit(10).Skip(5).Reverse().OrderBy("Age", "Name").First(&user)
|
||||
|
||||
// Delete all matching records
|
||||
err = query.Delete(new(User))
|
||||
|
||||
// Fetching records one by one (useful when the bucket contains a lot of records)
|
||||
query = db.Select(q.Gte("ID", 10),q.Lte("ID", 100)).OrderBy("Age")
|
||||
query = db.Select(q.Gte("ID", 10),q.Lte("ID", 100)).OrderBy("Age", "Name")
|
||||
|
||||
err = query.Each(new(User), func(record interface{}) error) {
|
||||
u := record.(*User)
|
||||
|
18
finder.go
18
finder.go
@@ -175,9 +175,6 @@ func (n *node) find(tx *bolt.Tx, bucketName, fieldName string, cfg *structConfig
|
||||
if bucket == nil {
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
sorter := newSorter(n)
|
||||
|
||||
idx, err := getIndex(bucket, cfg.Fields[fieldName].Index, fieldName)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -193,19 +190,19 @@ func (n *node) find(tx *bolt.Tx, bucketName, fieldName string, cfg *structConfig
|
||||
|
||||
sink.results = reflect.MakeSlice(reflect.Indirect(sink.ref).Type(), len(list), len(list))
|
||||
|
||||
sorter := newSorter(n, sink, nil, nil, false)
|
||||
for i := range list {
|
||||
raw := bucket.Get(list[i])
|
||||
if raw == nil {
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
_, err = sorter.filter(sink, nil, bucket, list[i], raw)
|
||||
if err != nil {
|
||||
if _, err := sorter.filter(bucket, list[i], raw); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return sink.flush()
|
||||
return sorter.flush()
|
||||
}
|
||||
|
||||
// AllByIndex gets all the records of a bucket that are indexed in the specified index
|
||||
@@ -381,8 +378,6 @@ func (n *node) rnge(tx *bolt.Tx, bucketName, fieldName string, cfg *structConfig
|
||||
return nil
|
||||
}
|
||||
|
||||
sorter := newSorter(n)
|
||||
|
||||
idx, err := getIndex(bucket, cfg.Fields[fieldName].Index, fieldName)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -394,20 +389,19 @@ func (n *node) rnge(tx *bolt.Tx, bucketName, fieldName string, cfg *structConfig
|
||||
}
|
||||
|
||||
sink.results = reflect.MakeSlice(reflect.Indirect(sink.ref).Type(), len(list), len(list))
|
||||
|
||||
sorter := newSorter(n, sink, nil, nil, false)
|
||||
for i := range list {
|
||||
raw := bucket.Get(list[i])
|
||||
if raw == nil {
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
_, err = sorter.filter(sink, nil, bucket, list[i], raw)
|
||||
if err != nil {
|
||||
if _, err := sorter.filter(bucket, list[i], raw); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return sink.flush()
|
||||
return sorter.flush()
|
||||
}
|
||||
|
||||
// Count counts all the records of a bucket
|
||||
|
27
query.go
27
query.go
@@ -20,8 +20,8 @@ type Query interface {
|
||||
// Limit the results by the given number
|
||||
Limit(int) Query
|
||||
|
||||
// Order by the given field.
|
||||
OrderBy(string) Query
|
||||
// Order by the given fields, in descending precedence, left-to-right.
|
||||
OrderBy(...string) Query
|
||||
|
||||
// Reverse the order of the results
|
||||
Reverse() Query
|
||||
@@ -53,11 +53,10 @@ type Query interface {
|
||||
|
||||
func newQuery(n *node, tree q.Matcher) *query {
|
||||
return &query{
|
||||
skip: 0,
|
||||
limit: -1,
|
||||
node: n,
|
||||
tree: tree,
|
||||
sorter: newSorter(n),
|
||||
skip: 0,
|
||||
limit: -1,
|
||||
node: n,
|
||||
tree: tree,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,7 +67,7 @@ type query struct {
|
||||
tree q.Matcher
|
||||
node *node
|
||||
bucket string
|
||||
sorter *sorter
|
||||
orderBy []string
|
||||
}
|
||||
|
||||
func (q *query) Skip(nb int) Query {
|
||||
@@ -81,14 +80,13 @@ func (q *query) Limit(nb int) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *query) OrderBy(field string) Query {
|
||||
q.sorter.orderBy = field
|
||||
func (q *query) OrderBy(field ...string) Query {
|
||||
q.orderBy = field
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *query) Reverse() Query {
|
||||
q.reverse = true
|
||||
q.sorter.reverse = true
|
||||
return q
|
||||
}
|
||||
|
||||
@@ -206,9 +204,10 @@ func (q *query) query(tx *bolt.Tx, sink sink) error {
|
||||
bucket := q.node.GetBucket(tx, bucketName)
|
||||
|
||||
if q.limit == 0 {
|
||||
return q.sorter.flush(sink)
|
||||
return sink.flush()
|
||||
}
|
||||
|
||||
sorter := newSorter(q.node, sink, q.tree, q.orderBy, q.reverse)
|
||||
if bucket != nil {
|
||||
c := internal.Cursor{C: bucket.Cursor(), Reverse: q.reverse}
|
||||
for k, v := c.First(); k != nil; k, v = c.Next() {
|
||||
@@ -216,7 +215,7 @@ func (q *query) query(tx *bolt.Tx, sink sink) error {
|
||||
continue
|
||||
}
|
||||
|
||||
stop, err := q.sorter.filter(sink, q.tree, bucket, k, v)
|
||||
stop, err := sorter.filter(bucket, k, v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -227,5 +226,5 @@ func (q *query) query(tx *bolt.Tx, sink sink) error {
|
||||
}
|
||||
}
|
||||
|
||||
return q.sorter.flush(sink)
|
||||
return sorter.flush()
|
||||
}
|
||||
|
@@ -213,9 +213,9 @@ func TestSelectFindOrderBy(t *testing.T) {
|
||||
Rnd int
|
||||
}
|
||||
|
||||
strs := []string{"e", "b", "a", "c", "d"}
|
||||
ints := []int{2, 3, 1, 4, 5}
|
||||
for i := 0; i < 5; i++ {
|
||||
strs := []string{"e", "b", "d", "a", "c", "d"}
|
||||
ints := []int{2, 3, 5, 4, 2, 1}
|
||||
for i := 0; i < 6; i++ {
|
||||
record := T{
|
||||
Str: strs[i],
|
||||
Int: ints[i],
|
||||
@@ -231,39 +231,52 @@ func TestSelectFindOrderBy(t *testing.T) {
|
||||
var list []T
|
||||
err := db.Select().OrderBy("ID").Find(&list)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, list, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
assert.Len(t, list, 6)
|
||||
for i, j := 0, 0; i < 6; i, j = i+1, j+1 {
|
||||
if i == 2 {
|
||||
j--
|
||||
}
|
||||
assert.Equal(t, i+1, list[i].ID)
|
||||
}
|
||||
|
||||
err = db.Select().OrderBy("Str").Find(&list)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, list, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
assert.Equal(t, string([]byte{'a' + byte(i)}), list[i].Str)
|
||||
assert.Len(t, list, 6)
|
||||
for i, j := 0, 0; i < 6; i, j = i+1, j+1 {
|
||||
if i == 4 {
|
||||
j--
|
||||
}
|
||||
assert.Equal(t, string([]byte{'a' + byte(j)}), list[i].Str)
|
||||
}
|
||||
|
||||
err = db.Select().OrderBy("Int").Find(&list)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, list, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
assert.Equal(t, i+1, list[i].Int)
|
||||
assert.Len(t, list, 6)
|
||||
for i, j := 0, 0; i < 6; i, j = i+1, j+1 {
|
||||
if i == 2 {
|
||||
j--
|
||||
}
|
||||
assert.Equal(t, j+1, list[i].Int)
|
||||
}
|
||||
|
||||
err = db.Select().OrderBy("Rnd").Find(&list)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, list, 5)
|
||||
assert.Len(t, list, 6)
|
||||
assert.Equal(t, 1, list[0].ID)
|
||||
assert.Equal(t, 2, list[1].ID)
|
||||
assert.Equal(t, 3, list[2].ID)
|
||||
assert.Equal(t, 5, list[3].ID)
|
||||
assert.Equal(t, 4, list[4].ID)
|
||||
assert.Equal(t, 6, list[4].ID)
|
||||
assert.Equal(t, 4, list[5].ID)
|
||||
|
||||
err = db.Select().OrderBy("Int").Reverse().Find(&list)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, list, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
assert.Equal(t, 5-i, list[i].Int)
|
||||
assert.Len(t, list, 6)
|
||||
for i, j := 0, 0; i < 6; i, j = i+1, j+1 {
|
||||
if i == 4 {
|
||||
j--
|
||||
}
|
||||
assert.Equal(t, 5-j, list[i].Int)
|
||||
}
|
||||
|
||||
err = db.Select().OrderBy("Int").Reverse().Limit(2).Find(&list)
|
||||
@@ -275,15 +288,34 @@ func TestSelectFindOrderBy(t *testing.T) {
|
||||
|
||||
err = db.Select().OrderBy("Int").Reverse().Skip(2).Find(&list)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, list, 3)
|
||||
for i := 0; i < 2; i++ {
|
||||
assert.Equal(t, 3-i, list[i].Int)
|
||||
assert.Len(t, list, 4)
|
||||
for i, j := 0, 0; i < 3; i, j = i+1, j+1 {
|
||||
if i == 2 {
|
||||
j--
|
||||
}
|
||||
assert.Equal(t, 3-j, list[i].Int)
|
||||
}
|
||||
|
||||
err = db.Select().OrderBy("Int").Reverse().Skip(4).Limit(2).Find(&list)
|
||||
err = db.Select().OrderBy("Int").Reverse().Skip(5).Limit(2).Find(&list)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, list, 1)
|
||||
assert.Equal(t, 1, list[0].Int)
|
||||
|
||||
err = db.Select().OrderBy("Str", "Int").Find(&list)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, list, 6)
|
||||
assert.Equal(t, "a", list[0].Str)
|
||||
assert.Equal(t, 4, list[0].Int)
|
||||
assert.Equal(t, "b", list[1].Str)
|
||||
assert.Equal(t, 3, list[1].Int)
|
||||
assert.Equal(t, "c", list[2].Str)
|
||||
assert.Equal(t, 2, list[2].Int)
|
||||
assert.Equal(t, "d", list[3].Str)
|
||||
assert.Equal(t, 1, list[3].Int)
|
||||
assert.Equal(t, "d", list[4].Str)
|
||||
assert.Equal(t, 5, list[4].Int)
|
||||
assert.Equal(t, "e", list[5].Str)
|
||||
assert.Equal(t, 2, list[5].Int)
|
||||
}
|
||||
|
||||
func TestSelectFirst(t *testing.T) {
|
||||
|
253
sink.go
253
sink.go
@@ -1,14 +1,12 @@
|
||||
package storm
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"reflect"
|
||||
"sort"
|
||||
|
||||
"github.com/asdine/storm/index"
|
||||
"github.com/asdine/storm/q"
|
||||
"github.com/boltdb/bolt"
|
||||
|
||||
rbt "github.com/emirpasic/gods/trees/redblacktree"
|
||||
)
|
||||
|
||||
type item struct {
|
||||
@@ -18,97 +16,190 @@ type item struct {
|
||||
v []byte
|
||||
}
|
||||
|
||||
func newSorter(node Node) *sorter {
|
||||
func newSorter(n Node, snk sink, tree q.Matcher, orderBy []string, reverse bool) *sorter {
|
||||
return &sorter{
|
||||
node: node,
|
||||
rbTree: rbt.NewWithStringComparator(),
|
||||
node: n,
|
||||
sink: snk,
|
||||
tree: tree,
|
||||
orderBy: orderBy,
|
||||
reverse: reverse,
|
||||
list: make([]*item, 0),
|
||||
err: make(chan error),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
type sorter struct {
|
||||
node Node
|
||||
rbTree *rbt.Tree
|
||||
orderBy string
|
||||
sink sink
|
||||
tree q.Matcher
|
||||
list []*item
|
||||
orderBy []string
|
||||
reverse bool
|
||||
counter int64
|
||||
err chan error
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func (s *sorter) filter(snk sink, tree q.Matcher, bucket *bolt.Bucket, k, v []byte) (bool, error) {
|
||||
s.counter++
|
||||
|
||||
rsnk, ok := snk.(reflectSink)
|
||||
func (s *sorter) filter(bucket *bolt.Bucket, k, v []byte) (bool, error) {
|
||||
rsink, ok := s.sink.(reflectSink)
|
||||
if !ok {
|
||||
return snk.add(&item{
|
||||
return s.sink.add(&item{
|
||||
bucket: bucket,
|
||||
k: k,
|
||||
v: v,
|
||||
})
|
||||
}
|
||||
|
||||
newElem := rsnk.elem()
|
||||
err := s.node.Codec().Unmarshal(v, newElem.Interface())
|
||||
if err != nil {
|
||||
newElem := rsink.elem()
|
||||
if err := s.node.Codec().Unmarshal(v, newElem.Interface()); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
ok = tree == nil
|
||||
if !ok {
|
||||
ok, err = tree.Match(newElem.Interface())
|
||||
itm := &item{
|
||||
bucket: bucket,
|
||||
value: &newElem,
|
||||
k: k,
|
||||
v: v,
|
||||
}
|
||||
|
||||
if s.tree == nil {
|
||||
if len(s.orderBy) == 0 {
|
||||
return s.sink.add(itm)
|
||||
}
|
||||
} else {
|
||||
ok, err := s.tree.Match(newElem.Interface())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
if ok {
|
||||
it := item{
|
||||
bucket: bucket,
|
||||
value: &newElem,
|
||||
k: k,
|
||||
v: v,
|
||||
}
|
||||
|
||||
if s.orderBy != "" {
|
||||
elm := reflect.Indirect(newElem).FieldByName(s.orderBy)
|
||||
if !elm.IsValid() {
|
||||
return false, ErrNotFound
|
||||
}
|
||||
raw, err := toBytes(elm.Interface(), s.node.Codec())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
key := make([]byte, len(raw)+8)
|
||||
for i := 0; i < len(raw); i++ {
|
||||
key[i] = raw[i]
|
||||
}
|
||||
binary.PutVarint(key[len(raw):], s.counter)
|
||||
s.rbTree.Put(string(key), &it)
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return snk.add(&it)
|
||||
}
|
||||
|
||||
if len(s.orderBy) == 0 {
|
||||
return s.sink.add(itm)
|
||||
}
|
||||
|
||||
s.list = append(s.list, itm)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *sorter) flush(snk sink) error {
|
||||
if s.orderBy == "" {
|
||||
return snk.flush()
|
||||
func (s *sorter) compareValue(left reflect.Value, right reflect.Value) int {
|
||||
if !left.IsValid() || !right.IsValid() {
|
||||
if left.IsValid() {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
}
|
||||
s.orderBy = ""
|
||||
var err error
|
||||
var stop bool
|
||||
|
||||
it := s.rbTree.Iterator()
|
||||
if s.reverse {
|
||||
it.End()
|
||||
} else {
|
||||
it.Begin()
|
||||
switch left.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
l, r := left.Int(), right.Int()
|
||||
if l < r {
|
||||
return -1
|
||||
}
|
||||
if l > r {
|
||||
return 1
|
||||
}
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
l, r := left.Uint(), right.Uint()
|
||||
if l < r {
|
||||
return -1
|
||||
}
|
||||
if l > r {
|
||||
return 1
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
l, r := left.Float(), right.Float()
|
||||
if l < r {
|
||||
return -1
|
||||
}
|
||||
if l > r {
|
||||
return 1
|
||||
}
|
||||
case reflect.String:
|
||||
l, r := left.String(), right.String()
|
||||
if l < r {
|
||||
return -1
|
||||
}
|
||||
if l > r {
|
||||
return 1
|
||||
}
|
||||
default:
|
||||
rawLeft, err := toBytes(left.Interface(), s.node.Codec())
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
rawRight, err := toBytes(right.Interface(), s.node.Codec())
|
||||
if err != nil {
|
||||
return 1
|
||||
}
|
||||
|
||||
l, r := string(rawLeft), string(rawRight)
|
||||
if l < r {
|
||||
return -1
|
||||
}
|
||||
if l > r {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
for (s.reverse && it.Prev()) || (!s.reverse && it.Next()) {
|
||||
item := it.Value().(*item)
|
||||
stop, err = snk.add(item)
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *sorter) less(leftElem reflect.Value, rightElem reflect.Value) bool {
|
||||
for _, orderBy := range s.orderBy {
|
||||
leftField := reflect.Indirect(leftElem).FieldByName(orderBy)
|
||||
if !leftField.IsValid() {
|
||||
s.err <- ErrNotFound
|
||||
return false
|
||||
}
|
||||
rightField := reflect.Indirect(rightElem).FieldByName(orderBy)
|
||||
if !rightField.IsValid() {
|
||||
s.err <- ErrNotFound
|
||||
return false
|
||||
}
|
||||
|
||||
direction := 1
|
||||
if s.reverse {
|
||||
direction = -1
|
||||
}
|
||||
|
||||
switch s.compareValue(leftField, rightField) * direction {
|
||||
case -1:
|
||||
return true
|
||||
case 1:
|
||||
return false
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *sorter) flush() error {
|
||||
if len(s.orderBy) == 0 {
|
||||
return s.sink.flush()
|
||||
}
|
||||
|
||||
go func() {
|
||||
sort.Sort(s)
|
||||
close(s.err)
|
||||
}()
|
||||
err := <-s.err
|
||||
close(s.done)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, itm := range s.list {
|
||||
if itm == nil {
|
||||
break
|
||||
}
|
||||
stop, err := s.sink.add(itm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -117,7 +208,38 @@ func (s *sorter) flush(snk sink) error {
|
||||
}
|
||||
}
|
||||
|
||||
return snk.flush()
|
||||
return s.sink.flush()
|
||||
}
|
||||
|
||||
func (s *sorter) Len() int {
|
||||
// skip if we encountered an earlier error
|
||||
select {
|
||||
case <-s.done:
|
||||
return 0
|
||||
default:
|
||||
return len(s.list)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sorter) Swap(i, j int) {
|
||||
// skip if we encountered an earlier error
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
default:
|
||||
s.list[i], s.list[j] = s.list[j], s.list[i]
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sorter) Less(i, j int) bool {
|
||||
// skip if we encountered an earlier error
|
||||
select {
|
||||
case <-s.done:
|
||||
return false
|
||||
default:
|
||||
}
|
||||
|
||||
return s.less(*s.list[i].value, *s.list[j].value)
|
||||
}
|
||||
|
||||
type sink interface {
|
||||
@@ -156,6 +278,7 @@ func newListSink(node Node, to interface{}) (*listSink, error) {
|
||||
elemType: elemType,
|
||||
name: elemType.Name(),
|
||||
limit: -1,
|
||||
results: reflect.MakeSlice(reflect.Indirect(ref).Type(), 0, 0),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -192,10 +315,6 @@ func (l *listSink) add(i *item) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if !l.results.IsValid() {
|
||||
l.results = reflect.MakeSlice(reflect.Indirect(l.ref).Type(), 0, 0)
|
||||
}
|
||||
|
||||
if l.limit > 0 {
|
||||
l.limit--
|
||||
}
|
||||
|
Reference in New Issue
Block a user