mirror of
https://github.com/chaisql/chai.git
synced 2025-10-07 08:30:52 +08:00
101 lines
2.2 KiB
Go
101 lines
2.2 KiB
Go
package statement
|
|
|
|
import (
|
|
"github.com/genjidb/genji/internal/database"
|
|
"github.com/genjidb/genji/internal/errors"
|
|
"github.com/genjidb/genji/internal/expr"
|
|
"github.com/genjidb/genji/internal/stream"
|
|
)
|
|
|
|
// InsertStmt holds INSERT configuration.
|
|
type InsertStmt struct {
|
|
basePreparedStatement
|
|
|
|
TableName string
|
|
Values []expr.Expr
|
|
Fields []string
|
|
SelectStmt Preparer
|
|
Returning []expr.Expr
|
|
OnConflict database.OnConflictAction
|
|
}
|
|
|
|
func NewInsertStatement() *InsertStmt {
|
|
var p InsertStmt
|
|
|
|
p.basePreparedStatement = basePreparedStatement{
|
|
Preparer: &p,
|
|
ReadOnly: false,
|
|
}
|
|
|
|
return &p
|
|
}
|
|
|
|
func (stmt *InsertStmt) Prepare(c *Context) (Statement, error) {
|
|
var s *stream.Stream
|
|
|
|
if stmt.Values != nil {
|
|
s = stream.New(stream.Expressions(stmt.Values...))
|
|
} else {
|
|
selectStream, err := stmt.SelectStmt.Prepare(c)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
s = selectStream.(*PreparedStreamStmt).Stream
|
|
|
|
// ensure we are not reading and writing to the same table.
|
|
// TODO(asdine): if same table, write content to a temp table.
|
|
if tableScan, ok := s.First().(*stream.TableScanOperator); ok && tableScan.TableName == stmt.TableName {
|
|
return nil, errors.New("cannot read and write to the same table")
|
|
}
|
|
|
|
if len(stmt.Fields) > 0 {
|
|
s = s.Pipe(stream.IterRename(stmt.Fields...))
|
|
}
|
|
}
|
|
|
|
// validate document
|
|
s = s.Pipe(stream.TableValidate(stmt.TableName))
|
|
|
|
if stmt.OnConflict != 0 {
|
|
switch stmt.OnConflict {
|
|
case database.OnConflictDoNothing:
|
|
s = s.Pipe(stream.HandleConflict(stream.New(stream.NoOp())))
|
|
case database.OnConflictDoReplace:
|
|
s = s.Pipe(stream.HandleConflict(stream.New(stream.TableReplace(stmt.TableName))))
|
|
default:
|
|
panic("unreachable")
|
|
}
|
|
}
|
|
|
|
// check unique constraints
|
|
indexNames := c.Catalog.ListIndexes(stmt.TableName)
|
|
for _, indexName := range indexNames {
|
|
info, err := c.Catalog.GetIndexInfo(indexName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if info.Unique {
|
|
s = s.Pipe(stream.IndexValidate(indexName))
|
|
}
|
|
}
|
|
|
|
s = s.Pipe(stream.TableInsert(stmt.TableName))
|
|
|
|
for _, indexName := range indexNames {
|
|
s = s.Pipe(stream.IndexInsert(indexName))
|
|
}
|
|
|
|
if len(stmt.Returning) > 0 {
|
|
s = s.Pipe(stream.Project(stmt.Returning...))
|
|
}
|
|
|
|
st := StreamStmt{
|
|
Stream: s,
|
|
ReadOnly: false,
|
|
}
|
|
|
|
return st.Prepare(c)
|
|
}
|