mirror of
https://github.com/timshannon/bolthold.git
synced 2025-12-24 13:37:58 +08:00
600 lines
12 KiB
Go
600 lines
12 KiB
Go
// Copyright 2016 Tim Shannon. All rights reserved.
|
|
// Use of this source code is governed by the MIT license
|
|
// that can be found in the LICENSE file.
|
|
|
|
package bolthold
|
|
|
|
import (
|
|
"fmt"
|
|
"reflect"
|
|
"sort"
|
|
"strings"
|
|
)
|
|
|
|
type record struct {
|
|
key []byte
|
|
value reflect.Value
|
|
}
|
|
|
|
func (s *Store) runQuery(source BucketSource, dataType interface{}, query *Query, retrievedKeys keyList, skip int,
|
|
action func(r *record) error) error {
|
|
storer := s.newStorer(dataType)
|
|
|
|
bkt := source.Bucket([]byte(storer.Type()))
|
|
if bkt == nil {
|
|
// if the bucket doesn't exist or is empty then our job is really easy!
|
|
return nil
|
|
}
|
|
|
|
if query.index != "" && source.Bucket(indexBucketName(storer.Type(), query.index)) == nil {
|
|
return fmt.Errorf("The index %s does not exist", query.index)
|
|
}
|
|
|
|
tp := dataType
|
|
|
|
for reflect.TypeOf(tp).Kind() == reflect.Ptr {
|
|
tp = reflect.ValueOf(tp).Elem().Interface()
|
|
}
|
|
|
|
query.dataType = reflect.TypeOf(tp)
|
|
|
|
if len(query.sort) > 0 {
|
|
return s.runQuerySort(source, dataType, query, action)
|
|
}
|
|
|
|
iter := s.newIterator(source, storer.Type(), query)
|
|
|
|
newKeys := make(keyList, 0)
|
|
|
|
limit := query.limit - len(retrievedKeys)
|
|
|
|
for k, v := iter.Next(); k != nil; k, v = iter.Next() {
|
|
if len(retrievedKeys) != 0 {
|
|
// don't check this record if it's already been retrieved
|
|
if retrievedKeys.in(k) {
|
|
continue
|
|
}
|
|
}
|
|
|
|
val := reflect.New(reflect.TypeOf(tp))
|
|
|
|
err := s.decode(v, val.Interface())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
query.source = source
|
|
|
|
ok, err := query.matchesAllFields(s, k, val, val.Interface())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if ok {
|
|
if skip > 0 {
|
|
skip--
|
|
continue
|
|
}
|
|
|
|
err = action(&record{
|
|
key: k,
|
|
value: val,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// track that this key's entry has been added to the result list
|
|
newKeys.add(k)
|
|
|
|
if query.limit != 0 {
|
|
limit--
|
|
if limit == 0 {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
if iter.Error() != nil {
|
|
return iter.Error()
|
|
}
|
|
|
|
if query.limit != 0 && limit == 0 {
|
|
return nil
|
|
}
|
|
|
|
if len(query.ors) > 0 {
|
|
for i := range newKeys {
|
|
retrievedKeys.add(newKeys[i])
|
|
}
|
|
|
|
for i := range query.ors {
|
|
err := s.runQuery(source, tp, query.ors[i], retrievedKeys, skip, action)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// runQuerySort runs the query without sort, skip, or limit, then applies them to the entire result set
|
|
func (s *Store) runQuerySort(source BucketSource, dataType interface{}, query *Query, action func(r *record) error) error {
|
|
// Validate sort fields
|
|
for _, field := range query.sort {
|
|
fields := strings.Split(field, ".")
|
|
|
|
current := query.dataType
|
|
for i := range fields {
|
|
var structField reflect.StructField
|
|
found := false
|
|
if current.Kind() == reflect.Ptr {
|
|
structField, found = current.Elem().FieldByName(fields[i])
|
|
} else {
|
|
structField, found = current.FieldByName(fields[i])
|
|
}
|
|
|
|
if !found {
|
|
return fmt.Errorf("The field %s does not exist in the type %s", field, query.dataType)
|
|
}
|
|
current = structField.Type
|
|
}
|
|
}
|
|
|
|
// Run query without sort, skip or limit
|
|
// apply sort, skip and limit to entire dataset
|
|
qCopy := *query
|
|
qCopy.sort = nil
|
|
qCopy.limit = 0
|
|
qCopy.skip = 0
|
|
|
|
var records []*record
|
|
err := s.runQuery(source, dataType, &qCopy, nil, 0,
|
|
func(r *record) error {
|
|
records = append(records, r)
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
sort.Slice(records, func(i, j int) bool {
|
|
for _, field := range query.sort {
|
|
value, err := fieldValue(records[i].value.Elem(), field)
|
|
if err != nil {
|
|
panic(err.Error()) // shouldn't happen due to field check above
|
|
}
|
|
|
|
other, err := fieldValue(records[j].value.Elem(), field)
|
|
if err != nil {
|
|
panic(err.Error()) // shouldn't happen due to field check above
|
|
}
|
|
|
|
if query.reverse {
|
|
value, other = other, value
|
|
}
|
|
|
|
cmp, cerr := compare(value, other)
|
|
if cerr != nil {
|
|
// if for some reason there is an error on compare, fallback to a lexicographic compare
|
|
valS := fmt.Sprintf("%s", value)
|
|
otherS := fmt.Sprintf("%s", other)
|
|
if valS < otherS {
|
|
return true
|
|
} else if valS == otherS {
|
|
continue
|
|
}
|
|
return false
|
|
}
|
|
|
|
if cmp == -1 {
|
|
return true
|
|
} else if cmp == 0 {
|
|
continue
|
|
}
|
|
return false
|
|
}
|
|
return false
|
|
})
|
|
|
|
// apply skip and limit
|
|
limit := query.limit
|
|
skip := query.skip
|
|
|
|
if skip > len(records) {
|
|
records = records[0:0]
|
|
} else {
|
|
records = records[skip:]
|
|
}
|
|
|
|
if limit > 0 && limit <= len(records) {
|
|
records = records[:limit]
|
|
}
|
|
|
|
for i := range records {
|
|
err = action(records[i])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
func (s *Store) findQuery(source BucketSource, result interface{}, query *Query) error {
|
|
if query == nil {
|
|
query = &Query{}
|
|
}
|
|
|
|
resultVal := reflect.ValueOf(result)
|
|
if resultVal.Kind() != reflect.Ptr || resultVal.Elem().Kind() != reflect.Slice {
|
|
panic("result argument must be a slice address")
|
|
}
|
|
|
|
sliceVal := resultVal.Elem()
|
|
|
|
elType := sliceVal.Type().Elem()
|
|
|
|
tp := elType
|
|
|
|
for tp.Kind() == reflect.Ptr {
|
|
tp = tp.Elem()
|
|
}
|
|
|
|
var keyType reflect.Type
|
|
var keyField string
|
|
|
|
for i := 0; i < tp.NumField(); i++ {
|
|
if strings.Contains(string(tp.Field(i).Tag), BoltholdKeyTag) {
|
|
keyType = tp.Field(i).Type
|
|
keyField = tp.Field(i).Name
|
|
break
|
|
}
|
|
}
|
|
|
|
val := reflect.New(tp)
|
|
|
|
err := s.runQuery(source, val.Interface(), query, nil, query.skip,
|
|
func(r *record) error {
|
|
var rowValue reflect.Value
|
|
|
|
// FIXME:
|
|
if elType.Kind() == reflect.Ptr {
|
|
rowValue = r.value
|
|
} else {
|
|
rowValue = r.value.Elem()
|
|
}
|
|
|
|
if keyType != nil {
|
|
rowKey := rowValue
|
|
for rowKey.Kind() == reflect.Ptr {
|
|
rowKey = rowKey.Elem()
|
|
}
|
|
err := s.decode(r.key, rowKey.FieldByName(keyField).Addr().Interface())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
sliceVal = reflect.Append(sliceVal, rowValue)
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
resultVal.Elem().Set(sliceVal.Slice(0, sliceVal.Len()))
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) deleteQuery(source BucketSource, dataType interface{}, query *Query) error {
|
|
if query == nil {
|
|
query = &Query{}
|
|
}
|
|
|
|
var records []*record
|
|
|
|
err := s.runQuery(source, dataType, query, nil, query.skip,
|
|
func(r *record) error {
|
|
records = append(records, r)
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
storer := s.newStorer(dataType)
|
|
|
|
b := source.Bucket([]byte(storer.Type()))
|
|
for i := range records {
|
|
err := b.Delete(records[i].key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// remove any indexes
|
|
err = s.deleteIndexes(storer, source, records[i].key, records[i].value.Interface())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) updateQuery(source BucketSource, dataType interface{}, query *Query, update func(record interface{}) error) error {
|
|
if query == nil {
|
|
query = &Query{}
|
|
}
|
|
|
|
var records []*record
|
|
|
|
err := s.runQuery(source, dataType, query, nil, query.skip,
|
|
func(r *record) error {
|
|
records = append(records, r)
|
|
|
|
return nil
|
|
|
|
})
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
storer := s.newStorer(dataType)
|
|
b := source.Bucket([]byte(storer.Type()))
|
|
|
|
for i := range records {
|
|
upVal := records[i].value.Interface()
|
|
|
|
// delete any existing indexes bad on original value
|
|
err := s.deleteIndexes(storer, source, records[i].key, upVal)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = update(upVal)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
encVal, err := s.encode(upVal)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = b.Put(records[i].key, encVal)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// insert any new indexes
|
|
err = s.addIndexes(storer, source, records[i].key, upVal)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) aggregateQuery(source BucketSource, dataType interface{}, query *Query,
|
|
groupBy ...string) ([]*AggregateResult, error) {
|
|
if query == nil {
|
|
query = &Query{}
|
|
}
|
|
|
|
var result []*AggregateResult
|
|
|
|
if len(groupBy) == 0 {
|
|
result = append(result, &AggregateResult{})
|
|
}
|
|
|
|
err := s.runQuery(source, dataType, query, nil, query.skip,
|
|
func(r *record) error {
|
|
if len(groupBy) == 0 {
|
|
result[0].reduction = append(result[0].reduction, r.value)
|
|
return nil
|
|
}
|
|
|
|
grouping := make([]reflect.Value, len(groupBy))
|
|
|
|
for i := range groupBy {
|
|
fVal := r.value.Elem().FieldByName(groupBy[i])
|
|
if !fVal.IsValid() {
|
|
return fmt.Errorf("The field %s does not exist in the type %s", groupBy[i], r.value.Type())
|
|
}
|
|
|
|
grouping[i] = fVal
|
|
}
|
|
|
|
var err error
|
|
var c int
|
|
var allEqual bool
|
|
|
|
i := sort.Search(len(result), func(i int) bool {
|
|
for j := range grouping {
|
|
c, err = compare(result[i].group[j].Interface(), grouping[j].Interface())
|
|
if err != nil {
|
|
return true
|
|
}
|
|
if c != 0 {
|
|
return c >= 0
|
|
}
|
|
// if group part is equal, compare the next group part
|
|
}
|
|
allEqual = true
|
|
return true
|
|
})
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if i < len(result) {
|
|
if allEqual {
|
|
// group already exists, append results to reduction
|
|
result[i].reduction = append(result[i].reduction, r.value)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// group not found, create another grouping at i
|
|
result = append(result, nil)
|
|
copy(result[i+1:], result[i:])
|
|
result[i] = &AggregateResult{
|
|
group: grouping,
|
|
reduction: []reflect.Value{r.value},
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func (s *Store) countQuery(source BucketSource, dataType interface{}, query *Query) (int, error) {
|
|
if query == nil {
|
|
query = &Query{}
|
|
}
|
|
|
|
count := 0
|
|
|
|
err := s.runQuery(source, dataType, query, nil, query.skip,
|
|
func(r *record) error {
|
|
count++
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
func (s *Store) findOneQuery(source BucketSource, result interface{}, query *Query) error {
|
|
if query == nil {
|
|
query = &Query{}
|
|
}
|
|
|
|
originalLimit := query.limit
|
|
|
|
query.limit = 1
|
|
|
|
resultVal := reflect.ValueOf(result)
|
|
if resultVal.Kind() != reflect.Ptr {
|
|
panic("result argument must be an address")
|
|
}
|
|
|
|
structType := resultVal.Elem().Type()
|
|
|
|
var keyType reflect.Type
|
|
var keyField string
|
|
|
|
for i := 0; i < structType.NumField(); i++ {
|
|
if strings.Contains(string(structType.Field(i).Tag), BoltholdKeyTag) {
|
|
keyType = structType.Field(i).Type
|
|
keyField = structType.Field(i).Name
|
|
break
|
|
}
|
|
}
|
|
|
|
found := false
|
|
|
|
err := s.runQuery(source, result, query, nil, query.skip,
|
|
func(r *record) error {
|
|
found = true
|
|
|
|
if keyType != nil {
|
|
rowKey := r.value
|
|
for rowKey.Kind() == reflect.Ptr {
|
|
rowKey = rowKey.Elem()
|
|
}
|
|
err := s.decode(r.key, rowKey.FieldByName(keyField).Addr().Interface())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
resultVal.Elem().Set(r.value.Elem())
|
|
|
|
return nil
|
|
})
|
|
query.limit = originalLimit
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !found {
|
|
return ErrNotFound
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) forEach(source BucketSource, query *Query, fn interface{}) error {
|
|
if query == nil {
|
|
query = &Query{}
|
|
}
|
|
|
|
fnVal := reflect.ValueOf(fn)
|
|
argType := reflect.TypeOf(fn).In(0)
|
|
|
|
if argType.Kind() == reflect.Ptr {
|
|
argType = argType.Elem()
|
|
}
|
|
|
|
dataType := reflect.New(argType).Interface()
|
|
|
|
var keyType reflect.Type
|
|
var keyField string
|
|
|
|
for i := 0; i < argType.NumField(); i++ {
|
|
if strings.Contains(string(argType.Field(i).Tag), BoltholdKeyTag) {
|
|
keyType = argType.Field(i).Type
|
|
keyField = argType.Field(i).Name
|
|
break
|
|
}
|
|
}
|
|
|
|
return s.runQuery(source, dataType, query, nil, query.skip, func(r *record) error {
|
|
if keyType != nil {
|
|
rowKey := r.value
|
|
for rowKey.Kind() == reflect.Ptr {
|
|
rowKey = rowKey.Elem()
|
|
}
|
|
err := s.decode(r.key, rowKey.FieldByName(keyField).Addr().Interface())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
out := fnVal.Call([]reflect.Value{r.value})
|
|
if len(out) != 1 {
|
|
return fmt.Errorf("foreach function does not return an error")
|
|
}
|
|
|
|
if out[0].IsNil() {
|
|
return nil
|
|
}
|
|
|
|
return out[0].Interface().(error)
|
|
})
|
|
}
|