diff --git a/client/arith_service_test.go b/_testutils/arith_service.go similarity index 99% rename from client/arith_service_test.go rename to _testutils/arith_service.go index 0ee2fca..c1ec37f 100644 --- a/client/arith_service_test.go +++ b/_testutils/arith_service.go @@ -12,7 +12,7 @@ ProtoArgs ProtoReply */ -package client +package testutils import proto "github.com/gogo/protobuf/proto" import fmt "fmt" diff --git a/client/arith_service.proto b/_testutils/arith_service.proto similarity index 100% rename from client/arith_service.proto rename to _testutils/arith_service.proto diff --git a/client/client_test.go b/client/client_test.go index 5aa1492..ae2bf05 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/smallnest/rpcx/_testutils" "github.com/smallnest/rpcx/protocol" "github.com/smallnest/rpcx/server" ) @@ -27,20 +28,20 @@ func (t *Arith) Mul(ctx context.Context, args *Args, reply *Reply) error { type PBArith int -func (t *PBArith) Mul(ctx context.Context, args *ProtoArgs, reply *ProtoReply) error { +func (t *PBArith) Mul(ctx context.Context, args *testutils.ProtoArgs, reply *testutils.ProtoReply) error { reply.C = args.A * args.B return nil } func TestClient_IT(t *testing.T) { - server := server.Server{} - server.RegisterName("Arith", new(Arith)) - server.RegisterName("PBArith", new(PBArith)) - go server.Serve("tcp", "127.0.0.1:0") - defer server.Close() + s := server.Server{} + s.RegisterName("Arith", new(Arith), "") + s.RegisterName("PBArith", new(PBArith), "") + go s.Serve("tcp", "127.0.0.1:0") + defer s.Close() time.Sleep(500 * time.Millisecond) - addr := server.Address().String() + addr := s.Address().String() client := &Client{ SerializeType: protocol.JSON, @@ -86,11 +87,11 @@ func TestClient_IT(t *testing.T) { client.SerializeType = protocol.ProtoBuffer - pbArgs := &ProtoArgs{ + pbArgs := &testutils.ProtoArgs{ A: 10, B: 20, } - pbReply := &ProtoReply{} + pbReply := &testutils.ProtoReply{} err = client.Call(context.Background(), "PBArith", "Mul", pbArgs, pbReply) if err != nil { t.Fatalf("failed to call: %v", err) diff --git a/errors/error.go b/errors/error.go new file mode 100644 index 0000000..26b7a10 --- /dev/null +++ b/errors/error.go @@ -0,0 +1,18 @@ +package errors + +import "fmt" + +// MultiError holds multiple errors +type MultiError struct { + Errors []error +} + +// Error returns the message of the actual error +func (e *MultiError) Error() string { + return fmt.Sprintf("%v", e.Errors) +} + +// NewMultiError creates and returns an Error with error splice +func NewMultiError(errors []error) *MultiError { + return &MultiError{Errors: errors} +} diff --git a/server/plugin.go b/server/plugin.go new file mode 100644 index 0000000..97a5b94 --- /dev/null +++ b/server/plugin.go @@ -0,0 +1,184 @@ +package server + +import ( + "context" + "net" + + "github.com/smallnest/rpcx/errors" + "github.com/smallnest/rpcx/protocol" +) + +//PluginContainer represents a plugin container that defines all methods to manage plugins. +//And it also defines all extension points. +type PluginContainer interface { + Add(plugin Plugin) + Remove(plugin Plugin) + All() []Plugin + + DoRegister(name string, rcvr interface{}, metadata string) error + + DoPostConnAccept(net.Conn) (net.Conn, bool) + + DoPreReadRequest(ctx context.Context) error + DoPostReadRequest(ctx context.Context, r *protocol.Message, e error) error + + DoPreWriteResponse(context.Context, *protocol.Message) error + DoPostWriteResponse(context.Context, *protocol.Message, *protocol.Message, error) error +} + +// Plugin is the server plugin interface. +type Plugin interface { +} + +type ( + // RegisterPlugin is . + RegisterPlugin interface { + Register(name string, rcvr interface{}, metadata string) error + } + + // PostConnAcceptPlugin represents connection accept plugin. + // if returns false, it means subsequent IPostConnAcceptPlugins should not contiune to handle this conn + // and this conn has been closed. + PostConnAcceptPlugin interface { + HandleConnAccept(net.Conn) (net.Conn, bool) + } + + //PreReadRequestPlugin represents . + PreReadRequestPlugin interface { + PreReadRequest(ctx context.Context) error + } + + //PostReadRequestPlugin represents . + PostReadRequestPlugin interface { + PostReadRequest(ctx context.Context, r *protocol.Message, e error) error + } + + //PreWriteResponsePlugin represents . + PreWriteResponsePlugin interface { + PreWriteResponse(context.Context, *protocol.Message) error + } + + //PostWriteResponsePlugin represents . + PostWriteResponsePlugin interface { + PostWriteResponse(context.Context, *protocol.Message, *protocol.Message, error) error + } +) + +// pluginContainer implements PluginContainer interface. +type pluginContainer struct { + plugins []Plugin +} + +// Add adds a plugin. +func (p *pluginContainer) Add(plugin Plugin) { + p.plugins = append(p.plugins, plugin) +} + +// Remove removes a plugin by it's name. +func (p *pluginContainer) Remove(plugin Plugin) { + if p.plugins == nil { + return + } + + var plugins []Plugin + for _, p := range p.plugins { + if p != plugin { + plugins = append(plugins, p) + } + } + + p.plugins = plugins +} + +func (p *pluginContainer) All() []Plugin { + return p.plugins +} + +// DoRegister invokes DoRegister plugin. +func (p *pluginContainer) DoRegister(name string, rcvr interface{}, metadata string) error { + var es []error + for i := range p.plugins { + if plugin, ok := p.plugins[i].(RegisterPlugin); ok { + err := plugin.Register(name, rcvr, metadata) + if err != nil { + es = append(es, err) + } + } + } + + if len(es) > 0 { + return errors.NewMultiError(es) + } + return nil +} + +//DoPostConnAccept handles accepted conn +func (p *pluginContainer) DoPostConnAccept(conn net.Conn) (net.Conn, bool) { + var flag bool + for i := range p.plugins { + if plugin, ok := p.plugins[i].(PostConnAcceptPlugin); ok { + conn, flag = plugin.HandleConnAccept(conn) + if !flag { //interrupt + conn.Close() + return conn, false + } + } + } + return conn, true +} + +// DoPreReadRequest invokes PreReadRequest plugin. +func (p *pluginContainer) DoPreReadRequest(ctx context.Context) error { + for i := range p.plugins { + if plugin, ok := p.plugins[i].(PreReadRequestPlugin); ok { + err := plugin.PreReadRequest(ctx) + if err != nil { + return err + } + } + } + + return nil +} + +// DoPostReadRequest invokes PostReadRequest plugin. +func (p *pluginContainer) DoPostReadRequest(ctx context.Context, r *protocol.Message, e error) error { + for i := range p.plugins { + if plugin, ok := p.plugins[i].(PostReadRequestPlugin); ok { + err := plugin.PostReadRequest(ctx, r, e) + if err != nil { + return err + } + } + } + + return nil +} + +// DoPreWriteResponse invokes PreWriteResponse plugin. +func (p *pluginContainer) DoPreWriteResponse(ctx context.Context, req *protocol.Message) error { + for i := range p.plugins { + if plugin, ok := p.plugins[i].(PreWriteResponsePlugin); ok { + err := plugin.PreWriteResponse(ctx, req) + if err != nil { + return err + } + } + } + + return nil +} + +// DoPostWriteResponse invokes PostWriteResponse plugin. +func (p *pluginContainer) DoPostWriteResponse(ctx context.Context, req *protocol.Message, resp *protocol.Message, e error) error { + for i := range p.plugins { + if plugin, ok := p.plugins[i].(PostWriteResponsePlugin); ok { + err := plugin.PostWriteResponse(ctx, req, resp, e) + if err != nil { + return err + } + } + } + + return nil +} diff --git a/server/server.go b/server/server.go index 409358e..4bf2092 100644 --- a/server/server.go +++ b/server/server.go @@ -68,6 +68,8 @@ type Server struct { // KCPConfig KCPConfig // // for QUIC // QUICConfig QUICConfig + + Plugins pluginContainer } // // KCPConfig is config of KCP. @@ -165,6 +167,11 @@ func (s *Server) serveListener(ln net.Listener) error { s.activeConn[conn] = struct{}{} s.mu.Unlock() + conn, ok := s.Plugins.DoPostConnAccept(conn) + if !ok { + continue + } + go s.serveConn(conn) } } @@ -237,18 +244,22 @@ func (s *Server) serveConn(conn net.Conn) { } go func() { + s.Plugins.DoPreWriteResponse(ctx, req) res, err := s.handleRequest(ctx, req) if err != nil { log.Errorf("rpcx: failed to handle request: %v", err) } res.WriteTo(w) w.Flush() + s.Plugins.DoPostWriteResponse(ctx, req, res, err) }() } } func (s *Server) readRequest(ctx context.Context, r io.Reader) (req *protocol.Message, err error) { + s.Plugins.DoPreReadRequest(ctx) req, err = protocol.Read(r) + s.Plugins.DoPostReadRequest(ctx, req, err) return req, err } @@ -349,7 +360,10 @@ func (s *Server) Close() error { s.mu.Lock() defer s.mu.Unlock() s.closeDoneChanLocked() - err := s.ln.Close() + var err error + if s.ln != nil { + err = s.ln.Close() + } for c := range s.activeConn { c.Close() diff --git a/server/server_test.go b/server/server_test.go index 606fcbe..2efc05f 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -54,7 +54,7 @@ func TestHandleRequest(t *testing.T) { req.Payload = data server := &Server{} - server.RegisterName("Arith", new(Arith)) + server.RegisterName("Arith", new(Arith), "") res, err := server.handleRequest(context.Background(), req) if err != nil { t.Fatalf("failed to hand request: %v", err) diff --git a/server/service.go b/server/service.go index 91edb90..5187ba1 100644 --- a/server/service.go +++ b/server/service.go @@ -57,13 +57,14 @@ func isExportedOrBuiltinType(t reflect.Type) bool { // no suitable methods. It also logs the error. // The client accesses each method using a string of the form "Type.Method", // where Type is the receiver's concrete type. -func (s *Server) Register(rcvr interface{}) error { +func (s *Server) Register(rcvr interface{}, metadata string) error { return s.register(rcvr, "", false) } // RegisterName is like Register but uses the provided name for the type // instead of the receiver's concrete type. -func (s *Server) RegisterName(name string, rcvr interface{}) error { +func (s *Server) RegisterName(name string, rcvr interface{}, metadata string) error { + s.Plugins.DoRegister(name, rcvr, metadata) return s.register(rcvr, name, true) } diff --git a/serverplugin/consul.go b/serverplugin/consul.go new file mode 100644 index 0000000..a08e251 --- /dev/null +++ b/serverplugin/consul.go @@ -0,0 +1,138 @@ +// +build consul + +package serverplugin + +import ( + "errors" + "fmt" + "net" + "net/url" + "strconv" + "strings" + "time" + + "github.com/docker/libkv" + "github.com/docker/libkv/store" + "github.com/docker/libkv/store/consul" + metrics "github.com/rcrowley/go-metrics" + "github.com/smallnest/rpcx/log" +) + +// ConsulRegisterPlugin implements consul registry. +type ConsulRegisterPlugin struct { + // service address, for example, tcp@127.0.0.1:8972, quic@127.0.0.1:1234 + ServiceAddress string + // consul addresses + ConsulServers []string + // base path for rpcx server, for example com/example/rpcx + BasePath string + Metrics metrics.Registry + // Registered services + Services []string + UpdateInterval time.Duration + + kv store.Store +} + +// Start starts to connect consul cluster +func (p *ConsulRegisterPlugin) Start() error { + consul.Register() + kv, err := libkv.NewStore(store.CONSUL, p.ConsulServers, nil) + if err != nil { + log.Errorf("cannot create consul registry: %v", err) + return err + } + p.kv = kv + + if p.BasePath[0] == '/' { + p.BasePath = p.BasePath[1:] + } + + err = kv.Put(p.BasePath, []byte("rpcx_path"), &store.WriteOptions{IsDir: true}) + if err != nil { + log.Errorf("cannot create consul path %s: %v", p.BasePath, err) + return err + } + + if p.UpdateInterval > 0 { + ticker := time.NewTicker(p.UpdateInterval) + go func() { + defer p.kv.Close() + + // refresh service TTL + for range ticker.C { + clientMeter := metrics.GetOrRegisterMeter("clientMeter", p.Metrics) + data := []byte(strconv.FormatInt(clientMeter.Count()/60, 10)) + //set this same metrics for all services at this server + for _, name := range p.Services { + nodePath := fmt.Sprintf("%s/%s/%s", p.BasePath, name, p.ServiceAddress) + kvPaire, err := p.kv.Get(nodePath) + if err != nil { + log.Infof("can't get data of node: %s, because of %v", nodePath, err.Error()) + } else { + v, _ := url.ParseQuery(string(kvPaire.Value)) + v.Set("tps", string(data)) + p.kv.Put(nodePath, []byte(v.Encode()), &store.WriteOptions{TTL: p.UpdateInterval * 2}) + } + } + + } + }() + } + + return nil +} + +// HandleConnAccept handles connections from clients +func (p *ConsulRegisterPlugin) HandleConnAccept(conn net.Conn) (net.Conn, bool) { + if p.Metrics != nil { + clientMeter := metrics.GetOrRegisterMeter("clientMeter", p.Metrics) + clientMeter.Mark(1) + } + return conn, true +} + +// Register handles registering event. +// this service is registered at BASE/serviceName/thisIpAddress node +func (p *ConsulRegisterPlugin) Register(name string, rcvr interface{}, metadata ...string) (err error) { + if "" == strings.TrimSpace(name) { + err = errors.New("Register service `name` can't be empty") + return + } + + if p.kv == nil { + consul.Register() + kv, err := libkv.NewStore(store.CONSUL, p.ConsulServers, nil) + if err != nil { + log.Errorf("cannot create consul registry: %v", err) + return err + } + p.kv = kv + } + + if p.BasePath[0] == '/' { + p.BasePath = p.BasePath[1:] + } + err = p.kv.Put(p.BasePath, []byte("rpcx_path"), &store.WriteOptions{IsDir: true}) + if err != nil { + log.Errorf("cannot create consul path %s: %v", p.BasePath, err) + return err + } + + nodePath := fmt.Sprintf("%s/%s", p.BasePath, name) + err = p.kv.Put(nodePath, []byte(name), &store.WriteOptions{IsDir: true}) + if err != nil { + log.Errorf("cannot create consul path %s: %v", nodePath, err) + return err + } + + nodePath = fmt.Sprintf("%s/%s/%s", p.BasePath, name, p.ServiceAddress) + err = p.kv.Put(nodePath, []byte(p.ServiceAddress), &store.WriteOptions{TTL: p.UpdateInterval * 2}) + if err != nil { + log.Errorf("cannot create consul path %s: %v", nodePath, err) + return err + } + + p.Services = append(p.Services, name) + return +} diff --git a/serverplugin/consul_test.go b/serverplugin/consul_test.go new file mode 100644 index 0000000..55c5663 --- /dev/null +++ b/serverplugin/consul_test.go @@ -0,0 +1,38 @@ +// +build consul + +package serverplugin + +import ( + "testing" + "time" + + metrics "github.com/rcrowley/go-metrics" + "github.com/smallnest/rpcx/server" +) + +func TestConsulRegistry(t *testing.T) { + s := &server.Server{} + + r := &ConsulRegisterPlugin{ + ServiceAddress: "tcp@127.0.0.1:8972", + ConsulServers: []string{"127.0.0.1:8500"}, + BasePath: "/rpcx_test", + Metrics: metrics.NewRegistry(), + Services: make([]string, 1), + UpdateInterval: time.Minute, + } + err := r.Start() + if err != nil { + t.Fatal(err) + } + s.Plugins.Add(r) + + s.RegisterName("Arith", new(Arith), "") + go s.Serve("tcp", "127.0.0.1:8972") + defer s.Close() + + if len(r.Services) != 1 { + t.Fatal("failed to register services in consul") + } + +} diff --git a/serverplugin/etcd.go b/serverplugin/etcd.go new file mode 100644 index 0000000..a836997 --- /dev/null +++ b/serverplugin/etcd.go @@ -0,0 +1,138 @@ +// +build etcd + +package serverplugin + +import ( + "errors" + "fmt" + "net" + "net/url" + "strconv" + "strings" + "time" + + "github.com/docker/libkv" + "github.com/docker/libkv/store" + "github.com/docker/libkv/store/etcd" + metrics "github.com/rcrowley/go-metrics" + "github.com/smallnest/rpcx/log" +) + +// EtcdRegisterPlugin implements etcd registry. +type EtcdRegisterPlugin struct { + // service address, for example, tcp@127.0.0.1:8972, quic@127.0.0.1:1234 + ServiceAddress string + // etcd addresses + EtcdServers []string + // base path for rpcx server, for example com/example/rpcx + BasePath string + Metrics metrics.Registry + // Registered services + Services []string + UpdateInterval time.Duration + + kv store.Store +} + +// Start starts to connect etcd cluster +func (p *EtcdRegisterPlugin) Start() error { + etcd.Register() + kv, err := libkv.NewStore(store.ETCD, p.EtcdServers, nil) + if err != nil { + log.Errorf("cannot create etcd registry: %v", err) + return err + } + p.kv = kv + + if p.BasePath[0] == '/' { + p.BasePath = p.BasePath[1:] + } + + err = kv.Put(p.BasePath, []byte("rpcx_path"), &store.WriteOptions{IsDir: true}) + if err != nil && !strings.Contains(err.Error(), "Not a file") { + log.Errorf("cannot create etcd path %s: %v", p.BasePath, err) + return err + } + + if p.UpdateInterval > 0 { + ticker := time.NewTicker(p.UpdateInterval) + go func() { + defer p.kv.Close() + + // refresh service TTL + for range ticker.C { + clientMeter := metrics.GetOrRegisterMeter("clientMeter", p.Metrics) + data := []byte(strconv.FormatInt(clientMeter.Count()/60, 10)) + //set this same metrics for all services at this server + for _, name := range p.Services { + nodePath := fmt.Sprintf("%s/%s/%s", p.BasePath, name, p.ServiceAddress) + kvPaire, err := p.kv.Get(nodePath) + if err != nil { + log.Infof("can't get data of node: %s, because of %v", nodePath, err.Error()) + } else { + v, _ := url.ParseQuery(string(kvPaire.Value)) + v.Set("tps", string(data)) + p.kv.Put(nodePath, []byte(v.Encode()), &store.WriteOptions{TTL: p.UpdateInterval * 2}) + } + } + + } + }() + } + + return nil +} + +// HandleConnAccept handles connections from clients +func (p *EtcdRegisterPlugin) HandleConnAccept(conn net.Conn) (net.Conn, bool) { + if p.Metrics != nil { + clientMeter := metrics.GetOrRegisterMeter("clientMeter", p.Metrics) + clientMeter.Mark(1) + } + return conn, true +} + +// Register handles registering event. +// this service is registered at BASE/serviceName/thisIpAddress node +func (p *EtcdRegisterPlugin) Register(name string, rcvr interface{}, metadata ...string) (err error) { + if "" == strings.TrimSpace(name) { + err = errors.New("Register service `name` can't be empty") + return + } + + if p.kv == nil { + etcd.Register() + kv, err := libkv.NewStore(store.ETCD, p.EtcdServers, nil) + if err != nil { + log.Errorf("cannot create etcd registry: %v", err) + return err + } + p.kv = kv + } + + if p.BasePath[0] == '/' { + p.BasePath = p.BasePath[1:] + } + err = p.kv.Put(p.BasePath, []byte("rpcx_path"), &store.WriteOptions{IsDir: true}) + if err != nil && !strings.Contains(err.Error(), "Not a file") { + log.Errorf("cannot create etcd path %s: %v", p.BasePath, err) + return err + } + + nodePath := fmt.Sprintf("%s/%s", p.BasePath, name) + err = p.kv.Put(nodePath, []byte(name), &store.WriteOptions{IsDir: true}) + if err != nil && !strings.Contains(err.Error(), "Not a file") { + log.Errorf("cannot create etcd path %s: %v", nodePath, err) + return err + } + + nodePath = fmt.Sprintf("%s/%s/%s", p.BasePath, name, p.ServiceAddress) + err = p.kv.Put(nodePath, []byte(p.ServiceAddress), &store.WriteOptions{TTL: p.UpdateInterval * 2}) + if err != nil { + log.Errorf("cannot create etcd path %s: %v", nodePath, err) + return err + } + + p.Services = append(p.Services, name) + return +} diff --git a/serverplugin/etcd_test.go b/serverplugin/etcd_test.go new file mode 100644 index 0000000..6e30962 --- /dev/null +++ b/serverplugin/etcd_test.go @@ -0,0 +1,38 @@ +// +build etcd + +package serverplugin + +import ( + "testing" + "time" + + metrics "github.com/rcrowley/go-metrics" + "github.com/smallnest/rpcx/server" +) + +func TestEtcdRegistry(t *testing.T) { + s := &server.Server{} + + r := &EtcdRegisterPlugin{ + ServiceAddress: "tcp@127.0.0.1:8972", + EtcdServers: []string{"127.0.0.1:2379"}, + BasePath: "/rpcx_test", + Metrics: metrics.NewRegistry(), + Services: make([]string, 1), + UpdateInterval: time.Minute, + } + err := r.Start() + if err != nil { + t.Fatal(err) + } + s.Plugins.Add(r) + + s.RegisterName("Arith", new(Arith), "") + go s.Serve("tcp", "127.0.0.1:8972") + defer s.Close() + + if len(r.Services) != 1 { + t.Fatal("failed to register services in etcd") + } + +} diff --git a/serverplugin/plugin.go b/serverplugin/plugin.go new file mode 100644 index 0000000..3849daf --- /dev/null +++ b/serverplugin/plugin.go @@ -0,0 +1 @@ +package serverplugin diff --git a/serverplugin/registry_test.go b/serverplugin/registry_test.go new file mode 100644 index 0000000..eed536b --- /dev/null +++ b/serverplugin/registry_test.go @@ -0,0 +1,19 @@ +package serverplugin + +import "context" + +type Args struct { + A int + B int +} + +type Reply struct { + C int +} + +type Arith int + +func (t *Arith) Mul(ctx context.Context, args *Args, reply *Reply) error { + reply.C = args.A * args.B + return nil +} diff --git a/serverplugin/zookeeper.go b/serverplugin/zookeeper.go new file mode 100644 index 0000000..346f09e --- /dev/null +++ b/serverplugin/zookeeper.go @@ -0,0 +1,139 @@ +// +build zookeeper + +package serverplugin + +import ( + "errors" + "fmt" + "net" + "net/url" + "strconv" + "strings" + "time" + + "github.com/docker/libkv" + "github.com/docker/libkv/store/zookeeper" + + "github.com/docker/libkv/store" + metrics "github.com/rcrowley/go-metrics" + "github.com/smallnest/rpcx/log" +) + +// ZooKeeperRegisterPlugin implements zookeeper registry. +type ZooKeeperRegisterPlugin struct { + // service address, for example, tcp@127.0.0.1:8972, quic@127.0.0.1:1234 + ServiceAddress string + // zookeeper addresses + ZooKeeperServers []string + // base path for rpcx server, for example com/example/rpcx + BasePath string + Metrics metrics.Registry + // Registered services + Services []string + UpdateInterval time.Duration + + kv store.Store +} + +// Start starts to connect zookeeper cluster +func (p *ZooKeeperRegisterPlugin) Start() error { + zookeeper.Register() + kv, err := libkv.NewStore(store.ZK, p.ZooKeeperServers, nil) + if err != nil { + log.Errorf("cannot create zk registry: %v", err) + return err + } + p.kv = kv + + if p.BasePath[0] == '/' { + p.BasePath = p.BasePath[1:] + } + + err = p.kv.Put(p.BasePath, []byte("rpcx_path"), &store.WriteOptions{IsDir: true}) + if err != nil { + log.Errorf("cannot create zk path %s: %v", p.BasePath, err) + return err + } + + if p.UpdateInterval > 0 { + ticker := time.NewTicker(p.UpdateInterval) + go func() { + defer p.kv.Close() + + // refresh service TTL + for range ticker.C { + clientMeter := metrics.GetOrRegisterMeter("clientMeter", p.Metrics) + data := []byte(strconv.FormatInt(clientMeter.Count()/60, 10)) + //set this same metrics for all services at this server + for _, name := range p.Services { + nodePath := fmt.Sprintf("%s/%s/%s", p.BasePath, name, p.ServiceAddress) + kvPaire, err := p.kv.Get(nodePath) + if err != nil { + log.Infof("can't get data of node: %s, because of %v", nodePath, err.Error()) + } else { + v, _ := url.ParseQuery(string(kvPaire.Value)) + v.Set("tps", string(data)) + p.kv.Put(nodePath, []byte(v.Encode()), &store.WriteOptions{TTL: p.UpdateInterval * 2}) + } + } + + } + }() + } + + return nil +} + +// HandleConnAccept handles connections from clients +func (p *ZooKeeperRegisterPlugin) HandleConnAccept(conn net.Conn) (net.Conn, bool) { + if p.Metrics != nil { + clientMeter := metrics.GetOrRegisterMeter("clientMeter", p.Metrics) + clientMeter.Mark(1) + } + return conn, true +} + +// Register handles registering event. +// this service is registered at BASE/serviceName/thisIpAddress node +func (p *ZooKeeperRegisterPlugin) Register(name string, rcvr interface{}, metadata ...string) (err error) { + if "" == strings.TrimSpace(name) { + err = errors.New("Register service `name` can't be empty") + return + } + + if p.kv == nil { + zookeeper.Register() + kv, err := libkv.NewStore(store.ZK, p.ZooKeeperServers, nil) + if err != nil { + log.Errorf("cannot create zk registry: %v", err) + return err + } + p.kv = kv + } + + if p.BasePath[0] == '/' { + p.BasePath = p.BasePath[1:] + } + err = p.kv.Put(p.BasePath, []byte("rpcx_path"), &store.WriteOptions{IsDir: true}) + if err != nil { + log.Errorf("cannot create zk path %s: %v", p.BasePath, err) + return err + } + + nodePath := fmt.Sprintf("%s/%s", p.BasePath, name) + err = p.kv.Put(nodePath, []byte(name), &store.WriteOptions{IsDir: true}) + if err != nil { + log.Errorf("cannot create zk path %s: %v", nodePath, err) + return err + } + + nodePath = fmt.Sprintf("%s/%s/%s", p.BasePath, name, p.ServiceAddress) + err = p.kv.Put(nodePath, []byte(p.ServiceAddress), &store.WriteOptions{IsDir: true}) + if err != nil { + log.Errorf("cannot create zk path %s: %v", nodePath, err) + return err + } + + p.Services = append(p.Services, name) + return +} diff --git a/serverplugin/zookeeper_test.go b/serverplugin/zookeeper_test.go new file mode 100644 index 0000000..99cce13 --- /dev/null +++ b/serverplugin/zookeeper_test.go @@ -0,0 +1,38 @@ +// +build zookeeper + +package serverplugin + +import ( + "testing" + "time" + + metrics "github.com/rcrowley/go-metrics" + "github.com/smallnest/rpcx/server" +) + +func TestZookeeperRegistry(t *testing.T) { + s := &server.Server{} + + r := &ZooKeeperRegisterPlugin{ + ServiceAddress: "tcp@127.0.0.1:8972", + ZooKeeperServers: []string{"127.0.0.1:2181"}, + BasePath: "/rpcx_test", + Metrics: metrics.NewRegistry(), + Services: make([]string, 1), + UpdateInterval: time.Minute, + } + err := r.Start() + if err != nil { + t.Fatal(err) + } + s.Plugins.Add(r) + + s.RegisterName("Arith", new(Arith), "") + go s.Serve("tcp", "127.0.0.1:8972") + defer s.Close() + + if len(r.Services) != 1 { + t.Fatal("failed to register services in zookeeper") + } + +}