Files
chaisql/internal/sql/parser/expr.go
2025-09-14 11:55:27 +05:30

552 lines
14 KiB
Go

package parser
import (
"encoding/hex"
"fmt"
"math"
"strconv"
"strings"
"github.com/chaisql/chai/internal/environment"
"github.com/chaisql/chai/internal/expr"
"github.com/chaisql/chai/internal/expr/functions"
"github.com/chaisql/chai/internal/sql/scanner"
"github.com/chaisql/chai/internal/types"
"github.com/cockroachdb/errors"
)
type dummyOperator struct {
rightHand expr.Expr
}
func (d *dummyOperator) Token() scanner.Token { panic("not implemented") }
func (d *dummyOperator) Equal(expr.Expr) bool { panic("not implemented") }
func (d *dummyOperator) Eval(*environment.Environment) (types.Value, error) {
panic("not implemented")
}
func (d *dummyOperator) String() string { panic("not implemented") }
func (d *dummyOperator) Precedence() int { panic("not implemented") }
func (d *dummyOperator) LeftHand() expr.Expr { panic("not implemented") }
func (d *dummyOperator) RightHand() expr.Expr { return d.rightHand }
func (d *dummyOperator) SetLeftHandExpr(e expr.Expr) { panic("not implemented") }
func (d *dummyOperator) SetRightHandExpr(e expr.Expr) { d.rightHand = e }
// ParseExpr parses an expression.
func (p *Parser) ParseExpr() (e expr.Expr, err error) {
return p.parseExprWithMinPrecedence(0)
}
func (p *Parser) parseExprWithMinPrecedence(precedence int, allowed ...scanner.Token) (e expr.Expr, err error) {
// Dummy root node.
var root expr.Operator = new(dummyOperator)
// Parse a non-binary expression type to start.
// This variable will always be the root of the expression tree.
e, err = p.parseUnaryExpr(allowed...)
if err != nil {
return nil, err
}
root.SetRightHandExpr(e)
// Loop over operations and unary exprs and build a tree based on precedence.
for {
// If the next token is NOT an operator then return the expression.
op, tok, err := p.parseOperator(precedence, allowed...)
if err != nil {
return nil, err
}
if tok == 0 {
return root.RightHand(), nil
}
var rhs expr.Expr
if rhs, err = p.parseUnaryExpr(allowed...); err != nil {
return nil, err
}
// Find the right spot in the tree to add the new expression by
// descending the RHS of the expression tree until we reach the last
// BinaryExpr or a BinaryExpr whose RHS has an operator with
// precedence >= the operator being added.
for node := root; ; {
p, ok := node.RightHand().(expr.Operator)
if !ok || p.Precedence() >= tok.Precedence() {
// Add the new expression here and break.
node.SetRightHandExpr(op(node.RightHand(), rhs))
break
}
node = p
}
}
}
func (p *Parser) parseOperator(minPrecedence int, allowed ...scanner.Token) (func(lhs, rhs expr.Expr) expr.Expr, scanner.Token, error) {
op, _, _ := p.ScanIgnoreWhitespace()
if !op.IsOperator() && op != scanner.NOT {
p.Unscan()
return nil, 0, nil
}
if !tokenIsAllowed(op, allowed...) {
p.Unscan()
return nil, 0, nil
}
// Ignore currently unused operators.
if op == scanner.EQREGEX || op == scanner.NEQREGEX {
p.Unscan()
return nil, 0, nil
}
if op == scanner.NOT {
tok, pos, lit := p.ScanIgnoreWhitespace()
if tok.Precedence() >= minPrecedence {
switch {
case tok == scanner.IN && tok.Precedence() >= minPrecedence:
return expr.NotIn, scanner.NIN, nil
case tok == scanner.LIKE && tok.Precedence() >= minPrecedence:
return expr.NotLike, scanner.NLIKE, nil
}
}
return nil, 0, newParseError(scanner.Tokstr(tok, lit), []string{"IN, LIKE"}, pos)
}
if op.Precedence() < minPrecedence {
p.Unscan()
return nil, 0, nil
}
switch op {
case scanner.EQ:
return expr.Eq, op, nil
case scanner.NEQ:
return expr.Neq, op, nil
case scanner.GT:
return expr.Gt, op, nil
case scanner.GTE:
return expr.Gte, op, nil
case scanner.LT:
return expr.Lt, op, nil
case scanner.LTE:
return expr.Lte, op, nil
case scanner.AND:
return expr.And, op, nil
case scanner.OR:
return expr.Or, op, nil
case scanner.ADD:
return expr.Add, op, nil
case scanner.SUB:
return expr.Sub, op, nil
case scanner.MUL:
return expr.Mul, op, nil
case scanner.DIV:
return expr.Div, op, nil
case scanner.MOD:
return expr.Mod, op, nil
case scanner.BITWISEAND:
return expr.BitwiseAnd, op, nil
case scanner.BITWISEOR:
return expr.BitwiseOr, op, nil
case scanner.BITWISEXOR:
return expr.BitwiseXor, op, nil
case scanner.IN:
return expr.In, op, nil
case scanner.IS:
if tok, _, _ := p.ScanIgnoreWhitespace(); tok == scanner.NOT {
return expr.IsNot, scanner.ISN, nil
}
p.Unscan()
return expr.Is, op, nil
case scanner.LIKE:
return expr.Like, op, nil
case scanner.CONCAT:
return expr.Concat, op, nil
case scanner.BETWEEN:
a, err := p.parseExprWithMinPrecedence(op.Precedence())
if err != nil {
return nil, op, err
}
err = p.ParseTokens(scanner.AND)
if err != nil {
return nil, op, err
}
return expr.Between(a), op, nil
}
p.Unscan()
return nil, 0, nil
}
// parseUnaryExpr parses an non-binary expression.
func (p *Parser) parseUnaryExpr(allowed ...scanner.Token) (expr.Expr, error) {
tok, pos, lit := p.ScanIgnoreWhitespace()
if !tokenIsAllowed(tok, allowed...) {
p.Unscan()
return nil, nil
}
switch tok {
case scanner.CAST:
p.Unscan()
return p.parseCastExpression()
case scanner.IDENT:
tok1, _, _ := p.ScanIgnoreWhitespace()
// if the next token is a left parenthesis, this is a function
if tok1 == scanner.LPAREN {
p.Unscan()
if tk, _, _ := p.s.Curr(); tk == scanner.WS {
p.Unscan()
}
p.Unscan()
return p.parseFunction()
}
p.Unscan()
if tk, _, _ := p.s.Curr(); tk == scanner.WS {
p.Unscan()
}
p.Unscan()
return p.parseColumn()
case scanner.POSITIONALPARAM:
pp, err := strconv.Atoi(lit[1:])
if err != nil {
return nil, errors.WithStack(&ParseError{Message: "invalid positional parameter syntax", Pos: pos})
}
return expr.PositionalParam(pp), nil
case scanner.STRING:
if strings.HasPrefix(lit, `\x`) {
bytea, err := hex.DecodeString(lit[2:])
if err != nil {
if bt, ok := err.(hex.InvalidByteError); ok {
return nil, fmt.Errorf("invalid hexadecimal digit: %c", bt)
}
return nil, err
}
return expr.LiteralValue{Value: types.NewByteaValue(bytea)}, nil
}
return expr.LiteralValue{Value: types.NewTextValue(lit)}, nil
case scanner.NUMBER:
v, err := strconv.ParseFloat(lit, 64)
if err != nil {
return nil, errors.WithStack(&ParseError{Message: "unable to parse number", Pos: pos})
}
return expr.LiteralValue{Value: types.NewDoubleValue(v)}, nil
case scanner.ADD, scanner.SUB:
sign := tok
tok, pos, lit = p.Scan()
if tok != scanner.NUMBER && tok != scanner.INTEGER {
return nil, errors.WithStack(&ParseError{Message: "syntax error", Pos: pos})
}
if sign == scanner.SUB {
lit = "-" + lit
}
fallthrough
case scanner.INTEGER:
v, err := strconv.ParseInt(lit, 10, 64)
if err != nil {
// The literal may be too large to fit into an int64, parse as Float64
if v, err := strconv.ParseFloat(lit, 64); err == nil {
return expr.LiteralValue{Value: types.NewDoubleValue(v)}, nil
}
return nil, errors.WithStack(&ParseError{Message: "unable to parse integer", Pos: pos})
}
if v > math.MaxInt32 || v < math.MinInt32 {
return expr.LiteralValue{Value: types.NewBigintValue(v)}, nil
}
return expr.LiteralValue{Value: types.NewIntegerValue(int32(v))}, nil
case scanner.TRUE, scanner.FALSE:
return expr.LiteralValue{Value: types.NewBooleanValue(tok == scanner.TRUE)}, nil
case scanner.NULL:
return expr.LiteralValue{Value: types.NewNullValue()}, nil
case scanner.MUL:
return expr.Wildcard{}, nil
case scanner.LPAREN:
e, err := p.ParseExpr()
if err != nil {
return nil, err
}
tok, pos, lit := p.ScanIgnoreWhitespace()
switch tok {
case scanner.RPAREN:
return expr.Parentheses{E: e}, nil
case scanner.COMMA:
exprList, err := p.parseExprListUntil(scanner.RPAREN)
if err != nil {
return nil, err
}
// prepend first parsed expression
exprList = append([]expr.Expr{e}, exprList...)
return exprList, nil
}
return nil, newParseError(scanner.Tokstr(tok, lit), []string{")", ","}, pos)
case scanner.NOT:
e, err := p.ParseExpr()
if err != nil {
return nil, err
}
return expr.Not(e), nil
default:
return nil, newParseError(scanner.Tokstr(tok, lit), nil, pos)
}
}
// parseInteger parses an integer.
func (p *Parser) parseInteger() (int64, error) {
tok, pos, lit := p.ScanIgnoreWhitespace()
if tok == scanner.ADD || tok == scanner.SUB {
sign := tok
tok, pos, lit = p.Scan()
if sign == scanner.SUB {
lit = "-" + lit
}
}
if tok != scanner.INTEGER {
return 0, newParseError(scanner.Tokstr(tok, lit), []string{"integer"}, pos)
}
v, err := strconv.ParseInt(lit, 10, 64)
if err != nil {
return 0, newParseError(scanner.Tokstr(tok, lit), []string{"INT"}, pos)
}
return v, nil
}
// parseIdent parses an identifier.
func (p *Parser) parseIdent() (string, error) {
tok, pos, lit := p.ScanIgnoreWhitespace()
if tok != scanner.IDENT {
return "", newParseError(scanner.Tokstr(tok, lit), []string{"identifier"}, pos)
}
return lit, nil
}
// parseIdentList parses a comma delimited list of identifiers.
func (p *Parser) parseIdentList() ([]string, error) {
// Parse first (required) identifier.
ident, err := p.parseIdent()
if err != nil {
return nil, err
}
idents := []string{ident}
// Parse remaining (optional) identifiers.
for {
if tok, _, _ := p.ScanIgnoreWhitespace(); tok != scanner.COMMA {
p.Unscan()
return idents, nil
}
if ident, err = p.parseIdent(); err != nil {
return nil, err
}
idents = append(idents, ident)
}
}
func (p *Parser) parseType() (types.Type, error) {
tok, pos, lit := p.ScanIgnoreWhitespace()
switch tok {
case scanner.TYPEBYTEA, scanner.TYPEBYTES:
return types.TypeBytea, nil
case scanner.TYPEBOOL, scanner.TYPEBOOLEAN:
return types.TypeBoolean, nil
case scanner.TYPEREAL:
return types.TypeDouble, nil
case scanner.TYPEDOUBLE:
tok, _, _ := p.ScanIgnoreWhitespace()
if tok == scanner.PRECISION {
return types.TypeDouble, nil
}
p.Unscan()
return types.TypeDouble, nil
case scanner.TYPEINTEGER, scanner.TYPEINT, scanner.TYPEINT2, scanner.TYPETINYINT,
scanner.TYPEMEDIUMINT, scanner.TYPESMALLINT:
return types.TypeInteger, nil
case scanner.TYPEINT8, scanner.TYPEBIGINT:
return types.TypeBigint, nil
case scanner.TYPETEXT:
return types.TypeText, nil
case scanner.TYPETIMESTAMP:
return types.TypeTimestamp, nil
case scanner.TYPEVARCHAR, scanner.TYPECHARACTER:
if tok, pos, lit := p.ScanIgnoreWhitespace(); tok != scanner.LPAREN {
return 0, newParseError(scanner.Tokstr(tok, lit), []string{"("}, pos)
}
// The value between parentheses is not used.
if tok, pos, lit := p.ScanIgnoreWhitespace(); tok != scanner.INTEGER {
return 0, newParseError(scanner.Tokstr(tok, lit), []string{"integer"}, pos)
}
if tok, pos, lit := p.ScanIgnoreWhitespace(); tok != scanner.RPAREN {
return 0, newParseError(scanner.Tokstr(tok, lit), []string{")"}, pos)
}
return types.TypeText, nil
}
return 0, newParseError(scanner.Tokstr(tok, lit), []string{"type"}, pos)
}
// parsePath parses a path to a specific value.
func (p *Parser) parseColumn() (*expr.Column, error) {
// parse first mandatory ident
col, err := p.parseIdent()
if err != nil {
return nil, err
}
return &expr.Column{Name: col}, nil
}
func (p *Parser) parseExprListUntil(rightToken scanner.Token) (expr.LiteralExprList, error) {
var exprList expr.LiteralExprList
var expr expr.Expr
var err error
// Parse expressions.
for {
if expr, err = p.ParseExpr(); err != nil {
p.Unscan()
break
}
exprList = append(exprList, expr)
if tok, _, _ := p.ScanIgnoreWhitespace(); tok != scanner.COMMA {
p.Unscan()
break
}
}
// Parse required ) or ] token.
if err := p.ParseTokens(rightToken); err != nil {
return nil, err
}
return exprList, nil
}
func (p *Parser) parseExprList(leftToken, rightToken scanner.Token) (expr.LiteralExprList, error) {
// Parse ( or [ token.
if err := p.ParseTokens(leftToken); err != nil {
return nil, err
}
return p.parseExprListUntil(rightToken)
}
// parseFunction parses a function call.
// a function is an identifier followed by a parenthesis,
// an optional coma-separated list of expressions and a closing parenthesis.
func (p *Parser) parseFunction() (expr.Expr, error) {
// Parse function name.
funcName, err := p.parseIdent()
if err != nil {
return nil, err
}
// Parse required ( token.
if err := p.ParseTokens(scanner.LPAREN); err != nil {
return nil, err
}
// Check if the function is called without arguments.
if tok, _, _ := p.ScanIgnoreWhitespace(); tok == scanner.RPAREN {
def, err := functions.GetFunc(funcName)
if err != nil {
return nil, err
}
return def.Function()
}
p.Unscan()
var exprs []expr.Expr
// Parse expressions.
for {
e, err := p.ParseExpr()
if err != nil {
return nil, err
}
exprs = append(exprs, e)
if tok, _, _ := p.ScanIgnoreWhitespace(); tok != scanner.COMMA {
p.Unscan()
break
}
}
// Parse required ) token.
if err := p.ParseTokens(scanner.RPAREN); err != nil {
return nil, err
}
def, err := functions.GetFunc(funcName)
if err != nil {
return nil, err
}
return def.Function(exprs...)
}
// parseCastExpression parses a string of the form CAST(expr AS type).
func (p *Parser) parseCastExpression() (expr.Expr, error) {
// Parse required CAST and ( tokens.
if err := p.ParseTokens(scanner.CAST, scanner.LPAREN); err != nil {
return nil, err
}
// parse required expression.
e, err := p.ParseExpr()
if err != nil {
return nil, err
}
// Parse required AS token.
if err := p.ParseTokens(scanner.AS); err != nil {
return nil, err
}
// Parse required typename.
tp, err := p.parseType()
if err != nil {
return nil, err
}
// Parse required ) token.
if err := p.ParseTokens(scanner.RPAREN); err != nil {
return nil, err
}
return &expr.Cast{Expr: e, CastAs: tp}, nil
}
// tokenIsAllowed is a helper function that determines if a token is allowed.
func tokenIsAllowed(tok scanner.Token, allowed ...scanner.Token) bool {
if allowed == nil {
return true
}
for _, a := range allowed {
if tok == a {
return true
}
}
return false
}