Files
kubevpn/vendor/github.com/DataDog/go-sqllexer/sqllexer.go
naison 87166494c0 refactor: update go mod library (#210)
refactor: update go mod library and refactor dev logic

Co-authored-by: wencaiwulue <895703375@qq.com>
2024-04-04 12:04:02 +08:00

486 lines
12 KiB
Go

package sqllexer
import "unicode/utf8"
type TokenType int
const (
ERROR TokenType = iota
EOF
WS // whitespace
STRING // string literal
INCOMPLETE_STRING // incomplete string literal so that we can obfuscate it, e.g. 'abc
NUMBER // number literal
IDENT // identifier
QUOTED_IDENT // quoted identifier
OPERATOR // operator
WILDCARD // wildcard *
COMMENT // comment
MULTILINE_COMMENT // multiline comment
PUNCTUATION // punctuation
DOLLAR_QUOTED_FUNCTION // dollar quoted function
DOLLAR_QUOTED_STRING // dollar quoted string
POSITIONAL_PARAMETER // numbered parameter
BIND_PARAMETER // bind parameter
FUNCTION // function
SYSTEM_VARIABLE // system variable
UNKNOWN // unknown token
)
// Token represents a SQL token with its type and value.
type Token struct {
Type TokenType
Value string
}
type LexerConfig struct {
DBMS DBMSType `json:"dbms,omitempty"`
}
type lexerOption func(*LexerConfig)
func WithDBMS(dbms DBMSType) lexerOption {
return func(c *LexerConfig) {
c.DBMS = dbms
}
}
// SQL Lexer inspired from Rob Pike's talk on Lexical Scanning in Go
type Lexer struct {
src string // the input src string
cursor int // the current position of the cursor
start int // the start position of the current token
config *LexerConfig
}
func New(input string, opts ...lexerOption) *Lexer {
lexer := &Lexer{src: input, config: &LexerConfig{}}
for _, opt := range opts {
opt(lexer.config)
}
return lexer
}
// ScanAll scans the entire input string and returns a slice of tokens.
func (s *Lexer) ScanAll() []Token {
var tokens []Token
for {
token := s.Scan()
if token.Type == EOF {
// don't include EOF token in the result
break
}
tokens = append(tokens, token)
}
return tokens
}
// ScanAllTokens scans the entire input string and returns a channel of tokens.
// Use this if you want to process the tokens as they are scanned.
func (s *Lexer) ScanAllTokens() <-chan Token {
tokenCh := make(chan Token)
go func() {
defer close(tokenCh)
for {
token := s.Scan()
if token.Type == EOF {
// don't include EOF token in the result
break
}
tokenCh <- token
}
}()
return tokenCh
}
// Scan scans the next token and returns it.
func (s *Lexer) Scan() Token {
ch := s.peek()
switch {
case isWhitespace(ch):
return s.scanWhitespace()
case isLetter(ch):
return s.scanIdentifier(ch)
case isDoubleQuote(ch):
return s.scanDoubleQuotedIdentifier('"')
case isSingleQuote(ch):
return s.scanString()
case isSingleLineComment(ch, s.lookAhead(1)):
return s.scanSingleLineComment()
case isMultiLineComment(ch, s.lookAhead(1)):
return s.scanMultiLineComment()
case isLeadingSign(ch):
// if the leading sign is followed by a digit, then it's a number
// although this is not strictly true, it's good enough for our purposes
nextCh := s.lookAhead(1)
if isDigit(nextCh) || nextCh == '.' {
return s.scanNumberWithLeadingSign()
}
return s.scanOperator(ch)
case isDigit(ch):
return s.scanNumber(ch)
case isWildcard(ch):
return s.scanWildcard()
case ch == '$':
if isDigit(s.lookAhead(1)) {
// if the dollar sign is followed by a digit, then it's a numbered parameter
return s.scanPositionalParameter()
}
if s.config.DBMS == DBMSSQLServer && isLetter(s.lookAhead(1)) {
return s.scanIdentifier(ch)
}
return s.scanDollarQuotedString()
case ch == ':':
if s.config.DBMS == DBMSOracle && isAlphaNumeric(s.lookAhead(1)) {
return s.scanBindParameter()
}
return s.scanOperator(ch)
case ch == '`':
if s.config.DBMS == DBMSMySQL {
return s.scanDoubleQuotedIdentifier('`')
}
fallthrough
case ch == '#':
if s.config.DBMS == DBMSSQLServer {
return s.scanIdentifier(ch)
} else if s.config.DBMS == DBMSMySQL {
// MySQL treats # as a comment
return s.scanSingleLineComment()
}
fallthrough
case ch == '@':
if isAlphaNumeric(s.lookAhead(1)) {
if s.config.DBMS == DBMSSnowflake {
return s.scanIdentifier(ch)
}
return s.scanBindParameter()
} else if s.lookAhead(1) == '@' {
return s.scanSystemVariable()
}
fallthrough
case isOperator(ch):
return s.scanOperator(ch)
case isPunctuation(ch):
if ch == '[' && s.config.DBMS == DBMSSQLServer {
return s.scanDoubleQuotedIdentifier('[')
}
return s.scanPunctuation()
case isEOF(ch):
return Token{EOF, ""}
default:
return s.scanUnknown()
}
}
// lookAhead returns the rune n positions ahead of the cursor.
func (s *Lexer) lookAhead(n int) rune {
if s.cursor+n >= len(s.src) || s.cursor+n < 0 {
return 0
}
r, _ := utf8.DecodeRuneInString(s.src[s.cursor+n:])
return r
}
// peek returns the rune at the cursor position.
func (s *Lexer) peek() rune {
return s.lookAhead(0)
}
// nextBy advances the cursor by n positions and returns the rune at the cursor position.
func (s *Lexer) nextBy(n int) rune {
// advance the cursor by n and return the rune at the cursor position
if s.cursor+n > len(s.src) {
return 0
}
s.cursor += n
if s.cursor >= len(s.src) {
return 0
}
r, _ := utf8.DecodeRuneInString(s.src[s.cursor:])
return r
}
// next advances the cursor by 1 position and returns the rune at the cursor position.
func (s *Lexer) next() rune {
return s.nextBy(1)
}
func (s *Lexer) matchAt(match []rune) bool {
if s.cursor+len(match) > len(s.src) {
return false
}
for i, ch := range match {
if s.src[s.cursor+i] != byte(ch) {
return false
}
}
return true
}
func (s *Lexer) scanNumberWithLeadingSign() Token {
s.start = s.cursor
ch := s.next() // consume the leading sign
return s.scanNumberic(ch)
}
func (s *Lexer) scanNumber(ch rune) Token {
s.start = s.cursor
return s.scanNumberic(ch)
}
func (s *Lexer) scanNumberic(ch rune) Token {
if ch == '0' {
nextCh := s.lookAhead(1)
if nextCh == 'x' || nextCh == 'X' {
return s.scanHexNumber()
} else if nextCh >= '0' && nextCh <= '7' {
return s.scanOctalNumber()
}
}
return s.scanDecimalNumber()
}
func (s *Lexer) scanDecimalNumber() Token {
ch := s.next()
// scan digits
for isDigit(ch) || ch == '.' || isExpontent(ch) {
if isExpontent(ch) {
ch = s.next()
if isLeadingSign(ch) {
ch = s.next()
}
} else {
ch = s.next()
}
}
return Token{NUMBER, s.src[s.start:s.cursor]}
}
func (s *Lexer) scanHexNumber() Token {
ch := s.nextBy(2) // consume the leading 0x
for isDigit(ch) || ('a' <= ch && ch <= 'f') || ('A' <= ch && ch <= 'F') {
ch = s.next()
}
return Token{NUMBER, s.src[s.start:s.cursor]}
}
func (s *Lexer) scanOctalNumber() Token {
ch := s.nextBy(2) // consume the leading 0 and number
for '0' <= ch && ch <= '7' {
ch = s.next()
}
return Token{NUMBER, s.src[s.start:s.cursor]}
}
func (s *Lexer) scanString() Token {
s.start = s.cursor
ch := s.next() // consume the opening quote
escaped := false
for {
if escaped {
// encountered an escape character
// reset the escaped flag and continue
escaped = false
ch = s.next()
continue
}
if ch == '\\' {
escaped = true
ch = s.next()
continue
}
if ch == '\'' {
s.next() // consume the closing quote
return Token{STRING, s.src[s.start:s.cursor]}
}
if isEOF(ch) {
// encountered EOF before closing quote
// this usually happens when the string is truncated
return Token{INCOMPLETE_STRING, s.src[s.start:s.cursor]}
}
ch = s.next()
}
}
func (s *Lexer) scanIdentifier(ch rune) Token {
// NOTE: this func does not distinguish between SQL keywords and identifiers
s.start = s.cursor
ch = s.nextBy(utf8.RuneLen(ch))
for isLetter(ch) || isDigit(ch) || ch == '.' || ch == '?' || ch == '$' || ch == '#' || ch == '/' {
ch = s.nextBy(utf8.RuneLen(ch))
}
if ch == '(' {
// if the identifier is followed by a (, then it's a function
return Token{FUNCTION, s.src[s.start:s.cursor]}
}
return Token{IDENT, s.src[s.start:s.cursor]}
}
func (s *Lexer) scanDoubleQuotedIdentifier(delimiter rune) Token {
closingDelimiter := delimiter
if delimiter == '[' {
closingDelimiter = ']'
}
s.start = s.cursor
ch := s.next() // consume the opening quote
for {
// encountered the closing quote
// BUT if it's followed by .", then we should keep going
// e.g. postgre "foo"."bar"
// e.g. sqlserver [foo].[bar]
if ch == closingDelimiter {
specialCase := []rune{closingDelimiter, '.', delimiter}
if s.matchAt([]rune(specialCase)) {
ch = s.nextBy(3) // consume the "."
continue
}
break
}
if isEOF(ch) {
return Token{ERROR, s.src[s.start:s.cursor]}
}
ch = s.next()
}
s.next() // consume the closing quote
return Token{QUOTED_IDENT, s.src[s.start:s.cursor]}
}
func (s *Lexer) scanWhitespace() Token {
// scan whitespace, tab, newline, carriage return
s.start = s.cursor
ch := s.next()
for isWhitespace(ch) {
ch = s.next()
}
return Token{WS, s.src[s.start:s.cursor]}
}
func (s *Lexer) scanOperator(lastCh rune) Token {
s.start = s.cursor
ch := s.next()
for isOperator(ch) && !(lastCh == '=' && ch == '?') {
// hack: we don't want to treat "=?" as an single operator
lastCh = ch
ch = s.next()
}
return Token{OPERATOR, s.src[s.start:s.cursor]}
}
func (s *Lexer) scanWildcard() Token {
s.start = s.cursor
s.next()
return Token{WILDCARD, s.src[s.start:s.cursor]}
}
func (s *Lexer) scanSingleLineComment() Token {
s.start = s.cursor
ch := s.nextBy(2) // consume the opening dashes
for ch != '\n' && !isEOF(ch) {
ch = s.next()
}
return Token{COMMENT, s.src[s.start:s.cursor]}
}
func (s *Lexer) scanMultiLineComment() Token {
s.start = s.cursor
ch := s.nextBy(2) // consume the opening slash and asterisk
for {
if ch == '*' && s.lookAhead(1) == '/' {
s.nextBy(2) // consume the closing asterisk and slash
break
}
if isEOF(ch) {
// encountered EOF before closing comment
// this usually happens when the comment is truncated
return Token{ERROR, s.src[s.start:s.cursor]}
}
ch = s.next()
}
return Token{MULTILINE_COMMENT, s.src[s.start:s.cursor]}
}
func (s *Lexer) scanPunctuation() Token {
s.start = s.cursor
s.next()
return Token{PUNCTUATION, s.src[s.start:s.cursor]}
}
func (s *Lexer) scanDollarQuotedString() Token {
s.start = s.cursor
ch := s.next() // consume the dollar sign
tagStart := s.cursor
for s.cursor < len(s.src) && ch != '$' {
ch = s.next()
}
s.next() // consume the closing dollar sign of the tag
tag := s.src[tagStart-1 : s.cursor] // include the opening and closing dollar sign e.g. $tag$
for s.cursor < len(s.src) {
if s.matchAt([]rune(tag)) {
s.nextBy(len(tag)) // consume the closing tag
if tag == "$func$" {
return Token{DOLLAR_QUOTED_FUNCTION, s.src[s.start:s.cursor]}
}
return Token{DOLLAR_QUOTED_STRING, s.src[s.start:s.cursor]}
}
s.next()
}
return Token{ERROR, s.src[s.start:s.cursor]}
}
func (s *Lexer) scanPositionalParameter() Token {
s.start = s.cursor
ch := s.nextBy(2) // consume the dollar sign and the number
for {
if !isDigit(ch) {
break
}
ch = s.next()
}
return Token{POSITIONAL_PARAMETER, s.src[s.start:s.cursor]}
}
func (s *Lexer) scanBindParameter() Token {
s.start = s.cursor
ch := s.nextBy(2) // consume the (colon|at sign) and the char
for {
if !isAlphaNumeric(ch) {
break
}
ch = s.next()
}
return Token{BIND_PARAMETER, s.src[s.start:s.cursor]}
}
func (s *Lexer) scanSystemVariable() Token {
s.start = s.cursor
ch := s.nextBy(2) // consume @@
for {
if !isAlphaNumeric(ch) {
break
}
ch = s.next()
}
return Token{SYSTEM_VARIABLE, s.src[s.start:s.cursor]}
}
func (s *Lexer) scanUnknown() Token {
// When we see an unknown token, we advance the cursor until we see something that looks like a token boundary.
s.start = s.cursor
s.next()
return Token{UNKNOWN, s.src[s.start:s.cursor]}
}