mirror of
https://github.com/smallnest/rpcx.git
synced 2025-10-25 00:50:24 +08:00
#165 support registering raw functions
This commit is contained in:
@@ -347,14 +347,15 @@ func (s *Server) auth(ctx context.Context, req *protocol.Message) error {
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
|
||||
// pool res?
|
||||
serviceName := req.ServicePath
|
||||
if serviceName == "" {
|
||||
return s.handleRequestForFunction(ctx, req)
|
||||
}
|
||||
methodName := req.ServiceMethod
|
||||
|
||||
res = req.Clone()
|
||||
|
||||
res.SetMessageType(protocol.Response)
|
||||
|
||||
serviceName := req.ServicePath
|
||||
methodName := req.ServiceMethod
|
||||
|
||||
s.serviceMapMu.RLock()
|
||||
service := s.serviceMap[serviceName]
|
||||
s.serviceMapMu.RUnlock()
|
||||
@@ -412,6 +413,69 @@ func (s *Server) handleRequest(ctx context.Context, req *protocol.Message) (res
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (s *Server) handleRequestForFunction(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
|
||||
res = req.Clone()
|
||||
|
||||
res.SetMessageType(protocol.Response)
|
||||
|
||||
methodName := req.ServiceMethod
|
||||
s.serviceMapMu.RLock()
|
||||
service := s.serviceMap[""]
|
||||
s.serviceMapMu.RUnlock()
|
||||
if service == nil {
|
||||
err = errors.New("rpcx: can't find service func raw function")
|
||||
return handleError(res, err)
|
||||
}
|
||||
mtype := service.function[methodName]
|
||||
if mtype == nil {
|
||||
err = errors.New("rpcx: can't find method " + methodName)
|
||||
return handleError(res, err)
|
||||
}
|
||||
|
||||
var argv, replyv reflect.Value
|
||||
|
||||
argIsValue := false // if true, need to indirect before calling.
|
||||
if mtype.ArgType.Kind() == reflect.Ptr {
|
||||
argv = reflect.New(mtype.ArgType.Elem())
|
||||
} else {
|
||||
argv = reflect.New(mtype.ArgType)
|
||||
argIsValue = true
|
||||
}
|
||||
|
||||
codec := share.Codecs[req.SerializeType()]
|
||||
if codec == nil {
|
||||
err = fmt.Errorf("can not find codec for %d", req.SerializeType())
|
||||
return handleError(res, err)
|
||||
}
|
||||
|
||||
err = codec.Decode(req.Payload, argv.Interface())
|
||||
if err != nil {
|
||||
return handleError(res, err)
|
||||
}
|
||||
|
||||
if argIsValue {
|
||||
argv = argv.Elem()
|
||||
}
|
||||
|
||||
replyv = reflect.New(mtype.ReplyType.Elem())
|
||||
|
||||
err = service.callForFunction(ctx, mtype, argv, replyv)
|
||||
if err != nil {
|
||||
return handleError(res, err)
|
||||
}
|
||||
|
||||
if !req.IsOneway() {
|
||||
data, err := codec.Encode(replyv.Interface())
|
||||
if err != nil {
|
||||
return handleError(res, err)
|
||||
|
||||
}
|
||||
res.Payload = data
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func handleError(res *protocol.Message, err error) (*protocol.Message, error) {
|
||||
res.SetMessageStatusType(protocol.Error)
|
||||
if res.Metadata == nil {
|
||||
|
||||
Reference in New Issue
Block a user