Files
chaisql/query/insert.go
2019-12-15 14:31:04 +01:00

169 lines
3.4 KiB
Go

package query
import (
"database/sql/driver"
"errors"
"fmt"
"github.com/asdine/genji/database"
"github.com/asdine/genji/document"
)
// InsertStmt is a DSL that allows creating a full Insert query.
type InsertStmt struct {
TableName string
FieldNames []string
Values LiteralExprList
}
// IsReadOnly always returns false. It implements the Statement interface.
func (stmt InsertStmt) IsReadOnly() bool {
return false
}
func (stmt InsertStmt) Run(tx *database.Transaction, args []driver.NamedValue) (Result, error) {
var res Result
if stmt.TableName == "" {
return res, errors.New("missing table name")
}
if stmt.Values == nil {
return res, errors.New("values are empty")
}
t, err := tx.GetTable(stmt.TableName)
if err != nil {
return res, err
}
stack := EvalStack{
Tx: tx,
Params: args,
}
if len(stmt.FieldNames) > 0 {
return stmt.insertExprList(t, stack)
}
return stmt.insertDocuments(t, stack)
}
type paramExtractor interface {
Extract(params []driver.NamedValue) (interface{}, error)
}
func (stmt InsertStmt) insertDocuments(t *database.Table, stack EvalStack) (Result, error) {
var res Result
var err error
for _, rec := range stmt.Values {
var d document.Document
switch tp := rec.(type) {
case document.Document:
d = tp
case paramExtractor:
v, err := tp.Extract(stack.Params)
if err != nil {
return res, err
}
var ok bool
d, ok = v.(document.Document)
if !ok {
return res, fmt.Errorf("unsupported parameter of type %t, expecting document.Document", v)
}
case LiteralValue:
if tp.Value.Type != document.DocumentValue {
return res, fmt.Errorf("values must be a list of documents if field list is empty")
}
d, err = tp.Value.DecodeToDocument()
if err != nil {
return res, err
}
case KVPairs:
v, err := tp.Eval(stack)
if err != nil {
return res, err
}
d, err = v.Value.Value.DecodeToDocument()
if err != nil {
return res, err
}
default:
return res, fmt.Errorf("values must be a list of documents if field list is empty")
}
res.lastInsertKey, err = t.Insert(d)
if err != nil {
return res, err
}
res.rowsAffected++
}
return res, nil
}
func (stmt InsertStmt) insertExprList(t *database.Table, stack EvalStack) (Result, error) {
var res Result
// iterate over all of the documents (r1, r2, r3, ...)
for _, e := range stmt.Values {
var fb document.FieldBuffer
v, err := e.Eval(stack)
if err != nil {
return res, err
}
// each record must be a list of values
// (e1, e2, e3, ...)
if !v.IsList {
return res, errors.New("invalid values")
}
if len(stmt.FieldNames) != len(v.List) {
return res, fmt.Errorf("%d values for %d fields", len(v.List), len(stmt.FieldNames))
}
// iterate over each value
for i, v := range v.List {
// get the field name
fieldName := stmt.FieldNames[i]
var lv *LiteralValue
// each value must be either a LitteralValue or a LitteralValueList with exactly
// one value
if !v.IsList {
lv = &v.Value
} else {
if len(v.List) == 1 {
if val := v.List[0]; !val.IsList {
lv = &val.Value
}
}
return res, fmt.Errorf("value expected, got list")
}
// Assign the value to the field and add it to the record
fb.Add(fieldName, document.Value{
Type: lv.Type,
Data: lv.Data,
})
}
res.lastInsertKey, err = t.Insert(&fb)
if err != nil {
return res, err
}
res.rowsAffected++
}
return res, nil
}