Files
rpcx/server/server.go
2022-02-16 13:15:46 +08:00

1012 lines
24 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package server
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"github.com/soheilhy/cmux"
"io"
"net"
"net/http"
"os"
"os/exec"
"reflect"
"regexp"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/smallnest/rpcx/log"
"github.com/smallnest/rpcx/protocol"
"github.com/smallnest/rpcx/share"
"golang.org/x/net/websocket"
)
// ErrServerClosed is returned by the Server's Serve, ListenAndServe after a call to Shutdown or Close.
var (
ErrServerClosed = errors.New("http: Server closed")
ErrReqReachLimit = errors.New("request reached rate limit")
)
const (
// ReaderBuffsize is used for bufio reader.
ReaderBuffsize = 1024
// WriterBuffsize is used for bufio writer.
WriterBuffsize = 1024
// WriteChanSize is used for response.
WriteChanSize = 1024 * 1024
)
// contextKey is a value for use with context.WithValue. It's used as
// a pointer so it fits in an interface{} without allocation.
type contextKey struct {
name string
}
func (k *contextKey) String() string { return "rpcx context value " + k.name }
var (
// RemoteConnContextKey is a context key. It can be used in
// services with context.WithValue to access the connection arrived on.
// The associated value will be of type net.Conn.
RemoteConnContextKey = &contextKey{"remote-conn"}
// StartRequestContextKey records the start time
StartRequestContextKey = &contextKey{"start-parse-request"}
// StartSendRequestContextKey records the start time
StartSendRequestContextKey = &contextKey{"start-send-request"}
// TagContextKey is used to record extra info in handling services. Its value is a map[string]interface{}
TagContextKey = &contextKey{"service-tag"}
// HttpConnContextKey is used to store http connection.
HttpConnContextKey = &contextKey{"http-conn"}
)
type Handler func(ctx *Context) error
// Server is rpcx server that use TCP or UDP.
type Server struct {
ln net.Listener
readTimeout time.Duration
writeTimeout time.Duration
gatewayHTTPServer *http.Server
DisableHTTPGateway bool // should disable http invoke or not.
DisableJSONRPC bool // should disable json rpc or not.
AsyncWrite bool // set true if your server only serves few clients
serviceMapMu sync.RWMutex
serviceMap map[string]*service
router map[string]Handler
mu sync.RWMutex
activeConn map[net.Conn]struct{}
doneChan chan struct{}
seq uint64
inShutdown int32
onShutdown []func(s *Server)
onRestart []func(s *Server)
// TLSConfig for creating tls tcp connection.
tlsConfig *tls.Config
// BlockCrypt for kcp.BlockCrypt
options map[string]interface{}
// CORS options
corsOptions *CORSOptions
Plugins PluginContainer
// AuthFunc can be used to auth.
AuthFunc func(ctx context.Context, req *protocol.Message, token string) error
handlerMsgNum int32
HandleServiceError func(error)
}
// NewServer returns a server.
func NewServer(options ...OptionFn) *Server {
s := &Server{
Plugins: &pluginContainer{},
options: make(map[string]interface{}),
activeConn: make(map[net.Conn]struct{}),
doneChan: make(chan struct{}),
serviceMap: make(map[string]*service),
router: make(map[string]Handler),
AsyncWrite: false, // 除非你想benchmark或者极致优化否则建议你设置为false
}
for _, op := range options {
op(s)
}
if s.options["TCPKeepAlivePeriod"] == nil {
s.options["TCPKeepAlivePeriod"] = 3 * time.Minute
}
return s
}
// Address returns listened address.
func (s *Server) Address() net.Addr {
s.mu.RLock()
defer s.mu.RUnlock()
if s.ln == nil {
return nil
}
return s.ln.Addr()
}
func (s *Server) AddHandler(servicePath, serviceMethod string, handler func(*Context) error) {
s.router[servicePath+"."+serviceMethod] = handler
}
// ActiveClientConn returns active connections.
func (s *Server) ActiveClientConn() []net.Conn {
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]net.Conn, 0, len(s.activeConn))
for clientConn := range s.activeConn {
result = append(result, clientConn)
}
return result
}
// SendMessage a request to the specified client.
// The client is designated by the conn.
// conn can be gotten from context in services:
//
// ctx.Value(RemoteConnContextKey)
//
// servicePath, serviceMethod, metadata can be set to zero values.
func (s *Server) SendMessage(conn net.Conn, servicePath, serviceMethod string, metadata map[string]string, data []byte) error {
ctx := share.WithValue(context.Background(), StartSendRequestContextKey, time.Now().UnixNano())
s.Plugins.DoPreWriteRequest(ctx)
req := protocol.GetPooledMsg()
req.SetMessageType(protocol.Request)
seq := atomic.AddUint64(&s.seq, 1)
req.SetSeq(seq)
req.SetOneway(true)
req.SetSerializeType(protocol.SerializeNone)
req.ServicePath = servicePath
req.ServiceMethod = serviceMethod
req.Metadata = metadata
req.Payload = data
b := req.EncodeSlicePointer()
_, err := conn.Write(*b)
protocol.PutData(b)
s.Plugins.DoPostWriteRequest(ctx, req, err)
protocol.FreeMsg(req)
return err
}
func (s *Server) getDoneChan() <-chan struct{} {
return s.doneChan
}
// Serve starts and listens RPC requests.
// It is blocked until receiving connections from clients.
func (s *Server) Serve(network, address string) (err error) {
var ln net.Listener
ln, err = s.makeListener(network, address)
if err != nil {
return
}
if network == "http" {
s.serveByHTTP(ln, "")
return nil
}
if network == "ws" || network == "wss" {
s.serveByWS(ln, "")
return nil
}
// try to start gateway
ln = s.startGateway(network, ln)
return s.serveListener(ln)
}
// ServeListener listens RPC requests.
// It is blocked until receiving connections from clients.
func (s *Server) ServeListener(network string, ln net.Listener) (err error) {
if network == "http" {
s.serveByHTTP(ln, "")
return nil
}
// try to start gateway
ln = s.startGateway(network, ln)
return s.serveListener(ln)
}
// serveListener accepts incoming connections on the Listener ln,
// creating a new service goroutine for each.
// The service goroutines read requests and then call services to reply to them.
func (s *Server) serveListener(ln net.Listener) error {
var tempDelay time.Duration
s.mu.Lock()
s.ln = ln
s.mu.Unlock()
for {
conn, e := ln.Accept()
if e != nil {
select {
case <-s.getDoneChan():
return ErrServerClosed
default:
}
if ne, ok := e.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
} else {
tempDelay *= 2
}
if max := 1 * time.Second; tempDelay > max {
tempDelay = max
}
log.Errorf("rpcx: Accept error: %v; retrying in %v", e, tempDelay)
time.Sleep(tempDelay)
continue
}
if e == cmux.ErrListenerClosed {
return ErrServerClosed
}
return e
}
tempDelay = 0
if tc, ok := conn.(*net.TCPConn); ok {
period := s.options["TCPKeepAlivePeriod"]
if period != nil {
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(period.(time.Duration))
tc.SetLinger(10)
}
}
conn, ok := s.Plugins.DoPostConnAccept(conn)
if !ok {
conn.Close()
continue
}
s.mu.Lock()
s.activeConn[conn] = struct{}{}
s.mu.Unlock()
if share.Trace {
log.Debugf("server accepted an conn: %v", conn.RemoteAddr().String())
}
go s.serveConn(conn)
}
}
// serveByHTTP serves by HTTP.
// if rpcPath is an empty string, use share.DefaultRPCPath.
func (s *Server) serveByHTTP(ln net.Listener, rpcPath string) {
s.ln = ln
if rpcPath == "" {
rpcPath = share.DefaultRPCPath
}
mux := http.NewServeMux()
mux.Handle(rpcPath, s)
srv := &http.Server{Handler: mux}
srv.Serve(ln)
}
func (s *Server) serveByWS(ln net.Listener, rpcPath string) {
s.ln = ln
if rpcPath == "" {
rpcPath = share.DefaultRPCPath
}
mux := http.NewServeMux()
mux.Handle(rpcPath, websocket.Handler(s.ServeWS))
srv := &http.Server{Handler: mux}
srv.Serve(ln)
}
func (s *Server) serveConn(conn net.Conn) {
if s.isShutdown() {
s.closeConn(conn)
return
}
defer func() {
if err := recover(); err != nil {
const size = 64 << 10
buf := make([]byte, size)
ss := runtime.Stack(buf, false)
if ss > size {
ss = size
}
buf = buf[:ss]
log.Errorf("serving %s panic error: %s, stack:\n %s", conn.RemoteAddr(), err, buf)
}
if share.Trace {
log.Debugf("server closed conn: %v", conn.RemoteAddr().String())
}
s.closeConn(conn)
}()
if tlsConn, ok := conn.(*tls.Conn); ok {
if d := s.readTimeout; d != 0 {
conn.SetReadDeadline(time.Now().Add(d))
}
if d := s.writeTimeout; d != 0 {
conn.SetWriteDeadline(time.Now().Add(d))
}
if err := tlsConn.Handshake(); err != nil {
log.Errorf("rpcx: TLS handshake error from %s: %v", conn.RemoteAddr(), err)
return
}
}
r := bufio.NewReaderSize(conn, ReaderBuffsize)
var writeCh chan *[]byte
if s.AsyncWrite {
writeCh = make(chan *[]byte, 1)
defer close(writeCh)
go s.serveAsyncWrite(conn, writeCh)
}
for {
if s.isShutdown() {
return
}
t0 := time.Now()
if s.readTimeout != 0 {
conn.SetReadDeadline(t0.Add(s.readTimeout))
}
ctx := share.WithValue(context.Background(), RemoteConnContextKey, conn)
req, err := s.readRequest(ctx, r)
if err != nil {
if err == io.EOF {
log.Infof("client has closed this connection: %s", conn.RemoteAddr().String())
} else if errors.Is(err, net.ErrClosed) {
log.Infof("rpcx: connection %s is closed", conn.RemoteAddr().String())
} else if errors.Is(err, ErrReqReachLimit) {
if !req.IsOneway() {
res := req.Clone()
res.SetMessageType(protocol.Response)
if len(res.Payload) > 1024 && req.CompressType() != protocol.None {
res.SetCompressType(req.CompressType())
}
handleError(res, err)
s.Plugins.DoPreWriteResponse(ctx, req, res, err)
data := res.EncodeSlicePointer()
if s.AsyncWrite {
writeCh <- data
} else {
conn.Write(*data)
protocol.PutData(data)
}
s.Plugins.DoPostWriteResponse(ctx, req, res, err)
protocol.FreeMsg(res)
} else {
s.Plugins.DoPreWriteResponse(ctx, req, nil, err)
}
protocol.FreeMsg(req)
continue
} else {
log.Warnf("rpcx: failed to read request: %v", err)
}
protocol.FreeMsg(req)
return
}
if s.writeTimeout != 0 {
conn.SetWriteDeadline(t0.Add(s.writeTimeout))
}
if share.Trace {
log.Debugf("server received an request %+v from conn: %v", req, conn.RemoteAddr().String())
}
ctx = share.WithLocalValue(ctx, StartRequestContextKey, time.Now().UnixNano())
closeConn := false
if !req.IsHeartbeat() {
err = s.auth(ctx, req)
closeConn = err != nil
}
if err != nil {
if !req.IsOneway() {
res := req.Clone()
res.SetMessageType(protocol.Response)
if len(res.Payload) > 1024 && req.CompressType() != protocol.None {
res.SetCompressType(req.CompressType())
}
handleError(res, err)
s.Plugins.DoPreWriteResponse(ctx, req, res, err)
data := res.EncodeSlicePointer()
if s.AsyncWrite {
writeCh <- data
} else {
conn.Write(*data)
protocol.PutData(data)
}
s.Plugins.DoPostWriteResponse(ctx, req, res, err)
protocol.FreeMsg(res)
} else {
s.Plugins.DoPreWriteResponse(ctx, req, nil, err)
}
protocol.FreeMsg(req)
// auth failed, closed the connection
if closeConn {
log.Infof("auth failed for conn %s: %v", conn.RemoteAddr().String(), err)
return
}
continue
}
go func() {
defer func() {
if r := recover(); r != nil {
// maybe panic because the writeCh is closed.
log.Errorf("failed to handle request: %v", r)
}
}()
atomic.AddInt32(&s.handlerMsgNum, 1)
defer atomic.AddInt32(&s.handlerMsgNum, -1)
if req.IsHeartbeat() {
s.Plugins.DoHeartbeatRequest(ctx, req)
req.SetMessageType(protocol.Response)
data := req.EncodeSlicePointer()
if s.AsyncWrite {
writeCh <- data
} else {
conn.Write(*data)
protocol.PutData(data)
}
protocol.FreeMsg(req)
return
}
resMetadata := make(map[string]string)
ctx = share.WithLocalValue(share.WithLocalValue(ctx, share.ReqMetaDataKey, req.Metadata),
share.ResMetaDataKey, resMetadata)
cancelFunc := parseServerTimeout(ctx, req)
if cancelFunc != nil {
defer cancelFunc()
}
s.Plugins.DoPreHandleRequest(ctx, req)
if share.Trace {
log.Debugf("server handle request %+v from conn: %v", req, conn.RemoteAddr().String())
}
// first use handler
if handler, ok := s.router[req.ServicePath+"."+req.ServiceMethod]; ok {
sctx := NewContext(ctx, conn, req, writeCh)
err := handler(sctx)
if err != nil {
log.Errorf("[handler internal error]: servicepath: %s, servicemethod, err: %v", req.ServicePath, req.ServiceMethod, err)
}
return
}
//
res, err := s.handleRequest(ctx, req)
if err != nil {
if s.HandleServiceError != nil {
s.HandleServiceError(err)
} else {
log.Warnf("rpcx: failed to handle request: %v", err)
}
}
s.Plugins.DoPreWriteResponse(ctx, req, res, err)
if !req.IsOneway() {
if len(resMetadata) > 0 { // copy meta in context to request
meta := res.Metadata
if meta == nil {
res.Metadata = resMetadata
} else {
for k, v := range resMetadata {
if meta[k] == "" {
meta[k] = v
}
}
}
}
if len(res.Payload) > 1024 && req.CompressType() != protocol.None {
res.SetCompressType(req.CompressType())
}
data := res.EncodeSlicePointer()
if s.AsyncWrite {
writeCh <- data
} else {
conn.Write(*data)
protocol.PutData(data)
}
}
s.Plugins.DoPostWriteResponse(ctx, req, res, err)
if share.Trace {
log.Debugf("server write response %+v for an request %+v from conn: %v", res, req, conn.RemoteAddr().String())
}
protocol.FreeMsg(req)
protocol.FreeMsg(res)
}()
}
}
func (s *Server) serveAsyncWrite(conn net.Conn, writeCh chan *[]byte) {
for {
select {
case <-s.doneChan:
return
case data := <-writeCh:
if data == nil {
return
}
conn.Write(*data)
protocol.PutData(data)
}
}
}
func parseServerTimeout(ctx *share.Context, req *protocol.Message) context.CancelFunc {
if req == nil || req.Metadata == nil {
return nil
}
st := req.Metadata[share.ServerTimeout]
if st == "" {
return nil
}
timeout, err := strconv.ParseInt(st, 10, 64)
if err != nil {
return nil
}
newCtx, cancel := context.WithTimeout(ctx.Context, time.Duration(timeout)*time.Millisecond)
ctx.Context = newCtx
return cancel
}
func (s *Server) isShutdown() bool {
return atomic.LoadInt32(&s.inShutdown) == 1
}
func (s *Server) closeConn(conn net.Conn) {
s.mu.Lock()
delete(s.activeConn, conn)
s.mu.Unlock()
conn.Close()
s.Plugins.DoPostConnClose(conn)
}
func (s *Server) readRequest(ctx context.Context, r io.Reader) (req *protocol.Message, err error) {
err = s.Plugins.DoPreReadRequest(ctx)
if err != nil {
return nil, err
}
// pool req?
req = protocol.GetPooledMsg()
err = req.Decode(r)
if err == io.EOF {
return req, err
}
perr := s.Plugins.DoPostReadRequest(ctx, req, err)
if err == nil {
err = perr
}
return req, err
}
func (s *Server) auth(ctx context.Context, req *protocol.Message) error {
if s.AuthFunc != nil {
token := req.Metadata[share.AuthKey]
return s.AuthFunc(ctx, req, token)
}
return nil
}
func (s *Server) handleRequest(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
serviceName := req.ServicePath
methodName := req.ServiceMethod
res = req.Clone()
res.SetMessageType(protocol.Response)
s.serviceMapMu.RLock()
service := s.serviceMap[serviceName]
if share.Trace {
log.Debugf("server get service %+v for an request %+v", service, req)
}
s.serviceMapMu.RUnlock()
if service == nil {
err = errors.New("rpcx: can't find service " + serviceName)
return handleError(res, err)
}
mtype := service.method[methodName]
if mtype == nil {
if service.function[methodName] != nil { // check raw functions
return s.handleRequestForFunction(ctx, req)
}
err = errors.New("rpcx: can't find method " + methodName)
return handleError(res, err)
}
// get a argv object from object pool
argv := reflectTypePools.Get(mtype.ArgType)
codec := share.Codecs[req.SerializeType()]
if codec == nil {
err = fmt.Errorf("can not find codec for %d", req.SerializeType())
return handleError(res, err)
}
err = codec.Decode(req.Payload, argv)
if err != nil {
return handleError(res, err)
}
// and get a reply object from object pool
replyv := reflectTypePools.Get(mtype.ReplyType)
argv, err = s.Plugins.DoPreCall(ctx, serviceName, methodName, argv)
if err != nil {
// return reply to object pool
reflectTypePools.Put(mtype.ReplyType, replyv)
return handleError(res, err)
}
if mtype.ArgType.Kind() != reflect.Ptr {
err = service.call(ctx, mtype, reflect.ValueOf(argv).Elem(), reflect.ValueOf(replyv))
} else {
err = service.call(ctx, mtype, reflect.ValueOf(argv), reflect.ValueOf(replyv))
}
if err == nil {
replyv, err = s.Plugins.DoPostCall(ctx, serviceName, methodName, argv, replyv)
}
// return argc to object pool
reflectTypePools.Put(mtype.ArgType, argv)
if err != nil {
if replyv != nil {
data, err := codec.Encode(replyv)
// return reply to object pool
reflectTypePools.Put(mtype.ReplyType, replyv)
if err != nil {
return handleError(res, err)
}
res.Payload = data
}
return handleError(res, err)
}
if !req.IsOneway() {
data, err := codec.Encode(replyv)
// return reply to object pool
reflectTypePools.Put(mtype.ReplyType, replyv)
if err != nil {
return handleError(res, err)
}
res.Payload = data
} else if replyv != nil {
reflectTypePools.Put(mtype.ReplyType, replyv)
}
if share.Trace {
log.Debugf("server called service %+v for an request %+v", service, req)
}
return res, nil
}
func (s *Server) handleRequestForFunction(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
res = req.Clone()
res.SetMessageType(protocol.Response)
serviceName := req.ServicePath
methodName := req.ServiceMethod
s.serviceMapMu.RLock()
service := s.serviceMap[serviceName]
s.serviceMapMu.RUnlock()
if service == nil {
err = errors.New("rpcx: can't find service for func raw function")
return handleError(res, err)
}
mtype := service.function[methodName]
if mtype == nil {
err = errors.New("rpcx: can't find method " + methodName)
return handleError(res, err)
}
argv := reflectTypePools.Get(mtype.ArgType)
codec := share.Codecs[req.SerializeType()]
if codec == nil {
err = fmt.Errorf("can not find codec for %d", req.SerializeType())
return handleError(res, err)
}
err = codec.Decode(req.Payload, argv)
if err != nil {
return handleError(res, err)
}
replyv := reflectTypePools.Get(mtype.ReplyType)
if mtype.ArgType.Kind() != reflect.Ptr {
err = service.callForFunction(ctx, mtype, reflect.ValueOf(argv).Elem(), reflect.ValueOf(replyv))
} else {
err = service.callForFunction(ctx, mtype, reflect.ValueOf(argv), reflect.ValueOf(replyv))
}
reflectTypePools.Put(mtype.ArgType, argv)
if err != nil {
reflectTypePools.Put(mtype.ReplyType, replyv)
return handleError(res, err)
}
if !req.IsOneway() {
data, err := codec.Encode(replyv)
reflectTypePools.Put(mtype.ReplyType, replyv)
if err != nil {
return handleError(res, err)
}
res.Payload = data
} else if replyv != nil {
reflectTypePools.Put(mtype.ReplyType, replyv)
}
return res, nil
}
func handleError(res *protocol.Message, err error) (*protocol.Message, error) {
res.SetMessageStatusType(protocol.Error)
if res.Metadata == nil {
res.Metadata = make(map[string]string)
}
res.Metadata[protocol.ServiceError] = err.Error()
return res, err
}
// Can connect to RPC service using HTTP CONNECT to rpcPath.
var connected = "200 Connected to rpcx"
// ServeHTTP implements an http.Handler that answers RPC requests.
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodConnect {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusMethodNotAllowed)
io.WriteString(w, "405 must CONNECT\n")
return
}
conn, _, err := w.(http.Hijacker).Hijack()
if err != nil {
log.Info("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
return
}
io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
s.mu.Lock()
s.activeConn[conn] = struct{}{}
s.mu.Unlock()
s.serveConn(conn)
}
func (s *Server) ServeWS(conn *websocket.Conn) {
s.mu.Lock()
s.activeConn[conn] = struct{}{}
s.mu.Unlock()
conn.PayloadType = websocket.BinaryFrame
s.serveConn(conn)
}
// Close immediately closes all active net.Listeners.
func (s *Server) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
var err error
if s.ln != nil {
err = s.ln.Close()
}
for c := range s.activeConn {
c.Close()
delete(s.activeConn, c)
s.Plugins.DoPostConnClose(c)
}
s.closeDoneChanLocked()
return err
}
// RegisterOnShutdown registers a function to call on Shutdown.
// This can be used to gracefully shutdown connections.
func (s *Server) RegisterOnShutdown(f func(s *Server)) {
s.mu.Lock()
s.onShutdown = append(s.onShutdown, f)
s.mu.Unlock()
}
// RegisterOnRestart registers a function to call on Restart.
func (s *Server) RegisterOnRestart(f func(s *Server)) {
s.mu.Lock()
s.onRestart = append(s.onRestart, f)
s.mu.Unlock()
}
var shutdownPollInterval = 1000 * time.Millisecond
// Shutdown gracefully shuts down the server without interrupting any
// active connections. Shutdown works by first closing the
// listener, then closing all idle connections, and then waiting
// indefinitely for connections to return to idle and then shut down.
// If the provided context expires before the shutdown is complete,
// Shutdown returns the context's error, otherwise it returns any
// error returned from closing the Server's underlying Listener.
func (s *Server) Shutdown(ctx context.Context) error {
var err error
if atomic.CompareAndSwapInt32(&s.inShutdown, 0, 1) {
log.Info("shutdown begin")
s.mu.Lock()
s.ln.Close()
for conn := range s.activeConn {
if tcpConn, ok := conn.(*net.TCPConn); ok {
tcpConn.CloseRead()
}
}
s.mu.Unlock()
// wait all in-processing requests finish.
ticker := time.NewTicker(shutdownPollInterval)
defer ticker.Stop()
outer:
for {
if s.checkProcessMsg() {
break
}
select {
case <-ctx.Done():
err = ctx.Err()
break outer
case <-ticker.C:
}
}
if s.gatewayHTTPServer != nil {
if err := s.closeHTTP1APIGateway(ctx); err != nil {
log.Warnf("failed to close gateway: %v", err)
} else {
log.Info("closed gateway")
}
}
s.mu.Lock()
for conn := range s.activeConn {
conn.Close()
delete(s.activeConn, conn)
s.Plugins.DoPostConnClose(conn)
}
s.closeDoneChanLocked()
s.mu.Unlock()
log.Info("shutdown end")
}
return err
}
// Restart restarts this server gracefully.
// It starts a new rpcx server with the same port with SO_REUSEPORT socket option,
// and shutdown this rpcx server gracefully.
func (s *Server) Restart(ctx context.Context) error {
pid, err := s.startProcess()
if err != nil {
return err
}
log.Infof("restart a new rpcx server: %d", pid)
// TODO: is it necessary?
time.Sleep(3 * time.Second)
return s.Shutdown(ctx)
}
func (s *Server) startProcess() (int, error) {
argv0, err := exec.LookPath(os.Args[0])
if err != nil {
return 0, err
}
// Pass on the environment and replace the old count key with the new one.
var env []string
env = append(env, os.Environ()...)
originalWD, _ := os.Getwd()
allFiles := []*os.File{os.Stdin, os.Stdout, os.Stderr}
process, err := os.StartProcess(argv0, os.Args, &os.ProcAttr{
Dir: originalWD,
Env: env,
Files: allFiles,
})
if err != nil {
return 0, err
}
return process.Pid, nil
}
func (s *Server) checkProcessMsg() bool {
size := atomic.LoadInt32(&s.handlerMsgNum)
log.Info("need handle in-processing msg size:", size)
return size == 0
}
func (s *Server) closeDoneChanLocked() {
select {
case <-s.doneChan:
// Already closed. Don't close again.
default:
// Safe to close here. We're the only closer, guarded
// by s.mu.RegisterName
close(s.doneChan)
}
}
var ip4Reg = regexp.MustCompile(`^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$`)
func validIP4(ipAddress string) bool {
ipAddress = strings.Trim(ipAddress, " ")
i := strings.LastIndex(ipAddress, ":")
ipAddress = ipAddress[:i] // remove port
return ip4Reg.MatchString(ipAddress)
}