diff --git a/server/plugin.go b/server/plugin.go index 070e57c..c86b336 100644 --- a/server/plugin.go +++ b/server/plugin.go @@ -26,6 +26,8 @@ type PluginContainer interface { DoPostReadRequest(ctx context.Context, r *protocol.Message, e error) error DoPreHandleRequest(ctx context.Context, req *protocol.Message) error + DoPreCall(ctx context.Context, args interface{}) (interface{}, error) + DoPostCall(ctx context.Context, args, reply interface{}) (interface{}, error) DoPreWriteResponse(context.Context, *protocol.Message, *protocol.Message) error DoPostWriteResponse(context.Context, *protocol.Message, *protocol.Message, error) error @@ -77,6 +79,14 @@ type ( PreHandleRequest(ctx context.Context, r *protocol.Message) error } + PreCallPlugin interface { + PreCall(ctx context.Context, args interface{}) (interface{}, error) + } + + PostCallPlugin interface { + PostCall(ctx context.Context, args, reply interface{}) (interface{}, error) + } + //PreWriteResponsePlugin represents . PreWriteResponsePlugin interface { PreWriteResponse(context.Context, *protocol.Message, *protocol.Message) error @@ -253,6 +263,36 @@ func (p *pluginContainer) DoPreHandleRequest(ctx context.Context, r *protocol.Me return nil } +// DoPreCall invokes PreCallPlugin plugin. +func (p *pluginContainer) DoPreCall(ctx context.Context, args interface{}) (interface{}, error) { + var err error + for i := range p.plugins { + if plugin, ok := p.plugins[i].(PreCallPlugin); ok { + args, err = plugin.PreCall(ctx, args) + if err != nil { + return args, err + } + } + } + + return args, err +} + +// DoPostCall invokes PostCallPlugin plugin. +func (p *pluginContainer) DoPostCall(ctx context.Context, args, reply interface{}) (interface{}, error) { + var err error + for i := range p.plugins { + if plugin, ok := p.plugins[i].(PostCallPlugin); ok { + reply, err = plugin.PostCall(ctx, args, reply) + if err != nil { + return reply, err + } + } + } + + return reply, err +} + // DoPreWriteResponse invokes PreWriteResponse plugin. func (p *pluginContainer) DoPreWriteResponse(ctx context.Context, req *protocol.Message, res *protocol.Message) error { for i := range p.plugins { diff --git a/server/server.go b/server/server.go index 01d475f..8977710 100644 --- a/server/server.go +++ b/server/server.go @@ -529,12 +529,22 @@ func (s *Server) handleRequest(ctx context.Context, req *protocol.Message) (res replyv := argsReplyPools.Get(mtype.ReplyType) + argv, err = s.Plugins.DoPreCall(ctx, argv) + if err != nil { + argsReplyPools.Put(mtype.ReplyType, replyv) + return handleError(res, err) + } + if mtype.ArgType.Kind() != reflect.Ptr { err = service.call(ctx, mtype, reflect.ValueOf(argv).Elem(), reflect.ValueOf(replyv)) } else { err = service.call(ctx, mtype, reflect.ValueOf(argv), reflect.ValueOf(replyv)) } + if err == nil { + replyv, err = s.Plugins.DoPostCall(ctx, argv, replyv) + } + argsReplyPools.Put(mtype.ArgType, argv) if err != nil { argsReplyPools.Put(mtype.ReplyType, replyv)