mirror of
https://github.com/chaisql/chai.git
synced 2025-09-27 03:55:59 +08:00
304 lines
7.1 KiB
Go
304 lines
7.1 KiB
Go
package parser
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
|
|
"github.com/chaisql/chai/internal/expr"
|
|
"github.com/chaisql/chai/internal/query/statement"
|
|
"github.com/chaisql/chai/internal/sql/scanner"
|
|
"github.com/chaisql/chai/internal/tree"
|
|
"github.com/cockroachdb/errors"
|
|
)
|
|
|
|
// Parser represents an Chai SQL Parser.
|
|
type Parser struct {
|
|
s *scanner.Scanner
|
|
}
|
|
|
|
// NewParser returns a new instance of Parser.
|
|
func NewParser(r io.Reader) *Parser {
|
|
return &Parser{s: scanner.NewScanner(r)}
|
|
}
|
|
|
|
// ParseQuery parses a query string and returns its AST representation.
|
|
func ParseQuery(s string) ([]statement.Statement, error) {
|
|
return NewParser(strings.NewReader(s)).ParseQuery()
|
|
}
|
|
|
|
// ParseExpr parses an expression.
|
|
func ParseExpr(s string) (expr.Expr, error) {
|
|
e, err := NewParser(strings.NewReader(s)).ParseExpr()
|
|
return e, err
|
|
}
|
|
|
|
// MustParseExpr calls ParseExpr and panics if it returns an error.
|
|
func MustParseExpr(s string) expr.Expr {
|
|
e, err := ParseExpr(s)
|
|
if err != nil {
|
|
panic(fmt.Sprintf("%+v", err))
|
|
}
|
|
|
|
return e
|
|
}
|
|
|
|
// ParseQuery parses a Chai SQL string and returns a Query.
|
|
func (p *Parser) ParseQuery() ([]statement.Statement, error) {
|
|
var statements []statement.Statement
|
|
|
|
err := p.Parse(func(s statement.Statement) error {
|
|
statements = append(statements, s)
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return statements, nil
|
|
}
|
|
|
|
// ParseQuery parses a Chai SQL string and returns a Query.
|
|
func (p *Parser) Parse(fn func(statement.Statement) error) error {
|
|
for {
|
|
err := p.skipMany(scanner.SEMICOLON)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
end, err := p.parseOptional(scanner.EOF)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if end {
|
|
return nil
|
|
}
|
|
|
|
s, err := p.ParseStatement()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
tok, pos, lit := p.ScanIgnoreWhitespace()
|
|
switch tok {
|
|
case scanner.EOF:
|
|
return fn(s)
|
|
case scanner.SEMICOLON:
|
|
err = fn(s)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
default:
|
|
p.Unscan()
|
|
return newParseError(scanner.Tokstr(tok, lit), []string{";"}, pos)
|
|
}
|
|
}
|
|
}
|
|
|
|
// ParseStatement parses a Chai SQL string and returns a statement.
|
|
func (p *Parser) ParseStatement() (statement.Statement, error) {
|
|
tok, pos, lit := p.ScanIgnoreWhitespace()
|
|
p.Unscan()
|
|
switch tok {
|
|
case scanner.ALTER:
|
|
return p.parseAlterStatement()
|
|
case scanner.BEGIN:
|
|
return p.parseBeginStatement()
|
|
case scanner.COMMIT:
|
|
return p.parseCommitStatement()
|
|
case scanner.SELECT:
|
|
return p.parseSelectStatement()
|
|
case scanner.DELETE:
|
|
return p.parseDeleteStatement()
|
|
case scanner.UPDATE:
|
|
return p.parseUpdateStatement()
|
|
case scanner.INSERT:
|
|
return p.parseInsertStatement()
|
|
case scanner.CREATE:
|
|
return p.parseCreateStatement()
|
|
case scanner.DROP:
|
|
return p.parseDropStatement()
|
|
case scanner.EXPLAIN:
|
|
return p.parseExplainStatement()
|
|
case scanner.REINDEX:
|
|
return p.parseReIndexStatement()
|
|
case scanner.ROLLBACK:
|
|
return p.parseRollbackStatement()
|
|
}
|
|
|
|
return nil, newParseError(scanner.Tokstr(tok, lit), []string{
|
|
"ALTER", "BEGIN", "COMMIT", "SELECT", "DELETE", "UPDATE", "INSERT", "CREATE", "DROP", "EXPLAIN", "REINDEX", "ROLLBACK",
|
|
}, pos)
|
|
}
|
|
|
|
func (p *Parser) skipMany(tok scanner.Token) error {
|
|
for {
|
|
t, _, _ := p.ScanIgnoreWhitespace()
|
|
if t != tok {
|
|
p.Unscan()
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// parseCondition parses the "WHERE" clause of the query, if it exists.
|
|
func (p *Parser) parseCondition() (expr.Expr, error) {
|
|
// Check if the WHERE token exists.
|
|
if ok, err := p.parseOptional(scanner.WHERE); !ok || err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Scan the identifier for the source.
|
|
expr, err := p.ParseExpr()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return expr, nil
|
|
}
|
|
|
|
// parseColumnList parses a list of columns in the form: (path, path, ...), if exists
|
|
func (p *Parser) parseColumnList() ([]string, tree.SortOrder, error) {
|
|
// Parse ( token.
|
|
if ok, err := p.parseOptional(scanner.LPAREN); !ok || err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
var columns []string
|
|
var err error
|
|
var col string
|
|
var order tree.SortOrder
|
|
|
|
// Parse first (required) column.
|
|
if col, err = p.parseIdent(); err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
columns = append(columns, col)
|
|
|
|
// Parse optional ASC/DESC token.
|
|
ok, err := p.parseOptional(scanner.DESC)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
if ok {
|
|
order = order.SetDesc(0)
|
|
} else {
|
|
// ignore ASC if set
|
|
_, err := p.parseOptional(scanner.ASC)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
}
|
|
|
|
// Parse remaining (optional) columns.
|
|
i := 0
|
|
for {
|
|
if tok, _, _ := p.ScanIgnoreWhitespace(); tok != scanner.COMMA {
|
|
p.Unscan()
|
|
break
|
|
}
|
|
|
|
c, err := p.parseIdent()
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
columns = append(columns, c)
|
|
|
|
i++
|
|
|
|
// Parse optional ASC/DESC token.
|
|
ok, err := p.parseOptional(scanner.DESC)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
if ok {
|
|
order = order.SetDesc(i)
|
|
} else {
|
|
// ignore ASC if set
|
|
_, err := p.parseOptional(scanner.ASC)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
}
|
|
}
|
|
|
|
// Parse required ) token.
|
|
if err := p.ParseTokens(scanner.RPAREN); err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
return columns, order, nil
|
|
}
|
|
|
|
// Scan returns the next token from the underlying scanner.
|
|
func (p *Parser) Scan() (tok scanner.Token, pos scanner.Pos, lit string) { return p.s.Scan() }
|
|
|
|
// ScanIgnoreWhitespace scans the next non-whitespace and non-comment token.
|
|
func (p *Parser) ScanIgnoreWhitespace() (tok scanner.Token, pos scanner.Pos, lit string) {
|
|
for {
|
|
tok, pos, lit = p.Scan()
|
|
if tok == scanner.WS || tok == scanner.COMMENT {
|
|
continue
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
// Unscan pushes the previously read token back onto the buffer.
|
|
func (p *Parser) Unscan() {
|
|
p.s.Unscan()
|
|
}
|
|
|
|
// ParseTokens parses all the given tokens one after the other.
|
|
// It returns an error if one of the token is missing.
|
|
func (p *Parser) ParseTokens(tokens ...scanner.Token) error {
|
|
for _, t := range tokens {
|
|
if tok, pos, lit := p.ScanIgnoreWhitespace(); tok != t {
|
|
return newParseError(scanner.Tokstr(tok, lit), []string{t.String()}, pos)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// parseOptional parses a list of consecutive tokens. If the first token is not
|
|
// present, it unscans and return false. If the first is present, all the others
|
|
// must be parsed otherwise an error is returned.
|
|
func (p *Parser) parseOptional(tokens ...scanner.Token) (bool, error) {
|
|
// Parse optional first token
|
|
if tok, _, _ := p.ScanIgnoreWhitespace(); tok != tokens[0] {
|
|
p.Unscan()
|
|
return false, nil
|
|
}
|
|
|
|
if len(tokens) == 1 {
|
|
return true, nil
|
|
}
|
|
|
|
err := p.ParseTokens(tokens[1:]...)
|
|
return err == nil, err
|
|
}
|
|
|
|
// ParseError represents an error that occurred during parsing.
|
|
type ParseError struct {
|
|
Message string
|
|
Found string
|
|
Expected []string
|
|
Pos scanner.Pos
|
|
}
|
|
|
|
// newParseError returns a new instance of ParseError.
|
|
func newParseError(found string, expected []string, pos scanner.Pos) error {
|
|
return errors.WithStack(&ParseError{Found: found, Expected: expected, Pos: pos})
|
|
}
|
|
|
|
// Error returns the string representation of the error.
|
|
func (e *ParseError) Error() string {
|
|
if e.Message != "" {
|
|
return fmt.Sprintf("%s at line %d, char %d", e.Message, e.Pos.Line+1, e.Pos.Char+1)
|
|
}
|
|
return fmt.Sprintf("found %s, expected %s at line %d, char %d", e.Found, strings.Join(e.Expected, ", "), e.Pos.Line+1, e.Pos.Char+1)
|
|
}
|