From 438f299d150ade613640bd7b0ba5d9d3f84ece28 Mon Sep 17 00:00:00 2001 From: tangpanqing Date: Fri, 16 Dec 2022 14:37:35 +0800 Subject: [PATCH] add method insert batch --- crud.go | 54 +++++++++++++++++++++++++++++++++++++++++++++++ test/aorm_test.go | 35 ++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+) diff --git a/crud.go b/crud.go index cf0f853..13ccfe3 100644 --- a/crud.go +++ b/crud.go @@ -2,6 +2,7 @@ package aorm import ( "database/sql" + "errors" "fmt" "reflect" "strings" @@ -77,6 +78,59 @@ func (db *Executor) Insert(dest interface{}) (int64, error) { return lastId, nil } +// InsertBatch 批量增加记录 +func (db *Executor) InsertBatch(values interface{}) (int64, error) { + + var keys []string + var paramList []any + var place []string + + valueOf := reflect.ValueOf(values).Elem() + if valueOf.Len() == 0 { + return 0, errors.New("the data list for insert batch not found") + } + typeOf := reflect.TypeOf(values).Elem().Elem() + + //如果没有设置表名 + if db.TableName == "" { + db.TableName = reflectTableName(typeOf, valueOf.Index(0)) + } + + for j := 0; j < valueOf.Len(); j++ { + var placeItem []string + + for i := 0; i < valueOf.Index(j).NumField(); i++ { + isNotNull := valueOf.Index(j).Field(i).Field(0).Field(1).Bool() + if isNotNull { + if j == 0 { + key := UnderLine(typeOf.Field(i).Name) + keys = append(keys, key) + } + + val := valueOf.Index(j).Field(i).Field(0).Field(0).Interface() + paramList = append(paramList, val) + placeItem = append(placeItem, "?") + } + } + + place = append(place, "("+strings.Join(placeItem, ",")+")") + } + + sqlStr := "INSERT INTO " + db.TableName + " (" + strings.Join(keys, ",") + ") VALUES " + strings.Join(place, ",") + + res, err := db.Exec(sqlStr, paramList...) + if err != nil { + return 0, err + } + + count, err := res.RowsAffected() + if err != nil { + return 0, err + } + + return count, nil +} + // GetMany 查询记录(新) func (db *Executor) GetMany(values interface{}) error { sqlStr, paramList := db.getSqlAndParams() diff --git a/test/aorm_test.go b/test/aorm_test.go index d1d490b..9871835 100644 --- a/test/aorm_test.go +++ b/test/aorm_test.go @@ -49,6 +49,8 @@ func TestAll(t *testing.T) { testShowCreateTable(db) id := testInsert(db) + testInsertBatch(db) + testGetOne(db, id) testGetMany(db) testUpdate(db, id) @@ -150,6 +152,39 @@ func testInsert(db *sql.DB) int64 { return id } +func testInsertBatch(db *sql.DB) int64 { + fmt.Println("--- testInsertBatch ---") + + var batch []Person + batch = append(batch, Person{ + Name: aorm.StringFrom("Alice"), + Sex: aorm.BoolFrom(false), + Age: aorm.IntFrom(18), + Type: aorm.IntFrom(0), + CreateTime: aorm.TimeFrom(time.Now()), + Money: aorm.FloatFrom(100.15987654321), + Test: aorm.FloatFrom(200.15987654321987654321), + }) + + batch = append(batch, Person{ + Name: aorm.StringFrom("Bob"), + Sex: aorm.BoolFrom(true), + Age: aorm.IntFrom(18), + Type: aorm.IntFrom(0), + CreateTime: aorm.TimeFrom(time.Now()), + Money: aorm.FloatFrom(100.15987654321), + Test: aorm.FloatFrom(200.15987654321987654321), + }) + + count, errInsertBatch := aorm.Use(db).Debug(true).InsertBatch(&batch) + if errInsertBatch != nil { + fmt.Println(errInsertBatch) + } + fmt.Println(count) + + return count +} + func testGetOne(db *sql.DB, id int64) { fmt.Println("--- testGetOne ---")