Files
gorose-pro/builder_oracle.go
tuuz ef8e4a4615 Fix/Change
- 修复db.Data需要手动添加""符号的问题
- 修改Data方法的Null处理,现在输入"nil"和空字符串都会让字段变成Null type,而不是原来的null的string type
2023-11-23 14:44:40 +08:00

575 lines
17 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package gorose
import (
"errors"
"fmt"
"github.com/gohouse/golib/structEngin"
"github.com/gohouse/t"
"reflect"
"regexp"
"strconv"
"strings"
)
const (
// DriverOracle ...
DriverOracle = "oracle"
)
var (
operatorOracle = []string{"=", ">", "<", "!=", "<>", ">=", "<=", "like", "not like",
"intersect", "minus", "union", "||", "in", "not in", "between", "not between"}
)
// BuilderOracle ...
type BuilderOracle struct {
FieldQuotesDefault
BuilderDefault
}
// NewBuilderOracle ...
func NewBuilderOracle(o IOrm) *BuilderOracle {
//onceBuilderDefault.Do(func() {
// builderOracle = new(BuilderOracle)
// builderOracle.operator = operatorOracle
//})
builderOracle := new(BuilderOracle)
builderOracle.operator = operatorOracle
builderOracle.IOrm = o
// 每次使用的时候, 重置为0, 方便pg的占位符使用
builderOracle.placeholder = 0
return builderOracle
}
func init() {
var builder = &BuilderOracle{}
NewBuilderDriver().Register(DriverOracle, NewBuilderOracle(builder))
}
// Clone : a new obj
func (b *BuilderOracle) Clone() IBuilder {
return &BuilderOracle{}
}
// SetDriver 设置驱动, 方便获取占位符使用
func (b *BuilderOracle) SetDriver(dr string) *BuilderOracle {
b.driver = dr
return b
}
// GetPlaceholder 获取占位符
func (b *BuilderOracle) GetPlaceholder() (phstr string) {
withLockContext(func() {
ph := b.placeholder + 1
phstr = fmt.Sprintf(":%v", ph)
b.placeholder = ph
})
return
}
// BuildQueryOra ...
func (b *BuilderOracle) BuildQueryOra() (sqlStr string, args []interface{}, err error) {
//b.IOrm = o
join, err := b.BuildJoin()
if err != nil {
b.IOrm.GetISession().GetIEngin().GetLogger().Error(err.Error())
return
}
where, err := b.BuildWhere()
if err != nil {
b.IOrm.GetISession().GetIEngin().GetLogger().Error(err.Error())
return
}
if len(b.GetJoin()) > 0 {
b.GetFields()
}
// 默认情况
fieldsStr := b.BuildFields()
tableName := b.BuildTable()
sqlStr = fmt.Sprintf("SELECT %s%s FROM %s%s%s%s%s%s%s", b.BuildDistinct(), fieldsStr,
tableName, join, where, b.BuildLimit(), b.BuildGroup(), b.BuildHaving(), b.BuildOrder())
//fmt.Println("sqlstr1", sqlStr)
// 批量取数据需嵌套写法
if b.GetLimit() > 0 {
aliasNameA := "tabA"
aliasNameB := "tabB"
page := b.GetOffset()/b.GetLimit() + 1
startRow := (page-1)*b.GetLimit() + 1
endRow := page*b.GetLimit() + 1
//if fieldsStr == "*" {
// fieldsStr = b.AddFieldQuotesOracle(b.GetTable()) + ".*, rownum r"
//} else {
// if b.GetGroup() == "" {
// fieldsStr = fieldsStr + ", rownum r"
// }
//}
// 没有group by需要1层嵌套 有group by需要2层嵌套
// 如果考虑orderby优化还需要一层嵌套。目前未考虑
if b.GetGroup() == "" {
sqlStr = fmt.Sprintf("SELECT %s%s FROM %s%s%s%s", b.BuildDistinct(), fieldsStr,
tableName, join, where, b.BuildOrder())
//sqlStr = fmt.Sprintf("select * from (%s) %s where %s.r>=%s",
// sqlStr, aliasNameA, aliasNameA, strconv.Itoa(startRow))
//fmt.Println("sqlStr2", sqlStr, where)
} else {
if fieldsStr == "*" {
fieldsStr = b.AddFieldQuotesOracle(b.GetTable()) + ".*, rownum r"
} else {
if b.GetGroup() == "" {
fieldsStr = fieldsStr + ", rownum r"
}
}
sqlStr = fmt.Sprintf("SELECT %s%s FROM %s%s%s%s%s%s GROUP BY ROWNUM", b.BuildDistinct(), fieldsStr,
tableName, join, where, b.BuildGroup(), b.BuildHaving(), b.BuildOrder())
sqlStr = fmt.Sprintf(
"select * from (select %s, rownum r from (%s) %s where rownum<%s ) %s where %s.r>=%s",
aliasNameA+".*", sqlStr, aliasNameA, strconv.Itoa(endRow), aliasNameB, aliasNameB,
strconv.Itoa(startRow))
}
}
//fmt.Println("final",sqlStr)
//args = b.bindParams
args = b.IOrm.GetBindValues()
return
}
// BuildExecuteOra ...
func (b *BuilderOracle) BuildExecuteOra(operType string) (sqlStr string, args []interface{}, err error) {
// insert : {"name":"fizz, "website":"fizzday.net"} or {{"name":"fizz2", "website":"www.fizzday.net"}, {"name":"fizz", "website":"fizzday.net"}}}
// update : {"name":"fizz", "website":"fizzday.net"}
// delete : ...
//b.IOrm = o
var update, insertkey, insertval string
if operType != "delete" {
if b.IOrm.GetData() == nil {
err = errors.New("insert,update请传入数据操作")
b.IOrm.GetISession().GetIEngin().GetLogger().Error(err.Error())
return
}
update, insertkey, insertval = b.BuildData(operType)
}
where := ""
switch operType {
case "insert":
sqlStr = fmt.Sprintf("INSERT INTO %s (%s) VALUES %s", b.BuildTable(), insertkey, insertval)
break
case "update":
where, err = b.BuildWhere()
if err != nil {
b.IOrm.GetISession().GetIEngin().GetLogger().Error(err.Error())
return
}
if where == "" && b.IOrm.GetForce() == false {
err = errors.New("出于安全考虑, update时where条件不能为空, 如果真的不需要where条件, 请使用Force()(如: db.xxx.Force().Update())")
b.IOrm.GetISession().GetIEngin().GetLogger().Error(err.Error())
return
}
sqlStr = fmt.Sprintf("UPDATE %s SET %s%s", b.BuildTable(), update, where)
break
case "delete":
where, err = b.BuildWhere()
if err != nil {
b.IOrm.GetISession().GetIEngin().GetLogger().Error(err.Error())
return
}
if where == "" && b.IOrm.GetForce() == false {
err = errors.New("出于安全考虑, delete时where条件不能为空, 如果真的不需要where条件, 请使用Force()(如: db.xxx.Force().Delete())")
b.IOrm.GetISession().GetIEngin().GetLogger().Error(err.Error())
return
}
sqlStr = fmt.Sprintf("DELETE FROM %s%s", b.BuildTable(), where)
break
case "replace":
where, err = b.BuildWithoutWhere()
if err != nil {
b.IOrm.GetISession().GetIEngin().GetLogger().Error(err.Error())
return
}
select_sql, on_sql, update_sql, insert_keys, insert_vals, err := b.BuildReplace(update, where)
if err != nil {
return "", nil, err
}
sqlStr = "MERGE INTO " + b.BuildTable() + " t "
sqlStr += "USING (select " + select_sql + " FROM dual) d on (" + on_sql + ") "
if update_sql != "" {
sqlStr += "WHEN matched THEN UPDATE SET " + update_sql + " "
}
sqlStr += "WHEN NOT matched THEN INSERT (" + insert_keys + ") VALUES (" + insert_vals + ")"
//sqlStr = fmt.Sprintf("MERGE INTO %s t USING (select %s FROM dual) d on (%s) WHEN matched THEN UPDATE SET%s WHEN NOT matched THEN INSERT (%s) VALUES (%s)", b.BuildTable(), select_sql, on_sql, update_sql, insert_keys, insert_vals)
break
}
args = b.IOrm.GetBindValues()
return
}
func (b *BuilderOracle) BuildReplace(update, where string) (select_sql, on_sql, update_sql, insert_keys, insert_vals string, err error) {
var reg *regexp.Regexp
reg, err = regexp.Compile(`\(([^\)]+)\)`)
if err != nil {
return
}
warr := []string{}
wheres := reg.FindAllString(where, -1)
for i, ws := range wheres {
ws = strings.ReplaceAll(ws, "(", "")
ws = strings.ReplaceAll(ws, ")", "")
if i > 0 {
on_sql += " and"
}
warr = append(warr, b.AddFieldQuotesOracle(ws))
on_sql += " t." + b.AddFieldQuotesOracle(ws) + "=" + "d." + b.AddFieldQuotesOracle(ws) + ""
}
data1 := strings.Split(update, ",")
for i, data := range data1 {
data_kv := strings.Split(data, "=")
if i > 0 {
select_sql += ","
insert_keys += ","
insert_vals += ","
}
if !inArray(b.AddFieldQuotesOracle(data_kv[0]), warr) {
if len(warr) > 0 {
if update_sql != "" && i > 0 {
update_sql += ","
}
update_sql += " t." + b.AddFieldQuotesOracle(data_kv[0]) + "=" + "d." + b.AddFieldQuotesOracle(data_kv[0]) + ""
} else {
if i > 0 {
on_sql += " and"
}
on_sql += " t." + b.AddFieldQuotesOracle(data_kv[0]) + "=" + "d." + b.AddFieldQuotesOracle(data_kv[0]) + ""
}
}
insert_keys += b.AddFieldQuotesOracle(data_kv[0])
insert_vals += "d." + b.AddFieldQuotesOracle(data_kv[0]) + ""
select_sql += data_kv[1] + " " + b.AddFieldQuotesOracle(data_kv[0])
}
return
}
// BuildData ...
func (b *BuilderOracle) BuildData(operType string) (string, string, string) {
data := b.IOrm.GetData()
ref := reflect.Indirect(reflect.ValueOf(data))
switch ref.Kind() {
case reflect.Struct:
return b.parseData(operType, structEngin.New().SetExtraCols(b.IOrm.GetExtraCols()).StructContent2Map(data))
case reflect.Map:
var tmp = []map[string]interface{}{t.New(data).MapStringInterface()}
return b.parseData(operType, tmp)
case reflect.Slice:
switch ref.Type().Elem().Kind() {
case reflect.Struct:
return b.parseData(operType, structEngin.New().SetExtraCols(b.IOrm.GetExtraCols()).StructContent2Map(data))
case reflect.Map:
return b.parseData(operType, t.New(data).SliceMapStringInterface())
}
}
return "", "", ""
}
// BuildData2 ...
func (b *BuilderOracle) BuildData2(operType string) (string, string, string) {
return b.BuilderDefault.BuildData2(operType)
}
func (b *BuilderOracle) parseData(operType string, data []map[string]interface{}) (string, string, string) {
// insert
var dataFields []string
var dataValues []string
// update or delete
var dataObj []string
for key := range data[0] {
if inArray(key, dataFields) == false {
dataFields = append(dataFields, key)
}
}
for _, item := range data {
// 定义1条数据的存储
var dataValuesSub []string
for _, key := range dataFields {
//if item[key] == nil {
// // 放入占位符
// dataValuesSub = append(dataValuesSub, b.GetPlaceholder())
// // 保存真正的值为null
// b.IOrm.SetBindValues("null")
//} else {
// 放入占位符
dataValuesSub = append(dataValuesSub, b.GetPlaceholder())
// 保存真正的值
b.IOrm.SetBindValues(item[key])
//}
// update
dataObj = append(dataObj, fmt.Sprintf("%s=%s", b.AddFieldQuotesOracle(key), b.GetPlaceholder()))
}
dataValues = append(dataValues, "("+strings.Join(dataValuesSub, ",")+")")
}
return strings.Join(dataObj, ","), strings.Join(dataFields, ","), strings.Join(dataValues, ",")
}
// BuildJoin ...
func (b *BuilderOracle) BuildJoin() (s string, err error) {
return b.BuilderDefault.BuildJoin()
}
func (b *BuilderOracle) BuildWithoutWhere() (where string, err error) {
var beforeParseWhere = b.IOrm.GetWhere()
where, err = b.parseWhere(b.IOrm)
b.IOrm.SetWhere(beforeParseWhere)
return If(where == "", "", " "+where).(string), err
}
// BuildWhere ...
func (b *BuilderOracle) BuildWhere() (where string, err error) {
var beforeParseWhere = b.IOrm.GetWhere()
where, err = b.parseWhere(b.IOrm)
b.IOrm.SetWhere(beforeParseWhere)
return If(where == "", "", " WHERE "+where).(string), err
}
func (b *BuilderOracle) BuildDistinct() (dis string) {
return b.BuilderDefault.BuildDistinct()
}
// BuildFields ...
func (b *BuilderOracle) BuildFields() string {
return b.BuilderDefault.BuildFields()
}
// BuildTable ...
func (b *BuilderOracle) BuildTable() string {
return b.AddFieldQuotesOracle(b.BuilderDefault.IOrm.GetTable())
}
// BuildGroup ...
func (b *BuilderOracle) BuildGroup() string {
return b.BuilderDefault.BuildGroup()
}
// BuildHaving ...
func (b *BuilderOracle) BuildHaving() string {
return b.BuilderDefault.BuildHaving()
}
// BuildOrder ...
func (b *BuilderOracle) BuildOrder() string {
str_arr1 := strings.Split(b.IOrm.GetOrder(), ",")
order := strings.Builder{}
for i, s := range str_arr1 {
if i > 0 {
order.WriteString(",")
}
temp_arr := strings.Split(s, " ")
for _, s2 := range temp_arr {
switch strings.TrimSpace(s2) {
case "desc", "asc":
order.WriteString(" " + strings.TrimSpace(s2))
break
default:
order.WriteString(b.AddFieldQuotesOracle(strings.TrimSpace(s2)))
break
}
}
}
return If(b.IOrm.GetOrder() == "", "", " ORDER BY "+order.String()).(string)
}
// BuildLimit ...
func (b *BuilderOracle) BuildLimit() string {
//if b.IOrm.GetUnion() != nil {
// return ""
//}
if b.GetLimit() == 0 {
return ""
}
page := b.GetOffset()/b.GetLimit() + 1
endRow := page*b.GetLimit() + 1
var limitStr string
if len(b.IOrm.GetWhere()) > 0 {
limitStr = fmt.Sprintf(" and rownum < %d", endRow)
} else {
limitStr = fmt.Sprintf(" where rownum < %d", endRow)
}
return If(b.IOrm.GetLimit() == 0, "", limitStr).(string)
}
// BuildOffset ...
func (b *BuilderOracle) BuildOffset() string {
return ""
}
func (b *BuilderOracle) parseWhere(ormApi IOrm) (string, error) {
// 取出所有where
wheres := ormApi.GetWhere()
// where解析后存放每一项的容器
var where []string
for _, args := range wheres {
// and或者or条件
var condition = args[0].(string)
// 统计当前数组中有多少个参数
params := args[1].([]interface{})
paramsLength := len(params)
switch paramsLength {
case 3: // 常规3个参数: {"id",">",1}
res, err := b.parseParams(params, ormApi)
if err != nil {
return res, err
}
where = append(where, condition+" "+res)
case 2: // 常规2个参数: {"id",1}
res, err := b.parseParams(params, ormApi)
if err != nil {
return res, err
}
where = append(where, condition+" "+res)
case 1: // 二维数组或字符串
switch paramReal := params[0].(type) {
case string:
where = append(where, condition+" ("+paramReal+")")
case map[string]interface{}: // 一维数组
var whereArr []string
for key, val := range paramReal {
whereArr = append(whereArr, key+"="+b.GetPlaceholder())
b.IOrm.SetBindValues(val)
}
where = append(where, condition+" ("+strings.Join(whereArr, " and ")+")")
case [][]interface{}: // 二维数组
var whereMore []string
for _, arr := range paramReal { // {{"a", 1}, {"id", ">", 1}}
whereMoreLength := len(arr)
switch whereMoreLength {
case 3:
res, err := b.parseParams(arr, ormApi)
if err != nil {
return res, err
}
whereMore = append(whereMore, res)
case 2:
res, err := b.parseParams(arr, ormApi)
if err != nil {
return res, err
}
whereMore = append(whereMore, res)
default:
return "", errors.New("where data format is wrong")
}
}
where = append(where, condition+" ("+strings.Join(whereMore, " and ")+")")
case func():
// 清空where,给嵌套的where让路,复用这个节点
ormApi.SetWhere([][]interface{}{})
// 执行嵌套where放入Database struct
paramReal()
// 再解析一遍后来嵌套进去的where
wherenested, err := b.parseWhere(ormApi)
if err != nil {
b.IOrm.GetISession().GetIEngin().GetLogger().Error(err.Error())
return "", err
}
// 嵌套的where放入一个括号内
where = append(where, condition+" ("+wherenested+")")
default:
return "", errors.New("where data format is wrong")
}
}
}
// 合并where,去掉左侧的空格,and,or并返回
return strings.TrimLeft(
strings.TrimPrefix(
strings.TrimPrefix(
strings.Trim(
strings.Join(where, " "),
" "),
"and"),
"or"),
" "), nil
}
func (b *BuilderOracle) parseParams(args []interface{}, ormApi IOrm) (s string, err error) {
paramsLength := len(args)
argsReal := args
// 存储当前所有数据的数组
var paramsToArr []string
switch paramsLength {
case 3: // 常规3个参数: {"id",">",1}
//if !inArray(argsReal[1], b.GetRegex()) {
if !inArray(argsReal[1], b.GetOperator()) {
err = errors.New("where parameter is wrong")
b.IOrm.GetISession().GetIEngin().GetLogger().Error(err.Error())
return
}
paramsToArr = append(paramsToArr, b.AddFieldQuotesOracle(argsReal[0].(string)))
paramsToArr = append(paramsToArr, argsReal[1].(string))
switch argsReal[1] {
case "like", "not like":
paramsToArr = append(paramsToArr, b.GetPlaceholder())
b.IOrm.SetBindValues(argsReal[2])
case "in", "not in":
var tmp []string
var ar2 = t.New(argsReal[2]).Slice()
for _, item := range ar2 {
tmp = append(tmp, b.GetPlaceholder())
b.IOrm.SetBindValues(t.New(item).Interface())
}
paramsToArr = append(paramsToArr, "("+strings.Join(tmp, ",")+")")
case "between", "not between":
var ar2 = t.New(argsReal[2]).Slice()
paramsToArr = append(paramsToArr, b.GetPlaceholder()+" and "+b.GetPlaceholder())
b.IOrm.SetBindValues(ar2[0].Interface())
b.IOrm.SetBindValues(ar2[1].Interface())
default:
paramsToArr = append(paramsToArr, b.GetPlaceholder())
b.IOrm.SetBindValues(argsReal[2])
}
case 2:
paramsToArr = append(paramsToArr, b.AddFieldQuotesOracle(argsReal[0].(string)))
paramsToArr = append(paramsToArr, "=")
paramsToArr = append(paramsToArr, b.GetPlaceholder())
b.IOrm.SetBindValues(argsReal[1])
}
return strings.Join(paramsToArr, " "), nil
}
// GetOperator ...
func (b *BuilderOracle) GetOperator() []string {
return b.BuilderDefault.GetOperator()
}
// 实现接口
// BuildQuery : build query sql string
func (b *BuilderOracle) BuildQuery(o IOrm) (sqlStr string, args []interface{}, err error) {
return NewBuilderOracle(o).SetDriver(DriverOracle).BuildQueryOra()
}
// BuildExecut : build execute sql string
func (b *BuilderOracle) BuildExecute(o IOrm, operType string) (sqlStr string, args []interface{}, err error) {
return NewBuilderOracle(o).SetDriver(DriverOracle).BuildExecuteOra(operType)
}