#165 support registering raw functions

This commit is contained in:
smallnest
2017-12-12 19:37:17 +08:00
parent 901bdef542
commit d117477c0f
3 changed files with 216 additions and 12 deletions

View File

@@ -16,6 +16,7 @@ type PluginContainer interface {
All() []Plugin All() []Plugin
DoRegister(name string, rcvr interface{}, metadata string) error DoRegister(name string, rcvr interface{}, metadata string) error
DoRegisterFunction(name string, fn interface{}, metadata string) error
DoPostConnAccept(net.Conn) (net.Conn, bool) DoPostConnAccept(net.Conn) (net.Conn, bool)
@@ -36,6 +37,11 @@ type (
Register(name string, rcvr interface{}, metadata string) error Register(name string, rcvr interface{}, metadata string) error
} }
// RegisterFunctionPlugin is .
RegisterFunctionPlugin interface {
RegisterFunction(name string, fn interface{}, metadata string) error
}
// PostConnAcceptPlugin represents connection accept plugin. // PostConnAcceptPlugin represents connection accept plugin.
// if returns false, it means subsequent IPostConnAcceptPlugins should not contiune to handle this conn // if returns false, it means subsequent IPostConnAcceptPlugins should not contiune to handle this conn
// and this conn has been closed. // and this conn has been closed.
@@ -112,6 +118,24 @@ func (p *pluginContainer) DoRegister(name string, rcvr interface{}, metadata str
return nil return nil
} }
// DoRegisterFunction invokes DoRegisterFunction plugin.
func (p *pluginContainer) DoRegisterFunction(name string, fn interface{}, metadata string) error {
var es []error
for _, rp := range p.plugins {
if plugin, ok := rp.(RegisterFunctionPlugin); ok {
err := plugin.RegisterFunction(name, fn, metadata)
if err != nil {
es = append(es, err)
}
}
}
if len(es) > 0 {
return errors.NewMultiError(es)
}
return nil
}
//DoPostConnAccept handles accepted conn //DoPostConnAccept handles accepted conn
func (p *pluginContainer) DoPostConnAccept(conn net.Conn) (net.Conn, bool) { func (p *pluginContainer) DoPostConnAccept(conn net.Conn) (net.Conn, bool) {
var flag bool var flag bool

View File

@@ -347,14 +347,15 @@ func (s *Server) auth(ctx context.Context, req *protocol.Message) error {
} }
func (s *Server) handleRequest(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) { func (s *Server) handleRequest(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
// pool res? serviceName := req.ServicePath
if serviceName == "" {
return s.handleRequestForFunction(ctx, req)
}
methodName := req.ServiceMethod
res = req.Clone() res = req.Clone()
res.SetMessageType(protocol.Response) res.SetMessageType(protocol.Response)
serviceName := req.ServicePath
methodName := req.ServiceMethod
s.serviceMapMu.RLock() s.serviceMapMu.RLock()
service := s.serviceMap[serviceName] service := s.serviceMap[serviceName]
s.serviceMapMu.RUnlock() s.serviceMapMu.RUnlock()
@@ -412,6 +413,69 @@ func (s *Server) handleRequest(ctx context.Context, req *protocol.Message) (res
return res, nil return res, nil
} }
func (s *Server) handleRequestForFunction(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
res = req.Clone()
res.SetMessageType(protocol.Response)
methodName := req.ServiceMethod
s.serviceMapMu.RLock()
service := s.serviceMap[""]
s.serviceMapMu.RUnlock()
if service == nil {
err = errors.New("rpcx: can't find service func raw function")
return handleError(res, err)
}
mtype := service.function[methodName]
if mtype == nil {
err = errors.New("rpcx: can't find method " + methodName)
return handleError(res, err)
}
var argv, replyv reflect.Value
argIsValue := false // if true, need to indirect before calling.
if mtype.ArgType.Kind() == reflect.Ptr {
argv = reflect.New(mtype.ArgType.Elem())
} else {
argv = reflect.New(mtype.ArgType)
argIsValue = true
}
codec := share.Codecs[req.SerializeType()]
if codec == nil {
err = fmt.Errorf("can not find codec for %d", req.SerializeType())
return handleError(res, err)
}
err = codec.Decode(req.Payload, argv.Interface())
if err != nil {
return handleError(res, err)
}
if argIsValue {
argv = argv.Elem()
}
replyv = reflect.New(mtype.ReplyType.Elem())
err = service.callForFunction(ctx, mtype, argv, replyv)
if err != nil {
return handleError(res, err)
}
if !req.IsOneway() {
data, err := codec.Encode(replyv.Interface())
if err != nil {
return handleError(res, err)
}
res.Payload = data
}
return res, nil
}
func handleError(res *protocol.Message, err error) (*protocol.Message, error) { func handleError(res *protocol.Message, err error) (*protocol.Message, error) {
res.SetMessageStatusType(protocol.Error) res.SetMessageStatusType(protocol.Error)
if res.Metadata == nil { if res.Metadata == nil {

View File

@@ -19,6 +19,8 @@ var typeOfError = reflect.TypeOf((*error)(nil)).Elem()
// Precompute the reflect type for context. // Precompute the reflect type for context.
var typeOfContext = reflect.TypeOf((*context.Context)(nil)).Elem() var typeOfContext = reflect.TypeOf((*context.Context)(nil)).Elem()
var emptyService = new(service)
type methodType struct { type methodType struct {
sync.Mutex // protects counters sync.Mutex // protects counters
method reflect.Method method reflect.Method
@@ -27,11 +29,20 @@ type methodType struct {
numCalls uint numCalls uint
} }
type functionType struct {
sync.Mutex // protects counters
fn reflect.Value
ArgType reflect.Type
ReplyType reflect.Type
numCalls uint
}
type service struct { type service struct {
name string // name of service name string // name of service
rcvr reflect.Value // receiver of methods for the service rcvr reflect.Value // receiver of methods for the service
typ reflect.Type // type of the receiver typ reflect.Type // type of the receiver
method map[string]*methodType // registered methods method map[string]*methodType // registered methods
function map[string]*functionType // registered functions
} }
func isExported(name string) bool { func isExported(name string) bool {
@@ -51,7 +62,7 @@ func isExportedOrBuiltinType(t reflect.Type) bool {
// Register publishes in the server the set of methods of the // Register publishes in the server the set of methods of the
// receiver value that satisfy the following conditions: // receiver value that satisfy the following conditions:
// - exported method of exported type // - exported method of exported type
// - three arguments, the first is of context.Context, both of exported type or three arguments // - three arguments, the first is of context.Context, both of exported type for three arguments
// - the third argument is a pointer // - the third argument is a pointer
// - one return value, of type error // - one return value, of type error
// It returns an error if the receiver is not an exported type or has // It returns an error if the receiver is not an exported type or has
@@ -59,6 +70,7 @@ func isExportedOrBuiltinType(t reflect.Type) bool {
// The client accesses each method using a string of the form "Type.Method", // The client accesses each method using a string of the form "Type.Method",
// where Type is the receiver's concrete type. // where Type is the receiver's concrete type.
func (s *Server) Register(rcvr interface{}, metadata string) error { func (s *Server) Register(rcvr interface{}, metadata string) error {
s.Plugins.DoRegister("", rcvr, metadata)
return s.register(rcvr, "", false) return s.register(rcvr, "", false)
} }
@@ -73,6 +85,28 @@ func (s *Server) RegisterName(name string, rcvr interface{}, metadata string) er
return s.register(rcvr, name, true) return s.register(rcvr, name, true)
} }
// RegisterFunction publishes a function that satisfy the following conditions:
// - three arguments, the first is of context.Context, both of exported type for three arguments
// - the third argument is a pointer
// - one return value, of type error
// The client accesses function using a string of the form ".Method",
// where service path is empty.
func (s *Server) RegisterFunction(fn interface{}, metadata string) error {
s.Plugins.DoRegisterFunction("", fn, metadata)
return s.registerFunction(fn, "", false)
}
// RegisterFunctionName is like RegisterFunction but uses the provided name for the function
// instead of the function's concrete type.
func (s *Server) RegisterFunctionName(name string, fn interface{}, metadata string) error {
if s.Plugins == nil {
s.Plugins = &pluginContainer{}
}
s.Plugins.DoRegisterFunction(name, fn, metadata)
return s.registerFunction(fn, name, true)
}
func (s *Server) register(rcvr interface{}, name string, useName bool) error { func (s *Server) register(rcvr interface{}, name string, useName bool) error {
s.serviceMapMu.Lock() s.serviceMapMu.Lock()
defer s.serviceMapMu.Unlock() defer s.serviceMapMu.Unlock()
@@ -119,6 +153,70 @@ func (s *Server) register(rcvr interface{}, name string, useName bool) error {
return nil return nil
} }
func (s *Server) registerFunction(fn interface{}, name string, useName bool) error {
s.serviceMapMu.Lock()
defer s.serviceMapMu.Unlock()
if s.serviceMap == nil {
s.serviceMap = make(map[string]*service)
}
f, ok := fn.(reflect.Value)
if !ok {
f = reflect.ValueOf(fn)
}
if f.Kind() != reflect.Func {
return errors.New("function must be func or bound method")
}
fname := reflect.Indirect(f).Type().Name()
if useName {
fname = name
}
if fname == "" {
errorStr := "rpcx.registerFunction: no func name for type " + f.Type().String()
log.Error(errorStr)
return errors.New(errorStr)
}
t := f.Type()
if t.NumIn() != 3 {
return fmt.Errorf("rpcx.registerFunction: has wrong number of ins: %s", f.Type().String())
}
if t.NumOut() != 1 {
return fmt.Errorf("rpcx.registerFunction: has wrong number of outs: %s", f.Type().String())
}
// First arg must be context.Context
ctxType := t.In(0)
if !ctxType.Implements(typeOfContext) {
return fmt.Errorf("function %s must use context as the first parameter", f.Type().String())
}
argType := t.In(1)
if !isExportedOrBuiltinType(argType) {
return fmt.Errorf("function %s parameter type not exported: %v", f.Type().String(), argType)
}
replyType := t.In(2)
if replyType.Kind() != reflect.Ptr {
return fmt.Errorf("function %s reply type not a pointer: %s", f.Type().String(), replyType)
}
if !isExportedOrBuiltinType(replyType) {
return fmt.Errorf("function %s reply type not exported: %v", f.Type().String(), replyType)
}
// The return type of the method must be error.
if returnType := t.Out(0); returnType != typeOfError {
return fmt.Errorf("function %s returns %s, not error", f.Type().String(), returnType.String())
}
// Install the methods
emptyService.function[fname] = &functionType{fn: f, ArgType: argType, ReplyType: replyType}
s.serviceMap[""] = emptyService
return nil
}
// suitableMethods returns suitable Rpc methods of typ, it will report // suitableMethods returns suitable Rpc methods of typ, it will report
// error using log if reportErr is true. // error using log if reportErr is true.
func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType { func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType {
@@ -142,7 +240,7 @@ func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType {
ctxType := mtype.In(1) ctxType := mtype.In(1)
if !ctxType.Implements(typeOfContext) { if !ctxType.Implements(typeOfContext) {
if reportErr { if reportErr {
log.Info("method", mname, "has wrong number of ins:", mtype.NumIn()) log.Info("method", mname, " must use context.Context as the first parameter")
} }
continue continue
} }
@@ -151,7 +249,7 @@ func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType {
argType := mtype.In(2) argType := mtype.In(2)
if !isExportedOrBuiltinType(argType) { if !isExportedOrBuiltinType(argType) {
if reportErr { if reportErr {
log.Info(mname, "argument type not exported:", argType) log.Info(mname, "parameter type not exported:", argType)
} }
continue continue
} }
@@ -207,3 +305,21 @@ func (s *service) call(ctx context.Context, mtype *methodType, argv, replyv refl
return nil return nil
} }
func (s *service) callForFunction(ctx context.Context, ft *functionType, argv, replyv reflect.Value) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("internal error: %v", r)
}
}()
// Invoke the function, providing a new value for the reply.
returnValues := ft.fn.Call([]reflect.Value{reflect.ValueOf(ctx), argv, replyv})
// The return value for the method is an error.
errInter := returnValues[0].Interface()
if errInter != nil {
return errInter.(error)
}
return nil
}