diff --git a/client/client.go b/client/client.go index 2fa1590..0386863 100644 --- a/client/client.go +++ b/client/client.go @@ -13,6 +13,7 @@ import ( "sync" "time" + opentracing "github.com/opentracing/opentracing-go" circuit "github.com/rubyist/circuitbreaker" "github.com/smallnest/rpcx/log" "github.com/smallnest/rpcx/protocol" @@ -208,6 +209,12 @@ func (client *Client) Go(ctx context.Context, servicePath, serviceMethod string, if meta != nil { //copy meta in context to meta in requests call.Metadata = meta.(map[string]string) } + + if _, ok := ctx.(*share.Context); !ok { + ctx = share.NewContext(ctx) + } + client.injectSpan(ctx, call) + call.Args = args call.Reply = reply if done == nil { @@ -226,6 +233,32 @@ func (client *Client) Go(ctx context.Context, servicePath, serviceMethod string, return call } +func (client *Client) injectSpan(ctx context.Context, call *Call) { + var rpcxContext *share.Context + var ok bool + if rpcxContext, ok = ctx.(*share.Context); !ok { + return + } + sp := rpcxContext.Value(share.OpentracingSpanClientKey) + if sp == nil { // have not config opentracing plugin + return + } + + span := sp.(opentracing.Span) + if call.Metadata == nil { + call.Metadata = make(map[string]string) + } + meta := call.Metadata + + err := opentracing.GlobalTracer().Inject( + span.Context(), + opentracing.TextMap, + opentracing.TextMapCarrier(meta)) + if err != nil { + log.Errorf("failed to inject span: %v", err) + } +} + // Call invokes the named function, waits for it to complete, and returns its error status. func (client *Client) Call(ctx context.Context, servicePath, serviceMethod string, args interface{}, reply interface{}) error { return client.call(ctx, servicePath, serviceMethod, args, reply) diff --git a/client/opentracing.go b/client/opentracing.go new file mode 100644 index 0000000..fbaa7cc --- /dev/null +++ b/client/opentracing.go @@ -0,0 +1,47 @@ +package client + +import ( + "context" + + opentracing "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go/ext" + "github.com/smallnest/rpcx/share" +) + +type OpenTracingPlugin struct{} + +func (p *OpenTracingPlugin) DoPreCall(ctx context.Context, servicePath, serviceMethod string, args interface{}) error { + var span1 opentracing.Span + + // if it is called in rpc service in case that a service calls antoher service, + // we uses the span in the service context as the parent span. + parentSpan := ctx.Value(share.OpentracingSpanServerKey) + if parentSpan != nil { + span1 = opentracing.StartSpan( + "rpcx.client."+servicePath+"."+serviceMethod, + opentracing.ChildOf(parentSpan.(opentracing.Span).Context())) + } else { + wireContext, err := share.GetSpanContextFromContext(ctx) + if err == nil && wireContext != nil { //try to parse span from request + span1 = opentracing.StartSpan( + "rpcx.client."+servicePath+"."+serviceMethod, + ext.RPCServerOption(wireContext)) + } else { // parse span from context or create root context + span1, _ = opentracing.StartSpanFromContext(ctx, "rpcx.client."+servicePath+"."+serviceMethod) + } + } + + if rpcxContext, ok := ctx.(*share.Context); ok { + rpcxContext.SetValue(share.OpentracingSpanClientKey, span1) + } + return nil +} +func (p *OpenTracingPlugin) DoPostCall(ctx context.Context, servicePath, serviceMethod string, args interface{}, reply interface{}, err error) error { + if rpcxContext, ok := ctx.(*share.Context); ok { + span1 := rpcxContext.Value(share.OpentracingSpanClientKey) + if span1 != nil { + span1.(opentracing.Span).Finish() + } + } + return nil +} diff --git a/client/xclient.go b/client/xclient.go index 19acae5..dee9618 100644 --- a/client/xclient.go +++ b/client/xclient.go @@ -299,7 +299,7 @@ func (c *xClient) getCachedClientWithoutLock(k string) (RPCClient, error) { //double check client = c.cachedClient[k] - if client == nil || client.IsShutdown(){ + if client == nil || client.IsShutdown() { network, addr := splitNetworkAndAddress(k) if network == "inprocess" { client = InprocessClient @@ -356,7 +356,7 @@ func (c *xClient) Go(ctx context.Context, serviceMethod string, args interface{} metadata := ctx.Value(share.ReqMetaDataKey) if metadata == nil { metadata = map[string]string{} - ctx = context.WithValue(ctx,share.ReqMetaDataKey,metadata) + ctx = context.WithValue(ctx, share.ReqMetaDataKey, metadata) } m := metadata.(map[string]string) m[share.AuthKey] = c.auth @@ -380,7 +380,7 @@ func (c *xClient) Call(ctx context.Context, serviceMethod string, args interface metadata := ctx.Value(share.ReqMetaDataKey) if metadata == nil { metadata = map[string]string{} - ctx = context.WithValue(ctx,share.ReqMetaDataKey,metadata) + ctx = context.WithValue(ctx, share.ReqMetaDataKey, metadata) } m := metadata.(map[string]string) m[share.AuthKey] = c.auth @@ -517,7 +517,7 @@ func (c *xClient) SendRaw(ctx context.Context, r *protocol.Message) (map[string] metadata := ctx.Value(share.ReqMetaDataKey) if metadata == nil { metadata = map[string]string{} - ctx = context.WithValue(ctx,share.ReqMetaDataKey,metadata) + ctx = context.WithValue(ctx, share.ReqMetaDataKey, metadata) } m := metadata.(map[string]string) m[share.AuthKey] = c.auth @@ -600,6 +600,8 @@ func (c *xClient) wrapCall(ctx context.Context, client RPCClient, serviceMethod if client == nil { return ErrServerUnavailable } + + ctx = share.NewContext(ctx) c.Plugins.DoPreCall(ctx, c.servicePath, serviceMethod, args) err := client.Call(ctx, c.servicePath, serviceMethod, args, reply) c.Plugins.DoPostCall(ctx, c.servicePath, serviceMethod, args, reply, err) @@ -619,7 +621,7 @@ func (c *xClient) Broadcast(ctx context.Context, serviceMethod string, args inte metadata := ctx.Value(share.ReqMetaDataKey) if metadata == nil { metadata = map[string]string{} - ctx = context.WithValue(ctx,share.ReqMetaDataKey,metadata) + ctx = context.WithValue(ctx, share.ReqMetaDataKey, metadata) } m := metadata.(map[string]string) m[share.AuthKey] = c.auth @@ -688,7 +690,7 @@ func (c *xClient) Fork(ctx context.Context, serviceMethod string, args interface metadata := ctx.Value(share.ReqMetaDataKey) if metadata == nil { metadata = map[string]string{} - ctx = context.WithValue(ctx,share.ReqMetaDataKey,metadata) + ctx = context.WithValue(ctx, share.ReqMetaDataKey, metadata) } m := metadata.(map[string]string) m[share.AuthKey] = c.auth diff --git a/server/plugin.go b/server/plugin.go index 046083e..c2374b0 100644 --- a/server/plugin.go +++ b/server/plugin.go @@ -25,6 +25,8 @@ type PluginContainer interface { DoPreReadRequest(ctx context.Context) error DoPostReadRequest(ctx context.Context, r *protocol.Message, e error) error + DoPreHandleRequest(ctx context.Context, req *protocol.Message) error + DoPreWriteResponse(context.Context, *protocol.Message, *protocol.Message) error DoPostWriteResponse(context.Context, *protocol.Message, *protocol.Message, error) error @@ -70,6 +72,11 @@ type ( PostReadRequest(ctx context.Context, r *protocol.Message, e error) error } + //PreHandleRequestPlugin represents . + PreHandleRequestPlugin interface { + PreHandleRequest(ctx context.Context, r *protocol.Message) error + } + //PreWriteResponsePlugin represents . PreWriteResponsePlugin interface { PreWriteResponse(context.Context, *protocol.Message, *protocol.Message) error @@ -232,6 +239,20 @@ func (p *pluginContainer) DoPostReadRequest(ctx context.Context, r *protocol.Mes return nil } +// DoPreHandleRequest invokes PreHandleRequest plugin. +func (p *pluginContainer) DoPreHandleRequest(ctx context.Context, r *protocol.Message) error { + for i := range p.plugins { + if plugin, ok := p.plugins[i].(PreHandleRequestPlugin); ok { + err := plugin.PreHandleRequest(ctx, r) + if err != nil { + return err + } + } + } + + return nil +} + // DoPreWriteResponse invokes PreWriteResponse plugin. func (p *pluginContainer) DoPreWriteResponse(ctx context.Context, req *protocol.Message, res *protocol.Message) error { for i := range p.plugins { diff --git a/server/server.go b/server/server.go index c7d0b66..c99b1e4 100644 --- a/server/server.go +++ b/server/server.go @@ -53,6 +53,8 @@ var ( StartRequestContextKey = &contextKey{"start-parse-request"} // StartSendRequestContextKey records the start time StartSendRequestContextKey = &contextKey{"start-send-request"} + // TagContextKey is used to record extra info in handling services. Its value is a map[string]interface{} + TagContextKey = &contextKey{"service-tag"} ) // Server is rpcx server that use TCP or UDP. @@ -132,7 +134,7 @@ func (s *Server) ActiveClientConn() []net.Conn { // // servicePath, serviceMethod, metadata can be set to zero values. func (s *Server) SendMessage(conn net.Conn, servicePath, serviceMethod string, metadata map[string]string, data []byte) error { - ctx := context.WithValue(context.Background(), StartSendRequestContextKey, time.Now().UnixNano()) + ctx := share.WithValue(context.Background(), StartSendRequestContextKey, time.Now().UnixNano()) s.Plugins.DoPreWriteRequest(ctx) req := protocol.GetPooledMsg() @@ -345,7 +347,8 @@ func (s *Server) serveConn(conn net.Conn) { conn.SetReadDeadline(t0.Add(s.readTimeout)) } - ctx := context.WithValue(context.Background(), RemoteConnContextKey, conn) + ctx := share.WithValue(context.Background(), RemoteConnContextKey, conn) + req, err := s.readRequest(ctx, r) if err != nil { if err == io.EOF { @@ -362,7 +365,7 @@ func (s *Server) serveConn(conn net.Conn) { conn.SetWriteDeadline(t0.Add(s.writeTimeout)) } - ctx = context.WithValue(ctx, StartRequestContextKey, time.Now().UnixNano()) + ctx = share.WithLocalValue(ctx, StartRequestContextKey, time.Now().UnixNano()) if !req.IsHeartbeat() { err = s.auth(ctx, req) } @@ -400,9 +403,11 @@ func (s *Server) serveConn(conn net.Conn) { } resMetadata := make(map[string]string) - newCtx := context.WithValue(context.WithValue(ctx, share.ReqMetaDataKey, req.Metadata), + newCtx := share.WithLocalValue(share.WithLocalValue(ctx, share.ReqMetaDataKey, req.Metadata), share.ResMetaDataKey, resMetadata) + s.Plugins.DoPreHandleRequest(newCtx, req) + res, err := s.handleRequest(newCtx, req) if err != nil { diff --git a/serverplugin/opentracing.go b/serverplugin/opentracing.go new file mode 100644 index 0000000..701cc18 --- /dev/null +++ b/serverplugin/opentracing.go @@ -0,0 +1,71 @@ +package serverplugin + +import ( + "context" + "net" + + opentracing "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go/ext" + "github.com/opentracing/opentracing-go/log" + "github.com/smallnest/rpcx/protocol" + "github.com/smallnest/rpcx/server" + "github.com/smallnest/rpcx/share" +) + +type OpenTracingPlugin struct{} + +func (p OpenTracingPlugin) Register(name string, rcvr interface{}, metadata string) error { + span1 := opentracing.StartSpan( + "rpcx.Register") + defer span1.Finish() + + span1.LogFields(log.String("register_service", name)) + + return nil +} + +func (p OpenTracingPlugin) RegisterFunction(name string, fn interface{}, metadata string) error { + span1 := opentracing.StartSpan( + "rpcx.RegisterFunction") + defer span1.Finish() + + span1.LogFields(log.String("register_function", name)) + return nil +} + +func (p OpenTracingPlugin) PostConnAccept(conn net.Conn) (net.Conn, bool) { + span1 := opentracing.StartSpan( + "rpcx.AcceptConn") + defer span1.Finish() + + span1.LogFields(log.String("remote_addr", conn.RemoteAddr().String())) + return conn, true +} + +func (p OpenTracingPlugin) PreHandleRequest(ctx context.Context, r *protocol.Message) error { + wireContext, err := share.GetSpanContextFromContext(ctx) + if err != nil { + return err + } + span1 := opentracing.StartSpan( + "rpcx.service."+r.ServicePath+"."+r.ServiceMethod, + ext.RPCServerOption(wireContext)) + + clientConn := ctx.Value(server.RemoteConnContextKey).(net.Conn) + span1.LogFields(log.String("remote_addr", clientConn.RemoteAddr().Network())) + + if rpcxContext, ok := ctx.(*share.Context); ok { + rpcxContext.SetValue(share.OpentracingSpanServerKey, span1) + } + return nil +} + +func (p OpenTracingPlugin) PostWriteResponse(ctx context.Context, req *protocol.Message, res *protocol.Message, err error) error { + if rpcxContext, ok := ctx.(*share.Context); ok { + span1 := rpcxContext.Value(share.OpentracingSpanServerKey) + if span1 != nil { + span1.(opentracing.Span).Finish() + } + } + return nil +} diff --git a/share/context.go b/share/context.go new file mode 100644 index 0000000..d74226f --- /dev/null +++ b/share/context.go @@ -0,0 +1,69 @@ +package share + +import ( + "context" + "fmt" + "reflect" + + opentracing "github.com/opentracing/opentracing-go" +) + +// var _ context.Context = &Context{} + +// Context is a rpcx customized Context that can contains multiple values. +type Context struct { + tags map[interface{}]interface{} + context.Context +} + +func NewContext(ctx context.Context) *Context { + tags := make(map[interface{}]interface{}) + return &Context{Context: ctx, tags: tags} +} +func (c *Context) Value(key interface{}) interface{} { + if v, ok := c.tags[key]; ok { + return v + } + return c.Context.Value(key) +} + +func (c *Context) SetValue(key, val interface{}) { + c.tags[key] = val +} + +func (c *Context) String() string { + return fmt.Sprintf("%v.WithValue(%v)", c.Context, c.tags) +} + +func WithValue(parent context.Context, key, val interface{}) *Context { + if key == nil { + panic("nil key") + } + if !reflect.TypeOf(key).Comparable() { + panic("key is not comparable") + } + + tags := make(map[interface{}]interface{}) + tags[key] = val + return &Context{Context: parent, tags: tags} +} + +func WithLocalValue(ctx *Context, key, val interface{}) *Context { + if key == nil { + panic("nil key") + } + if !reflect.TypeOf(key).Comparable() { + panic("key is not comparable") + } + + ctx.tags[key] = val + return ctx +} + +// GetSpanContextFromContext get opentracing.SpanContext from context.Context. +func GetSpanContextFromContext(ctx context.Context) (opentracing.SpanContext, error) { + reqMeta := ctx.Value(ReqMetaDataKey).(map[string]string) + return opentracing.GlobalTracer().Extract( + opentracing.TextMap, + opentracing.TextMapCarrier(reqMeta)) +} diff --git a/share/share.go b/share/share.go index 7373f32..333c30f 100644 --- a/share/share.go +++ b/share/share.go @@ -11,6 +11,11 @@ const ( // AuthKey is used in metadata. AuthKey = "__AUTH" + + // OpentracingSpanServerKey key in service context + OpentracingSpanServerKey = "opentracing_span_server_key" + // OpentracingSpanClientKey key in client context + OpentracingSpanClientKey = "opentracing_span_client_key" ) var (