Files
rpcx/server/server.go
2025-03-06 13:06:33 +08:00

1099 lines
26 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package server
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"os"
"os/exec"
"reflect"
"regexp"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/smallnest/rpcx/log"
"github.com/smallnest/rpcx/protocol"
"github.com/smallnest/rpcx/share"
"github.com/soheilhy/cmux"
"golang.org/x/net/websocket"
)
// ErrServerClosed is returned by the Server's Serve, ListenAndServe after a call to Shutdown or Close.
var (
ErrServerClosed = errors.New("http: Server closed")
ErrReqReachLimit = errors.New("request reached rate limit")
)
const (
// ReaderBuffsize is used for bufio reader.
ReaderBuffsize = 1024
// WriterBuffsize is used for bufio writer.
WriterBuffsize = 1024
// // WriteChanSize is used for response.
// WriteChanSize = 1024 * 1024
)
// contextKey is a value for use with context.WithValue. It's used as
// a pointer so it fits in an interface{} without allocation.
type contextKey struct {
name string
}
func (k *contextKey) String() string { return "rpcx context value " + k.name }
var (
// RemoteConnContextKey is a context key. It can be used in
// services with context.WithValue to access the connection arrived on.
// The associated value will be of type net.Conn.
RemoteConnContextKey = &contextKey{"remote-conn"}
// StartRequestContextKey records the start time
StartRequestContextKey = &contextKey{"start-parse-request"}
// StartSendRequestContextKey records the start time
StartSendRequestContextKey = &contextKey{"start-send-request"}
// TagContextKey is used to record extra info in handling services. Its value is a map[string]interface{}
TagContextKey = &contextKey{"service-tag"}
// HttpConnContextKey is used to store http connection.
HttpConnContextKey = &contextKey{"http-conn"}
)
type Handler func(ctx *Context) error
type WorkerPool interface {
Submit(task func())
StopAndWaitFor(deadline time.Duration)
Stop() context.Context
StopAndWait()
}
// Server is rpcx server that use TCP or UDP.
type Server struct {
ln net.Listener
readTimeout time.Duration
writeTimeout time.Duration
gatewayHTTPServer *http.Server
jsonrpcHTTPServerLock sync.Mutex
jsonrpcHTTPServer *http.Server
DisableHTTPGateway bool // disable http invoke or not.
DisableJSONRPC bool // disable json rpc or not.
AsyncWrite bool // set true if your server only serves few clients
pool WorkerPool
serviceMapMu sync.RWMutex
serviceMap map[string]*service
router map[string]Handler
mu sync.RWMutex
activeConn map[net.Conn]struct{}
doneChan chan struct{}
seq atomic.Uint64
inShutdown int32
onShutdown []func(s *Server)
onRestart []func(s *Server)
// TLSConfig for creating tls tcp connection.
tlsConfig *tls.Config
// BlockCrypt for kcp.BlockCrypt
options map[string]interface{}
// CORS options
corsOptions *CORSOptions
Plugins PluginContainer
// AuthFunc can be used to auth.
AuthFunc func(ctx context.Context, req *protocol.Message, token string) error
handlerMsgNum int32
requestCount atomic.Uint64
// HandleServiceError is used to get all service errors. You can use it write logs or others.
HandleServiceError func(error)
// ServerErrorFunc is a customized error handlers and you can use it to return customized error strings to clients.
// If not set, it use err.Error()
ServerErrorFunc func(res *protocol.Message, err error) string
// The server is started.
Started chan struct{}
unregisterAllOnce sync.Once
}
// NewServer returns a server.
func NewServer(options ...OptionFn) *Server {
s := &Server{
Plugins: &pluginContainer{},
options: make(map[string]interface{}),
activeConn: make(map[net.Conn]struct{}),
doneChan: make(chan struct{}),
serviceMap: make(map[string]*service),
router: make(map[string]Handler),
AsyncWrite: false, // 除非你想做进一步的优化测试否则建议你设置为false
Started: make(chan struct{}),
}
for _, op := range options {
op(s)
}
if s.options["TCPKeepAlivePeriod"] == nil {
s.options["TCPKeepAlivePeriod"] = 3 * time.Minute
}
return s
}
// Address returns listened address.
func (s *Server) Address() net.Addr {
s.mu.RLock()
defer s.mu.RUnlock()
if s.ln == nil {
return nil
}
return s.ln.Addr()
}
func (s *Server) AddHandler(servicePath, serviceMethod string, handler func(*Context) error) {
s.router[servicePath+"."+serviceMethod] = handler
}
// ActiveClientConn returns active connections.
func (s *Server) ActiveClientConn() []net.Conn {
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]net.Conn, 0, len(s.activeConn))
for clientConn := range s.activeConn {
result = append(result, clientConn)
}
return result
}
// SendMessage a request to the specified client.
// The client is designated by the conn.
// conn can be gotten from context in services:
//
// ctx.Value(RemoteConnContextKey)
//
// servicePath, serviceMethod, metadata can be set to zero values.
func (s *Server) SendMessage(conn net.Conn, servicePath, serviceMethod string, metadata map[string]string, data []byte) error {
ctx := share.WithValue(context.Background(), StartSendRequestContextKey, time.Now().UnixNano())
s.Plugins.DoPreWriteRequest(ctx)
req := protocol.NewMessage()
req.SetMessageType(protocol.Request)
seq := s.seq.Add(1)
req.SetSeq(seq)
req.SetOneway(true)
req.SetSerializeType(protocol.SerializeNone)
req.ServicePath = servicePath
req.ServiceMethod = serviceMethod
req.Metadata = metadata
req.Payload = data
b := req.EncodeSlicePointer()
_, err := conn.Write(*b)
protocol.PutData(b)
s.Plugins.DoPostWriteRequest(ctx, req, err)
return err
}
func (s *Server) getDoneChan() <-chan struct{} {
return s.doneChan
}
// Serve starts and listens RPC requests.
// It is blocked until receiving connections from clients.
func (s *Server) Serve(network, address string) (err error) {
var ln net.Listener
ln, err = s.makeListener(network, address)
if err != nil {
return err
}
defer s.UnregisterAll()
if network == "http" {
s.serveByHTTP(ln, "")
return nil
}
if network == "ws" || network == "wss" {
s.serveByWS(ln, "")
return nil
}
// try to start gateway
ln = s.startGateway(network, ln)
return s.serveListener(ln)
}
// ServeListener listens RPC requests.
// It is blocked until receiving connections from clients.
func (s *Server) ServeListener(network string, ln net.Listener) (err error) {
defer s.UnregisterAll()
if network == "http" {
s.serveByHTTP(ln, "")
return nil
}
// try to start gateway
ln = s.startGateway(network, ln)
return s.serveListener(ln)
}
// serveListener accepts incoming connections on the Listener ln,
// creating a new service goroutine for each.
// The service goroutines read requests and then call services to reply to them.
func (s *Server) serveListener(ln net.Listener) error {
var tempDelay time.Duration
s.mu.Lock()
s.ln = ln
close(s.Started)
s.mu.Unlock()
for {
conn, e := ln.Accept()
if e != nil {
if s.isShutdown() {
<-s.doneChan
return ErrServerClosed
}
if ne, ok := e.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
} else {
tempDelay *= 2
}
if max := 1 * time.Second; tempDelay > max {
tempDelay = max
}
log.Errorf("rpcx: Accept error: %v; retrying in %v", e, tempDelay)
time.Sleep(tempDelay)
continue
}
if errors.Is(e, cmux.ErrListenerClosed) {
return ErrServerClosed
}
return e
}
tempDelay = 0
if tc, ok := conn.(*net.TCPConn); ok {
period := s.options["TCPKeepAlivePeriod"]
if period != nil {
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(period.(time.Duration))
tc.SetLinger(10)
}
}
conn, ok := s.Plugins.DoPostConnAccept(conn)
if !ok {
conn.Close()
continue
}
s.mu.Lock()
s.activeConn[conn] = struct{}{}
s.mu.Unlock()
if share.Trace {
log.Debugf("server accepted an conn: %v", conn.RemoteAddr().String())
}
go s.serveConn(conn)
}
}
// serveByHTTP serves by HTTP.
// if rpcPath is an empty string, use share.DefaultRPCPath.
func (s *Server) serveByHTTP(ln net.Listener, rpcPath string) {
s.ln = ln
if rpcPath == "" {
rpcPath = share.DefaultRPCPath
}
mux := http.NewServeMux()
mux.Handle(rpcPath, s)
srv := &http.Server{Handler: mux}
srv.Serve(ln)
}
func (s *Server) serveByWS(ln net.Listener, rpcPath string) {
s.ln = ln
if rpcPath == "" {
rpcPath = share.DefaultRPCPath
}
mux := http.NewServeMux()
mux.Handle(rpcPath, websocket.Handler(s.ServeWS))
srv := &http.Server{Handler: mux}
srv.Serve(ln)
}
func (s *Server) sendResponse(ctx *share.Context, conn net.Conn, err error, req, res *protocol.Message) {
if len(res.Payload) > 1024 && req.CompressType() != protocol.None {
res.SetCompressType(req.CompressType())
}
s.Plugins.DoPreWriteResponse(ctx, req, res, err)
data := res.EncodeSlicePointer()
if s.AsyncWrite {
if s.pool != nil {
s.pool.Submit(func() {
if s.writeTimeout != 0 {
conn.SetWriteDeadline(time.Now().Add(s.writeTimeout))
}
conn.Write(*data)
protocol.PutData(data)
})
} else {
go func() {
if s.writeTimeout != 0 {
conn.SetWriteDeadline(time.Now().Add(s.writeTimeout))
}
conn.Write(*data)
protocol.PutData(data)
}()
}
} else {
if s.writeTimeout != 0 {
conn.SetWriteDeadline(time.Now().Add(s.writeTimeout))
}
conn.Write(*data)
protocol.PutData(data)
}
s.Plugins.DoPostWriteResponse(ctx, req, res, err)
}
func (s *Server) serveConn(conn net.Conn) {
if s.isShutdown() {
s.closeConn(conn)
return
}
defer func() {
if err := recover(); err != nil {
const size = 64 << 10
buf := make([]byte, size)
ss := runtime.Stack(buf, false)
if ss > size {
ss = size
}
buf = buf[:ss]
log.Errorf("serving %s panic error: %s, stack:\n %s", conn.RemoteAddr(), err, buf)
}
if share.Trace {
log.Debugf("server closed conn: %v", conn.RemoteAddr().String())
}
// make sure all inflight requests are handled and all drained
if s.isShutdown() {
<-s.doneChan
}
s.closeConn(conn)
}()
if tlsConn, ok := conn.(*tls.Conn); ok {
if d := s.readTimeout; d != 0 {
conn.SetReadDeadline(time.Now().Add(d))
}
if d := s.writeTimeout; d != 0 {
conn.SetWriteDeadline(time.Now().Add(d))
}
if err := tlsConn.Handshake(); err != nil {
log.Errorf("rpcx: TLS handshake error from %s: %v", conn.RemoteAddr(), err)
return
}
}
r := bufio.NewReaderSize(conn, ReaderBuffsize)
// read requests and handle it
for {
if s.isShutdown() {
return
}
t0 := time.Now()
if s.readTimeout != 0 {
conn.SetReadDeadline(t0.Add(s.readTimeout))
}
// create a rpcx Context
ctx := share.WithValue(context.Background(), RemoteConnContextKey, conn)
// read a request from the underlying connection
req, err := s.readRequest(ctx, r)
if err != nil {
if errors.Is(err, io.EOF) {
log.Infof("client has closed this connection: %s", conn.RemoteAddr().String())
} else if errors.Is(err, net.ErrClosed) {
log.Infof("rpcx: connection %s is closed", conn.RemoteAddr().String())
} else if errors.Is(err, ErrReqReachLimit) {
if !req.IsOneway() { // return a error response
res := req.Clone()
res.SetMessageType(protocol.Response)
s.handleError(res, err)
s.sendResponse(ctx, conn, err, req, res)
} else { // Oneway and only call the plugins
s.Plugins.DoPreWriteResponse(ctx, req, nil, err)
}
continue
} else { // wrong data
log.Warnf("rpcx: failed to read request: %v", err)
}
if s.HandleServiceError != nil {
s.HandleServiceError(err)
}
return
}
if share.Trace {
log.Debugf("server received an request %+v from conn: %v", req, conn.RemoteAddr().String())
}
ctx = share.WithLocalValue(ctx, StartRequestContextKey, time.Now().UnixNano())
closeConn := false
if !req.IsHeartbeat() {
err = s.auth(ctx, req)
closeConn = err != nil
}
if err != nil {
if !req.IsOneway() { // return a error response
res := req.Clone()
res.SetMessageType(protocol.Response)
s.handleError(res, err)
s.sendResponse(ctx, conn, err, req, res)
} else {
s.Plugins.DoPreWriteResponse(ctx, req, nil, err)
}
if s.HandleServiceError != nil {
s.HandleServiceError(err)
}
// auth failed, closed the connection
if closeConn {
log.Infof("auth failed for conn %s: %v", conn.RemoteAddr().String(), err)
return
}
continue
}
if s.pool != nil {
s.pool.Submit(func() {
s.processOneRequest(ctx, req, conn)
})
} else {
go s.processOneRequest(ctx, req, conn)
}
}
}
func (s *Server) processOneRequest(ctx *share.Context, req *protocol.Message, conn net.Conn) {
defer func() {
if r := recover(); r != nil {
buf := make([]byte, 1024)
buf = buf[:runtime.Stack(buf, true)]
if s.HandleServiceError != nil {
s.HandleServiceError(fmt.Errorf("%v", r))
} else {
log.Errorf("[handler internal error]: servicepath: %s, servicemethod: %s, err: %vstacks: %s", req.ServicePath, req.ServiceMethod, r, string(buf))
}
sctx := NewContext(ctx, conn, req, s.AsyncWrite)
sctx.WriteError(fmt.Errorf("%v", r))
}
}()
atomic.AddInt32(&s.handlerMsgNum, 1)
defer atomic.AddInt32(&s.handlerMsgNum, -1)
// 心跳请求,直接处理返回
if req.IsHeartbeat() {
s.Plugins.DoHeartbeatRequest(ctx, req)
req.SetMessageType(protocol.Response)
data := req.EncodeSlicePointer()
if s.writeTimeout != 0 {
conn.SetWriteDeadline(time.Now().Add(s.writeTimeout))
}
conn.Write(*data)
protocol.PutData(data)
return
}
cancelFunc := parseServerTimeout(ctx, req)
if cancelFunc != nil {
defer cancelFunc()
}
resMetadata := make(map[string]string)
if req.Metadata == nil {
req.Metadata = make(map[string]string)
}
ctx = share.WithLocalValue(share.WithLocalValue(ctx, share.ReqMetaDataKey, req.Metadata),
share.ResMetaDataKey, resMetadata)
s.Plugins.DoPreHandleRequest(ctx, req)
if share.Trace {
log.Debugf("server handle request %+v from conn: %v", req, conn.RemoteAddr().String())
}
// use handlers first
if handler, ok := s.router[req.ServicePath+"."+req.ServiceMethod]; ok {
sctx := NewContext(ctx, conn, req, s.AsyncWrite)
err := handler(sctx)
if err != nil {
if s.HandleServiceError != nil {
s.HandleServiceError(err)
} else {
log.Errorf("[handler internal error]: servicepath: %s, servicemethod, err: %v", req.ServicePath, req.ServiceMethod, err)
}
sctx.WriteError(err)
}
return
}
res, err := s.handleRequest(ctx, req)
if err != nil {
if s.HandleServiceError != nil {
s.HandleServiceError(err)
} else {
log.Warnf("rpcx: failed to handle request: %v", err)
}
}
if !req.IsOneway() {
if len(resMetadata) > 0 { // copy meta in context to responses
meta := res.Metadata
if meta == nil {
res.Metadata = resMetadata
} else {
for k, v := range resMetadata {
if meta[k] == "" {
meta[k] = v
}
}
}
}
s.sendResponse(ctx, conn, err, req, res)
}
if share.Trace {
log.Debugf("server write response %+v for an request %+v from conn: %v", res, req, conn.RemoteAddr().String())
}
}
func parseServerTimeout(ctx *share.Context, req *protocol.Message) context.CancelFunc {
if req == nil || req.Metadata == nil {
return nil
}
st := req.Metadata[share.ServerTimeout]
if st == "" {
return nil
}
timeout, err := strconv.ParseInt(st, 10, 64)
if err != nil {
return nil
}
newCtx, cancel := context.WithTimeout(ctx.Context, time.Duration(timeout)*time.Millisecond)
ctx.Context = newCtx
return cancel
}
func (s *Server) isShutdown() bool {
return atomic.LoadInt32(&s.inShutdown) == 1
}
func (s *Server) closeConn(conn net.Conn) {
s.mu.Lock()
delete(s.activeConn, conn)
s.mu.Unlock()
conn.Close()
s.Plugins.DoPostConnClose(conn)
}
func (s *Server) readRequest(ctx context.Context, r io.Reader) (req *protocol.Message, err error) {
err = s.Plugins.DoPreReadRequest(ctx)
if err != nil {
return nil, err
}
// pool req?
req = protocol.NewMessage()
err = req.Decode(r)
if err == io.EOF {
return req, err
}
perr := s.Plugins.DoPostReadRequest(ctx, req, err)
if err == nil {
err = perr
}
return req, err
}
func (s *Server) auth(ctx context.Context, req *protocol.Message) error {
if s.AuthFunc != nil {
token := req.Metadata[share.AuthKey]
return s.AuthFunc(ctx, req, token)
}
return nil
}
func (s *Server) handleRequest(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
serviceName := req.ServicePath
methodName := req.ServiceMethod
res = req.Clone()
res.SetMessageType(protocol.Response)
s.serviceMapMu.RLock()
service := s.serviceMap[serviceName]
if share.Trace {
log.Debugf("server get service %+v for an request %+v", service, req)
}
s.serviceMapMu.RUnlock()
if service == nil {
err = errors.New("rpcx: can't find service " + serviceName)
return s.handleError(res, err)
}
mtype := service.method[methodName]
if mtype == nil {
if service.function[methodName] != nil { // check raw functions
return s.handleRequestForFunction(ctx, req)
}
err = errors.New("rpcx: can't find method " + methodName)
return s.handleError(res, err)
}
// get a argv object from object pool
argv := reflectTypePools.Get(mtype.ArgType)
codec := share.Codecs[req.SerializeType()]
if codec == nil {
err = fmt.Errorf("can not find codec for %d", req.SerializeType())
return s.handleError(res, err)
}
err = codec.Decode(req.Payload, argv)
if err != nil {
return s.handleError(res, err)
}
// and get a reply object from object pool
replyv := reflectTypePools.Get(mtype.ReplyType)
argv, err = s.Plugins.DoPreCall(ctx, serviceName, methodName, argv)
if err != nil {
// return reply to object pool
reflectTypePools.Put(mtype.ReplyType, replyv)
return s.handleError(res, err)
}
if mtype.ArgType.Kind() != reflect.Ptr {
err = service.call(ctx, mtype, reflect.ValueOf(argv).Elem(), reflect.ValueOf(replyv))
} else {
err = service.call(ctx, mtype, reflect.ValueOf(argv), reflect.ValueOf(replyv))
}
replyv, err1 := s.Plugins.DoPostCall(ctx, serviceName, methodName, argv, replyv, err)
if err == nil {
err = err1
}
// return argc to object pool
reflectTypePools.Put(mtype.ArgType, argv)
if err != nil {
if replyv != nil {
data, err := codec.Encode(replyv)
// return reply to object pool
reflectTypePools.Put(mtype.ReplyType, replyv)
if err != nil {
return s.handleError(res, err)
}
res.Payload = data
}
return s.handleError(res, err)
}
if !req.IsOneway() {
data, err := codec.Encode(replyv)
// return reply to object pool
reflectTypePools.Put(mtype.ReplyType, replyv)
if err != nil {
return s.handleError(res, err)
}
res.Payload = data
} else if replyv != nil {
reflectTypePools.Put(mtype.ReplyType, replyv)
}
if share.Trace {
log.Debugf("server called service %+v for an request %+v", service, req)
}
return res, nil
}
func (s *Server) handleRequestForFunction(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
res = req.Clone()
res.SetMessageType(protocol.Response)
serviceName := req.ServicePath
methodName := req.ServiceMethod
s.serviceMapMu.RLock()
service := s.serviceMap[serviceName]
s.serviceMapMu.RUnlock()
if service == nil {
err = errors.New("rpcx: can't find service for func raw function")
return s.handleError(res, err)
}
mtype := service.function[methodName]
if mtype == nil {
err = errors.New("rpcx: can't find method " + methodName)
return s.handleError(res, err)
}
argv := reflectTypePools.Get(mtype.ArgType)
codec := share.Codecs[req.SerializeType()]
if codec == nil {
err = fmt.Errorf("can not find codec for %d", req.SerializeType())
return s.handleError(res, err)
}
err = codec.Decode(req.Payload, argv)
if err != nil {
return s.handleError(res, err)
}
replyv := reflectTypePools.Get(mtype.ReplyType)
argv, err = s.Plugins.DoPreCall(ctx, serviceName, methodName, argv)
if err != nil {
// return reply to object pool
reflectTypePools.Put(mtype.ReplyType, replyv)
return s.handleError(res, err)
}
if mtype.ArgType.Kind() != reflect.Ptr {
err = service.callForFunction(ctx, mtype, reflect.ValueOf(argv).Elem(), reflect.ValueOf(replyv))
} else {
err = service.callForFunction(ctx, mtype, reflect.ValueOf(argv), reflect.ValueOf(replyv))
}
replyv, err1 := s.Plugins.DoPostCall(ctx, serviceName, methodName, argv, replyv, err)
if err == nil {
err = err1
}
reflectTypePools.Put(mtype.ArgType, argv)
if err != nil {
reflectTypePools.Put(mtype.ReplyType, replyv)
return s.handleError(res, err)
}
if !req.IsOneway() {
data, err := codec.Encode(replyv)
reflectTypePools.Put(mtype.ReplyType, replyv)
if err != nil {
return s.handleError(res, err)
}
res.Payload = data
} else if replyv != nil {
reflectTypePools.Put(mtype.ReplyType, replyv)
}
return res, nil
}
func (s *Server) handleError(res *protocol.Message, err error) (*protocol.Message, error) {
res.SetMessageStatusType(protocol.Error)
if res.Metadata == nil {
res.Metadata = make(map[string]string)
}
if s.ServerErrorFunc != nil {
res.Metadata[protocol.ServiceError] = s.ServerErrorFunc(res, err)
} else {
res.Metadata[protocol.ServiceError] = err.Error()
}
return res, err
}
// Can connect to RPC service using HTTP CONNECT to rpcPath.
var connected = "200 Connected to rpcx"
// ServeHTTP implements an http.Handler that answers RPC requests.
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodConnect {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusMethodNotAllowed)
io.WriteString(w, "405 must CONNECT\n")
return
}
conn, _, err := w.(http.Hijacker).Hijack()
if err != nil {
log.Info("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
return
}
io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
s.mu.Lock()
s.activeConn[conn] = struct{}{}
s.mu.Unlock()
s.serveConn(conn)
}
func (s *Server) ServeWS(conn *websocket.Conn) {
s.mu.Lock()
s.activeConn[conn] = struct{}{}
s.mu.Unlock()
conn.PayloadType = websocket.BinaryFrame
s.serveConn(conn)
}
// Close immediately closes all active net.Listeners.
func (s *Server) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
var err error
if s.ln != nil {
err = s.ln.Close()
}
for c := range s.activeConn {
c.Close()
delete(s.activeConn, c)
s.Plugins.DoPostConnClose(c)
}
s.closeDoneChanLocked()
if s.pool != nil {
s.pool.StopAndWaitFor(10 * time.Second)
}
return err
}
// RegisterOnShutdown registers a function to call on Shutdown.
// This can be used to gracefully shutdown connections.
func (s *Server) RegisterOnShutdown(f func(s *Server)) {
s.mu.Lock()
s.onShutdown = append(s.onShutdown, f)
s.mu.Unlock()
}
// RegisterOnRestart registers a function to call on Restart.
func (s *Server) RegisterOnRestart(f func(s *Server)) {
s.mu.Lock()
s.onRestart = append(s.onRestart, f)
s.mu.Unlock()
}
var shutdownPollInterval = 1000 * time.Millisecond
// Shutdown gracefully shuts down the server without interrupting any
// active connections. Shutdown works by first closing the
// listener, then closing all idle connections, and then waiting
// indefinitely for connections to return to idle and then shut down.
// If the provided context expires before the shutdown is complete,
// Shutdown returns the context's error, otherwise it returns any
// error returned from closing the Server's underlying Listener.
func (s *Server) Shutdown(ctx context.Context) error {
var err error
if atomic.CompareAndSwapInt32(&s.inShutdown, 0, 1) {
log.Info("shutdown begin")
s.mu.Lock()
// 主动注销注册的服务
s.UnregisterAll()
if s.ln != nil {
s.ln.Close()
}
for conn := range s.activeConn {
if tcpConn, ok := conn.(*net.TCPConn); ok {
tcpConn.CloseRead()
}
}
s.mu.Unlock()
// wait all in-processing requests finish.
ticker := time.NewTicker(shutdownPollInterval)
defer ticker.Stop()
outer:
for {
if s.checkProcessMsg() {
break
}
select {
case <-ctx.Done():
err = ctx.Err()
break outer
case <-ticker.C:
}
}
s.jsonrpcHTTPServerLock.Lock()
if s.gatewayHTTPServer != nil {
if err := s.closeHTTP1APIGateway(ctx); err != nil {
log.Warnf("failed to close gateway: %v", err)
} else {
log.Info("closed gateway")
}
}
s.jsonrpcHTTPServerLock.Unlock()
if s.jsonrpcHTTPServer != nil {
if err := s.closeJSONRPC2(ctx); err != nil {
log.Warnf("failed to close JSONRPC: %v", err)
} else {
log.Info("closed JSONRPC")
}
}
s.mu.Lock()
for conn := range s.activeConn {
conn.Close()
delete(s.activeConn, conn)
s.Plugins.DoPostConnClose(conn)
}
s.closeDoneChanLocked()
s.mu.Unlock()
log.Info("shutdown end")
}
return err
}
// Restart restarts this server gracefully.
// It starts a new rpcx server with the same port with SO_REUSEPORT socket option,
// and shutdown this rpcx server gracefully.
func (s *Server) Restart(ctx context.Context) error {
pid, err := s.startProcess()
if err != nil {
return err
}
log.Infof("restart a new rpcx server: %d", pid)
// TODO: is it necessary?
time.Sleep(3 * time.Second)
return s.Shutdown(ctx)
}
func (s *Server) startProcess() (int, error) {
argv0, err := exec.LookPath(os.Args[0])
if err != nil {
return 0, err
}
// Pass on the environment and replace the old count key with the new one.
var env []string
env = append(env, os.Environ()...)
originalWD, _ := os.Getwd()
allFiles := []*os.File{os.Stdin, os.Stdout, os.Stderr}
process, err := os.StartProcess(argv0, os.Args, &os.ProcAttr{
Dir: originalWD,
Env: env,
Files: allFiles,
})
if err != nil {
return 0, err
}
return process.Pid, nil
}
func (s *Server) checkProcessMsg() bool {
size := atomic.LoadInt32(&s.handlerMsgNum)
log.Info("need handle in-processing msg size:", size)
return size == 0
}
func (s *Server) closeDoneChanLocked() {
select {
case <-s.doneChan:
// Already closed. Don't close again.
default:
// Safe to close here. We're the only closer, guarded
// by s.mu.RegisterName
close(s.doneChan)
}
}
var ip4Reg = regexp.MustCompile(`^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$`)
func validIP4(ipAddress string) bool {
ipAddress = strings.Trim(ipAddress, " ")
i := strings.LastIndex(ipAddress, ":")
ipAddress = ipAddress[:i] // remove port
return ip4Reg.MatchString(ipAddress)
}
func validIP6(ipAddress string) bool {
ipAddress = strings.Trim(ipAddress, " ")
i := strings.LastIndex(ipAddress, ":")
ipAddress = ipAddress[:i] // remove port
ipAddress = strings.TrimPrefix(ipAddress, "[")
ipAddress = strings.TrimSuffix(ipAddress, "]")
ip := net.ParseIP(ipAddress)
if ip != nil && ip.To4() == nil {
return true
} else {
return false
}
}