diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4170155 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/go.sum diff --git a/README.md b/README.md index ba5a19f..1a50eb2 100644 Binary files a/README.md and b/README.md differ diff --git a/aorm.go b/aorm.go new file mode 100644 index 0000000..ec128ed --- /dev/null +++ b/aorm.go @@ -0,0 +1,53 @@ +package aorm + +import ( + "database/sql" //只需导入你需要的驱动即可 +) + +// LinkCommon database/sql提供的库连接与事务,二者有很多方法是一致的,为了通用,抽象为该interface +type LinkCommon interface { + Exec(query string, args ...interface{}) (sql.Result, error) + Prepare(query string) (*sql.Stmt, error) + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row +} + +// Executor 查询记录所需要的条件 +type Executor struct { + LinkCommon LinkCommon + TableName string + FiledList []string + GroupList []string + WhereList []WhereItem + JoinList []string + HavingList []WhereItem + OrderList []string + Offset int + PageSize int + IsDebug bool + OpinionList []OpinionItem +} + +// Use 使用数据库连接,或者事务 +func Use(linkCommon LinkCommon) *Executor { + executor := &Executor{ + LinkCommon: linkCommon, + } + + return executor +} + +//清空查询条件,复用对象 +func (db *Executor) clear() { + db.TableName = "" + db.FiledList = make([]string, 0) + db.GroupList = make([]string, 0) + db.WhereList = make([]WhereItem, 0) + db.JoinList = make([]string, 0) + db.HavingList = make([]WhereItem, 0) + db.OrderList = make([]string, 0) + db.Offset = 0 + db.PageSize = 0 + db.IsDebug = false + db.OpinionList = make([]OpinionItem, 0) +} diff --git a/crud.go b/crud.go new file mode 100644 index 0000000..17546b4 --- /dev/null +++ b/crud.go @@ -0,0 +1,637 @@ +package aorm + +import ( + "database/sql" + "errors" + "fmt" + "gopkg.in/guregu/null.v4" + "reflect" + "strconv" + "strings" + "time" +) + +const Desc = "DESC" +const Asc = "ASC" + +const Eq = "=" +const Ne = "!=" +const Gt = ">" +const Ge = ">=" +const Lt = "<" +const Le = "<=" + +const In = "IN" +const NotIn = "NOT IN" +const Like = "LIKE" +const NotLike = "NOT LIKE" +const Between = "BETWEEN" +const NotBetween = "NOT BETWEEN" + +type WhereItem struct { + Field string + Opt string + Val any +} + +type CountStruct struct { + C null.Int +} + +// Insert 增加记录 +func (db *Executor) Insert(dest interface{}) (int64, error) { + typeOf := reflect.TypeOf(dest) + valueOf := reflect.ValueOf(dest) + + //如果没有设置表名 + if db.TableName == "" { + arr := strings.Split(typeOf.String(), ".") + db.TableName = UnderLine(arr[len(arr)-1]) + } + + var keys []string + var paramList []any + var place []string + for i := 0; i < typeOf.Elem().NumField(); i++ { + isNotNull := valueOf.Elem().Field(i).Field(0).Field(1).Bool() + if isNotNull { + key := UnderLine(typeOf.Elem().Field(i).Name) + val := valueOf.Elem().Field(i).Field(0).Field(0).Interface() + keys = append(keys, key) + paramList = append(paramList, val) + place = append(place, "?") + } + } + + sqlStr := "INSERT INTO " + db.TableName + " (" + strings.Join(keys, ",") + ") VALUES (" + strings.Join(place, ",") + ")" + + res, err := db.Exec(sqlStr, paramList...) + if err != nil { + return 0, err + } + + lastId, err := res.LastInsertId() + if err != nil { + return 0, err + } + + return lastId, nil +} + +// Select 查询记录 +func (db *Executor) Select(values interface{}) error { + destSlice := reflect.Indirect(reflect.ValueOf(values)) + destType := destSlice.Type().Elem() + + res := db.selectArr() + for i := 0; i < len(res); i++ { + dest := reflect.New(destType).Elem() + + for k, v := range res[i] { + fieldName := CamelString(k) + if dest.FieldByName(fieldName).CanSet() { + filedType := dest.FieldByName(fieldName).Type().String() + x := transToNullType(v, filedType) + dest.FieldByName(fieldName).Set(x) + } + } + destSlice.Set(reflect.Append(destSlice, dest)) + } + + return nil +} + +// Find 查询某一条记录 +func (db *Executor) Find(obj interface{}) error { + + dest := reflect.ValueOf(obj).Elem() + res := db.Limit(0, 1).selectArr() + if len(res[0]) == 0 { + return errors.New("找不到相关信息") + } + + for k, v := range res[0] { + fieldName := CamelString(k) + if dest.FieldByName(fieldName).CanSet() { + filedType := dest.FieldByName(fieldName).Type().String() + x := transToNullType(v, filedType) + dest.FieldByName(fieldName).Set(x) + } + } + + return nil +} + +func (db *Executor) selectArr() []map[string]interface{} { + var paramList []any + fieldStr := handleField(db.FiledList) + whereStr, paramList := handleWhere(db.WhereList, paramList) + joinStr := handleJoin(db.JoinList) + groupStr := handleGroup(db.GroupList) + havingStr, paramList := handleHaving(db.HavingList, paramList) + orderStr := handleOrder(db.OrderList) + limitStr, paramList := handleLimit(db.Offset, db.PageSize, paramList) + + sqlStr := "SELECT " + fieldStr + " FROM " + db.TableName + joinStr + whereStr + groupStr + havingStr + orderStr + limitStr + res, _ := db.Query(sqlStr, paramList...) + + return res +} + +func transToNullType(v interface{}, filedType string) reflect.Value { + x := reflect.ValueOf("") + if "null.String" == filedType { + if nil == v { + x = reflect.ValueOf(null.String{}) + } else { + x = reflect.ValueOf(null.StringFrom(fmt.Sprintf("%v", v))) + } + } else if "null.Int" == filedType { + if nil == v { + x = reflect.ValueOf(null.Int{}) + } else { + int64Val, _ := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64) + x = reflect.ValueOf(null.IntFrom(int64Val)) + } + } else if "null.Time" == filedType { + if nil == v { + x = reflect.ValueOf(null.Time{}) + } else { + timeStr := fmt.Sprintf("%v", v) + timeArr := strings.Split(timeStr, " ") + timeArr1 := strings.Split(timeArr[0], "-") + timeArr2 := strings.Split(timeArr[1], ":") + + a := time.Date( + str2Int(timeArr1[0]), time.Month(str2Int(timeArr1[1])), str2Int(timeArr1[2]), + str2Int(timeArr2[0]), + str2Int(timeArr2[1]), + str2Int(timeArr2[2]), + 0, + time.Local, + ) + x = reflect.ValueOf(null.TimeFrom(a)) + } + } else if "null.Bool" == filedType { + if nil == v { + x = reflect.ValueOf(null.Bool{}) + } else { + boolVal, _ := strconv.ParseBool(fmt.Sprintf("%v", v)) + x = reflect.ValueOf(null.BoolFrom(boolVal)) + } + } else if "null.Float" == filedType { + if nil == v { + x = reflect.ValueOf(null.Float{}) + } else { + float64Val, _ := strconv.ParseFloat(fmt.Sprintf("%v", v), 64) + x = reflect.ValueOf(null.FloatFrom(float64Val)) + } + } else { + panic("不受支持的类型转换" + filedType) + } + + return x +} + +// Update 更新记录 +func (db *Executor) Update(dest interface{}) (int64, error) { + var paramList []any + setStr, paramList := db.handleSet(dest, paramList) + whereStr, paramList := handleWhere(db.WhereList, paramList) + sqlStr := "UPDATE " + db.TableName + setStr + whereStr + + res, err := db.Exec(sqlStr, paramList...) + if err != nil { + return 0, err + } + + count, err := res.RowsAffected() + if err != nil { + return 0, err + } + + return count, nil +} + +// Delete 删除记录 +func (db *Executor) Delete() (int64, error) { + var paramList []any + whereStr, paramList := handleWhere(db.WhereList, paramList) + sqlStr := "DELETE FROM " + db.TableName + whereStr + + res, err := db.Exec(sqlStr, paramList...) + if err != nil { + return 0, err + } + + count, err := res.RowsAffected() + if err != nil { + return 0, err + } + + return count, nil +} + +// Count 聚合函数-数量 +func (db *Executor) Count(fieldName string) int64 { + var obj []CountStruct + err := db.Field("count(" + fieldName + ") as c").Select(&obj) + if err != nil { + return 0 + } + + return obj[0].C.Int64 +} + +// Sum 聚合函数-合计 +func (db *Executor) Sum(fieldName string) int64 { + var obj []CountStruct + err := db.Field("sum(" + fieldName + ") as c").Select(&obj) + if err != nil { + return 0 + } + + return obj[0].C.Int64 +} + +// Avg 聚合函数-平均值 +func (db *Executor) Avg(fieldName string) int64 { + var obj []CountStruct + err := db.Field("avg(" + fieldName + ") as c").Select(&obj) + if err != nil { + return 0 + } + + return obj[0].C.Int64 +} + +// Max 聚合函数-最大值 +func (db *Executor) Max(fieldName string) int64 { + var obj []CountStruct + err := db.Field("avg(" + fieldName + ") as c").Select(&obj) + if err != nil { + return 0 + } + + return obj[0].C.Int64 +} + +// Min 聚合函数-最小值 +func (db *Executor) Min(fieldName string) int64 { + var obj []CountStruct + err := db.Field("avg(" + fieldName + ") as c").Select(&obj) + if err != nil { + return 0 + } + + return obj[0].C.Int64 +} + +// Query 通用查询 +func (db *Executor) Query(sqlStr string, args ...interface{}) ([]map[string]interface{}, error) { + if db.IsDebug { + fmt.Println(sqlStr) + fmt.Println(args) + } + + var listData []map[string]interface{} + + smt, err1 := db.LinkCommon.Prepare(sqlStr) + if err1 != nil { + return listData, err1 + } + + rows, err2 := smt.Query(args...) + if err2 != nil { + return listData, err2 + } + + fieldsTypes, _ := rows.ColumnTypes() + fields, _ := rows.Columns() + + for rows.Next() { + + data := make(map[string]interface{}) + + scans := make([]interface{}, len(fields)) + for i := range scans { + scans[i] = &scans[i] + } + err := rows.Scan(scans...) + if err != nil { + return nil, err + } + + for i, v := range scans { + if v == nil { + data[fields[i]] = v + } else { + if fieldsTypes[i].DatabaseTypeName() == "VARCHAR" || fieldsTypes[i].DatabaseTypeName() == "TEXT" { + data[fields[i]] = fmt.Sprintf("%s", v) + } else if fieldsTypes[i].DatabaseTypeName() == "INT" || fieldsTypes[i].DatabaseTypeName() == "BIGINT" { + data[fields[i]] = fmt.Sprintf("%v", v) + } else { + data[fields[i]] = v + } + } + } + + listData = append(listData, data) + } + + db.clear() + return listData, nil +} + +// Exec 通用执行-新增,更新,删除 +func (db *Executor) Exec(sqlStr string, args ...interface{}) (sql.Result, error) { + if db.IsDebug { + fmt.Println(sqlStr) + fmt.Println(args) + } + + smt, err1 := db.LinkCommon.Prepare(sqlStr) + if err1 != nil { + return nil, err1 + } + + res, err2 := smt.Exec(args...) + if err2 != nil { + return nil, err2 + } + + db.clear() + return res, nil +} + +// Debug 链式操作-是否开启调试,打印sql +func (db *Executor) Debug(isDebug bool) *Executor { + db.IsDebug = isDebug + return db +} + +// Field 链式操作-查询哪些字段,默认 * +func (db *Executor) Field(f string) *Executor { + db.FiledList = append(db.FiledList, f) + return db +} + +// Table 链式操作-从哪个表查询,允许直接写别名,例如 person p +func (db *Executor) Table(tableName string) *Executor { + db.TableName = tableName + return db +} + +// LeftJoin 链式操作,左联查询,例如 LeftJoin("project p", "p.project_id=o.project_id") +func (db *Executor) LeftJoin(tableName string, condition string) *Executor { + db.JoinList = append(db.JoinList, "LEFT JOIN "+tableName+" ON "+condition) + return db +} + +// RightJoin 链式操作,右联查询,例如 RightJoin("project p", "p.project_id=o.project_id") +func (db *Executor) RightJoin(tableName string, condition string) *Executor { + db.JoinList = append(db.JoinList, "RIGHT JOIN "+tableName+" ON "+condition) + return db +} + +// Join 链式操作,内联查询,例如 Join("project p", "p.project_id=o.project_id") +func (db *Executor) Join(tableName string, condition string) *Executor { + db.JoinList = append(db.JoinList, "INNER JOIN "+tableName+" ON "+condition) + return db +} + +// WhereArr 链式操作,以数组作为查询条件 +func (db *Executor) WhereArr(whereList []WhereItem) *Executor { + db.WhereList = append(db.WhereList, whereList...) + return db +} + +// Where 链式操作,以对象作为查询条件 +func (db *Executor) Where(dest interface{}) *Executor { + typeOf := reflect.TypeOf(dest) + valueOf := reflect.ValueOf(dest) + + //如果没有设置表名 + if db.TableName == "" { + arr := strings.Split(typeOf.String(), ".") + db.TableName = UnderLine(arr[len(arr)-1]) + } + + for i := 0; i < typeOf.Elem().NumField(); i++ { + isNotNull := valueOf.Elem().Field(i).Field(0).Field(1).Bool() + if isNotNull { + key := UnderLine(typeOf.Elem().Field(i).Name) + val := valueOf.Elem().Field(i).Field(0).Field(0).Interface() + db.WhereList = append(db.WhereList, WhereItem{Field: key, Opt: Eq, Val: val}) + } + } + + return db +} + +// Group 链式操作,以某字段进行分组 +func (db *Executor) Group(f string) *Executor { + db.GroupList = append(db.GroupList, f) + return db +} + +// Having 链式操作,以数组作为筛选条件 +func (db *Executor) Having(havingList []WhereItem) *Executor { + db.HavingList = havingList + return db +} + +// Order 链式操作,以某字段进行排序 +func (db *Executor) Order(field string, orderType string) *Executor { + db.OrderList = append(db.OrderList, field+" "+orderType) + return db +} + +// Limit 链式操作,分页 +func (db *Executor) Limit(offset int, pageSize int) *Executor { + db.Offset = offset + db.PageSize = pageSize + return db +} + +// Page 链式操作,分页 +func (db *Executor) Page(pageNum int, pageSize int) *Executor { + db.Offset = (pageNum - 1) * pageSize + db.PageSize = pageSize + return db +} + +//拼接SQL,字段相关 +func handleField(filedList []string) string { + if len(filedList) == 0 { + return "*" + } + + return strings.Join(filedList, ",") +} + +//拼接SQL,查询条件 +func handleWhere(where []WhereItem, paramList []any) (string, []any) { + if len(where) == 0 { + return "", paramList + } + + whereList, paramList := whereAndHaving(where, paramList) + + return " WHERE " + strings.Join(whereList, " AND "), paramList +} + +//拼接SQL,更新信息 +func (db *Executor) handleSet(dest interface{}, paramList []any) (string, []any) { + typeOf := reflect.TypeOf(dest) + valueOf := reflect.ValueOf(dest) + + //如果没有设置表名 + if db.TableName == "" { + arr := strings.Split(typeOf.String(), ".") + db.TableName = UnderLine(arr[len(arr)-1]) + } + + var keys []string + for i := 0; i < typeOf.Elem().NumField(); i++ { + isNotNull := valueOf.Elem().Field(i).Field(0).Field(1).Bool() + if isNotNull { + key := UnderLine(typeOf.Elem().Field(i).Name) + val := valueOf.Elem().Field(i).Field(0).Field(0).Interface() + + keys = append(keys, key+"=?") + paramList = append(paramList, val) + } + } + + return " SET " + strings.Join(keys, ","), paramList +} + +//拼接SQL,关联查询 +func handleJoin(joinList []string) string { + if len(joinList) == 0 { + return "" + } + + return " " + strings.Join(joinList, " ") +} + +//拼接SQL,结果分组 +func handleGroup(groupList []string) string { + if len(groupList) == 0 { + return "" + } + + return " GROUP BY " + strings.Join(groupList, ",") +} + +//拼接SQL,结果筛选 +func handleHaving(having []WhereItem, paramList []any) (string, []any) { + if len(having) == 0 { + return "", paramList + } + + whereList, paramList := whereAndHaving(having, paramList) + + return " Having " + strings.Join(whereList, " AND "), paramList +} + +//拼接SQL,结果排序 +func handleOrder(orderList []string) string { + if len(orderList) == 0 { + return "" + } + + return " Order BY " + strings.Join(orderList, ",") +} + +//拼接SQL,分页相关 +func handleLimit(offset int, pageSize int, paramList []any) (string, []any) { + if 0 == pageSize { + return "", paramList + } + + paramList = append(paramList, offset) + paramList = append(paramList, pageSize) + + return " Limit ?,? ", paramList +} + +//拼接SQL,查询与筛选通用操作 +func whereAndHaving(where []WhereItem, paramList []any) ([]string, []any) { + var whereList []string + for i := 0; i < len(where); i++ { + if where[i].Opt == Eq || where[i].Opt == Ne || where[i].Opt == Gt || where[i].Opt == Ge || where[i].Opt == Lt || where[i].Opt == Le { + whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"?") + paramList = append(paramList, toStr(where[i].Val)) + } + + if where[i].Opt == Between || where[i].Opt == NotBetween { + values := toAnyArr(where[i].Val) + whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"(?) AND (?)") + paramList = append(paramList, values...) + } + + if where[i].Opt == Like || where[i].Opt == NotLike { + values := toAnyArr(where[i].Val) + var valueStr []string + for j := 0; j < len(values); j++ { + str := fmt.Sprintf("%v", values[j]) + valueStr = append(valueStr, str) + + if "%" != str { + paramList = append(paramList, str) + values[j] = "?" + } + } + + whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+strings.Join(valueStr, "")) + } + + if where[i].Opt == In || where[i].Opt == NotIn { + values := toAnyArr(where[i].Val) + var placeholder []string + for j := 0; j < len(values); j++ { + placeholder = append(placeholder, "?") + } + + whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"("+strings.Join(placeholder, ",")+")") + paramList = append(paramList, values...) + } + } + + return whereList, paramList +} + +//将一个interface抽取成数组 +func toAnyArr(val any) []any { + var values []any + switch val.(type) { + case []int: + for _, value := range val.([]int) { + values = append(values, strconv.Itoa(value)) + } + case []string: + values = val.([]any) + } + + return values +} + +//将一个interface抽取成字符串 +func toStr(val any) string { + switch val.(type) { + case int: + return strconv.Itoa(val.(int)) + case int64: + return strconv.FormatInt(val.(int64), 10) + case string: + return val.(string) + } + return "" +} + +func str2Int(str string) int { + num, _ := strconv.Atoi(str) + return num +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..27539e4 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module github.com/tangpanqing/aorm + +go 1.18 + +require gopkg.in/guregu/null.v4 v4.0.0 // indirect diff --git a/migrate.go b/migrate.go new file mode 100644 index 0000000..bc67955 --- /dev/null +++ b/migrate.go @@ -0,0 +1,527 @@ +package aorm + +import ( + "fmt" + "reflect" + "strconv" + "strings" +) + +type Table struct { + TableName string + Engine string + Comment string +} + +type Column struct { + ColumnName string + ColumnDefault string + IsNullable string + DataType string //数据类型 varchar,bigint,int + MaxLength int //数据最大长度 20 + ColumnType string //列类型 varchar(20) + ColumnComment string + Extra string //扩展信息 auto_increment + DefaultVal string //默认值 +} + +type Index struct { + NonUnique int + ColumnName string + KeyName string +} + +type OpinionItem struct { + Key string + Val string +} + +func (db *Executor) Opinion(key string, val string) *Executor { + if key == "COMMENT" { + val = "'" + val + "'" + } + + db.OpinionList = append(db.OpinionList, OpinionItem{Key: key, Val: val}) + + return db +} + +func (db *Executor) ShowCreateTable(tableName string) string { + list, _ := db.Query("show create table " + tableName) + return list[0]["Create Table"].(string) +} + +// Migrate 迁移数据库结构,需要输入数据库名,表名自动获取 +func (db *Executor) AutoMigrate(dest interface{}) { + typeOf := reflect.TypeOf(dest) + arr := strings.Split(typeOf.String(), ".") + tableName := UnderLine(arr[len(arr)-1]) + + db.migrateCommon(tableName, typeOf) +} + +// AutoMigrate 自动迁移数据库结构,需要输入数据库名,表名 +func (db *Executor) Migrate(tableName string, dest interface{}) { + typeOf := reflect.TypeOf(dest) + db.migrateCommon(tableName, typeOf) +} + +func (db *Executor) migrateCommon(tableName string, typeOf reflect.Type) { + tableFromCode := db.getTableFromCode(tableName) + columnsFromCode := db.getColumnsFromCode(typeOf) + indexsFromCode := db.getIndexsFromCode(typeOf, tableFromCode) + + //获取数据库名称 + dbNameRows, _ := db.Query("SELECT DATABASE()") + dbName := dbNameRows[0]["DATABASE()"].(string) + + //查询表信息,如果找不到就新建 + sql := "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA =" + "'" + dbName + "' AND TABLE_NAME =" + "'" + tableName + "'" + dataList, _ := db.Query(sql) + if len(dataList) != 0 { + tableFromDb := getTableFromDb(dataList) + columnsFromDb := db.getColumnsFromDb(dbName, tableName) + indexsFromDb := db.getIndexsFromDb(tableName) + + db.modifyTable(tableFromCode, columnsFromCode, indexsFromCode, tableFromDb, columnsFromDb, indexsFromDb) + } else { + db.createTable(tableFromCode, columnsFromCode, indexsFromCode) + } +} + +func (db *Executor) getTableFromCode(tableName string) Table { + var tableFromCode Table + tableFromCode.TableName = tableName + tableFromCode.Engine = db.getValFromOpinion("ENGINE", "MyISAM") + tableFromCode.Comment = db.getValFromOpinion("COMMENT", "") + + return tableFromCode +} + +func (db *Executor) getColumnsFromCode(typeOf reflect.Type) []Column { + var columnsFromCode []Column + for i := 0; i < typeOf.Elem().NumField(); i++ { + fieldName := UnderLine(typeOf.Elem().Field(i).Name) + fieldType := typeOf.Elem().Field(i).Type.Name() + fieldMap := getTagMap(typeOf.Elem().Field(i).Tag.Get("aorm")) + columnsFromCode = append(columnsFromCode, getColumnFromCode(fieldName, fieldType, fieldMap)) + } + + return columnsFromCode +} + +func (db *Executor) getIndexsFromCode(typeOf reflect.Type, tableFromCode Table) []Index { + var indexsFromCode []Index + for i := 0; i < typeOf.Elem().NumField(); i++ { + fieldName := UnderLine(typeOf.Elem().Field(i).Name) + fieldMap := getTagMap(typeOf.Elem().Field(i).Tag.Get("aorm")) + + _, primaryIs := fieldMap["primary"] + if primaryIs { + indexsFromCode = append(indexsFromCode, Index{ + NonUnique: 0, + ColumnName: fieldName, + KeyName: "PRIMARY", + }) + } + + _, uniqueIndexIs := fieldMap["unique"] + if uniqueIndexIs { + indexsFromCode = append(indexsFromCode, Index{ + NonUnique: 0, + ColumnName: fieldName, + KeyName: "idx_" + tableFromCode.TableName + "_" + fieldName, + }) + } + + _, indexIs := fieldMap["index"] + if indexIs { + indexsFromCode = append(indexsFromCode, Index{ + NonUnique: 1, + ColumnName: fieldName, + KeyName: "idx_" + tableFromCode.TableName + "_" + fieldName, + }) + } + } + + return indexsFromCode +} + +func getTableFromDb(dataList []map[string]interface{}) Table { + var tableFromDb Table + tableFromDb.TableName = fmt.Sprintf("%v", dataList[0]["TABLE_NAME"]) + tableFromDb.Engine = fmt.Sprintf("%v", dataList[0]["ENGINE"]) + tableFromDb.Comment = "'" + fmt.Sprintf("%v", dataList[0]["TABLE_COMMENT"]) + "'" + + return tableFromDb +} + +func (db *Executor) getColumnsFromDb(dbName string, tableName string) []Column { + var columnsFromDb []Column + + sqlColumn := "SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA =" + "'" + dbName + "' AND TABLE_NAME =" + "'" + tableName + "'" + dataColumn, _ := db.Query(sqlColumn) + + for j := 0; j < len(dataColumn); j++ { + maxLength, _ := strconv.Atoi(fmt.Sprintf("%v", dataColumn[j]["CHARACTER_MAXIMUM_LENGTH"])) + defaultVal := "" + if dataColumn[j]["COLUMN_DEFAULT"] != nil { + defaultVal = dataColumn[j]["COLUMN_DEFAULT"].(string) + } + + columnsFromDb = append(columnsFromDb, Column{ + ColumnName: dataColumn[j]["COLUMN_NAME"].(string), + DataType: dataColumn[j]["DATA_TYPE"].(string), + IsNullable: dataColumn[j]["IS_NULLABLE"].(string), + MaxLength: maxLength, + ColumnType: dataColumn[j]["COLUMN_TYPE"].(string), + ColumnComment: dataColumn[j]["COLUMN_COMMENT"].(string), + Extra: dataColumn[j]["EXTRA"].(string), + DefaultVal: defaultVal, + }) + } + + return columnsFromDb +} + +func (db *Executor) getIndexsFromDb(tableName string) []Index { + sqlIndex := "SHOW INDEXES FROM " + tableName + dataIndex, _ := db.Query(sqlIndex) + + var indexsFromDb []Index + for j := 0; j < len(dataIndex); j++ { + nonUnique, _ := strconv.Atoi(fmt.Sprintf("%v", dataIndex[j]["Non_unique"])) + indexsFromDb = append(indexsFromDb, Index{ + ColumnName: fmt.Sprintf("%v", dataIndex[j]["Column_name"]), + KeyName: fmt.Sprintf("%v", dataIndex[j]["Key_name"]), + NonUnique: nonUnique, + }) + } + + return indexsFromDb +} + +// 修改表 +func (db *Executor) modifyTable(tableFromCode Table, columnsFromCode []Column, indexsFromCode []Index, tableFromDb Table, columnsFromDb []Column, indexsFromDb []Index) { + //fmt.Println("正在修改表" + tableFromCode.TableName) + //fmt.Println(columnsFromCode) + //fmt.Println(columnsFromDb) + + //fmt.Println(indexsFromCode) + //fmt.Println(indexsFromDb) + + if tableFromCode.Engine != tableFromDb.Engine { + sql := "ALTER TABLE " + tableFromCode.TableName + " Engine " + tableFromCode.Engine + _, err := db.Exec(sql) + if err != nil { + fmt.Println(err) + } else { + fmt.Println("修改表:" + sql) + } + } + + if tableFromCode.Comment != tableFromDb.Comment { + sql := "ALTER TABLE " + tableFromCode.TableName + " Comment " + tableFromCode.Comment + _, err := db.Exec(sql) + if err != nil { + fmt.Println(err) + } else { + fmt.Println("修改表:" + sql) + } + } + + for i := 0; i < len(columnsFromCode); i++ { + isFind := 0 + columnCode := columnsFromCode[i] + + for j := 0; j < len(columnsFromDb); j++ { + columnDb := columnsFromDb[j] + if columnCode.ColumnName == columnDb.ColumnName { + isFind = 1 + + if columnCode.ColumnType != columnDb.ColumnType || columnCode.ColumnComment != columnDb.ColumnComment || columnCode.Extra != columnDb.Extra || columnCode.DefaultVal != columnDb.DefaultVal { + sql := "ALTER TABLE " + tableFromCode.TableName + " MODIFY " + getColumnStr(columnCode) + _, err := db.Exec(sql) + if err != nil { + fmt.Println(err) + } else { + fmt.Println("修改属性:" + sql) + } + } + } + } + + if isFind == 0 { + sql := "ALTER TABLE " + tableFromCode.TableName + " ADD " + getColumnStr(columnCode) + _, err := db.Exec(sql) + if err != nil { + fmt.Println(err) + } else { + fmt.Println("增加属性:" + sql) + } + } + } + + for i := 0; i < len(indexsFromCode); i++ { + isFind := 0 + indexCode := indexsFromCode[i] + + for j := 0; j < len(indexsFromDb); j++ { + indexDb := indexsFromDb[j] + if indexCode.ColumnName == indexDb.ColumnName { + isFind = 1 + if indexCode.KeyName != indexDb.KeyName || indexCode.NonUnique != indexDb.NonUnique { + sql := "ALTER TABLE " + tableFromCode.TableName + " MODIFY " + getIndexStr(indexCode) + _, err := db.Exec(sql) + if err != nil { + fmt.Println(err) + } else { + fmt.Println("修改索引:" + sql) + } + } + } + } + + if isFind == 0 { + sql := "ALTER TABLE " + tableFromCode.TableName + " ADD " + getIndexStr(indexCode) + _, err := db.Exec(sql) + if err != nil { + fmt.Println(err) + } else { + fmt.Println("增加索引:" + sql) + } + } + } +} + +// 创建表 +func (db *Executor) createTable(tableFromCode Table, columnsFromCode []Column, indexsFromCode []Index) { + var fieldArr []string + + for i := 0; i < len(columnsFromCode); i++ { + column := columnsFromCode[i] + fieldArr = append(fieldArr, getColumnStr(column)) + } + + for i := 0; i < len(indexsFromCode); i++ { + index := indexsFromCode[i] + fieldArr = append(fieldArr, getIndexStr(index)) + } + + sqlStr := "CREATE TABLE `" + tableFromCode.TableName + "` (\n" + strings.Join(fieldArr, ",\n") + "\n) " + getTableInfoFromCode(tableFromCode) + ";" + + res, err := db.Exec(sqlStr) + if err != nil { + fmt.Println(err) + } else { + fmt.Println(res.RowsAffected()) + } +} + +// +func (db *Executor) getValFromOpinion(key string, def string) string { + for i := 0; i < len(db.OpinionList); i++ { + opinionItem := db.OpinionList[i] + if opinionItem.Key == key { + def = opinionItem.Val + } + } + return def +} + +func getTableInfoFromCode(tableFromCode Table) string { + return " ENGINE " + tableFromCode.Engine + " COMMENT " + tableFromCode.Comment +} + +// 获得某列的结构 +func getColumnFromCode(fieldName string, fieldType string, fieldMap map[string]string) Column { + var column Column + + //字段名 + column.ColumnName = fieldName + //字段数据类型 + column.DataType = getDataType(fieldType, fieldMap) + //字段数据长度 + maxLength := getMaxLength(column.DataType, fieldMap) + columnType := column.DataType + if maxLength > 0 { + columnType = columnType + "(" + strconv.Itoa(maxLength) + ")" + } + column.MaxLength = maxLength + //字段是否可以为空 + column.IsNullable = getNullAble(fieldMap) + //字段注释 + column.ColumnComment = getComment(fieldMap) + //字段类型 + column.ColumnType = columnType + //扩展信息 + column.Extra = getExtra(fieldMap) + //默认信息 + column.DefaultVal = getDefaultVal(fieldMap) + + return column +} + +// 转换tag成map +func getTagMap(fieldTag string) map[string]string { + var fieldMap = make(map[string]string) + if "" != fieldTag { + tagArr := strings.Split(fieldTag, ";") + for j := 0; j < len(tagArr); j++ { + tagArrArr := strings.Split(tagArr[j], ":") + fieldMap[tagArrArr[0]] = "" + if len(tagArrArr) > 1 { + fieldMap[tagArrArr[0]] = tagArrArr[1] + } + } + } + return fieldMap +} + +func getColumnStr(column Column) string { + var strArr []string + strArr = append(strArr, column.ColumnName) + strArr = append(strArr, column.DataType+"("+strconv.Itoa(column.MaxLength)+")") + + if column.DefaultVal != "" { + strArr = append(strArr, "DEFAULT '"+column.DefaultVal+"'") + } + + if column.IsNullable == "NO" { + strArr = append(strArr, "NOT NULL") + } + + if column.ColumnComment != "" { + strArr = append(strArr, "COMMENT '"+column.ColumnComment+"'") + } + + if column.Extra != "" { + strArr = append(strArr, column.Extra) + } + + return strings.Join(strArr, " ") +} + +func getIndexStr(index Index) string { + var strArr []string + + if "PRIMARY" == index.KeyName { + strArr = append(strArr, index.KeyName) + strArr = append(strArr, "KEY") + strArr = append(strArr, "(`"+index.ColumnName+"`)") + } else { + if 0 == index.NonUnique { + strArr = append(strArr, "Unique") + strArr = append(strArr, index.KeyName) + strArr = append(strArr, "(`"+index.ColumnName+"`)") + } else { + strArr = append(strArr, "Index") + strArr = append(strArr, index.KeyName) + strArr = append(strArr, "(`"+index.ColumnName+"`)") + } + } + + return strings.Join(strArr, " ") +} + +//将对象属性类型转换数据库字段数据类型 +func getDataType(fieldType string, fieldMap map[string]string) string { + var DataType string + + dataTypeVal, dataTypeOk := fieldMap["type"] + if dataTypeOk { + DataType = dataTypeVal + } else { + if "Int" == fieldType { + DataType = "int" + } + if "String" == fieldType { + DataType = "varchar" + } + if "Bool" == fieldType { + DataType = "tinyint" + } + if "Time" == fieldType { + DataType = "datetime" + } + if "Float" == fieldType { + DataType = "float" + } + } + + return DataType +} + +func getMaxLength(DataType string, fieldMap map[string]string) int { + var MaxLength int + + maxLengthVal, maxLengthOk := fieldMap["size"] + if maxLengthOk { + num, _ := strconv.Atoi(maxLengthVal) + MaxLength = num + } else { + if "bigint" == DataType { + MaxLength = 20 + } + if "int" == DataType { + MaxLength = 11 + } + if "tinyint" == DataType { + MaxLength = 4 + } + if "varchar" == DataType { + MaxLength = 255 + } + if "text" == DataType { + MaxLength = 0 + } + if "datetime" == DataType { + MaxLength = 0 + } + if "float" == DataType { + MaxLength = 0 + } + } + + return MaxLength +} + +func getNullAble(fieldMap map[string]string) string { + var IsNullable string + + _, ok := fieldMap["not null"] + if ok { + IsNullable = "NO" + } else { + IsNullable = "YES" + } + + return IsNullable +} + +func getComment(fieldMap map[string]string) string { + commentVal, commentIs := fieldMap["comment"] + if commentIs { + return commentVal + } + + return "" +} + +func getExtra(fieldMap map[string]string) string { + _, commentIs := fieldMap["auto_increment"] + if commentIs { + return "auto_increment" + } + + return "" +} + +func getDefaultVal(fieldMap map[string]string) string { + defaultVal, defaultIs := fieldMap["default"] + if defaultIs { + return defaultVal + } + + return "" +} diff --git a/str.go b/str.go new file mode 100644 index 0000000..e3b06a8 --- /dev/null +++ b/str.go @@ -0,0 +1,51 @@ +package aorm + +import ( + "unicode" +) + +// CamelString 将某字符串转成驼峰写法 +func CamelString(s string) string { + data := make([]byte, 0, len(s)) + j := false + k := false + num := len(s) - 1 + for i := 0; i <= num; i++ { + d := s[i] + if k == false && d >= 'A' && d <= 'Z' { + k = true + } + if d >= 'a' && d <= 'z' && (j || k == false) { + d = d - 32 + j = false + k = true + } + if k && d == '_' && num > i && s[i+1] >= 'a' && s[i+1] <= 'z' { + j = true + continue + } + data = append(data, d) + } + return string(data[:]) +} + +// UnderLine 将某字符串转成下划线写法 +func UnderLine(s string) string { + var output []rune + for i, r := range s { + if i == 0 { + output = append(output, unicode.ToLower(r)) + continue + } + if unicode.IsUpper(r) { + output = append(output, '_') + } + output = append(output, unicode.ToLower(r)) + } + return string(output) +} + +// Ul 将某字符串转成下划线写法-简化用法 +func Ul(s string) string { + return UnderLine(s) +}