Files
chaisql/internal/sql/parser/create.go
2024-02-17 14:27:02 +04:00

666 lines
15 KiB
Go

package parser
import (
"fmt"
"math"
"github.com/chaisql/chai/internal/database"
"github.com/chaisql/chai/internal/expr"
"github.com/chaisql/chai/internal/query/statement"
"github.com/chaisql/chai/internal/sql/scanner"
"github.com/chaisql/chai/internal/tree"
)
// parseCreateStatement parses a create string and returns a Statement AST row.
func (p *Parser) parseCreateStatement() (statement.Statement, error) {
// Parse "CREATE".
if err := p.ParseTokens(scanner.CREATE); err != nil {
return nil, err
}
tok, pos, lit := p.ScanIgnoreWhitespace()
switch tok {
case scanner.TABLE:
return p.parseCreateTableStatement()
case scanner.UNIQUE:
if tok, pos, lit := p.ScanIgnoreWhitespace(); tok != scanner.INDEX {
return nil, newParseError(scanner.Tokstr(tok, lit), []string{"INDEX"}, pos)
}
return p.parseCreateIndexStatement(true)
case scanner.INDEX:
return p.parseCreateIndexStatement(false)
case scanner.SEQUENCE:
return p.parseCreateSequenceStatement()
}
return nil, newParseError(scanner.Tokstr(tok, lit), []string{"TABLE", "INDEX", "SEQUENCE"}, pos)
}
// parseCreateTableStatement parses a create table string and returns a Statement AST row.
// This function assumes the CREATE TABLE tokens have already been consumed.
func (p *Parser) parseCreateTableStatement() (*statement.CreateTableStmt, error) {
var stmt statement.CreateTableStmt
var err error
// Parse IF NOT EXISTS
stmt.IfNotExists, err = p.parseOptional(scanner.IF, scanner.NOT, scanner.EXISTS)
if err != nil {
return nil, err
}
// Parse table name
stmt.Info.TableName, err = p.parseIdent()
if err != nil {
return nil, err
}
// parse field constraints
err = p.parseConstraints(&stmt)
if err != nil {
return nil, err
}
return &stmt, err
}
func (p *Parser) parseConstraints(stmt *statement.CreateTableStmt) error {
// Parse ( token.
tok, pos, lit := p.ScanIgnoreWhitespace()
if tok != scanner.LPAREN {
return newParseError(scanner.Tokstr(tok, lit), []string{"("}, pos)
}
// if set to true, the parser must no longer
// expect column definitions, but only table constraints.
var parsingTableConstraints bool
stmt.Info.ColumnConstraints, _ = database.NewColumnConstraints()
var allTableConstraints []*database.TableConstraint
// Parse constraints.
for {
// check if it is a table constraint,
// as it's easier to determine
tc, err := p.parseTableConstraint(stmt)
if err != nil {
return err
}
// no table constraint found
if tc == nil && parsingTableConstraints {
tok, pos, lit := p.ScanIgnoreWhitespace()
return newParseError(scanner.Tokstr(tok, lit), []string{"CONSTRAINT", ")"}, pos)
}
if tc != nil {
parsingTableConstraints = true
allTableConstraints = append(allTableConstraints, tc)
}
// if set to false, we are still parsing column definitions
if !parsingTableConstraints {
cc, tcs, err := p.parseColumnDefinition()
if err != nil {
return err
}
err = stmt.Info.AddColumnConstraint(cc)
if err != nil {
return err
}
allTableConstraints = append(allTableConstraints, tcs...)
}
if tok, _, _ := p.ScanIgnoreWhitespace(); tok != scanner.COMMA {
p.Unscan()
break
}
}
// Parse required ) token.
if err := p.ParseTokens(scanner.RPAREN); err != nil {
return err
}
// add all table constraints to the table info
for _, tc := range allTableConstraints {
err := stmt.Info.AddTableConstraint(tc)
if err != nil {
return err
}
}
return nil
}
func (p *Parser) parseColumnDefinition() (*database.ColumnConstraint, []*database.TableConstraint, error) {
var err error
var cc database.ColumnConstraint
cc.Column, err = p.parseIdent()
if err != nil {
return nil, nil, err
}
cc.Type, err = p.parseType()
if err != nil {
return nil, nil, err
}
var tcs []*database.TableConstraint
LOOP:
for {
tok, pos, lit := p.ScanIgnoreWhitespace()
switch tok {
case scanner.PRIMARY:
// Parse "KEY"
if err := p.ParseTokens(scanner.KEY); err != nil {
return nil, nil, err
}
tc := database.TableConstraint{
PrimaryKey: true,
Columns: []string{cc.Column},
}
// if ASC is set, we ignore it, otherwise we check for DESC
ok, err := p.parseOptional(scanner.ASC)
if err != nil {
return nil, nil, err
}
if !ok {
ok, err = p.parseOptional(scanner.DESC)
if err != nil {
return nil, nil, err
}
if ok {
tc.SortOrder = tree.SortOrder(0).SetDesc(0)
}
}
tcs = append(tcs, &tc)
case scanner.NOT:
// Parse "NULL"
if err := p.ParseTokens(scanner.NULL); err != nil {
return nil, nil, err
}
// if it's already not null we return an error
if cc.IsNotNull {
return nil, nil, newParseError(scanner.Tokstr(tok, lit), []string{"CONSTRAINT", ")"}, pos)
}
cc.IsNotNull = true
case scanner.DEFAULT:
// if it has already a default value we return an error
if cc.DefaultValue != nil {
return nil, nil, newParseError(scanner.Tokstr(tok, lit), []string{"CONSTRAINT", ")"}, pos)
}
withParentheses, err := p.parseOptional(scanner.LPAREN)
if err != nil {
return nil, nil, err
}
// Parse default value expression.
// Only a few tokens are allowed.
e, err := p.parseExprWithMinPrecedence(scanner.EQ.Precedence(),
scanner.EQ,
scanner.NEQ,
scanner.BITWISEOR,
scanner.BITWISEXOR,
scanner.BITWISEAND,
scanner.LT,
scanner.LTE,
scanner.GT,
scanner.GTE,
scanner.ADD,
scanner.SUB,
scanner.MUL,
scanner.DIV,
scanner.MOD,
scanner.CONCAT,
scanner.INTEGER,
scanner.NUMBER,
scanner.STRING,
scanner.TRUE,
scanner.FALSE,
scanner.NULL,
scanner.LPAREN, // only opening parenthesis are necessary
scanner.LBRACKET, // only opening brackets are necessary
scanner.NEXT,
)
if err != nil {
return nil, nil, err
}
cc.DefaultValue = expr.Constraint(e)
if withParentheses {
_, err = p.parseOptional(scanner.RPAREN)
if err != nil {
return nil, nil, err
}
}
case scanner.UNIQUE:
tcs = append(tcs, &database.TableConstraint{
Unique: true,
Columns: []string{cc.Column},
})
case scanner.CHECK:
e, cols, err := p.parseCheckConstraint()
if err != nil {
return nil, nil, err
}
tcs = append(tcs, &database.TableConstraint{
Check: expr.Constraint(e),
Columns: cols,
})
default:
p.Unscan()
break LOOP
}
}
return &cc, tcs, nil
}
func (p *Parser) parseTableConstraint(stmt *statement.CreateTableStmt) (*database.TableConstraint, error) {
var err error
var tc database.TableConstraint
var requiresTc bool
var order tree.SortOrder
if ok, _ := p.parseOptional(scanner.CONSTRAINT); ok {
tok, pos, lit := p.ScanIgnoreWhitespace()
switch tok {
case scanner.IDENT, scanner.STRING:
tc.Name = lit
default:
return nil, newParseError(scanner.Tokstr(tok, lit), []string{"IDENT", "STRING"}, pos)
}
requiresTc = true
}
tok, pos, lit := p.ScanIgnoreWhitespace()
switch tok {
case scanner.PRIMARY:
// Parse "KEY ("
err = p.ParseTokens(scanner.KEY)
if err != nil {
return nil, err
}
tc.PrimaryKey = true
tc.Columns, order, err = p.parseColumnList()
if err != nil {
return nil, err
}
if len(tc.Columns) == 0 {
tok, pos, lit := p.ScanIgnoreWhitespace()
return nil, newParseError(scanner.Tokstr(tok, lit), []string{"PATHS"}, pos)
}
tc.SortOrder = order
case scanner.UNIQUE:
tc.Unique = true
tc.Columns, order, err = p.parseColumnList()
if err != nil {
return nil, err
}
if len(tc.Columns) == 0 {
tok, pos, lit := p.ScanIgnoreWhitespace()
return nil, newParseError(scanner.Tokstr(tok, lit), []string{"PATHS"}, pos)
}
tc.SortOrder = order
case scanner.CHECK:
e, columns, err := p.parseCheckConstraint()
if err != nil {
return nil, err
}
tc.Check = expr.Constraint(e)
tc.Columns = columns
default:
if requiresTc {
return nil, newParseError(scanner.Tokstr(tok, lit), []string{"PRIMARY", "UNIQUE", "CHECK"}, pos)
}
p.Unscan()
return nil, nil
}
return &tc, nil
}
// parseCreateIndexStatement parses a create index string and returns a Statement AST row.
// This function assumes the CREATE INDEX or CREATE UNIQUE INDEX tokens have already been consumed.
func (p *Parser) parseCreateIndexStatement(unique bool) (*statement.CreateIndexStmt, error) {
var err error
var stmt statement.CreateIndexStmt
stmt.Info.Unique = unique
// Parse IF NOT EXISTS
stmt.IfNotExists, err = p.parseOptional(scanner.IF, scanner.NOT, scanner.EXISTS)
if err != nil {
return nil, err
}
// Parse optional index name
stmt.Info.IndexName, err = p.parseIdent()
if err != nil {
// if IF NOT EXISTS is set, index name is mandatory
if stmt.IfNotExists {
return nil, err
}
p.Unscan()
}
// Parse "ON"
if err := p.ParseTokens(scanner.ON); err != nil {
return nil, err
}
// Parse table name
stmt.Info.Owner.TableName, err = p.parseIdent()
if err != nil {
return nil, err
}
columns, order, err := p.parseColumnList()
if err != nil {
return nil, err
}
if len(columns) == 0 {
tok, pos, lit := p.ScanIgnoreWhitespace()
return nil, newParseError(scanner.Tokstr(tok, lit), []string{"("}, pos)
}
stmt.Info.Columns = columns
stmt.Info.KeySortOrder = order
return &stmt, nil
}
// This function assumes the CREATE SEQUENCE tokens have already been consumed.
func (p *Parser) parseCreateSequenceStatement() (*statement.CreateSequenceStmt, error) {
var stmt statement.CreateSequenceStmt
var err error
// Parse IF NOT EXISTS
stmt.IfNotExists, err = p.parseOptional(scanner.IF, scanner.NOT, scanner.EXISTS)
if err != nil {
return nil, err
}
// Parse sequence name
stmt.Info.Name, err = p.parseIdent()
if err != nil {
return nil, err
}
var hasAsInt, hasNoMin, hasNoMax, hasNoCycle bool
var min, max, incrementBy, start, cache *int64
for {
// Parse AS [any int type]
// Only integers are supported
if ok, _ := p.parseOptional(scanner.AS); ok {
tok, pos, lit := p.ScanIgnoreWhitespace()
switch tok {
case scanner.TYPEINTEGER, scanner.TYPEINT, scanner.TYPEINT2, scanner.TYPEINT8, scanner.TYPETINYINT,
scanner.TYPEBIGINT, scanner.TYPEMEDIUMINT, scanner.TYPESMALLINT:
default:
return nil, newParseError(scanner.Tokstr(tok, lit), []string{"INT"}, pos)
}
if hasAsInt {
return nil, &ParseError{Message: "conflicting or redundant options"}
}
hasAsInt = true
continue
}
// Parse INCREMENT [BY] integer
if ok, _ := p.parseOptional(scanner.INCREMENT); ok {
// parse optional BY token
_, _ = p.parseOptional(scanner.BY)
if incrementBy != nil {
return nil, &ParseError{Message: "conflicting or redundant options"}
}
i, err := p.parseInteger()
if err != nil {
return nil, err
}
if i == 0 {
return nil, &ParseError{Message: "INCREMENT must not be zero"}
}
incrementBy = &i
continue
}
// Parse NO [MINVALUE | MAXVALUE | CYCLE]
if ok, _ := p.parseOptional(scanner.NO); ok {
tok, pos, lit := p.ScanIgnoreWhitespace()
if tok == scanner.MINVALUE {
if hasNoMin {
return nil, &ParseError{Message: "conflicting or redundant options"}
}
hasNoMin = true
continue
}
if tok == scanner.MAXVALUE {
if hasNoMax {
return nil, &ParseError{Message: "conflicting or redundant options"}
}
hasNoMax = true
continue
}
if tok == scanner.CYCLE {
if hasNoCycle {
return nil, &ParseError{Message: "conflicting or redundant options"}
}
hasNoCycle = true
continue
}
return nil, newParseError(scanner.Tokstr(tok, lit), []string{"MINVALUE", "MAXVALUE", "CYCLE"}, pos)
}
// Parse MINVALUE integer
if ok, _ := p.parseOptional(scanner.MINVALUE); ok {
if hasNoMin || min != nil {
return nil, &ParseError{Message: "conflicting or redundant options"}
}
i, err := p.parseInteger()
if err != nil {
return nil, err
}
min = &i
continue
}
// Parse MAXVALUE integer
if ok, _ := p.parseOptional(scanner.MAXVALUE); ok {
if hasNoMax || max != nil {
return nil, &ParseError{Message: "conflicting or redundant options"}
}
i, err := p.parseInteger()
if err != nil {
return nil, err
}
max = &i
continue
}
// Parse START [WITH] integer
if ok, _ := p.parseOptional(scanner.START); ok {
// parse optional WITH token
_, _ = p.parseOptional(scanner.WITH)
if start != nil {
return nil, &ParseError{Message: "conflicting or redundant options"}
}
i, err := p.parseInteger()
if err != nil {
return nil, err
}
start = &i
continue
}
// Parse CACHE integer
if ok, _ := p.parseOptional(scanner.CACHE); ok {
if cache != nil {
return nil, &ParseError{Message: "conflicting or redundant options"}
}
v, err := p.parseInteger()
if err != nil {
return nil, err
}
if v < 0 {
return nil, &ParseError{Message: "cache value must be positive"}
}
cache = &v
continue
}
// Parse CYCLE
if ok, _ := p.parseOptional(scanner.CYCLE); ok {
if hasNoCycle || stmt.Info.Cycle {
return nil, &ParseError{Message: "conflicting or redundant options"}
}
stmt.Info.Cycle = true
continue
}
break
}
// default value for increment is 1
if incrementBy != nil {
stmt.Info.IncrementBy = *incrementBy
} else {
stmt.Info.IncrementBy = 1
}
// determine if the sequence is ascending or descending
asc := stmt.Info.IncrementBy > 0
// default value for min is 1 if ascending
// or the minimum value of ints if descending
if min != nil {
stmt.Info.Min = *min
} else if asc {
stmt.Info.Min = 1
} else {
stmt.Info.Min = math.MinInt64
}
// default value for max is the maximum value of ints if ascending
// or the -1 if descending
if max != nil {
stmt.Info.Max = *max
} else if asc {
stmt.Info.Max = math.MaxInt64
} else {
stmt.Info.Max = -1
}
// check if min > max
if stmt.Info.Min > stmt.Info.Max {
return nil, &ParseError{Message: fmt.Sprintf("MINVALUE (%d) must be less than MAXVALUE (%d)", stmt.Info.Min, stmt.Info.Max)}
}
// default value for start is min if ascending
// or max if descending
if start != nil {
stmt.Info.Start = *start
} else if asc {
stmt.Info.Start = stmt.Info.Min
} else {
stmt.Info.Start = stmt.Info.Max
}
// check if min < start < max
if stmt.Info.Start < stmt.Info.Min {
return nil, &ParseError{Message: fmt.Sprintf("START value (%d) cannot be less than MINVALUE (%d)", stmt.Info.Start, stmt.Info.Min)}
}
if stmt.Info.Start > stmt.Info.Max {
return nil, &ParseError{Message: fmt.Sprintf("START value (%d) cannot be greater than MAXVALUE (%d)", stmt.Info.Start, stmt.Info.Max)}
}
// default for cache is 1
if cache != nil {
stmt.Info.Cache = uint64(*cache)
} else {
stmt.Info.Cache = 1
}
return &stmt, err
}
// parseCheckConstraint parses a check constraint.
// it assumes the CHECK token has already been parsed.
func (p *Parser) parseCheckConstraint() (expr.Expr, []string, error) {
// Parse "("
err := p.ParseTokens(scanner.LPAREN)
if err != nil {
return nil, nil, err
}
e, err := p.ParseExpr()
if err != nil {
return nil, nil, err
}
var columns []string
// extract all the paths from the expression
expr.Walk(e, func(e expr.Expr) bool {
switch t := e.(type) {
case expr.Column:
scol := string(t)
// ensure that the path is not already in the list
found := false
for _, c := range columns {
if c == scol {
found = true
break
}
}
if !found {
columns = append(columns, scol)
}
}
return true
})
// Parse ")"
err = p.ParseTokens(scanner.RPAREN)
if err != nil {
return nil, nil, err
}
return e, columns, nil
}