From 20e332420acfbd75992d1365d657630c623ac5c5 Mon Sep 17 00:00:00 2001 From: smallnest Date: Mon, 2 Aug 2021 18:18:59 +0800 Subject: [PATCH] add handler implementation --- server/context.go | 159 ++++++++++++++++++++++++++++++++++++++++++ server/server.go | 21 ++++++ server/server_test.go | 68 +++++++++++++++++- 3 files changed, 245 insertions(+), 3 deletions(-) create mode 100644 server/context.go diff --git a/server/context.go b/server/context.go new file mode 100644 index 0000000..ca31720 --- /dev/null +++ b/server/context.go @@ -0,0 +1,159 @@ +package server + +import ( + "fmt" + "net" + + "github.com/smallnest/rpcx/protocol" + "github.com/smallnest/rpcx/share" +) + +// Context represents a rpcx FastCall context. +type Context struct { + conn net.Conn + req *protocol.Message + ctx *share.Context +} + +// NewContext creates a server.Context for Handler. +func NewContext(ctx *share.Context, conn net.Conn, req *protocol.Message) *Context { + return &Context{conn: conn, req: req, ctx: ctx} +} + +// Get returns value for key. +func (ctx *Context) Get(key interface{}) interface{} { + return ctx.ctx.Value(key) +} + +// SetValue sets the kv pair. +func (ctx *Context) SetValue(key, val interface{}) { + if key == nil || val == nil { + return + } + ctx.ctx.SetValue(key, val) +} + +// Payload returns the payload. +func (ctx *Context) Payload() []byte { + return ctx.req.Payload +} + +// Metadata returns the metadata. +func (ctx *Context) Metadata() map[string]string { + return ctx.req.Metadata +} + +// ServicePath returns the ServicePath. +func (ctx *Context) ServicePath() string { + return ctx.req.ServicePath +} + +// ServiceMethod returns the ServiceMethod. +func (ctx *Context) ServiceMethod() string { + return ctx.req.ServiceMethod +} + +// Bind parses the body data and stores the result to v. +func (ctx *Context) Bind(v interface{}) error { + req := ctx.req + if v != nil { + codec := share.Codecs[req.SerializeType()] + if codec == nil { + return fmt.Errorf("can not find codec for %d", req.SerializeType()) + } + + err := codec.Decode(req.Payload, v) + if err != nil { + return err + } + } + return nil +} + +func (ctx *Context) Write(v interface{}) error { + req := ctx.req + + if req.IsOneway() { // no need to send response + return nil + } + + codec := share.Codecs[req.SerializeType()] + if codec == nil { + return fmt.Errorf("can not find codec for %d", req.SerializeType()) + } + + res := req.Clone() + res.SetMessageType(protocol.Response) + + if v != nil { + data, err := codec.Encode(v) + if err != nil { + return err + } + res.Payload = data + } + + resMetadata := ctx.Get(share.ResMetaDataKey) + if resMetadata != nil { + resMetaInCtx := resMetadata.(map[string]string) + meta := res.Metadata + if meta == nil { + res.Metadata = resMetaInCtx + } else { + for k, v := range resMetaInCtx { + if meta[k] == "" { + meta[k] = v + } + } + } + } + + if len(res.Payload) > 1024 && req.CompressType() != protocol.None { + res.SetCompressType(req.CompressType()) + } + respData := res.EncodeSlicePointer() + _, err := ctx.conn.Write(*respData) + protocol.PutData(respData) + + return err +} + +func (ctx *Context) WriteError(err error) error { + req := ctx.req + + if req.IsOneway() { // no need to send response + return nil + } + + codec := share.Codecs[req.SerializeType()] + if codec == nil { + return fmt.Errorf("can not find codec for %d", req.SerializeType()) + } + + res := req.Clone() + res.SetMessageType(protocol.Response) + + resMetadata := ctx.Get(share.ResMetaDataKey) + if resMetadata != nil { + resMetaInCtx := resMetadata.(map[string]string) + meta := res.Metadata + if meta == nil { + res.Metadata = resMetaInCtx + } else { + for k, v := range resMetaInCtx { + if meta[k] == "" { + meta[k] = v + } + } + } + } + + res.SetMessageStatusType(protocol.Error) + res.Metadata[protocol.ServiceError] = err.Error() + + respData := res.EncodeSlicePointer() + ctx.conn.Write(*respData) + protocol.PutData(respData) + + return nil +} diff --git a/server/server.go b/server/server.go index c3b1fca..c48dd04 100644 --- a/server/server.go +++ b/server/server.go @@ -61,6 +61,8 @@ var ( HttpConnContextKey = &contextKey{"http-conn"} ) +type Handler func(ctx *Context) error + // Server is rpcx server that use TCP or UDP. type Server struct { ln net.Listener @@ -73,6 +75,8 @@ type Server struct { serviceMapMu sync.RWMutex serviceMap map[string]*service + router map[string]Handler + mu sync.RWMutex activeConn map[net.Conn]struct{} doneChan chan struct{} @@ -130,6 +134,10 @@ func (s *Server) Address() net.Addr { return s.ln.Addr() } +func (s *Server) AddHandler(servicePath, serviceMethod string, handler func(*Context) error) { + s.router[servicePath+"."+serviceMethod] = handler +} + // ActiveClientConn returns active connections. func (s *Server) ActiveClientConn() []net.Conn { s.mu.RLock() @@ -479,6 +487,19 @@ func (s *Server) serveConn(conn net.Conn) { if share.Trace { log.Debugf("server handle request %+v from conn: %v", req, conn.RemoteAddr().String()) } + + // first use handler + if handler, ok := s.router[req.ServicePath+"."+req.ServiceMethod]; ok { + sctx := NewContext(ctx, conn, req) + err := handler(sctx) + if err != nil { + log.Errorf("[handler internal error]: servicepath: %s, servicemethod, err: %v", req.ServicePath, req.ServiceMethod, err) + } + + return + } + + // res, err := s.handleRequest(ctx, req) if err != nil { if s.HandleServiceError != nil { diff --git a/server/server_test.go b/server/server_test.go index c072787..0fbfa82 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1,15 +1,18 @@ package server import ( + "bytes" "context" "encoding/json" + "io/ioutil" + "net" "testing" - "time" testutils "github.com/smallnest/rpcx/_testutils" "github.com/smallnest/rpcx/protocol" "github.com/smallnest/rpcx/share" + "github.com/stretchr/testify/assert" ) type Args struct { @@ -58,11 +61,10 @@ func TestShutdownHook(t *testing.T) { if cancel1 != nil { cancel1() } - } func TestHandleRequest(t *testing.T) { - //use jsoncodec + // use jsoncodec req := protocol.NewMessage() req.SetVersion(0) @@ -116,3 +118,63 @@ func TestHandleRequest(t *testing.T) { t.Fatalf("expect 200 but got %d", reply.C) } } + +func TestHandler(t *testing.T) { + // use jsoncodec + + req := protocol.NewMessage() + req.SetVersion(0) + req.SetMessageType(protocol.Request) + req.SetHeartbeat(false) + req.SetOneway(false) + req.SetCompressType(protocol.None) + req.SetMessageStatusType(protocol.Normal) + req.SetSerializeType(protocol.JSON) + req.SetSeq(1234567890) + + req.ServicePath = "Arith" + req.ServiceMethod = "Mul" + + argv := &Args{ + A: 10, + B: 20, + } + + data, err := json.Marshal(argv) + if err != nil { + t.Fatal(err) + } + + req.Payload = data + + serverConn, clientConn := net.Pipe() + + handler := func(ctx *Context) error { + req := &Args{} + res := &Reply{} + ctx.Bind(req) + res.C = req.A * req.B + + return ctx.Write(res) + } + + go func() { + ctx := NewContext(share.NewContext(context.Background()), serverConn, req) + err = handler(ctx) + assert.NoError(t, err) + + serverConn.Close() + }() + + data, err = ioutil.ReadAll(clientConn) + assert.NoError(t, err) + + resp, err := protocol.Read(bytes.NewReader(data)) + assert.NoError(t, err) + + assert.Equal(t, "Arith", resp.ServicePath) + assert.Equal(t, "Mul", resp.ServiceMethod) + assert.Equal(t, req.Seq(), resp.Seq()) + + assert.Equal(t, "{\"C\":200}", string(resp.Payload)) +}