mirror of
https://github.com/chaisql/chai.git
synced 2025-10-05 07:36:56 +08:00
169 lines
3.4 KiB
Go
169 lines
3.4 KiB
Go
package query
|
|
|
|
import (
|
|
"database/sql/driver"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"github.com/asdine/genji/database"
|
|
"github.com/asdine/genji/document"
|
|
)
|
|
|
|
// InsertStmt is a DSL that allows creating a full Insert query.
|
|
type InsertStmt struct {
|
|
TableName string
|
|
FieldNames []string
|
|
Values LiteralExprList
|
|
}
|
|
|
|
// IsReadOnly always returns false. It implements the Statement interface.
|
|
func (stmt InsertStmt) IsReadOnly() bool {
|
|
return false
|
|
}
|
|
|
|
func (stmt InsertStmt) Run(tx *database.Transaction, args []driver.NamedValue) (Result, error) {
|
|
var res Result
|
|
|
|
if stmt.TableName == "" {
|
|
return res, errors.New("missing table name")
|
|
}
|
|
|
|
if stmt.Values == nil {
|
|
return res, errors.New("values are empty")
|
|
}
|
|
|
|
t, err := tx.GetTable(stmt.TableName)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
|
|
stack := EvalStack{
|
|
Tx: tx,
|
|
Params: args,
|
|
}
|
|
|
|
if len(stmt.FieldNames) > 0 {
|
|
return stmt.insertExprList(t, stack)
|
|
}
|
|
|
|
return stmt.insertDocuments(t, stack)
|
|
}
|
|
|
|
type paramExtractor interface {
|
|
Extract(params []driver.NamedValue) (interface{}, error)
|
|
}
|
|
|
|
func (stmt InsertStmt) insertDocuments(t *database.Table, stack EvalStack) (Result, error) {
|
|
var res Result
|
|
var err error
|
|
|
|
for _, rec := range stmt.Values {
|
|
var d document.Document
|
|
|
|
switch tp := rec.(type) {
|
|
case document.Document:
|
|
d = tp
|
|
case paramExtractor:
|
|
v, err := tp.Extract(stack.Params)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
|
|
var ok bool
|
|
d, ok = v.(document.Document)
|
|
if !ok {
|
|
return res, fmt.Errorf("unsupported parameter of type %t, expecting document.Document", v)
|
|
}
|
|
case LiteralValue:
|
|
if tp.Value.Type != document.DocumentValue {
|
|
return res, fmt.Errorf("values must be a list of documents if field list is empty")
|
|
}
|
|
|
|
d, err = tp.Value.DecodeToDocument()
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
case KVPairs:
|
|
v, err := tp.Eval(stack)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
d, err = v.Value.Value.DecodeToDocument()
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
default:
|
|
return res, fmt.Errorf("values must be a list of documents if field list is empty")
|
|
}
|
|
|
|
res.lastInsertKey, err = t.Insert(d)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
|
|
res.rowsAffected++
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func (stmt InsertStmt) insertExprList(t *database.Table, stack EvalStack) (Result, error) {
|
|
var res Result
|
|
|
|
// iterate over all of the documents (r1, r2, r3, ...)
|
|
for _, e := range stmt.Values {
|
|
var fb document.FieldBuffer
|
|
|
|
v, err := e.Eval(stack)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
|
|
// each record must be a list of values
|
|
// (e1, e2, e3, ...)
|
|
if !v.IsList {
|
|
return res, errors.New("invalid values")
|
|
}
|
|
|
|
if len(stmt.FieldNames) != len(v.List) {
|
|
return res, fmt.Errorf("%d values for %d fields", len(v.List), len(stmt.FieldNames))
|
|
}
|
|
|
|
// iterate over each value
|
|
for i, v := range v.List {
|
|
// get the field name
|
|
fieldName := stmt.FieldNames[i]
|
|
|
|
var lv *LiteralValue
|
|
|
|
// each value must be either a LitteralValue or a LitteralValueList with exactly
|
|
// one value
|
|
if !v.IsList {
|
|
lv = &v.Value
|
|
} else {
|
|
if len(v.List) == 1 {
|
|
if val := v.List[0]; !val.IsList {
|
|
lv = &val.Value
|
|
}
|
|
}
|
|
return res, fmt.Errorf("value expected, got list")
|
|
}
|
|
|
|
// Assign the value to the field and add it to the record
|
|
fb.Add(fieldName, document.Value{
|
|
Type: lv.Type,
|
|
Data: lv.Data,
|
|
})
|
|
}
|
|
|
|
res.lastInsertKey, err = t.Insert(&fb)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
|
|
res.rowsAffected++
|
|
}
|
|
|
|
return res, nil
|
|
}
|