From aa50d1f68cd90dbcfd6af9c6bc31eebb9b030389 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E5=B0=8F=E5=86=9B?= <346944475@qq.com> Date: Mon, 1 Jun 2020 20:06:15 +0800 Subject: [PATCH] add middleware MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 支持调用中间件 --- README.md | 13 ++++ README_cn.md | 12 +++ common.go | 212 +++++++++++++++++++++++++++++++++++++------------- ginrpc.go | 18 +++-- go.mod | 2 +- go.sum | 2 + middleware.go | 64 +++++++++++++++ 7 files changed, 263 insertions(+), 60 deletions(-) create mode 100644 middleware.go diff --git a/README.md b/README.md index a3a5403..a09de35 100644 --- a/README.md +++ b/README.md @@ -162,6 +162,8 @@ func main() { ginrpc.WithBigCamel(true) : Set big camel standard (false is web mode, _, lowercase) + ginrpc.WithBeforeAfter(&ginrpc.DefaultGinBeforeAfter{}) : Before After call + [more>>](https://godoc.org/github.com/xxjwxc/ginrpc) ### 4. Execute curl to automatically bind parameters. See the results directly @@ -189,6 +191,17 @@ type ReqTest struct { ``` - [more >>>](https://github.com/xxjwxc/gmsec) +## 三. Support to call Middleware +- using `ginrpc.WithBeforeAfter(&ginrpc.DefaultGinBeforeAfter{})` +- You can also implement functions (single types) on objects +```go + // GinBeforeAfter Execute middleware before and after the object call (support adding the object separately from the object in total) + type GinBeforeAfter interface { + GinBefore(req *GinBeforeAfterInfo) bool + GinAfter(req *GinBeforeAfterInfo) bool + } +``` + ## Stargazers over time [![Stargazers over time](https://starchart.cc/xxjwxc/ginrpc.svg)](https://starchart.cc/xxjwxc/ginrpc) diff --git a/README_cn.md b/README_cn.md index 901c2df..38ce273 100644 --- a/README_cn.md +++ b/README_cn.md @@ -158,6 +158,8 @@ _ "[mod]/routers" // debug模式需要添加[mod]/routers 注册注解路由 ginrpc.WithDebug(true) : 设置debug模式 ginrpc.WithBigCamel(true) : 设置大驼峰标准(false 为web模式,_,小写) + + ginrpc.WithBeforeAfter(&ginrpc.DefaultGinBeforeAfter{}) : 设置调用前后执行中间件 [更多>>](https://godoc.org/github.com/xxjwxc/ginrpc) @@ -181,6 +183,16 @@ type ReqTest struct { - [更多 >>>](https://github.com/xxjwxc/gmsec) +## 三. 支持调用中间件 +- 可通过 `ginrpc.WithBeforeAfter(&ginrpc.DefaultGinBeforeAfter{})`设置(全局) +- 也可以在对象上实现函数(单个类型) +```go + // GinBeforeAfter 对象调用前后执行中间件(支持总的跟对象单独添加) + type GinBeforeAfter interface { + GinBefore(req *GinBeforeAfterInfo) bool + GinAfter(req *GinBeforeAfterInfo) bool + } +``` ## Stargazers over time diff --git a/common.go b/common.go index 35f441b..d4e2170 100644 --- a/common.go +++ b/common.go @@ -1,6 +1,7 @@ package ginrpc import ( + "context" "encoding/json" "fmt" "go/ast" @@ -54,7 +55,7 @@ func (b *_Base) checkHandlerFunc(typ reflect.Type, isObj bool) (int, bool) { // } // HandlerFunc Get and filter the parameters to be bound (object call type) -func (b *_Base) handlerFuncObj(tvl, obj reflect.Value) gin.HandlerFunc { // 获取并过滤要绑定的参数(obj 对象类型) +func (b *_Base) handlerFuncObj(tvl, obj reflect.Value, methodName string) gin.HandlerFunc { // 获取并过滤要绑定的参数(obj 对象类型) typ := tvl.Type() if typ.NumIn() == 2 { // Parameter checking 参数检查 ctxType := typ.In(1) @@ -73,7 +74,7 @@ func (b *_Base) handlerFuncObj(tvl, obj reflect.Value) gin.HandlerFunc { // 获 } // Custom context type with request parameters .自定义的context类型,带request 请求参数 - call, err := b.getCallFunc3(tvl, obj) + call, err := b.getCallObj3(tvl, obj, methodName) if err != nil { // Direct reporting error. panic(err) } @@ -81,16 +82,39 @@ func (b *_Base) handlerFuncObj(tvl, obj reflect.Value) gin.HandlerFunc { // 获 return call } -// Custom context type with request parameters -func (b *_Base) getCallFunc3(tvls ...reflect.Value) (func(*gin.Context), error) { - offset := 0 - if len(tvls) > 1 { - offset = 1 +func (b *_Base) beforCall(c *gin.Context, tvl, obj reflect.Value, req interface{}, methodName string) (*GinBeforeAfterInfo, bool) { + info := &GinBeforeAfterInfo{ + C: c, + FuncName: fmt.Sprintf("%v.%v", reflect.Indirect(obj).Type().Name(), methodName), // 函数名 + Req: req, // 调用前的请求参数 + Context: context.Background(), // 占位参数,可用于存储其他参数,前后连接可用 } - tvl := tvls[0] + is := true + if bfobj, ok := obj.Interface().(GinBeforeAfter); ok { // 本类型 + is = bfobj.GinBefore(info) + } + if is && b.beforeAfter != nil { + is = b.beforeAfter.GinBefore(info) + } + return info, is +} + +func (b *_Base) afterCall(info *GinBeforeAfterInfo, obj reflect.Value) bool { + is := true + if bfobj, ok := obj.Interface().(GinBeforeAfter); ok { // 本类型 + is = bfobj.GinAfter(info) + } + if is && b.beforeAfter != nil { + is = b.beforeAfter.GinAfter(info) + } + return is +} + +// Custom context type with request parameters +func (b *_Base) getCallFunc3(tvl reflect.Value) (func(*gin.Context), error) { typ := tvl.Type() - if typ.NumIn() != (2 + offset) { // Parameter checking 参数检查 + if typ.NumIn() != 2 { // Parameter checking 参数检查 return nil, errors.New("method " + runtime.FuncForPC(tvl.Pointer()).Name() + " not support!") } @@ -105,7 +129,7 @@ func (b *_Base) getCallFunc3(tvls ...reflect.Value) (func(*gin.Context), error) } } - ctxType, reqType := typ.In(0+offset), typ.In(1+offset) + ctxType, reqType := typ.In(0), typ.In(1) reqIsGinCtx := false if ctxType == reflect.TypeOf(&gin.Context{}) { @@ -129,44 +153,11 @@ func (b *_Base) getCallFunc3(tvls ...reflect.Value) (func(*gin.Context), error) return func(c *gin.Context) { req := reflect.New(reqType) - if reqIsValue { - req = reflect.New(reqType) - } else { + if !reqIsValue { req = reflect.New(reqType.Elem()) } if err := b.unmarshal(c, req.Interface()); err != nil { // Return error message.返回错误信息 - var fields []string - if _, ok := err.(validator.ValidationErrors); ok { - for _, err := range err.(validator.ValidationErrors) { - tmp := fmt.Sprintf("%v:%v", myreflect.FindTag(req.Interface(), err.Field(), "json"), err.Tag()) - if len(err.Param()) > 0 { - tmp += fmt.Sprintf("[%v](but[%v])", err.Param(), err.Value()) - } - fields = append(fields, tmp) - // fmt.Println(err.Namespace()) - // fmt.Println(err.Field()) - // fmt.Println(err.StructNamespace()) // can differ when a custom TagNameFunc is registered or - // fmt.Println(err.StructField()) // by passing alt name to ReportError like below - // fmt.Println(err.Tag()) - // fmt.Println(err.ActualTag()) - // fmt.Println(err.Kind()) - // fmt.Println(err.Type()) - // fmt.Println(err.Value()) - // fmt.Println(err.Param()) - // fmt.Println() - } - } else if _, ok := err.(*json.UnmarshalTypeError); ok { - err := err.(*json.UnmarshalTypeError) - tmp := fmt.Sprintf("%v:%v(but[%v])", err.Field, err.Type.String(), err.Value) - fields = append(fields, tmp) - - } else { - fields = append(fields, err.Error()) - } - - msg := message.GetErrorMsg(message.ParameterInvalid) - msg.Error = fmt.Sprintf("req param : %v", strings.Join(fields, ";")) - c.JSON(http.StatusBadRequest, msg) + b.handErrorString(c, req, err) return } @@ -174,11 +165,8 @@ func (b *_Base) getCallFunc3(tvls ...reflect.Value) (func(*gin.Context), error) req = req.Elem() } var returnValues []reflect.Value - if offset > 0 { - returnValues = tvl.Call([]reflect.Value{tvls[1], reflect.ValueOf(apiFun(c)), req}) - } else { - returnValues = tvl.Call([]reflect.Value{reflect.ValueOf(apiFun(c)), req}) - } + returnValues = tvl.Call([]reflect.Value{reflect.ValueOf(apiFun(c)), req}) + if returnValues != nil { obj := returnValues[0].Interface() rerr := returnValues[1].Interface() @@ -193,6 +181,122 @@ func (b *_Base) getCallFunc3(tvls ...reflect.Value) (func(*gin.Context), error) }, nil } +// Custom context type with request parameters +func (b *_Base) getCallObj3(tvl, obj reflect.Value, methodName string) (func(*gin.Context), error) { + typ := tvl.Type() + if typ.NumIn() != 3 { // Parameter checking 参数检查 + return nil, errors.New("method " + runtime.FuncForPC(tvl.Pointer()).Name() + " not support!") + } + + if typ.NumOut() != 0 { + if typ.NumOut() == 2 { // Parameter checking 参数检查 + if returnType := typ.Out(1); returnType != typeOfError { + return nil, errors.Errorf("method : %v , returns[1] %v not error", + runtime.FuncForPC(tvl.Pointer()).Name(), returnType.String()) + } + } else { + return nil, errors.Errorf("method : %v , Only 2 return values (obj, error) are supported", runtime.FuncForPC(tvl.Pointer()).Name()) + } + } + + ctxType, reqType := typ.In(1), typ.In(2) + + reqIsGinCtx := false + if ctxType == reflect.TypeOf(&gin.Context{}) { + reqIsGinCtx = true + } + + // ctxType != reflect.TypeOf(gin.Context{}) && + // ctxType != reflect.Indirect(reflect.ValueOf(b.iAPIType)).Type() + if !reqIsGinCtx && ctxType != b.apiType && !b.apiType.ConvertibleTo(ctxType) { + return nil, errors.New("method " + runtime.FuncForPC(tvl.Pointer()).Name() + " first parm not support!") + } + + reqIsValue := true + if reqType.Kind() == reflect.Ptr { + reqIsValue = false + } + apiFun := func(c *gin.Context) interface{} { return c } + if !reqIsGinCtx { + apiFun = b.apiFun + } + + return func(c *gin.Context) { + req := reflect.New(reqType) + if !reqIsValue { + req = reflect.New(reqType.Elem()) + } + if err := b.unmarshal(c, req.Interface()); err != nil { // Return error message.返回错误信息 + b.handErrorString(c, req, err) + return + } + + if reqIsValue { + req = req.Elem() + } + + bainfo, is := b.beforCall(c, tvl, obj, req.Interface(), methodName) + if !is { + c.JSON(http.StatusBadRequest, bainfo.Resp) + return + } + + var returnValues []reflect.Value + returnValues = tvl.Call([]reflect.Value{obj, reflect.ValueOf(apiFun(c)), req}) + + if returnValues != nil { + bainfo.Resp = returnValues[0].Interface() + rerr := returnValues[1].Interface() + if rerr != nil { + bainfo.Error = rerr.(error) + } + + is = b.afterCall(bainfo, obj) + if is { + c.JSON(http.StatusOK, bainfo.Resp) + } else { + c.JSON(http.StatusBadRequest, bainfo.Resp) + } + } + }, nil +} + +func (b *_Base) handErrorString(c *gin.Context, req reflect.Value, err error) { + var fields []string + if _, ok := err.(validator.ValidationErrors); ok { + for _, err := range err.(validator.ValidationErrors) { + tmp := fmt.Sprintf("%v:%v", myreflect.FindTag(req.Interface(), err.Field(), "json"), err.Tag()) + if len(err.Param()) > 0 { + tmp += fmt.Sprintf("[%v](but[%v])", err.Param(), err.Value()) + } + fields = append(fields, tmp) + // fmt.Println(err.Namespace()) + // fmt.Println(err.Field()) + // fmt.Println(err.StructNamespace()) // can differ when a custom TagNameFunc is registered or + // fmt.Println(err.StructField()) // by passing alt name to ReportError like below + // fmt.Println(err.Tag()) + // fmt.Println(err.ActualTag()) + // fmt.Println(err.Kind()) + // fmt.Println(err.Type()) + // fmt.Println(err.Value()) + // fmt.Println(err.Param()) + // fmt.Println() + } + } else if _, ok := err.(*json.UnmarshalTypeError); ok { + err := err.(*json.UnmarshalTypeError) + tmp := fmt.Sprintf("%v:%v(but[%v])", err.Field, err.Type.String(), err.Value) + fields = append(fields, tmp) + + } else { + fields = append(fields, err.Error()) + } + + msg := message.GetErrorMsg(message.ParameterInvalid) + msg.Error = fmt.Sprintf("req param : %v", strings.Join(fields, ";")) + c.JSON(http.StatusBadRequest, msg) + return +} + func (b *_Base) unmarshal(c *gin.Context, v interface{}) error { return c.ShouldBind(v) } @@ -400,11 +504,11 @@ func (b *_Base) register(router gin.IRouter, cList ...interface{}) bool { if _b { if v, ok := mp[objName+"."+method.Name]; ok { for _, v1 := range v { - b.registerHandlerObj(router, v1.GenComment.Methods, v1.GenComment.RouterPath, method.Func, refVal) + b.registerHandlerObj(router, v1.GenComment.Methods, v1.GenComment.RouterPath, method.Name, method.Func, refVal) } } else { // not find using default case routerPath, methods := b.getDefaultComments(objName, method.Name, num) - b.registerHandlerObj(router, methods, routerPath, method.Func, refVal) + b.registerHandlerObj(router, methods, routerPath, method.Name, method.Func, refVal) } } } @@ -428,8 +532,8 @@ func (b *_Base) getDefaultComments(objName, objFunc string, num int) (routerPath } // registerHandlerObj Multiple registration methods.获取并过滤要绑定的参数 -func (b *_Base) registerHandlerObj(router gin.IRouter, httpMethod []string, relativePath string, tvl, obj reflect.Value) error { - call := b.handlerFuncObj(tvl, obj) +func (b *_Base) registerHandlerObj(router gin.IRouter, httpMethod []string, relativePath, methodName string, tvl, obj reflect.Value) error { + call := b.handlerFuncObj(tvl, obj, methodName) for _, v := range httpMethod { // method := strings.ToUpper(v) diff --git a/ginrpc.go b/ginrpc.go index 72639c9..e9fdfab 100644 --- a/ginrpc.go +++ b/ginrpc.go @@ -12,11 +12,12 @@ import ( // _Base base struct type _Base struct { - isBigCamel bool // big camel style.大驼峰命名规则 - isDev bool // if is development - apiFun NewAPIFunc - apiType reflect.Type - outPath string // output path.输出目录 + isBigCamel bool // big camel style.大驼峰命名规则 + isDev bool // if is development + apiFun NewAPIFunc + apiType reflect.Type + outPath string // output path.输出目录 + beforeAfter GinBeforeAfter } // Option overrides behavior of Connect. @@ -61,6 +62,13 @@ func WithBigCamel(b bool) Option { }) } +// WithBeforeAfter set before and after call.设置对象调用前后执行中间件 +func WithBeforeAfter(beforeAfter GinBeforeAfter) Option { + return optionFunc(func(o *_Base) { + o.beforeAfter = beforeAfter + }) +} + // Default new op obj func Default() *_Base { b := new(_Base) diff --git a/go.mod b/go.mod index 42a6c13..42b3c49 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.12 require ( github.com/gin-gonic/gin v1.6.3 github.com/go-playground/validator/v10 v10.2.0 - github.com/xxjwxc/public v0.0.0-20200526160023-d8d1bd6babeb + github.com/xxjwxc/public v0.0.0-20200601115915-ab2b4ce31a9c ) // replace github.com/xxjwxc/public => ../public diff --git a/go.sum b/go.sum index cf493a7..bd62256 100644 --- a/go.sum +++ b/go.sum @@ -168,6 +168,8 @@ github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/xxjwxc/public v0.0.0-20200526160023-d8d1bd6babeb h1:tmQZkQGPqTyZ3dLg3ZMhSfXrio69EUpM6i6tJ2yH75w= github.com/xxjwxc/public v0.0.0-20200526160023-d8d1bd6babeb/go.mod h1:s1lcFEJl/8sQNC5jYCmmHqwWH/uf2EtMmYoji33ZKQQ= +github.com/xxjwxc/public v0.0.0-20200601115915-ab2b4ce31a9c h1:FC1aGStnQOk+H7nztD6LAh2ZY8SODeO0i1Agvrkw1tc= +github.com/xxjwxc/public v0.0.0-20200601115915-ab2b4ce31a9c/go.mod h1:s1lcFEJl/8sQNC5jYCmmHqwWH/uf2EtMmYoji33ZKQQ= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..9978f24 --- /dev/null +++ b/middleware.go @@ -0,0 +1,64 @@ +package ginrpc + +import ( + "context" + "fmt" + "time" + + "github.com/xxjwxc/public/message" + + "github.com/xxjwxc/public/mylog" + + "github.com/gin-gonic/gin" +) + +// GinBeforeAfterInfo 对象调用前后执行中间件参数 +type GinBeforeAfterInfo struct { + C *gin.Context + FuncName string // 函数名 + Req interface{} // 调用前的请求参数 + Resp interface{} // 调用后的返回参数 + Error error + // Other options for implementations of the interface + // can be stored in a context + Context context.Context // 占位参数,可用于存储其他参数,前后连接可用 + +} + +// GinBeforeAfter 对象调用前后执行中间件(支持总的跟对象单独添加) +type GinBeforeAfter interface { + GinBefore(req *GinBeforeAfterInfo) bool + GinAfter(req *GinBeforeAfterInfo) bool +} + +// DefaultGinBeforeAfter 创建一个默认 BeforeAfter Middleware +type DefaultGinBeforeAfter struct { +} + +type timeTrace struct{} + +// GinBefore call之前调用 +func (d *DefaultGinBeforeAfter) GinBefore(req *GinBeforeAfterInfo) bool { + req.Context = context.WithValue(req.Context, timeTrace{}, time.Now()) + return true +} + +// GinAfter call之后调用 +func (d *DefaultGinBeforeAfter) GinAfter(req *GinBeforeAfterInfo) bool { + begin := (req.Context.Value(timeTrace{})).(time.Time) + now := time.Now() + mylog.Info(fmt.Sprintf("[middleware] call[%v] [%v]", req.FuncName, now.Sub(begin))) + // 设置resp 结果 + if req.Error == nil { + msg := message.GetSuccessMsg() + msg.Data = req.Resp + req.Resp = msg + } else { + msg := message.GetErrorStrMsg(req.Error.Error()) + msg.Data = req.Resp + req.Resp = msg + } + return true +} + +// ----------------end