mirror of
https://github.com/smallnest/rpcx.git
synced 2025-10-26 09:30:29 +08:00
#165 support registering raw functions
This commit is contained in:
@@ -16,6 +16,7 @@ type PluginContainer interface {
|
||||
All() []Plugin
|
||||
|
||||
DoRegister(name string, rcvr interface{}, metadata string) error
|
||||
DoRegisterFunction(name string, fn interface{}, metadata string) error
|
||||
|
||||
DoPostConnAccept(net.Conn) (net.Conn, bool)
|
||||
|
||||
@@ -36,6 +37,11 @@ type (
|
||||
Register(name string, rcvr interface{}, metadata string) error
|
||||
}
|
||||
|
||||
// RegisterFunctionPlugin is .
|
||||
RegisterFunctionPlugin interface {
|
||||
RegisterFunction(name string, fn interface{}, metadata string) error
|
||||
}
|
||||
|
||||
// PostConnAcceptPlugin represents connection accept plugin.
|
||||
// if returns false, it means subsequent IPostConnAcceptPlugins should not contiune to handle this conn
|
||||
// and this conn has been closed.
|
||||
@@ -112,6 +118,24 @@ func (p *pluginContainer) DoRegister(name string, rcvr interface{}, metadata str
|
||||
return nil
|
||||
}
|
||||
|
||||
// DoRegisterFunction invokes DoRegisterFunction plugin.
|
||||
func (p *pluginContainer) DoRegisterFunction(name string, fn interface{}, metadata string) error {
|
||||
var es []error
|
||||
for _, rp := range p.plugins {
|
||||
if plugin, ok := rp.(RegisterFunctionPlugin); ok {
|
||||
err := plugin.RegisterFunction(name, fn, metadata)
|
||||
if err != nil {
|
||||
es = append(es, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(es) > 0 {
|
||||
return errors.NewMultiError(es)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//DoPostConnAccept handles accepted conn
|
||||
func (p *pluginContainer) DoPostConnAccept(conn net.Conn) (net.Conn, bool) {
|
||||
var flag bool
|
||||
|
||||
@@ -347,14 +347,15 @@ func (s *Server) auth(ctx context.Context, req *protocol.Message) error {
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
|
||||
// pool res?
|
||||
serviceName := req.ServicePath
|
||||
if serviceName == "" {
|
||||
return s.handleRequestForFunction(ctx, req)
|
||||
}
|
||||
methodName := req.ServiceMethod
|
||||
|
||||
res = req.Clone()
|
||||
|
||||
res.SetMessageType(protocol.Response)
|
||||
|
||||
serviceName := req.ServicePath
|
||||
methodName := req.ServiceMethod
|
||||
|
||||
s.serviceMapMu.RLock()
|
||||
service := s.serviceMap[serviceName]
|
||||
s.serviceMapMu.RUnlock()
|
||||
@@ -412,6 +413,69 @@ func (s *Server) handleRequest(ctx context.Context, req *protocol.Message) (res
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (s *Server) handleRequestForFunction(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
|
||||
res = req.Clone()
|
||||
|
||||
res.SetMessageType(protocol.Response)
|
||||
|
||||
methodName := req.ServiceMethod
|
||||
s.serviceMapMu.RLock()
|
||||
service := s.serviceMap[""]
|
||||
s.serviceMapMu.RUnlock()
|
||||
if service == nil {
|
||||
err = errors.New("rpcx: can't find service func raw function")
|
||||
return handleError(res, err)
|
||||
}
|
||||
mtype := service.function[methodName]
|
||||
if mtype == nil {
|
||||
err = errors.New("rpcx: can't find method " + methodName)
|
||||
return handleError(res, err)
|
||||
}
|
||||
|
||||
var argv, replyv reflect.Value
|
||||
|
||||
argIsValue := false // if true, need to indirect before calling.
|
||||
if mtype.ArgType.Kind() == reflect.Ptr {
|
||||
argv = reflect.New(mtype.ArgType.Elem())
|
||||
} else {
|
||||
argv = reflect.New(mtype.ArgType)
|
||||
argIsValue = true
|
||||
}
|
||||
|
||||
codec := share.Codecs[req.SerializeType()]
|
||||
if codec == nil {
|
||||
err = fmt.Errorf("can not find codec for %d", req.SerializeType())
|
||||
return handleError(res, err)
|
||||
}
|
||||
|
||||
err = codec.Decode(req.Payload, argv.Interface())
|
||||
if err != nil {
|
||||
return handleError(res, err)
|
||||
}
|
||||
|
||||
if argIsValue {
|
||||
argv = argv.Elem()
|
||||
}
|
||||
|
||||
replyv = reflect.New(mtype.ReplyType.Elem())
|
||||
|
||||
err = service.callForFunction(ctx, mtype, argv, replyv)
|
||||
if err != nil {
|
||||
return handleError(res, err)
|
||||
}
|
||||
|
||||
if !req.IsOneway() {
|
||||
data, err := codec.Encode(replyv.Interface())
|
||||
if err != nil {
|
||||
return handleError(res, err)
|
||||
|
||||
}
|
||||
res.Payload = data
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func handleError(res *protocol.Message, err error) (*protocol.Message, error) {
|
||||
res.SetMessageStatusType(protocol.Error)
|
||||
if res.Metadata == nil {
|
||||
|
||||
@@ -19,6 +19,8 @@ var typeOfError = reflect.TypeOf((*error)(nil)).Elem()
|
||||
// Precompute the reflect type for context.
|
||||
var typeOfContext = reflect.TypeOf((*context.Context)(nil)).Elem()
|
||||
|
||||
var emptyService = new(service)
|
||||
|
||||
type methodType struct {
|
||||
sync.Mutex // protects counters
|
||||
method reflect.Method
|
||||
@@ -27,11 +29,20 @@ type methodType struct {
|
||||
numCalls uint
|
||||
}
|
||||
|
||||
type functionType struct {
|
||||
sync.Mutex // protects counters
|
||||
fn reflect.Value
|
||||
ArgType reflect.Type
|
||||
ReplyType reflect.Type
|
||||
numCalls uint
|
||||
}
|
||||
|
||||
type service struct {
|
||||
name string // name of service
|
||||
rcvr reflect.Value // receiver of methods for the service
|
||||
typ reflect.Type // type of the receiver
|
||||
method map[string]*methodType // registered methods
|
||||
function map[string]*functionType // registered functions
|
||||
}
|
||||
|
||||
func isExported(name string) bool {
|
||||
@@ -51,7 +62,7 @@ func isExportedOrBuiltinType(t reflect.Type) bool {
|
||||
// Register publishes in the server the set of methods of the
|
||||
// receiver value that satisfy the following conditions:
|
||||
// - exported method of exported type
|
||||
// - three arguments, the first is of context.Context, both of exported type or three arguments
|
||||
// - three arguments, the first is of context.Context, both of exported type for three arguments
|
||||
// - the third argument is a pointer
|
||||
// - one return value, of type error
|
||||
// It returns an error if the receiver is not an exported type or has
|
||||
@@ -59,6 +70,7 @@ func isExportedOrBuiltinType(t reflect.Type) bool {
|
||||
// The client accesses each method using a string of the form "Type.Method",
|
||||
// where Type is the receiver's concrete type.
|
||||
func (s *Server) Register(rcvr interface{}, metadata string) error {
|
||||
s.Plugins.DoRegister("", rcvr, metadata)
|
||||
return s.register(rcvr, "", false)
|
||||
}
|
||||
|
||||
@@ -73,6 +85,28 @@ func (s *Server) RegisterName(name string, rcvr interface{}, metadata string) er
|
||||
return s.register(rcvr, name, true)
|
||||
}
|
||||
|
||||
// RegisterFunction publishes a function that satisfy the following conditions:
|
||||
// - three arguments, the first is of context.Context, both of exported type for three arguments
|
||||
// - the third argument is a pointer
|
||||
// - one return value, of type error
|
||||
// The client accesses function using a string of the form ".Method",
|
||||
// where service path is empty.
|
||||
func (s *Server) RegisterFunction(fn interface{}, metadata string) error {
|
||||
s.Plugins.DoRegisterFunction("", fn, metadata)
|
||||
return s.registerFunction(fn, "", false)
|
||||
}
|
||||
|
||||
// RegisterFunctionName is like RegisterFunction but uses the provided name for the function
|
||||
// instead of the function's concrete type.
|
||||
func (s *Server) RegisterFunctionName(name string, fn interface{}, metadata string) error {
|
||||
if s.Plugins == nil {
|
||||
s.Plugins = &pluginContainer{}
|
||||
}
|
||||
|
||||
s.Plugins.DoRegisterFunction(name, fn, metadata)
|
||||
return s.registerFunction(fn, name, true)
|
||||
}
|
||||
|
||||
func (s *Server) register(rcvr interface{}, name string, useName bool) error {
|
||||
s.serviceMapMu.Lock()
|
||||
defer s.serviceMapMu.Unlock()
|
||||
@@ -119,6 +153,70 @@ func (s *Server) register(rcvr interface{}, name string, useName bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) registerFunction(fn interface{}, name string, useName bool) error {
|
||||
s.serviceMapMu.Lock()
|
||||
defer s.serviceMapMu.Unlock()
|
||||
if s.serviceMap == nil {
|
||||
s.serviceMap = make(map[string]*service)
|
||||
}
|
||||
|
||||
f, ok := fn.(reflect.Value)
|
||||
if !ok {
|
||||
f = reflect.ValueOf(fn)
|
||||
}
|
||||
if f.Kind() != reflect.Func {
|
||||
return errors.New("function must be func or bound method")
|
||||
}
|
||||
|
||||
fname := reflect.Indirect(f).Type().Name()
|
||||
if useName {
|
||||
fname = name
|
||||
}
|
||||
if fname == "" {
|
||||
errorStr := "rpcx.registerFunction: no func name for type " + f.Type().String()
|
||||
log.Error(errorStr)
|
||||
return errors.New(errorStr)
|
||||
}
|
||||
|
||||
t := f.Type()
|
||||
if t.NumIn() != 3 {
|
||||
return fmt.Errorf("rpcx.registerFunction: has wrong number of ins: %s", f.Type().String())
|
||||
}
|
||||
if t.NumOut() != 1 {
|
||||
return fmt.Errorf("rpcx.registerFunction: has wrong number of outs: %s", f.Type().String())
|
||||
}
|
||||
|
||||
// First arg must be context.Context
|
||||
ctxType := t.In(0)
|
||||
if !ctxType.Implements(typeOfContext) {
|
||||
return fmt.Errorf("function %s must use context as the first parameter", f.Type().String())
|
||||
}
|
||||
|
||||
argType := t.In(1)
|
||||
if !isExportedOrBuiltinType(argType) {
|
||||
return fmt.Errorf("function %s parameter type not exported: %v", f.Type().String(), argType)
|
||||
}
|
||||
|
||||
replyType := t.In(2)
|
||||
if replyType.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("function %s reply type not a pointer: %s", f.Type().String(), replyType)
|
||||
}
|
||||
if !isExportedOrBuiltinType(replyType) {
|
||||
return fmt.Errorf("function %s reply type not exported: %v", f.Type().String(), replyType)
|
||||
}
|
||||
|
||||
// The return type of the method must be error.
|
||||
if returnType := t.Out(0); returnType != typeOfError {
|
||||
return fmt.Errorf("function %s returns %s, not error", f.Type().String(), returnType.String())
|
||||
}
|
||||
|
||||
// Install the methods
|
||||
emptyService.function[fname] = &functionType{fn: f, ArgType: argType, ReplyType: replyType}
|
||||
|
||||
s.serviceMap[""] = emptyService
|
||||
return nil
|
||||
}
|
||||
|
||||
// suitableMethods returns suitable Rpc methods of typ, it will report
|
||||
// error using log if reportErr is true.
|
||||
func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType {
|
||||
@@ -142,7 +240,7 @@ func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType {
|
||||
ctxType := mtype.In(1)
|
||||
if !ctxType.Implements(typeOfContext) {
|
||||
if reportErr {
|
||||
log.Info("method", mname, "has wrong number of ins:", mtype.NumIn())
|
||||
log.Info("method", mname, " must use context.Context as the first parameter")
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -151,7 +249,7 @@ func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType {
|
||||
argType := mtype.In(2)
|
||||
if !isExportedOrBuiltinType(argType) {
|
||||
if reportErr {
|
||||
log.Info(mname, "argument type not exported:", argType)
|
||||
log.Info(mname, "parameter type not exported:", argType)
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -207,3 +305,21 @@ func (s *service) call(ctx context.Context, mtype *methodType, argv, replyv refl
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) callForFunction(ctx context.Context, ft *functionType, argv, replyv reflect.Value) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("internal error: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Invoke the function, providing a new value for the reply.
|
||||
returnValues := ft.fn.Call([]reflect.Value{reflect.ValueOf(ctx), argv, replyv})
|
||||
// The return value for the method is an error.
|
||||
errInter := returnValues[0].Interface()
|
||||
if errInter != nil {
|
||||
return errInter.(error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user