diff --git a/server/plugin.go b/server/plugin.go index 005a84d..d688647 100644 --- a/server/plugin.go +++ b/server/plugin.go @@ -34,6 +34,8 @@ type PluginContainer interface { DoPreWriteRequest(ctx context.Context) error DoPostWriteRequest(ctx context.Context, r *protocol.Message, e error) error + + DoHeartbeatRequest(ctx context.Context, req *protocol.Message) error } // Plugin is the server plugin interface. @@ -106,6 +108,11 @@ type ( PostWriteRequestPlugin interface { PostWriteRequest(ctx context.Context, r *protocol.Message, e error) error } + + // HeartbeatPlugin is . + HeartbeatPlugin interface { + OnHeartbeat(ctx context.Context, req *protocol.Message) error + } ) // pluginContainer implements PluginContainer interface. @@ -348,3 +355,17 @@ func (p *pluginContainer) DoPostWriteRequest(ctx context.Context, r *protocol.Me return nil } + +// DoHeartbeatRequest invokes HeartbeatRequest plugin. +func (p *pluginContainer) DoHeartbeatRequest(ctx context.Context, r *protocol.Message) error { + for i := range p.plugins { + if plugin, ok := p.plugins[i].(HeartbeatPlugin); ok { + err := plugin.OnHeartbeat(ctx, r) + if err != nil { + return err + } + } + } + + return nil +} \ No newline at end of file diff --git a/server/plugin_test.go b/server/plugin_test.go new file mode 100644 index 0000000..255ace7 --- /dev/null +++ b/server/plugin_test.go @@ -0,0 +1,69 @@ +package server + +import ( + "context" + "github.com/smallnest/rpcx/client" + "github.com/smallnest/rpcx/protocol" + "net" + "sync" + "testing" + "time" +) + +type HeartbeatHandler struct{} + +func (h *HeartbeatHandler) OnHeartbeat(ctx context.Context, req *protocol.Message) error { + conn := ctx.Value(RemoteConnContextKey).(net.Conn) + println("OnHeartbeat:", conn.RemoteAddr().String()) + return nil +} + +// TestPluginHeartbeat: go test -v -test.run TestPluginHeartbeat +func TestPluginHeartbeat(t *testing.T) { + h := &HeartbeatHandler{} + s := NewServer( + WithReadTimeout(time.Duration(5)*time.Second), + WithWriteTimeout(time.Duration(5)*time.Second), + ) + s.Plugins.Add(h) + s.RegisterName("Arith", new(Arith), "") + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + // server + defer wg.Done() + err := s.Serve("tcp", "127.0.0.1:9001") + if err != nil { + t.Log(err.Error()) + } + }() + go func() { + defer wg.Done() + // client + opts := client.DefaultOption + opts.Heartbeat = true + opts.HeartbeatInterval = time.Second + opts.ReadTimeout = time.Duration(5) * time.Second + opts.WriteTimeout = time.Duration(5) * time.Second + opts.ConnectTimeout = time.Duration(5) * time.Second + // PeerDiscovery + d := client.NewPeer2PeerDiscovery("tcp@127.0.0.1:9001", "") + c := client.NewXClient("Arith", client.Failtry, client.RoundRobin, d, opts) + i := 0 + for { + i++ + resp := &Reply{} + c.Call(context.Background(), "Mul", &Args{A: 1, B: 5}, resp) + t.Log("call Mul resp:", resp.C) + time.Sleep(time.Second) + if i > 10 { + break + } + } + c.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s.Shutdown(ctx) + }() + wg.Wait() +} \ No newline at end of file diff --git a/server/server.go b/server/server.go index f268140..5d607d3 100644 --- a/server/server.go +++ b/server/server.go @@ -405,6 +405,7 @@ func (s *Server) serveConn(conn net.Conn) { defer atomic.AddInt32(&s.handlerMsgNum, -1) if req.IsHeartbeat() { + s.Plugins.DoHeartbeatRequest(ctx, req) req.SetMessageType(protocol.Response) data := req.EncodeSlicePointer() conn.Write(*data)