mirror of
https://github.com/smallnest/rpcx.git
synced 2025-10-25 17:10:21 +08:00
#165 support registering raw functions
This commit is contained in:
@@ -16,6 +16,7 @@ type PluginContainer interface {
|
|||||||
All() []Plugin
|
All() []Plugin
|
||||||
|
|
||||||
DoRegister(name string, rcvr interface{}, metadata string) error
|
DoRegister(name string, rcvr interface{}, metadata string) error
|
||||||
|
DoRegisterFunction(name string, fn interface{}, metadata string) error
|
||||||
|
|
||||||
DoPostConnAccept(net.Conn) (net.Conn, bool)
|
DoPostConnAccept(net.Conn) (net.Conn, bool)
|
||||||
|
|
||||||
@@ -36,6 +37,11 @@ type (
|
|||||||
Register(name string, rcvr interface{}, metadata string) error
|
Register(name string, rcvr interface{}, metadata string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterFunctionPlugin is .
|
||||||
|
RegisterFunctionPlugin interface {
|
||||||
|
RegisterFunction(name string, fn interface{}, metadata string) error
|
||||||
|
}
|
||||||
|
|
||||||
// PostConnAcceptPlugin represents connection accept plugin.
|
// PostConnAcceptPlugin represents connection accept plugin.
|
||||||
// if returns false, it means subsequent IPostConnAcceptPlugins should not contiune to handle this conn
|
// if returns false, it means subsequent IPostConnAcceptPlugins should not contiune to handle this conn
|
||||||
// and this conn has been closed.
|
// and this conn has been closed.
|
||||||
@@ -112,6 +118,24 @@ func (p *pluginContainer) DoRegister(name string, rcvr interface{}, metadata str
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DoRegisterFunction invokes DoRegisterFunction plugin.
|
||||||
|
func (p *pluginContainer) DoRegisterFunction(name string, fn interface{}, metadata string) error {
|
||||||
|
var es []error
|
||||||
|
for _, rp := range p.plugins {
|
||||||
|
if plugin, ok := rp.(RegisterFunctionPlugin); ok {
|
||||||
|
err := plugin.RegisterFunction(name, fn, metadata)
|
||||||
|
if err != nil {
|
||||||
|
es = append(es, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(es) > 0 {
|
||||||
|
return errors.NewMultiError(es)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
//DoPostConnAccept handles accepted conn
|
//DoPostConnAccept handles accepted conn
|
||||||
func (p *pluginContainer) DoPostConnAccept(conn net.Conn) (net.Conn, bool) {
|
func (p *pluginContainer) DoPostConnAccept(conn net.Conn) (net.Conn, bool) {
|
||||||
var flag bool
|
var flag bool
|
||||||
|
|||||||
@@ -347,14 +347,15 @@ func (s *Server) auth(ctx context.Context, req *protocol.Message) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleRequest(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
|
func (s *Server) handleRequest(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
|
||||||
// pool res?
|
serviceName := req.ServicePath
|
||||||
|
if serviceName == "" {
|
||||||
|
return s.handleRequestForFunction(ctx, req)
|
||||||
|
}
|
||||||
|
methodName := req.ServiceMethod
|
||||||
|
|
||||||
res = req.Clone()
|
res = req.Clone()
|
||||||
|
|
||||||
res.SetMessageType(protocol.Response)
|
res.SetMessageType(protocol.Response)
|
||||||
|
|
||||||
serviceName := req.ServicePath
|
|
||||||
methodName := req.ServiceMethod
|
|
||||||
|
|
||||||
s.serviceMapMu.RLock()
|
s.serviceMapMu.RLock()
|
||||||
service := s.serviceMap[serviceName]
|
service := s.serviceMap[serviceName]
|
||||||
s.serviceMapMu.RUnlock()
|
s.serviceMapMu.RUnlock()
|
||||||
@@ -412,6 +413,69 @@ func (s *Server) handleRequest(ctx context.Context, req *protocol.Message) (res
|
|||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleRequestForFunction(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
|
||||||
|
res = req.Clone()
|
||||||
|
|
||||||
|
res.SetMessageType(protocol.Response)
|
||||||
|
|
||||||
|
methodName := req.ServiceMethod
|
||||||
|
s.serviceMapMu.RLock()
|
||||||
|
service := s.serviceMap[""]
|
||||||
|
s.serviceMapMu.RUnlock()
|
||||||
|
if service == nil {
|
||||||
|
err = errors.New("rpcx: can't find service func raw function")
|
||||||
|
return handleError(res, err)
|
||||||
|
}
|
||||||
|
mtype := service.function[methodName]
|
||||||
|
if mtype == nil {
|
||||||
|
err = errors.New("rpcx: can't find method " + methodName)
|
||||||
|
return handleError(res, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var argv, replyv reflect.Value
|
||||||
|
|
||||||
|
argIsValue := false // if true, need to indirect before calling.
|
||||||
|
if mtype.ArgType.Kind() == reflect.Ptr {
|
||||||
|
argv = reflect.New(mtype.ArgType.Elem())
|
||||||
|
} else {
|
||||||
|
argv = reflect.New(mtype.ArgType)
|
||||||
|
argIsValue = true
|
||||||
|
}
|
||||||
|
|
||||||
|
codec := share.Codecs[req.SerializeType()]
|
||||||
|
if codec == nil {
|
||||||
|
err = fmt.Errorf("can not find codec for %d", req.SerializeType())
|
||||||
|
return handleError(res, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = codec.Decode(req.Payload, argv.Interface())
|
||||||
|
if err != nil {
|
||||||
|
return handleError(res, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if argIsValue {
|
||||||
|
argv = argv.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
replyv = reflect.New(mtype.ReplyType.Elem())
|
||||||
|
|
||||||
|
err = service.callForFunction(ctx, mtype, argv, replyv)
|
||||||
|
if err != nil {
|
||||||
|
return handleError(res, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !req.IsOneway() {
|
||||||
|
data, err := codec.Encode(replyv.Interface())
|
||||||
|
if err != nil {
|
||||||
|
return handleError(res, err)
|
||||||
|
|
||||||
|
}
|
||||||
|
res.Payload = data
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
func handleError(res *protocol.Message, err error) (*protocol.Message, error) {
|
func handleError(res *protocol.Message, err error) (*protocol.Message, error) {
|
||||||
res.SetMessageStatusType(protocol.Error)
|
res.SetMessageStatusType(protocol.Error)
|
||||||
if res.Metadata == nil {
|
if res.Metadata == nil {
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ var typeOfError = reflect.TypeOf((*error)(nil)).Elem()
|
|||||||
// Precompute the reflect type for context.
|
// Precompute the reflect type for context.
|
||||||
var typeOfContext = reflect.TypeOf((*context.Context)(nil)).Elem()
|
var typeOfContext = reflect.TypeOf((*context.Context)(nil)).Elem()
|
||||||
|
|
||||||
|
var emptyService = new(service)
|
||||||
|
|
||||||
type methodType struct {
|
type methodType struct {
|
||||||
sync.Mutex // protects counters
|
sync.Mutex // protects counters
|
||||||
method reflect.Method
|
method reflect.Method
|
||||||
@@ -27,11 +29,20 @@ type methodType struct {
|
|||||||
numCalls uint
|
numCalls uint
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type functionType struct {
|
||||||
|
sync.Mutex // protects counters
|
||||||
|
fn reflect.Value
|
||||||
|
ArgType reflect.Type
|
||||||
|
ReplyType reflect.Type
|
||||||
|
numCalls uint
|
||||||
|
}
|
||||||
|
|
||||||
type service struct {
|
type service struct {
|
||||||
name string // name of service
|
name string // name of service
|
||||||
rcvr reflect.Value // receiver of methods for the service
|
rcvr reflect.Value // receiver of methods for the service
|
||||||
typ reflect.Type // type of the receiver
|
typ reflect.Type // type of the receiver
|
||||||
method map[string]*methodType // registered methods
|
method map[string]*methodType // registered methods
|
||||||
|
function map[string]*functionType // registered functions
|
||||||
}
|
}
|
||||||
|
|
||||||
func isExported(name string) bool {
|
func isExported(name string) bool {
|
||||||
@@ -51,7 +62,7 @@ func isExportedOrBuiltinType(t reflect.Type) bool {
|
|||||||
// Register publishes in the server the set of methods of the
|
// Register publishes in the server the set of methods of the
|
||||||
// receiver value that satisfy the following conditions:
|
// receiver value that satisfy the following conditions:
|
||||||
// - exported method of exported type
|
// - exported method of exported type
|
||||||
// - three arguments, the first is of context.Context, both of exported type or three arguments
|
// - three arguments, the first is of context.Context, both of exported type for three arguments
|
||||||
// - the third argument is a pointer
|
// - the third argument is a pointer
|
||||||
// - one return value, of type error
|
// - one return value, of type error
|
||||||
// It returns an error if the receiver is not an exported type or has
|
// It returns an error if the receiver is not an exported type or has
|
||||||
@@ -59,6 +70,7 @@ func isExportedOrBuiltinType(t reflect.Type) bool {
|
|||||||
// The client accesses each method using a string of the form "Type.Method",
|
// The client accesses each method using a string of the form "Type.Method",
|
||||||
// where Type is the receiver's concrete type.
|
// where Type is the receiver's concrete type.
|
||||||
func (s *Server) Register(rcvr interface{}, metadata string) error {
|
func (s *Server) Register(rcvr interface{}, metadata string) error {
|
||||||
|
s.Plugins.DoRegister("", rcvr, metadata)
|
||||||
return s.register(rcvr, "", false)
|
return s.register(rcvr, "", false)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -73,6 +85,28 @@ func (s *Server) RegisterName(name string, rcvr interface{}, metadata string) er
|
|||||||
return s.register(rcvr, name, true)
|
return s.register(rcvr, name, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterFunction publishes a function that satisfy the following conditions:
|
||||||
|
// - three arguments, the first is of context.Context, both of exported type for three arguments
|
||||||
|
// - the third argument is a pointer
|
||||||
|
// - one return value, of type error
|
||||||
|
// The client accesses function using a string of the form ".Method",
|
||||||
|
// where service path is empty.
|
||||||
|
func (s *Server) RegisterFunction(fn interface{}, metadata string) error {
|
||||||
|
s.Plugins.DoRegisterFunction("", fn, metadata)
|
||||||
|
return s.registerFunction(fn, "", false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterFunctionName is like RegisterFunction but uses the provided name for the function
|
||||||
|
// instead of the function's concrete type.
|
||||||
|
func (s *Server) RegisterFunctionName(name string, fn interface{}, metadata string) error {
|
||||||
|
if s.Plugins == nil {
|
||||||
|
s.Plugins = &pluginContainer{}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Plugins.DoRegisterFunction(name, fn, metadata)
|
||||||
|
return s.registerFunction(fn, name, true)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) register(rcvr interface{}, name string, useName bool) error {
|
func (s *Server) register(rcvr interface{}, name string, useName bool) error {
|
||||||
s.serviceMapMu.Lock()
|
s.serviceMapMu.Lock()
|
||||||
defer s.serviceMapMu.Unlock()
|
defer s.serviceMapMu.Unlock()
|
||||||
@@ -119,6 +153,70 @@ func (s *Server) register(rcvr interface{}, name string, useName bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) registerFunction(fn interface{}, name string, useName bool) error {
|
||||||
|
s.serviceMapMu.Lock()
|
||||||
|
defer s.serviceMapMu.Unlock()
|
||||||
|
if s.serviceMap == nil {
|
||||||
|
s.serviceMap = make(map[string]*service)
|
||||||
|
}
|
||||||
|
|
||||||
|
f, ok := fn.(reflect.Value)
|
||||||
|
if !ok {
|
||||||
|
f = reflect.ValueOf(fn)
|
||||||
|
}
|
||||||
|
if f.Kind() != reflect.Func {
|
||||||
|
return errors.New("function must be func or bound method")
|
||||||
|
}
|
||||||
|
|
||||||
|
fname := reflect.Indirect(f).Type().Name()
|
||||||
|
if useName {
|
||||||
|
fname = name
|
||||||
|
}
|
||||||
|
if fname == "" {
|
||||||
|
errorStr := "rpcx.registerFunction: no func name for type " + f.Type().String()
|
||||||
|
log.Error(errorStr)
|
||||||
|
return errors.New(errorStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
t := f.Type()
|
||||||
|
if t.NumIn() != 3 {
|
||||||
|
return fmt.Errorf("rpcx.registerFunction: has wrong number of ins: %s", f.Type().String())
|
||||||
|
}
|
||||||
|
if t.NumOut() != 1 {
|
||||||
|
return fmt.Errorf("rpcx.registerFunction: has wrong number of outs: %s", f.Type().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// First arg must be context.Context
|
||||||
|
ctxType := t.In(0)
|
||||||
|
if !ctxType.Implements(typeOfContext) {
|
||||||
|
return fmt.Errorf("function %s must use context as the first parameter", f.Type().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
argType := t.In(1)
|
||||||
|
if !isExportedOrBuiltinType(argType) {
|
||||||
|
return fmt.Errorf("function %s parameter type not exported: %v", f.Type().String(), argType)
|
||||||
|
}
|
||||||
|
|
||||||
|
replyType := t.In(2)
|
||||||
|
if replyType.Kind() != reflect.Ptr {
|
||||||
|
return fmt.Errorf("function %s reply type not a pointer: %s", f.Type().String(), replyType)
|
||||||
|
}
|
||||||
|
if !isExportedOrBuiltinType(replyType) {
|
||||||
|
return fmt.Errorf("function %s reply type not exported: %v", f.Type().String(), replyType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The return type of the method must be error.
|
||||||
|
if returnType := t.Out(0); returnType != typeOfError {
|
||||||
|
return fmt.Errorf("function %s returns %s, not error", f.Type().String(), returnType.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Install the methods
|
||||||
|
emptyService.function[fname] = &functionType{fn: f, ArgType: argType, ReplyType: replyType}
|
||||||
|
|
||||||
|
s.serviceMap[""] = emptyService
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// suitableMethods returns suitable Rpc methods of typ, it will report
|
// suitableMethods returns suitable Rpc methods of typ, it will report
|
||||||
// error using log if reportErr is true.
|
// error using log if reportErr is true.
|
||||||
func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType {
|
func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType {
|
||||||
@@ -142,7 +240,7 @@ func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType {
|
|||||||
ctxType := mtype.In(1)
|
ctxType := mtype.In(1)
|
||||||
if !ctxType.Implements(typeOfContext) {
|
if !ctxType.Implements(typeOfContext) {
|
||||||
if reportErr {
|
if reportErr {
|
||||||
log.Info("method", mname, "has wrong number of ins:", mtype.NumIn())
|
log.Info("method", mname, " must use context.Context as the first parameter")
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -151,7 +249,7 @@ func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType {
|
|||||||
argType := mtype.In(2)
|
argType := mtype.In(2)
|
||||||
if !isExportedOrBuiltinType(argType) {
|
if !isExportedOrBuiltinType(argType) {
|
||||||
if reportErr {
|
if reportErr {
|
||||||
log.Info(mname, "argument type not exported:", argType)
|
log.Info(mname, "parameter type not exported:", argType)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -207,3 +305,21 @@ func (s *service) call(ctx context.Context, mtype *methodType, argv, replyv refl
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *service) callForFunction(ctx context.Context, ft *functionType, argv, replyv reflect.Value) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = fmt.Errorf("internal error: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Invoke the function, providing a new value for the reply.
|
||||||
|
returnValues := ft.fn.Call([]reflect.Value{reflect.ValueOf(ctx), argv, replyv})
|
||||||
|
// The return value for the method is an error.
|
||||||
|
errInter := returnValues[0].Interface()
|
||||||
|
if errInter != nil {
|
||||||
|
return errInter.(error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user