package server import ( "context" "net" "github.com/smallnest/rpcx/errors" "github.com/smallnest/rpcx/protocol" ) //PluginContainer represents a plugin container that defines all methods to manage plugins. //And it also defines all extension points. type PluginContainer interface { Add(plugin Plugin) Remove(plugin Plugin) All() []Plugin DoRegister(name string, rcvr interface{}, metadata string) error DoPostConnAccept(net.Conn) (net.Conn, bool) DoPreReadRequest(ctx context.Context) error DoPostReadRequest(ctx context.Context, r *protocol.Message, e error) error DoPreWriteResponse(context.Context, *protocol.Message) error DoPostWriteResponse(context.Context, *protocol.Message, *protocol.Message, error) error } // Plugin is the server plugin interface. type Plugin interface { } type ( // RegisterPlugin is . RegisterPlugin interface { Register(name string, rcvr interface{}, metadata string) error } // PostConnAcceptPlugin represents connection accept plugin. // if returns false, it means subsequent IPostConnAcceptPlugins should not contiune to handle this conn // and this conn has been closed. PostConnAcceptPlugin interface { HandleConnAccept(net.Conn) (net.Conn, bool) } //PreReadRequestPlugin represents . PreReadRequestPlugin interface { PreReadRequest(ctx context.Context) error } //PostReadRequestPlugin represents . PostReadRequestPlugin interface { PostReadRequest(ctx context.Context, r *protocol.Message, e error) error } //PreWriteResponsePlugin represents . PreWriteResponsePlugin interface { PreWriteResponse(context.Context, *protocol.Message) error } //PostWriteResponsePlugin represents . PostWriteResponsePlugin interface { PostWriteResponse(context.Context, *protocol.Message, *protocol.Message, error) error } ) // pluginContainer implements PluginContainer interface. type pluginContainer struct { plugins []Plugin } // Add adds a plugin. func (p *pluginContainer) Add(plugin Plugin) { p.plugins = append(p.plugins, plugin) } // Remove removes a plugin by it's name. func (p *pluginContainer) Remove(plugin Plugin) { if p.plugins == nil { return } var plugins []Plugin for _, p := range p.plugins { if p != plugin { plugins = append(plugins, p) } } p.plugins = plugins } func (p *pluginContainer) All() []Plugin { return p.plugins } // DoRegister invokes DoRegister plugin. func (p *pluginContainer) DoRegister(name string, rcvr interface{}, metadata string) error { var es []error for _, rp := range p.plugins { if plugin, ok := rp.(RegisterPlugin); ok { err := plugin.Register(name, rcvr, metadata) if err != nil { es = append(es, err) } } } if len(es) > 0 { return errors.NewMultiError(es) } return nil } //DoPostConnAccept handles accepted conn func (p *pluginContainer) DoPostConnAccept(conn net.Conn) (net.Conn, bool) { var flag bool for i := range p.plugins { if plugin, ok := p.plugins[i].(PostConnAcceptPlugin); ok { conn, flag = plugin.HandleConnAccept(conn) if !flag { //interrupt conn.Close() return conn, false } } } return conn, true } // DoPreReadRequest invokes PreReadRequest plugin. func (p *pluginContainer) DoPreReadRequest(ctx context.Context) error { for i := range p.plugins { if plugin, ok := p.plugins[i].(PreReadRequestPlugin); ok { err := plugin.PreReadRequest(ctx) if err != nil { return err } } } return nil } // DoPostReadRequest invokes PostReadRequest plugin. func (p *pluginContainer) DoPostReadRequest(ctx context.Context, r *protocol.Message, e error) error { for i := range p.plugins { if plugin, ok := p.plugins[i].(PostReadRequestPlugin); ok { err := plugin.PostReadRequest(ctx, r, e) if err != nil { return err } } } return nil } // DoPreWriteResponse invokes PreWriteResponse plugin. func (p *pluginContainer) DoPreWriteResponse(ctx context.Context, req *protocol.Message) error { for i := range p.plugins { if plugin, ok := p.plugins[i].(PreWriteResponsePlugin); ok { err := plugin.PreWriteResponse(ctx, req) if err != nil { return err } } } return nil } // DoPostWriteResponse invokes PostWriteResponse plugin. func (p *pluginContainer) DoPostWriteResponse(ctx context.Context, req *protocol.Message, resp *protocol.Message, e error) error { for i := range p.plugins { if plugin, ok := p.plugins[i].(PostWriteResponsePlugin); ok { err := plugin.PostWriteResponse(ctx, req, resp, e) if err != nil { return err } } } return nil }