package server import ( "bufio" "context" "crypto/tls" "errors" "fmt" "io" "net" "net/http" "reflect" "regexp" "runtime" "strings" "sync" "sync/atomic" "time" "github.com/smallnest/rpcx/log" "github.com/smallnest/rpcx/protocol" "github.com/smallnest/rpcx/share" "os" "os/signal" "syscall" ) // ErrServerClosed is returned by the Server's Serve, ListenAndServe after a call to Shutdown or Close. var ErrServerClosed = errors.New("http: Server closed") const ( // ReaderBuffsize is used for bufio reader. ReaderBuffsize = 1024 // WriterBuffsize is used for bufio writer. WriterBuffsize = 1024 ) // contextKey is a value for use with context.WithValue. It's used as // a pointer so it fits in an interface{} without allocation. type contextKey struct { name string } func (k *contextKey) String() string { return "rpcx context value " + k.name } var ( // RemoteConnContextKey is a context key. It can be used in // services with context.WithValue to access the connection arrived on. // The associated value will be of type net.Conn. RemoteConnContextKey = &contextKey{"remote-conn"} // StartRequestContextKey records the start time StartRequestContextKey = &contextKey{"start-parse-request"} // StartSendRequestContextKey records the start time StartSendRequestContextKey = &contextKey{"start-send-request"} ) // Server is rpcx server that use TCP or UDP. type Server struct { ln net.Listener readTimeout time.Duration writeTimeout time.Duration serviceMapMu sync.RWMutex serviceMap map[string]*service mu sync.RWMutex activeConn map[net.Conn]struct{} doneChan chan struct{} seq uint64 inShutdown int32 onShutdown []func() // TLSConfig for creating tls tcp connection. tlsConfig *tls.Config // BlockCrypt for kcp.BlockCrypt options map[string]interface{} // // use for KCP // KCPConfig KCPConfig // // for QUIC // QUICConfig QUICConfig Plugins PluginContainer // AuthFunc can be used to auth. AuthFunc func(ctx context.Context, req *protocol.Message, token string) error ShutdownFunc func(s *Server) HandleMsgChan chan struct{} } // NewServer returns a server. func NewServer(options ...OptionFn) *Server { s := &Server{ Plugins: &pluginContainer{}, options: make(map[string]interface{}), } s.HandleMsgChan = make(chan struct{}, 100000) for _, op := range options { op(s) } return s } // Address returns listened address. func (s *Server) Address() net.Addr { s.mu.RLock() defer s.mu.RUnlock() if s.ln == nil { return nil } return s.ln.Addr() } // SendMessage a request to the specified client. // The client is designated by the conn. // conn can be gotten from context in services: // // ctx.Value(RemoteConnContextKey) // // servicePath, serviceMethod, metadata can be set to zero values. func (s *Server) SendMessage(conn net.Conn, servicePath, serviceMethod string, metadata map[string]string, data []byte) error { ctx := context.WithValue(context.Background(), StartSendRequestContextKey, time.Now().UnixNano()) s.Plugins.DoPreWriteRequest(ctx) req := protocol.GetPooledMsg() req.SetMessageType(protocol.Request) seq := atomic.AddUint64(&s.seq, 1) req.SetSeq(seq) req.SetOneway(true) req.SetSerializeType(protocol.SerializeNone) req.ServicePath = servicePath req.ServiceMethod = serviceMethod req.Metadata = metadata req.Payload = data reqData := req.Encode() _, err := conn.Write(reqData) s.Plugins.DoPostWriteRequest(ctx, req, err) protocol.FreeMsg(req) return err } func (s *Server) getDoneChan() <-chan struct{} { s.mu.Lock() defer s.mu.Unlock() if s.doneChan == nil { s.doneChan = make(chan struct{}) } return s.doneChan } func (s *Server)startShutdownListener() { go func(s *Server) { log.Info("server pid:", os.Getpid()) c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGTERM) si := <-c if si.String() == "terminated" { if nil != s.ShutdownFunc { s.ShutdownFunc(s) } os.Exit(0) } }(s) } // Serve starts and listens RPC requests. // It is blocked until receiving connectings from clients. func (s *Server) Serve(network, address string) (err error) { s.startShutdownListener() var ln net.Listener ln, err = s.makeListener(network, address) if err != nil { return } if network == "http" { s.serveByHTTP(ln, "") return nil } // try to start gateway ln = s.startGateway(network, ln) return s.serveListener(ln) } // serveListener accepts incoming connections on the Listener ln, // creating a new service goroutine for each. // The service goroutines read requests and then call services to reply to them. func (s *Server) serveListener(ln net.Listener) error { if s.Plugins == nil { s.Plugins = &pluginContainer{} } var tempDelay time.Duration s.mu.Lock() s.ln = ln if s.activeConn == nil { s.activeConn = make(map[net.Conn]struct{}) } s.mu.Unlock() for { conn, e := ln.Accept() if e != nil { select { case <-s.getDoneChan(): return ErrServerClosed default: } if ne, ok := e.(net.Error); ok && ne.Temporary() { if tempDelay == 0 { tempDelay = 5 * time.Millisecond } else { tempDelay *= 2 } if max := 1 * time.Second; tempDelay > max { tempDelay = max } log.Errorf("rpcx: Accept error: %v; retrying in %v", e, tempDelay) time.Sleep(tempDelay) continue } return e } tempDelay = 0 if tc, ok := conn.(*net.TCPConn); ok { tc.SetKeepAlive(true) tc.SetKeepAlivePeriod(3 * time.Minute) tc.SetLinger(10) } s.mu.Lock() s.activeConn[conn] = struct{}{} s.mu.Unlock() conn, ok := s.Plugins.DoPostConnAccept(conn) if !ok { continue } go s.serveConn(conn) } } // serveByHTTP serves by HTTP. // if rpcPath is an empty string, use share.DefaultRPCPath. func (s *Server) serveByHTTP(ln net.Listener, rpcPath string) { s.ln = ln if s.Plugins == nil { s.Plugins = &pluginContainer{} } if rpcPath == "" { rpcPath = share.DefaultRPCPath } http.Handle(rpcPath, s) srv := &http.Server{Handler: nil} s.mu.Lock() if s.activeConn == nil { s.activeConn = make(map[net.Conn]struct{}) } s.mu.Unlock() srv.Serve(ln) } func (s *Server) serveConn(conn net.Conn) { defer func() { if err := recover(); err != nil { const size = 64 << 10 buf := make([]byte, size) ss := runtime.Stack(buf, false) if ss > size { ss = size } buf = buf[:ss] log.Errorf("serving %s panic error: %s, stack:\n %s", conn.RemoteAddr(), err, buf) } s.mu.Lock() delete(s.activeConn, conn) s.mu.Unlock() conn.Close() if s.Plugins == nil { s.Plugins = &pluginContainer{} } s.Plugins.DoPostConnClose(conn) }() if isShutdown(s) { closeChannel(s,conn) return } if tlsConn, ok := conn.(*tls.Conn); ok { if d := s.readTimeout; d != 0 { conn.SetReadDeadline(time.Now().Add(d)) } if d := s.writeTimeout; d != 0 { conn.SetWriteDeadline(time.Now().Add(d)) } if err := tlsConn.Handshake(); err != nil { log.Errorf("rpcx: TLS handshake error from %s: %v", conn.RemoteAddr(), err) return } } r := bufio.NewReaderSize(conn, ReaderBuffsize) for { if isShutdown(s) { closeChannel(s,conn) return } t0 := time.Now() if s.readTimeout != 0 { conn.SetReadDeadline(t0.Add(s.readTimeout)) } ctx := context.WithValue(context.Background(), RemoteConnContextKey, conn) req, err := s.readRequest(ctx, r) if err != nil { if err == io.EOF { log.Infof("client has closed this connection: %s", conn.RemoteAddr().String()) } else if strings.Contains(err.Error(), "use of closed network connection") { log.Infof("rpcx: connection %s is closed", conn.RemoteAddr().String()) } else { log.Warnf("rpcx: failed to read request: %v", err) } return } s.HandleMsgChan <- struct{}{} if s.writeTimeout != 0 { conn.SetWriteDeadline(t0.Add(s.writeTimeout)) } ctx = context.WithValue(ctx, StartRequestContextKey, time.Now().UnixNano()) err = s.auth(ctx, req) if err != nil { if !req.IsOneway() { res := req.Clone() res.SetMessageType(protocol.Response) if len(res.Payload) > 1024 && req.CompressType() != protocol.None { res.SetCompressType(req.CompressType()) } handleError(res, err) data := res.Encode() s.Plugins.DoPreWriteResponse(ctx, req, res) conn.Write(data) s.Plugins.DoPostWriteResponse(ctx, req, res, err) protocol.FreeMsg(res) } else { s.Plugins.DoPreWriteResponse(ctx, req, nil) } <-s.HandleMsgChan protocol.FreeMsg(req) continue } go func() { defer func(){ <-s.HandleMsgChan }() if req.IsHeartbeat() { req.SetMessageType(protocol.Response) data := req.Encode() conn.Write(data) return } resMetadata := make(map[string]string) newCtx := context.WithValue(context.WithValue(ctx, share.ReqMetaDataKey, req.Metadata), share.ResMetaDataKey, resMetadata) res, err := s.handleRequest(newCtx, req) if err != nil { log.Warnf("rpcx: failed to handle request: %v", err) } s.Plugins.DoPreWriteResponse(newCtx, req, res) if !req.IsOneway() { if len(resMetadata) > 0 { //copy meta in context to request meta := res.Metadata if meta == nil { res.Metadata = resMetadata } else { for k, v := range resMetadata { meta[k] = v } } } if len(res.Payload) > 1024 && req.CompressType() != protocol.None { res.SetCompressType(req.CompressType()) } data := res.Encode() conn.Write(data) //res.WriteTo(conn) } s.Plugins.DoPostWriteResponse(newCtx, req, res, err) protocol.FreeMsg(req) protocol.FreeMsg(res) }() } } func isShutdown(s *Server) (bool) { return atomic.LoadInt32(&s.inShutdown) == 1 } func closeChannel(s *Server,conn net.Conn) { s.mu.Lock() delete(s.activeConn, conn) s.mu.Unlock() conn.Close() } func (s *Server) readRequest(ctx context.Context, r io.Reader) (req *protocol.Message, err error) { err = s.Plugins.DoPreReadRequest(ctx) if err != nil { return nil, err } // pool req? req = protocol.GetPooledMsg() err = req.Decode(r) perr := s.Plugins.DoPostReadRequest(ctx, req, err) if err == nil { err = perr } return req, err } func (s *Server) auth(ctx context.Context, req *protocol.Message) error { if s.AuthFunc != nil { token := req.Metadata[share.AuthKey] return s.AuthFunc(ctx, req, token) } return nil } func (s *Server) handleRequest(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) { serviceName := req.ServicePath methodName := req.ServiceMethod res = req.Clone() res.SetMessageType(protocol.Response) s.serviceMapMu.RLock() service := s.serviceMap[serviceName] s.serviceMapMu.RUnlock() if service == nil { err = errors.New("rpcx: can't find service " + serviceName) return handleError(res, err) } mtype := service.method[methodName] if mtype == nil { if service.function[methodName] != nil { //check raw functions return s.handleRequestForFunction(ctx, req) } err = errors.New("rpcx: can't find method " + methodName) return handleError(res, err) } var argv = argsReplyPools.Get(mtype.ArgType) codec := share.Codecs[req.SerializeType()] if codec == nil { err = fmt.Errorf("can not find codec for %d", req.SerializeType()) return handleError(res, err) } err = codec.Decode(req.Payload, argv) if err != nil { return handleError(res, err) } replyv := argsReplyPools.Get(mtype.ReplyType) if mtype.ArgType.Kind() != reflect.Ptr { err = service.call(ctx, mtype, reflect.ValueOf(argv).Elem(), reflect.ValueOf(replyv)) } else { err = service.call(ctx, mtype, reflect.ValueOf(argv), reflect.ValueOf(replyv)) } argsReplyPools.Put(mtype.ArgType, argv) if err != nil { argsReplyPools.Put(mtype.ReplyType, replyv) return handleError(res, err) } if !req.IsOneway() { data, err := codec.Encode(replyv) argsReplyPools.Put(mtype.ReplyType, replyv) if err != nil { return handleError(res, err) } res.Payload = data } return res, nil } func (s *Server) handleRequestForFunction(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) { res = req.Clone() res.SetMessageType(protocol.Response) serviceName := req.ServicePath methodName := req.ServiceMethod s.serviceMapMu.RLock() service := s.serviceMap[serviceName] s.serviceMapMu.RUnlock() if service == nil { err = errors.New("rpcx: can't find service for func raw function") return handleError(res, err) } mtype := service.function[methodName] if mtype == nil { err = errors.New("rpcx: can't find method " + methodName) return handleError(res, err) } var argv = argsReplyPools.Get(mtype.ArgType) codec := share.Codecs[req.SerializeType()] if codec == nil { err = fmt.Errorf("can not find codec for %d", req.SerializeType()) return handleError(res, err) } err = codec.Decode(req.Payload, argv) if err != nil { return handleError(res, err) } replyv := argsReplyPools.Get(mtype.ReplyType) err = service.callForFunction(ctx, mtype, reflect.ValueOf(argv), reflect.ValueOf(replyv)) argsReplyPools.Put(mtype.ArgType, argv) if err != nil { argsReplyPools.Put(mtype.ReplyType, replyv) return handleError(res, err) } if !req.IsOneway() { data, err := codec.Encode(replyv) argsReplyPools.Put(mtype.ReplyType, replyv) if err != nil { return handleError(res, err) } res.Payload = data } return res, nil } func handleError(res *protocol.Message, err error) (*protocol.Message, error) { res.SetMessageStatusType(protocol.Error) if res.Metadata == nil { res.Metadata = make(map[string]string) } res.Metadata[protocol.ServiceError] = err.Error() return res, err } // Can connect to RPC service using HTTP CONNECT to rpcPath. var connected = "200 Connected to rpcx" // ServeHTTP implements an http.Handler that answers RPC requests. func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { if req.Method != "CONNECT" { w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.WriteHeader(http.StatusMethodNotAllowed) io.WriteString(w, "405 must CONNECT\n") return } conn, _, err := w.(http.Hijacker).Hijack() if err != nil { log.Info("rpc hijacking ", req.RemoteAddr, ": ", err.Error()) return } io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n") s.mu.Lock() s.activeConn[conn] = struct{}{} s.mu.Unlock() s.serveConn(conn) } // Close immediately closes all active net.Listeners. func (s *Server) Close() error { s.closeDoneChanLocked() var err error if s.ln != nil { err = s.ln.Close() } for c := range s.activeConn { c.Close() delete(s.activeConn, c) s.Plugins.DoPostConnClose(c) } return err } // RegisterOnShutdown registers a function to call on Shutdown. // This can be used to gracefully shutdown connections. func (s *Server) RegisterOnShutdown(f func()) { s.mu.Lock() s.onShutdown = append(s.onShutdown, f) s.mu.Unlock() } var shutdownPollInterval = 500 * time.Millisecond // // Shutdown gracefully shuts down the server without interrupting any // // active connections. Shutdown works by first closing the // // listener, then closing all idle connections, and then waiting // // indefinitely for connections to return to idle and then shut down. // // If the provided context expires before the shutdown is complete, // // Shutdown returns the context's error, otherwise it returns any // // error returned from closing the Server's underlying Listener. func (s *Server) Shutdown(ctx context.Context) error { if atomic.CompareAndSwapInt32(&s.inShutdown, 0, 1) { log.Info("shutdown begin") ticker := time.NewTicker(shutdownPollInterval) defer ticker.Stop() for { if s.checkProcessMsg() { break } select { case <-ctx.Done(): return ctx.Err() case <-ticker.C: } } s.Close() log.Info("shutdown end") } return nil } func (s *Server) checkProcessMsg() (bool) { size := len(s.HandleMsgChan) log.Info("need handle msg size:",size) if size == 0 { return true } return false } func (s *Server) closeDoneChanLocked() { ch := s.getDoneChanLocked() select { case <-ch: // Already closed. Don't close again. default: // Safe to close here. We're the only closer, guarded // by s.mu. close(ch) } } func (s *Server) getDoneChanLocked() chan struct{} { if s.doneChan == nil { s.doneChan = make(chan struct{}) } return s.doneChan } var ip4Reg = regexp.MustCompile(`^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$`) func validIP4(ipAddress string) bool { ipAddress = strings.Trim(ipAddress, " ") i := strings.LastIndex(ipAddress, ":") ipAddress = ipAddress[:i] //remove port return ip4Reg.MatchString(ipAddress) }