mirror of
https://github.com/smallnest/rpcx.git
synced 2025-09-27 04:26:26 +08:00
1099 lines
26 KiB
Go
1099 lines
26 KiB
Go
package server
|
||
|
||
import (
|
||
"bufio"
|
||
"context"
|
||
"crypto/tls"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net"
|
||
"net/http"
|
||
"os"
|
||
"os/exec"
|
||
"reflect"
|
||
"regexp"
|
||
"runtime"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
"github.com/smallnest/rpcx/log"
|
||
"github.com/smallnest/rpcx/protocol"
|
||
"github.com/smallnest/rpcx/share"
|
||
"github.com/soheilhy/cmux"
|
||
"golang.org/x/net/websocket"
|
||
)
|
||
|
||
// ErrServerClosed is returned by the Server's Serve, ListenAndServe after a call to Shutdown or Close.
|
||
var (
|
||
ErrServerClosed = errors.New("http: Server closed")
|
||
ErrReqReachLimit = errors.New("request reached rate limit")
|
||
)
|
||
|
||
const (
|
||
// ReaderBuffsize is used for bufio reader.
|
||
ReaderBuffsize = 1024
|
||
// WriterBuffsize is used for bufio writer.
|
||
WriterBuffsize = 1024
|
||
|
||
// // WriteChanSize is used for response.
|
||
// WriteChanSize = 1024 * 1024
|
||
)
|
||
|
||
// contextKey is a value for use with context.WithValue. It's used as
|
||
// a pointer so it fits in an interface{} without allocation.
|
||
type contextKey struct {
|
||
name string
|
||
}
|
||
|
||
func (k *contextKey) String() string { return "rpcx context value " + k.name }
|
||
|
||
var (
|
||
// RemoteConnContextKey is a context key. It can be used in
|
||
// services with context.WithValue to access the connection arrived on.
|
||
// The associated value will be of type net.Conn.
|
||
RemoteConnContextKey = &contextKey{"remote-conn"}
|
||
// StartRequestContextKey records the start time
|
||
StartRequestContextKey = &contextKey{"start-parse-request"}
|
||
// StartSendRequestContextKey records the start time
|
||
StartSendRequestContextKey = &contextKey{"start-send-request"}
|
||
// TagContextKey is used to record extra info in handling services. Its value is a map[string]interface{}
|
||
TagContextKey = &contextKey{"service-tag"}
|
||
// HttpConnContextKey is used to store http connection.
|
||
HttpConnContextKey = &contextKey{"http-conn"}
|
||
)
|
||
|
||
type Handler func(ctx *Context) error
|
||
|
||
type WorkerPool interface {
|
||
Submit(task func())
|
||
StopAndWaitFor(deadline time.Duration)
|
||
Stop() context.Context
|
||
StopAndWait()
|
||
}
|
||
|
||
// Server is rpcx server that use TCP or UDP.
|
||
type Server struct {
|
||
ln net.Listener
|
||
readTimeout time.Duration
|
||
writeTimeout time.Duration
|
||
gatewayHTTPServer *http.Server
|
||
|
||
jsonrpcHTTPServerLock sync.Mutex
|
||
jsonrpcHTTPServer *http.Server
|
||
DisableHTTPGateway bool // disable http invoke or not.
|
||
DisableJSONRPC bool // disable json rpc or not.
|
||
AsyncWrite bool // set true if your server only serves few clients
|
||
pool WorkerPool
|
||
|
||
serviceMapMu sync.RWMutex
|
||
serviceMap map[string]*service
|
||
|
||
router map[string]Handler
|
||
|
||
mu sync.RWMutex
|
||
activeConn map[net.Conn]struct{}
|
||
doneChan chan struct{}
|
||
seq atomic.Uint64
|
||
|
||
inShutdown int32
|
||
onShutdown []func(s *Server)
|
||
onRestart []func(s *Server)
|
||
|
||
// TLSConfig for creating tls tcp connection.
|
||
tlsConfig *tls.Config
|
||
// BlockCrypt for kcp.BlockCrypt
|
||
options map[string]interface{}
|
||
// CORS options
|
||
corsOptions *CORSOptions
|
||
|
||
Plugins PluginContainer
|
||
|
||
// AuthFunc can be used to auth.
|
||
AuthFunc func(ctx context.Context, req *protocol.Message, token string) error
|
||
|
||
handlerMsgNum int32
|
||
requestCount atomic.Uint64
|
||
|
||
// HandleServiceError is used to get all service errors. You can use it write logs or others.
|
||
HandleServiceError func(error)
|
||
|
||
// ServerErrorFunc is a customized error handlers and you can use it to return customized error strings to clients.
|
||
// If not set, it use err.Error()
|
||
ServerErrorFunc func(res *protocol.Message, err error) string
|
||
|
||
// The server is started.
|
||
Started chan struct{}
|
||
unregisterAllOnce sync.Once
|
||
}
|
||
|
||
// NewServer returns a server.
|
||
func NewServer(options ...OptionFn) *Server {
|
||
s := &Server{
|
||
Plugins: &pluginContainer{},
|
||
options: make(map[string]interface{}),
|
||
activeConn: make(map[net.Conn]struct{}),
|
||
doneChan: make(chan struct{}),
|
||
serviceMap: make(map[string]*service),
|
||
router: make(map[string]Handler),
|
||
AsyncWrite: false, // 除非你想做进一步的优化测试,否则建议你设置为false
|
||
Started: make(chan struct{}),
|
||
}
|
||
|
||
for _, op := range options {
|
||
op(s)
|
||
}
|
||
|
||
if s.options["TCPKeepAlivePeriod"] == nil {
|
||
s.options["TCPKeepAlivePeriod"] = 3 * time.Minute
|
||
}
|
||
|
||
return s
|
||
}
|
||
|
||
// Address returns listened address.
|
||
func (s *Server) Address() net.Addr {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
if s.ln == nil {
|
||
return nil
|
||
}
|
||
return s.ln.Addr()
|
||
}
|
||
|
||
func (s *Server) AddHandler(servicePath, serviceMethod string, handler func(*Context) error) {
|
||
s.router[servicePath+"."+serviceMethod] = handler
|
||
}
|
||
|
||
// ActiveClientConn returns active connections.
|
||
func (s *Server) ActiveClientConn() []net.Conn {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
result := make([]net.Conn, 0, len(s.activeConn))
|
||
for clientConn := range s.activeConn {
|
||
result = append(result, clientConn)
|
||
}
|
||
return result
|
||
}
|
||
|
||
// SendMessage a request to the specified client.
|
||
// The client is designated by the conn.
|
||
// conn can be gotten from context in services:
|
||
//
|
||
// ctx.Value(RemoteConnContextKey)
|
||
//
|
||
// servicePath, serviceMethod, metadata can be set to zero values.
|
||
func (s *Server) SendMessage(conn net.Conn, servicePath, serviceMethod string, metadata map[string]string, data []byte) error {
|
||
ctx := share.WithValue(context.Background(), StartSendRequestContextKey, time.Now().UnixNano())
|
||
s.Plugins.DoPreWriteRequest(ctx)
|
||
|
||
req := protocol.NewMessage()
|
||
req.SetMessageType(protocol.Request)
|
||
|
||
seq := s.seq.Add(1)
|
||
req.SetSeq(seq)
|
||
req.SetOneway(true)
|
||
req.SetSerializeType(protocol.SerializeNone)
|
||
req.ServicePath = servicePath
|
||
req.ServiceMethod = serviceMethod
|
||
req.Metadata = metadata
|
||
req.Payload = data
|
||
|
||
b := req.EncodeSlicePointer()
|
||
_, err := conn.Write(*b)
|
||
protocol.PutData(b)
|
||
|
||
s.Plugins.DoPostWriteRequest(ctx, req, err)
|
||
|
||
return err
|
||
}
|
||
|
||
func (s *Server) getDoneChan() <-chan struct{} {
|
||
return s.doneChan
|
||
}
|
||
|
||
// Serve starts and listens RPC requests.
|
||
// It is blocked until receiving connections from clients.
|
||
func (s *Server) Serve(network, address string) (err error) {
|
||
var ln net.Listener
|
||
ln, err = s.makeListener(network, address)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
defer s.UnregisterAll()
|
||
|
||
if network == "http" {
|
||
s.serveByHTTP(ln, "")
|
||
return nil
|
||
}
|
||
|
||
if network == "ws" || network == "wss" {
|
||
s.serveByWS(ln, "")
|
||
return nil
|
||
}
|
||
|
||
// try to start gateway
|
||
ln = s.startGateway(network, ln)
|
||
|
||
return s.serveListener(ln)
|
||
}
|
||
|
||
// ServeListener listens RPC requests.
|
||
// It is blocked until receiving connections from clients.
|
||
func (s *Server) ServeListener(network string, ln net.Listener) (err error) {
|
||
defer s.UnregisterAll()
|
||
|
||
if network == "http" {
|
||
s.serveByHTTP(ln, "")
|
||
return nil
|
||
}
|
||
|
||
// try to start gateway
|
||
ln = s.startGateway(network, ln)
|
||
|
||
return s.serveListener(ln)
|
||
}
|
||
|
||
// serveListener accepts incoming connections on the Listener ln,
|
||
// creating a new service goroutine for each.
|
||
// The service goroutines read requests and then call services to reply to them.
|
||
func (s *Server) serveListener(ln net.Listener) error {
|
||
var tempDelay time.Duration
|
||
|
||
s.mu.Lock()
|
||
s.ln = ln
|
||
close(s.Started)
|
||
s.mu.Unlock()
|
||
|
||
for {
|
||
conn, e := ln.Accept()
|
||
if e != nil {
|
||
if s.isShutdown() {
|
||
<-s.doneChan
|
||
return ErrServerClosed
|
||
}
|
||
|
||
if ne, ok := e.(net.Error); ok && ne.Temporary() {
|
||
if tempDelay == 0 {
|
||
tempDelay = 5 * time.Millisecond
|
||
} else {
|
||
tempDelay *= 2
|
||
}
|
||
|
||
if max := 1 * time.Second; tempDelay > max {
|
||
tempDelay = max
|
||
}
|
||
|
||
log.Errorf("rpcx: Accept error: %v; retrying in %v", e, tempDelay)
|
||
time.Sleep(tempDelay)
|
||
continue
|
||
}
|
||
|
||
if errors.Is(e, cmux.ErrListenerClosed) {
|
||
return ErrServerClosed
|
||
}
|
||
return e
|
||
}
|
||
tempDelay = 0
|
||
|
||
if tc, ok := conn.(*net.TCPConn); ok {
|
||
period := s.options["TCPKeepAlivePeriod"]
|
||
if period != nil {
|
||
tc.SetKeepAlive(true)
|
||
tc.SetKeepAlivePeriod(period.(time.Duration))
|
||
tc.SetLinger(10)
|
||
}
|
||
}
|
||
|
||
conn, ok := s.Plugins.DoPostConnAccept(conn)
|
||
if !ok {
|
||
conn.Close()
|
||
continue
|
||
}
|
||
|
||
s.mu.Lock()
|
||
s.activeConn[conn] = struct{}{}
|
||
s.mu.Unlock()
|
||
|
||
if share.Trace {
|
||
log.Debugf("server accepted an conn: %v", conn.RemoteAddr().String())
|
||
}
|
||
|
||
go s.serveConn(conn)
|
||
}
|
||
}
|
||
|
||
// serveByHTTP serves by HTTP.
|
||
// if rpcPath is an empty string, use share.DefaultRPCPath.
|
||
func (s *Server) serveByHTTP(ln net.Listener, rpcPath string) {
|
||
s.ln = ln
|
||
|
||
if rpcPath == "" {
|
||
rpcPath = share.DefaultRPCPath
|
||
}
|
||
mux := http.NewServeMux()
|
||
mux.Handle(rpcPath, s)
|
||
srv := &http.Server{Handler: mux}
|
||
|
||
srv.Serve(ln)
|
||
}
|
||
|
||
func (s *Server) serveByWS(ln net.Listener, rpcPath string) {
|
||
s.ln = ln
|
||
|
||
if rpcPath == "" {
|
||
rpcPath = share.DefaultRPCPath
|
||
}
|
||
mux := http.NewServeMux()
|
||
mux.Handle(rpcPath, websocket.Handler(s.ServeWS))
|
||
srv := &http.Server{Handler: mux}
|
||
|
||
srv.Serve(ln)
|
||
}
|
||
|
||
func (s *Server) sendResponse(ctx *share.Context, conn net.Conn, err error, req, res *protocol.Message) {
|
||
if len(res.Payload) > 1024 && req.CompressType() != protocol.None {
|
||
res.SetCompressType(req.CompressType())
|
||
}
|
||
|
||
s.Plugins.DoPreWriteResponse(ctx, req, res, err)
|
||
|
||
data := res.EncodeSlicePointer()
|
||
if s.AsyncWrite {
|
||
if s.pool != nil {
|
||
s.pool.Submit(func() {
|
||
if s.writeTimeout != 0 {
|
||
conn.SetWriteDeadline(time.Now().Add(s.writeTimeout))
|
||
}
|
||
conn.Write(*data)
|
||
protocol.PutData(data)
|
||
})
|
||
} else {
|
||
go func() {
|
||
if s.writeTimeout != 0 {
|
||
conn.SetWriteDeadline(time.Now().Add(s.writeTimeout))
|
||
}
|
||
conn.Write(*data)
|
||
protocol.PutData(data)
|
||
}()
|
||
}
|
||
} else {
|
||
if s.writeTimeout != 0 {
|
||
conn.SetWriteDeadline(time.Now().Add(s.writeTimeout))
|
||
}
|
||
conn.Write(*data)
|
||
protocol.PutData(data)
|
||
}
|
||
s.Plugins.DoPostWriteResponse(ctx, req, res, err)
|
||
}
|
||
|
||
func (s *Server) serveConn(conn net.Conn) {
|
||
if s.isShutdown() {
|
||
s.closeConn(conn)
|
||
return
|
||
}
|
||
|
||
defer func() {
|
||
if err := recover(); err != nil {
|
||
const size = 64 << 10
|
||
buf := make([]byte, size)
|
||
ss := runtime.Stack(buf, false)
|
||
if ss > size {
|
||
ss = size
|
||
}
|
||
buf = buf[:ss]
|
||
log.Errorf("serving %s panic error: %s, stack:\n %s", conn.RemoteAddr(), err, buf)
|
||
}
|
||
|
||
if share.Trace {
|
||
log.Debugf("server closed conn: %v", conn.RemoteAddr().String())
|
||
}
|
||
|
||
// make sure all inflight requests are handled and all drained
|
||
if s.isShutdown() {
|
||
<-s.doneChan
|
||
}
|
||
|
||
s.closeConn(conn)
|
||
}()
|
||
|
||
if tlsConn, ok := conn.(*tls.Conn); ok {
|
||
if d := s.readTimeout; d != 0 {
|
||
conn.SetReadDeadline(time.Now().Add(d))
|
||
}
|
||
if d := s.writeTimeout; d != 0 {
|
||
conn.SetWriteDeadline(time.Now().Add(d))
|
||
}
|
||
if err := tlsConn.Handshake(); err != nil {
|
||
log.Errorf("rpcx: TLS handshake error from %s: %v", conn.RemoteAddr(), err)
|
||
return
|
||
}
|
||
}
|
||
|
||
r := bufio.NewReaderSize(conn, ReaderBuffsize)
|
||
|
||
// read requests and handle it
|
||
for {
|
||
if s.isShutdown() {
|
||
return
|
||
}
|
||
|
||
t0 := time.Now()
|
||
if s.readTimeout != 0 {
|
||
conn.SetReadDeadline(t0.Add(s.readTimeout))
|
||
}
|
||
|
||
// create a rpcx Context
|
||
ctx := share.WithValue(context.Background(), RemoteConnContextKey, conn)
|
||
|
||
// read a request from the underlying connection
|
||
req, err := s.readRequest(ctx, r)
|
||
if err != nil {
|
||
if errors.Is(err, io.EOF) {
|
||
log.Infof("client has closed this connection: %s", conn.RemoteAddr().String())
|
||
} else if errors.Is(err, net.ErrClosed) {
|
||
log.Infof("rpcx: connection %s is closed", conn.RemoteAddr().String())
|
||
} else if errors.Is(err, ErrReqReachLimit) {
|
||
if !req.IsOneway() { // return a error response
|
||
res := req.Clone()
|
||
res.SetMessageType(protocol.Response)
|
||
|
||
s.handleError(res, err)
|
||
s.sendResponse(ctx, conn, err, req, res)
|
||
} else { // Oneway and only call the plugins
|
||
s.Plugins.DoPreWriteResponse(ctx, req, nil, err)
|
||
}
|
||
continue
|
||
} else { // wrong data
|
||
log.Warnf("rpcx: failed to read request: %v", err)
|
||
}
|
||
|
||
if s.HandleServiceError != nil {
|
||
s.HandleServiceError(err)
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
if share.Trace {
|
||
log.Debugf("server received an request %+v from conn: %v", req, conn.RemoteAddr().String())
|
||
}
|
||
|
||
ctx = share.WithLocalValue(ctx, StartRequestContextKey, time.Now().UnixNano())
|
||
closeConn := false
|
||
if !req.IsHeartbeat() {
|
||
err = s.auth(ctx, req)
|
||
closeConn = err != nil
|
||
}
|
||
|
||
if err != nil {
|
||
if !req.IsOneway() { // return a error response
|
||
res := req.Clone()
|
||
res.SetMessageType(protocol.Response)
|
||
s.handleError(res, err)
|
||
s.sendResponse(ctx, conn, err, req, res)
|
||
} else {
|
||
s.Plugins.DoPreWriteResponse(ctx, req, nil, err)
|
||
}
|
||
|
||
if s.HandleServiceError != nil {
|
||
s.HandleServiceError(err)
|
||
}
|
||
|
||
// auth failed, closed the connection
|
||
if closeConn {
|
||
log.Infof("auth failed for conn %s: %v", conn.RemoteAddr().String(), err)
|
||
return
|
||
}
|
||
continue
|
||
}
|
||
|
||
if s.pool != nil {
|
||
s.pool.Submit(func() {
|
||
s.processOneRequest(ctx, req, conn)
|
||
})
|
||
} else {
|
||
go s.processOneRequest(ctx, req, conn)
|
||
}
|
||
}
|
||
}
|
||
|
||
func (s *Server) processOneRequest(ctx *share.Context, req *protocol.Message, conn net.Conn) {
|
||
defer func() {
|
||
if r := recover(); r != nil {
|
||
buf := make([]byte, 1024)
|
||
buf = buf[:runtime.Stack(buf, true)]
|
||
if s.HandleServiceError != nil {
|
||
s.HandleServiceError(fmt.Errorf("%v", r))
|
||
} else {
|
||
log.Errorf("[handler internal error]: servicepath: %s, servicemethod: %s, err: %v,stacks: %s", req.ServicePath, req.ServiceMethod, r, string(buf))
|
||
}
|
||
sctx := NewContext(ctx, conn, req, s.AsyncWrite)
|
||
sctx.WriteError(fmt.Errorf("%v", r))
|
||
}
|
||
}()
|
||
|
||
atomic.AddInt32(&s.handlerMsgNum, 1)
|
||
defer atomic.AddInt32(&s.handlerMsgNum, -1)
|
||
|
||
// 心跳请求,直接处理返回
|
||
if req.IsHeartbeat() {
|
||
s.Plugins.DoHeartbeatRequest(ctx, req)
|
||
req.SetMessageType(protocol.Response)
|
||
data := req.EncodeSlicePointer()
|
||
|
||
if s.writeTimeout != 0 {
|
||
conn.SetWriteDeadline(time.Now().Add(s.writeTimeout))
|
||
}
|
||
conn.Write(*data)
|
||
|
||
protocol.PutData(data)
|
||
|
||
return
|
||
}
|
||
|
||
cancelFunc := parseServerTimeout(ctx, req)
|
||
if cancelFunc != nil {
|
||
defer cancelFunc()
|
||
}
|
||
|
||
resMetadata := make(map[string]string)
|
||
if req.Metadata == nil {
|
||
req.Metadata = make(map[string]string)
|
||
}
|
||
ctx = share.WithLocalValue(share.WithLocalValue(ctx, share.ReqMetaDataKey, req.Metadata),
|
||
share.ResMetaDataKey, resMetadata)
|
||
|
||
s.Plugins.DoPreHandleRequest(ctx, req)
|
||
|
||
if share.Trace {
|
||
log.Debugf("server handle request %+v from conn: %v", req, conn.RemoteAddr().String())
|
||
}
|
||
|
||
// use handlers first
|
||
if handler, ok := s.router[req.ServicePath+"."+req.ServiceMethod]; ok {
|
||
sctx := NewContext(ctx, conn, req, s.AsyncWrite)
|
||
err := handler(sctx)
|
||
if err != nil {
|
||
if s.HandleServiceError != nil {
|
||
s.HandleServiceError(err)
|
||
} else {
|
||
log.Errorf("[handler internal error]: servicepath: %s, servicemethod, err: %v", req.ServicePath, req.ServiceMethod, err)
|
||
}
|
||
sctx.WriteError(err)
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
res, err := s.handleRequest(ctx, req)
|
||
if err != nil {
|
||
if s.HandleServiceError != nil {
|
||
s.HandleServiceError(err)
|
||
} else {
|
||
log.Warnf("rpcx: failed to handle request: %v", err)
|
||
}
|
||
}
|
||
|
||
if !req.IsOneway() {
|
||
if len(resMetadata) > 0 { // copy meta in context to responses
|
||
meta := res.Metadata
|
||
if meta == nil {
|
||
res.Metadata = resMetadata
|
||
} else {
|
||
for k, v := range resMetadata {
|
||
if meta[k] == "" {
|
||
meta[k] = v
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
s.sendResponse(ctx, conn, err, req, res)
|
||
}
|
||
|
||
if share.Trace {
|
||
log.Debugf("server write response %+v for an request %+v from conn: %v", res, req, conn.RemoteAddr().String())
|
||
}
|
||
}
|
||
|
||
func parseServerTimeout(ctx *share.Context, req *protocol.Message) context.CancelFunc {
|
||
if req == nil || req.Metadata == nil {
|
||
return nil
|
||
}
|
||
|
||
st := req.Metadata[share.ServerTimeout]
|
||
if st == "" {
|
||
return nil
|
||
}
|
||
|
||
timeout, err := strconv.ParseInt(st, 10, 64)
|
||
if err != nil {
|
||
return nil
|
||
}
|
||
|
||
newCtx, cancel := context.WithTimeout(ctx.Context, time.Duration(timeout)*time.Millisecond)
|
||
ctx.Context = newCtx
|
||
return cancel
|
||
}
|
||
|
||
func (s *Server) isShutdown() bool {
|
||
return atomic.LoadInt32(&s.inShutdown) == 1
|
||
}
|
||
|
||
func (s *Server) closeConn(conn net.Conn) {
|
||
s.mu.Lock()
|
||
delete(s.activeConn, conn)
|
||
s.mu.Unlock()
|
||
|
||
conn.Close()
|
||
|
||
s.Plugins.DoPostConnClose(conn)
|
||
}
|
||
|
||
func (s *Server) readRequest(ctx context.Context, r io.Reader) (req *protocol.Message, err error) {
|
||
err = s.Plugins.DoPreReadRequest(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
// pool req?
|
||
req = protocol.NewMessage()
|
||
err = req.Decode(r)
|
||
if err == io.EOF {
|
||
return req, err
|
||
}
|
||
perr := s.Plugins.DoPostReadRequest(ctx, req, err)
|
||
if err == nil {
|
||
err = perr
|
||
}
|
||
return req, err
|
||
}
|
||
|
||
func (s *Server) auth(ctx context.Context, req *protocol.Message) error {
|
||
if s.AuthFunc != nil {
|
||
token := req.Metadata[share.AuthKey]
|
||
return s.AuthFunc(ctx, req, token)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (s *Server) handleRequest(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
|
||
serviceName := req.ServicePath
|
||
methodName := req.ServiceMethod
|
||
|
||
res = req.Clone()
|
||
|
||
res.SetMessageType(protocol.Response)
|
||
s.serviceMapMu.RLock()
|
||
service := s.serviceMap[serviceName]
|
||
|
||
if share.Trace {
|
||
log.Debugf("server get service %+v for an request %+v", service, req)
|
||
}
|
||
|
||
s.serviceMapMu.RUnlock()
|
||
if service == nil {
|
||
err = errors.New("rpcx: can't find service " + serviceName)
|
||
return s.handleError(res, err)
|
||
}
|
||
mtype := service.method[methodName]
|
||
if mtype == nil {
|
||
if service.function[methodName] != nil { // check raw functions
|
||
return s.handleRequestForFunction(ctx, req)
|
||
}
|
||
err = errors.New("rpcx: can't find method " + methodName)
|
||
return s.handleError(res, err)
|
||
}
|
||
|
||
// get a argv object from object pool
|
||
argv := reflectTypePools.Get(mtype.ArgType)
|
||
|
||
codec := share.Codecs[req.SerializeType()]
|
||
if codec == nil {
|
||
err = fmt.Errorf("can not find codec for %d", req.SerializeType())
|
||
return s.handleError(res, err)
|
||
}
|
||
|
||
err = codec.Decode(req.Payload, argv)
|
||
if err != nil {
|
||
return s.handleError(res, err)
|
||
}
|
||
|
||
// and get a reply object from object pool
|
||
replyv := reflectTypePools.Get(mtype.ReplyType)
|
||
|
||
argv, err = s.Plugins.DoPreCall(ctx, serviceName, methodName, argv)
|
||
if err != nil {
|
||
// return reply to object pool
|
||
reflectTypePools.Put(mtype.ReplyType, replyv)
|
||
return s.handleError(res, err)
|
||
}
|
||
|
||
if mtype.ArgType.Kind() != reflect.Ptr {
|
||
err = service.call(ctx, mtype, reflect.ValueOf(argv).Elem(), reflect.ValueOf(replyv))
|
||
} else {
|
||
err = service.call(ctx, mtype, reflect.ValueOf(argv), reflect.ValueOf(replyv))
|
||
}
|
||
|
||
replyv, err1 := s.Plugins.DoPostCall(ctx, serviceName, methodName, argv, replyv, err)
|
||
if err == nil {
|
||
err = err1
|
||
}
|
||
|
||
// return argc to object pool
|
||
reflectTypePools.Put(mtype.ArgType, argv)
|
||
|
||
if err != nil {
|
||
if replyv != nil {
|
||
data, err := codec.Encode(replyv)
|
||
// return reply to object pool
|
||
reflectTypePools.Put(mtype.ReplyType, replyv)
|
||
if err != nil {
|
||
return s.handleError(res, err)
|
||
}
|
||
res.Payload = data
|
||
}
|
||
return s.handleError(res, err)
|
||
}
|
||
|
||
if !req.IsOneway() {
|
||
data, err := codec.Encode(replyv)
|
||
// return reply to object pool
|
||
reflectTypePools.Put(mtype.ReplyType, replyv)
|
||
if err != nil {
|
||
return s.handleError(res, err)
|
||
}
|
||
res.Payload = data
|
||
} else if replyv != nil {
|
||
reflectTypePools.Put(mtype.ReplyType, replyv)
|
||
}
|
||
|
||
if share.Trace {
|
||
log.Debugf("server called service %+v for an request %+v", service, req)
|
||
}
|
||
|
||
return res, nil
|
||
}
|
||
|
||
func (s *Server) handleRequestForFunction(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
|
||
res = req.Clone()
|
||
|
||
res.SetMessageType(protocol.Response)
|
||
|
||
serviceName := req.ServicePath
|
||
methodName := req.ServiceMethod
|
||
s.serviceMapMu.RLock()
|
||
service := s.serviceMap[serviceName]
|
||
s.serviceMapMu.RUnlock()
|
||
if service == nil {
|
||
err = errors.New("rpcx: can't find service for func raw function")
|
||
return s.handleError(res, err)
|
||
}
|
||
mtype := service.function[methodName]
|
||
if mtype == nil {
|
||
err = errors.New("rpcx: can't find method " + methodName)
|
||
return s.handleError(res, err)
|
||
}
|
||
|
||
argv := reflectTypePools.Get(mtype.ArgType)
|
||
|
||
codec := share.Codecs[req.SerializeType()]
|
||
if codec == nil {
|
||
err = fmt.Errorf("can not find codec for %d", req.SerializeType())
|
||
return s.handleError(res, err)
|
||
}
|
||
|
||
err = codec.Decode(req.Payload, argv)
|
||
if err != nil {
|
||
return s.handleError(res, err)
|
||
}
|
||
|
||
replyv := reflectTypePools.Get(mtype.ReplyType)
|
||
argv, err = s.Plugins.DoPreCall(ctx, serviceName, methodName, argv)
|
||
if err != nil {
|
||
// return reply to object pool
|
||
reflectTypePools.Put(mtype.ReplyType, replyv)
|
||
return s.handleError(res, err)
|
||
}
|
||
|
||
if mtype.ArgType.Kind() != reflect.Ptr {
|
||
err = service.callForFunction(ctx, mtype, reflect.ValueOf(argv).Elem(), reflect.ValueOf(replyv))
|
||
} else {
|
||
err = service.callForFunction(ctx, mtype, reflect.ValueOf(argv), reflect.ValueOf(replyv))
|
||
}
|
||
|
||
replyv, err1 := s.Plugins.DoPostCall(ctx, serviceName, methodName, argv, replyv, err)
|
||
if err == nil {
|
||
err = err1
|
||
}
|
||
|
||
reflectTypePools.Put(mtype.ArgType, argv)
|
||
if err != nil {
|
||
reflectTypePools.Put(mtype.ReplyType, replyv)
|
||
return s.handleError(res, err)
|
||
}
|
||
|
||
if !req.IsOneway() {
|
||
data, err := codec.Encode(replyv)
|
||
reflectTypePools.Put(mtype.ReplyType, replyv)
|
||
if err != nil {
|
||
return s.handleError(res, err)
|
||
}
|
||
res.Payload = data
|
||
} else if replyv != nil {
|
||
reflectTypePools.Put(mtype.ReplyType, replyv)
|
||
}
|
||
|
||
return res, nil
|
||
}
|
||
|
||
func (s *Server) handleError(res *protocol.Message, err error) (*protocol.Message, error) {
|
||
res.SetMessageStatusType(protocol.Error)
|
||
if res.Metadata == nil {
|
||
res.Metadata = make(map[string]string)
|
||
}
|
||
|
||
if s.ServerErrorFunc != nil {
|
||
res.Metadata[protocol.ServiceError] = s.ServerErrorFunc(res, err)
|
||
} else {
|
||
res.Metadata[protocol.ServiceError] = err.Error()
|
||
}
|
||
|
||
return res, err
|
||
}
|
||
|
||
// Can connect to RPC service using HTTP CONNECT to rpcPath.
|
||
var connected = "200 Connected to rpcx"
|
||
|
||
// ServeHTTP implements an http.Handler that answers RPC requests.
|
||
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||
if req.Method != http.MethodConnect {
|
||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||
io.WriteString(w, "405 must CONNECT\n")
|
||
return
|
||
}
|
||
conn, _, err := w.(http.Hijacker).Hijack()
|
||
if err != nil {
|
||
log.Info("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
|
||
return
|
||
}
|
||
io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
|
||
|
||
s.mu.Lock()
|
||
s.activeConn[conn] = struct{}{}
|
||
s.mu.Unlock()
|
||
|
||
s.serveConn(conn)
|
||
}
|
||
|
||
func (s *Server) ServeWS(conn *websocket.Conn) {
|
||
s.mu.Lock()
|
||
s.activeConn[conn] = struct{}{}
|
||
s.mu.Unlock()
|
||
|
||
conn.PayloadType = websocket.BinaryFrame
|
||
s.serveConn(conn)
|
||
}
|
||
|
||
// Close immediately closes all active net.Listeners.
|
||
func (s *Server) Close() error {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
|
||
var err error
|
||
if s.ln != nil {
|
||
err = s.ln.Close()
|
||
}
|
||
for c := range s.activeConn {
|
||
c.Close()
|
||
delete(s.activeConn, c)
|
||
s.Plugins.DoPostConnClose(c)
|
||
}
|
||
s.closeDoneChanLocked()
|
||
|
||
if s.pool != nil {
|
||
s.pool.StopAndWaitFor(10 * time.Second)
|
||
}
|
||
|
||
return err
|
||
}
|
||
|
||
// RegisterOnShutdown registers a function to call on Shutdown.
|
||
// This can be used to gracefully shutdown connections.
|
||
func (s *Server) RegisterOnShutdown(f func(s *Server)) {
|
||
s.mu.Lock()
|
||
s.onShutdown = append(s.onShutdown, f)
|
||
s.mu.Unlock()
|
||
}
|
||
|
||
// RegisterOnRestart registers a function to call on Restart.
|
||
func (s *Server) RegisterOnRestart(f func(s *Server)) {
|
||
s.mu.Lock()
|
||
s.onRestart = append(s.onRestart, f)
|
||
s.mu.Unlock()
|
||
}
|
||
|
||
var shutdownPollInterval = 1000 * time.Millisecond
|
||
|
||
// Shutdown gracefully shuts down the server without interrupting any
|
||
// active connections. Shutdown works by first closing the
|
||
// listener, then closing all idle connections, and then waiting
|
||
// indefinitely for connections to return to idle and then shut down.
|
||
// If the provided context expires before the shutdown is complete,
|
||
// Shutdown returns the context's error, otherwise it returns any
|
||
// error returned from closing the Server's underlying Listener.
|
||
func (s *Server) Shutdown(ctx context.Context) error {
|
||
var err error
|
||
if atomic.CompareAndSwapInt32(&s.inShutdown, 0, 1) {
|
||
log.Info("shutdown begin")
|
||
|
||
s.mu.Lock()
|
||
|
||
// 主动注销注册的服务
|
||
s.UnregisterAll()
|
||
|
||
if s.ln != nil {
|
||
s.ln.Close()
|
||
}
|
||
for conn := range s.activeConn {
|
||
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
||
tcpConn.CloseRead()
|
||
}
|
||
}
|
||
s.mu.Unlock()
|
||
|
||
// wait all in-processing requests finish.
|
||
ticker := time.NewTicker(shutdownPollInterval)
|
||
defer ticker.Stop()
|
||
outer:
|
||
for {
|
||
if s.checkProcessMsg() {
|
||
break
|
||
}
|
||
select {
|
||
case <-ctx.Done():
|
||
err = ctx.Err()
|
||
break outer
|
||
case <-ticker.C:
|
||
}
|
||
}
|
||
|
||
s.jsonrpcHTTPServerLock.Lock()
|
||
if s.gatewayHTTPServer != nil {
|
||
if err := s.closeHTTP1APIGateway(ctx); err != nil {
|
||
log.Warnf("failed to close gateway: %v", err)
|
||
} else {
|
||
log.Info("closed gateway")
|
||
}
|
||
}
|
||
s.jsonrpcHTTPServerLock.Unlock()
|
||
|
||
if s.jsonrpcHTTPServer != nil {
|
||
if err := s.closeJSONRPC2(ctx); err != nil {
|
||
log.Warnf("failed to close JSONRPC: %v", err)
|
||
} else {
|
||
log.Info("closed JSONRPC")
|
||
}
|
||
}
|
||
|
||
s.mu.Lock()
|
||
for conn := range s.activeConn {
|
||
conn.Close()
|
||
delete(s.activeConn, conn)
|
||
s.Plugins.DoPostConnClose(conn)
|
||
}
|
||
s.closeDoneChanLocked()
|
||
|
||
s.mu.Unlock()
|
||
|
||
log.Info("shutdown end")
|
||
|
||
}
|
||
return err
|
||
}
|
||
|
||
// Restart restarts this server gracefully.
|
||
// It starts a new rpcx server with the same port with SO_REUSEPORT socket option,
|
||
// and shutdown this rpcx server gracefully.
|
||
func (s *Server) Restart(ctx context.Context) error {
|
||
pid, err := s.startProcess()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
log.Infof("restart a new rpcx server: %d", pid)
|
||
|
||
// TODO: is it necessary?
|
||
time.Sleep(3 * time.Second)
|
||
return s.Shutdown(ctx)
|
||
}
|
||
|
||
func (s *Server) startProcess() (int, error) {
|
||
argv0, err := exec.LookPath(os.Args[0])
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
|
||
// Pass on the environment and replace the old count key with the new one.
|
||
var env []string
|
||
env = append(env, os.Environ()...)
|
||
|
||
originalWD, _ := os.Getwd()
|
||
allFiles := []*os.File{os.Stdin, os.Stdout, os.Stderr}
|
||
process, err := os.StartProcess(argv0, os.Args, &os.ProcAttr{
|
||
Dir: originalWD,
|
||
Env: env,
|
||
Files: allFiles,
|
||
})
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return process.Pid, nil
|
||
}
|
||
|
||
func (s *Server) checkProcessMsg() bool {
|
||
size := atomic.LoadInt32(&s.handlerMsgNum)
|
||
log.Info("need handle in-processing msg size:", size)
|
||
return size == 0
|
||
}
|
||
|
||
func (s *Server) closeDoneChanLocked() {
|
||
select {
|
||
case <-s.doneChan:
|
||
// Already closed. Don't close again.
|
||
default:
|
||
// Safe to close here. We're the only closer, guarded
|
||
// by s.mu.RegisterName
|
||
close(s.doneChan)
|
||
}
|
||
}
|
||
|
||
var ip4Reg = regexp.MustCompile(`^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$`)
|
||
|
||
func validIP4(ipAddress string) bool {
|
||
ipAddress = strings.Trim(ipAddress, " ")
|
||
i := strings.LastIndex(ipAddress, ":")
|
||
ipAddress = ipAddress[:i] // remove port
|
||
|
||
return ip4Reg.MatchString(ipAddress)
|
||
}
|
||
|
||
func validIP6(ipAddress string) bool {
|
||
ipAddress = strings.Trim(ipAddress, " ")
|
||
i := strings.LastIndex(ipAddress, ":")
|
||
ipAddress = ipAddress[:i] // remove port
|
||
ipAddress = strings.TrimPrefix(ipAddress, "[")
|
||
ipAddress = strings.TrimSuffix(ipAddress, "]")
|
||
ip := net.ParseIP(ipAddress)
|
||
if ip != nil && ip.To4() == nil {
|
||
return true
|
||
} else {
|
||
return false
|
||
}
|
||
}
|