mirror of
https://github.com/chaisql/chai.git
synced 2025-10-12 19:10:08 +08:00
178 lines
3.5 KiB
Go
178 lines
3.5 KiB
Go
package query
|
|
|
|
import (
|
|
"database/sql/driver"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"github.com/asdine/genji/database"
|
|
"github.com/asdine/genji/record"
|
|
"github.com/asdine/genji/value"
|
|
)
|
|
|
|
// insertStmt is a DSL that allows creating a full Insert query.
|
|
type insertStmt struct {
|
|
tableName string
|
|
fieldNames []string
|
|
values litteralExprList
|
|
records []interface{}
|
|
}
|
|
|
|
// IsReadOnly always returns false. It implements the Statement interface.
|
|
func (stmt insertStmt) IsReadOnly() bool {
|
|
return false
|
|
}
|
|
|
|
type kvPair struct {
|
|
K string
|
|
V expr
|
|
}
|
|
|
|
func (stmt insertStmt) Run(tx *database.Transaction, args []driver.NamedValue) (Result, error) {
|
|
var res Result
|
|
|
|
if stmt.tableName == "" {
|
|
return res, errors.New("missing table name")
|
|
}
|
|
|
|
if stmt.values == nil && stmt.records == nil {
|
|
return res, errors.New("values and records are empty")
|
|
}
|
|
|
|
t, err := tx.GetTable(stmt.tableName)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
|
|
stack := evalStack{
|
|
Tx: tx,
|
|
Params: args,
|
|
}
|
|
|
|
if len(stmt.records) > 0 {
|
|
return stmt.insertRecords(t, stack)
|
|
}
|
|
|
|
return stmt.insertValues(t, stack)
|
|
}
|
|
|
|
type paramExtractor interface {
|
|
Extract(params []driver.NamedValue) (interface{}, error)
|
|
}
|
|
|
|
func (stmt insertStmt) insertRecords(t *database.Table, stack evalStack) (Result, error) {
|
|
var res Result
|
|
var err error
|
|
|
|
if len(stmt.fieldNames) > 0 {
|
|
return res, errors.New("can't provide a field list with RECORDS clause")
|
|
}
|
|
|
|
for _, rec := range stmt.records {
|
|
var r record.Record
|
|
|
|
switch tp := rec.(type) {
|
|
case record.Record:
|
|
r = tp
|
|
case paramExtractor:
|
|
v, err := tp.Extract(stack.Params)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
|
|
var ok bool
|
|
r, ok = v.(record.Record)
|
|
if !ok {
|
|
return res, fmt.Errorf("unsupported parameter of type %t, expecting record.Record", v)
|
|
}
|
|
case []kvPair:
|
|
var fb record.FieldBuffer
|
|
for _, pair := range tp {
|
|
v, err := pair.V.Eval(stack)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
|
|
if v.IsList {
|
|
return res, errors.New("invalid values")
|
|
}
|
|
|
|
fb.Add(record.Field{Name: pair.K, Value: v.Value.Value})
|
|
}
|
|
r = &fb
|
|
}
|
|
|
|
res.lastInsertKey, err = t.Insert(r)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
|
|
res.rowsAffected++
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func (stmt insertStmt) insertValues(t *database.Table, stack evalStack) (Result, error) {
|
|
var res Result
|
|
|
|
// iterate over all of the records (r1, r2, r3, ...)
|
|
for _, e := range stmt.values {
|
|
var fb record.FieldBuffer
|
|
|
|
v, err := e.Eval(stack)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
|
|
// each record must be a list of values
|
|
// (e1, e2, e3, ...)
|
|
if !v.IsList {
|
|
return res, errors.New("invalid values")
|
|
}
|
|
|
|
if len(stmt.fieldNames) != len(v.List) {
|
|
return res, fmt.Errorf("%d values for %d fields", len(v.List), len(stmt.fieldNames))
|
|
}
|
|
|
|
// iterate over each value
|
|
for i, v := range v.List {
|
|
// get the field name
|
|
fieldName := stmt.fieldNames[i]
|
|
|
|
var lv *litteralValue
|
|
|
|
// each value must be either a LitteralValue or a LitteralValueList with exactly
|
|
// one value
|
|
if !v.IsList {
|
|
lv = &v.Value
|
|
} else {
|
|
if len(v.List) == 1 {
|
|
if val := v.List[0]; !val.IsList {
|
|
lv = &val.Value
|
|
}
|
|
}
|
|
return res, fmt.Errorf("value expected, got list")
|
|
}
|
|
|
|
// Assign the value to the field and add it to the record
|
|
fb.Add(record.Field{
|
|
Name: fieldName,
|
|
Value: value.Value{
|
|
Type: lv.Type,
|
|
Data: lv.Data,
|
|
},
|
|
})
|
|
}
|
|
|
|
res.lastInsertKey, err = t.Insert(&fb)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
|
|
res.rowsAffected++
|
|
}
|
|
|
|
return res, nil
|
|
}
|