diff --git a/server/plugin.go b/server/plugin.go index 2b16202..5ee56c4 100644 --- a/server/plugin.go +++ b/server/plugin.go @@ -16,6 +16,7 @@ type PluginContainer interface { All() []Plugin DoRegister(name string, rcvr interface{}, metadata string) error + DoRegisterFunction(name string, fn interface{}, metadata string) error DoPostConnAccept(net.Conn) (net.Conn, bool) @@ -36,6 +37,11 @@ type ( Register(name string, rcvr interface{}, metadata string) error } + // RegisterFunctionPlugin is . + RegisterFunctionPlugin interface { + RegisterFunction(name string, fn interface{}, metadata string) error + } + // PostConnAcceptPlugin represents connection accept plugin. // if returns false, it means subsequent IPostConnAcceptPlugins should not contiune to handle this conn // and this conn has been closed. @@ -112,6 +118,24 @@ func (p *pluginContainer) DoRegister(name string, rcvr interface{}, metadata str return nil } +// DoRegisterFunction invokes DoRegisterFunction plugin. +func (p *pluginContainer) DoRegisterFunction(name string, fn interface{}, metadata string) error { + var es []error + for _, rp := range p.plugins { + if plugin, ok := rp.(RegisterFunctionPlugin); ok { + err := plugin.RegisterFunction(name, fn, metadata) + if err != nil { + es = append(es, err) + } + } + } + + if len(es) > 0 { + return errors.NewMultiError(es) + } + return nil +} + //DoPostConnAccept handles accepted conn func (p *pluginContainer) DoPostConnAccept(conn net.Conn) (net.Conn, bool) { var flag bool diff --git a/server/server.go b/server/server.go index 58dbefc..cfd9bd8 100644 --- a/server/server.go +++ b/server/server.go @@ -347,14 +347,15 @@ func (s *Server) auth(ctx context.Context, req *protocol.Message) error { } func (s *Server) handleRequest(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) { - // pool res? + serviceName := req.ServicePath + if serviceName == "" { + return s.handleRequestForFunction(ctx, req) + } + methodName := req.ServiceMethod + res = req.Clone() res.SetMessageType(protocol.Response) - - serviceName := req.ServicePath - methodName := req.ServiceMethod - s.serviceMapMu.RLock() service := s.serviceMap[serviceName] s.serviceMapMu.RUnlock() @@ -412,6 +413,69 @@ func (s *Server) handleRequest(ctx context.Context, req *protocol.Message) (res return res, nil } +func (s *Server) handleRequestForFunction(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) { + res = req.Clone() + + res.SetMessageType(protocol.Response) + + methodName := req.ServiceMethod + s.serviceMapMu.RLock() + service := s.serviceMap[""] + s.serviceMapMu.RUnlock() + if service == nil { + err = errors.New("rpcx: can't find service func raw function") + return handleError(res, err) + } + mtype := service.function[methodName] + if mtype == nil { + err = errors.New("rpcx: can't find method " + methodName) + return handleError(res, err) + } + + var argv, replyv reflect.Value + + argIsValue := false // if true, need to indirect before calling. + if mtype.ArgType.Kind() == reflect.Ptr { + argv = reflect.New(mtype.ArgType.Elem()) + } else { + argv = reflect.New(mtype.ArgType) + argIsValue = true + } + + codec := share.Codecs[req.SerializeType()] + if codec == nil { + err = fmt.Errorf("can not find codec for %d", req.SerializeType()) + return handleError(res, err) + } + + err = codec.Decode(req.Payload, argv.Interface()) + if err != nil { + return handleError(res, err) + } + + if argIsValue { + argv = argv.Elem() + } + + replyv = reflect.New(mtype.ReplyType.Elem()) + + err = service.callForFunction(ctx, mtype, argv, replyv) + if err != nil { + return handleError(res, err) + } + + if !req.IsOneway() { + data, err := codec.Encode(replyv.Interface()) + if err != nil { + return handleError(res, err) + + } + res.Payload = data + } + + return res, nil +} + func handleError(res *protocol.Message, err error) (*protocol.Message, error) { res.SetMessageStatusType(protocol.Error) if res.Metadata == nil { diff --git a/server/service.go b/server/service.go index 75148f1..280dc02 100644 --- a/server/service.go +++ b/server/service.go @@ -19,6 +19,8 @@ var typeOfError = reflect.TypeOf((*error)(nil)).Elem() // Precompute the reflect type for context. var typeOfContext = reflect.TypeOf((*context.Context)(nil)).Elem() +var emptyService = new(service) + type methodType struct { sync.Mutex // protects counters method reflect.Method @@ -27,11 +29,20 @@ type methodType struct { numCalls uint } +type functionType struct { + sync.Mutex // protects counters + fn reflect.Value + ArgType reflect.Type + ReplyType reflect.Type + numCalls uint +} + type service struct { - name string // name of service - rcvr reflect.Value // receiver of methods for the service - typ reflect.Type // type of the receiver - method map[string]*methodType // registered methods + name string // name of service + rcvr reflect.Value // receiver of methods for the service + typ reflect.Type // type of the receiver + method map[string]*methodType // registered methods + function map[string]*functionType // registered functions } func isExported(name string) bool { @@ -51,7 +62,7 @@ func isExportedOrBuiltinType(t reflect.Type) bool { // Register publishes in the server the set of methods of the // receiver value that satisfy the following conditions: // - exported method of exported type -// - three arguments, the first is of context.Context, both of exported type or three arguments +// - three arguments, the first is of context.Context, both of exported type for three arguments // - the third argument is a pointer // - one return value, of type error // It returns an error if the receiver is not an exported type or has @@ -59,6 +70,7 @@ func isExportedOrBuiltinType(t reflect.Type) bool { // The client accesses each method using a string of the form "Type.Method", // where Type is the receiver's concrete type. func (s *Server) Register(rcvr interface{}, metadata string) error { + s.Plugins.DoRegister("", rcvr, metadata) return s.register(rcvr, "", false) } @@ -73,6 +85,28 @@ func (s *Server) RegisterName(name string, rcvr interface{}, metadata string) er return s.register(rcvr, name, true) } +// RegisterFunction publishes a function that satisfy the following conditions: +// - three arguments, the first is of context.Context, both of exported type for three arguments +// - the third argument is a pointer +// - one return value, of type error +// The client accesses function using a string of the form ".Method", +// where service path is empty. +func (s *Server) RegisterFunction(fn interface{}, metadata string) error { + s.Plugins.DoRegisterFunction("", fn, metadata) + return s.registerFunction(fn, "", false) +} + +// RegisterFunctionName is like RegisterFunction but uses the provided name for the function +// instead of the function's concrete type. +func (s *Server) RegisterFunctionName(name string, fn interface{}, metadata string) error { + if s.Plugins == nil { + s.Plugins = &pluginContainer{} + } + + s.Plugins.DoRegisterFunction(name, fn, metadata) + return s.registerFunction(fn, name, true) +} + func (s *Server) register(rcvr interface{}, name string, useName bool) error { s.serviceMapMu.Lock() defer s.serviceMapMu.Unlock() @@ -119,6 +153,70 @@ func (s *Server) register(rcvr interface{}, name string, useName bool) error { return nil } +func (s *Server) registerFunction(fn interface{}, name string, useName bool) error { + s.serviceMapMu.Lock() + defer s.serviceMapMu.Unlock() + if s.serviceMap == nil { + s.serviceMap = make(map[string]*service) + } + + f, ok := fn.(reflect.Value) + if !ok { + f = reflect.ValueOf(fn) + } + if f.Kind() != reflect.Func { + return errors.New("function must be func or bound method") + } + + fname := reflect.Indirect(f).Type().Name() + if useName { + fname = name + } + if fname == "" { + errorStr := "rpcx.registerFunction: no func name for type " + f.Type().String() + log.Error(errorStr) + return errors.New(errorStr) + } + + t := f.Type() + if t.NumIn() != 3 { + return fmt.Errorf("rpcx.registerFunction: has wrong number of ins: %s", f.Type().String()) + } + if t.NumOut() != 1 { + return fmt.Errorf("rpcx.registerFunction: has wrong number of outs: %s", f.Type().String()) + } + + // First arg must be context.Context + ctxType := t.In(0) + if !ctxType.Implements(typeOfContext) { + return fmt.Errorf("function %s must use context as the first parameter", f.Type().String()) + } + + argType := t.In(1) + if !isExportedOrBuiltinType(argType) { + return fmt.Errorf("function %s parameter type not exported: %v", f.Type().String(), argType) + } + + replyType := t.In(2) + if replyType.Kind() != reflect.Ptr { + return fmt.Errorf("function %s reply type not a pointer: %s", f.Type().String(), replyType) + } + if !isExportedOrBuiltinType(replyType) { + return fmt.Errorf("function %s reply type not exported: %v", f.Type().String(), replyType) + } + + // The return type of the method must be error. + if returnType := t.Out(0); returnType != typeOfError { + return fmt.Errorf("function %s returns %s, not error", f.Type().String(), returnType.String()) + } + + // Install the methods + emptyService.function[fname] = &functionType{fn: f, ArgType: argType, ReplyType: replyType} + + s.serviceMap[""] = emptyService + return nil +} + // suitableMethods returns suitable Rpc methods of typ, it will report // error using log if reportErr is true. func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType { @@ -142,7 +240,7 @@ func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType { ctxType := mtype.In(1) if !ctxType.Implements(typeOfContext) { if reportErr { - log.Info("method", mname, "has wrong number of ins:", mtype.NumIn()) + log.Info("method", mname, " must use context.Context as the first parameter") } continue } @@ -151,7 +249,7 @@ func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType { argType := mtype.In(2) if !isExportedOrBuiltinType(argType) { if reportErr { - log.Info(mname, "argument type not exported:", argType) + log.Info(mname, "parameter type not exported:", argType) } continue } @@ -207,3 +305,21 @@ func (s *service) call(ctx context.Context, mtype *methodType, argv, replyv refl return nil } + +func (s *service) callForFunction(ctx context.Context, ft *functionType, argv, replyv reflect.Value) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("internal error: %v", r) + } + }() + + // Invoke the function, providing a new value for the reply. + returnValues := ft.fn.Call([]reflect.Value{reflect.ValueOf(ctx), argv, replyv}) + // The return value for the method is an error. + errInter := returnValues[0].Interface() + if errInter != nil { + return errInter.(error) + } + + return nil +}