Files
rpcx/server/server.go
goodjava@qq.com d2ad510535 [+]优雅关机
2018-06-29 10:15:20 +08:00

707 lines
16 KiB
Go

package server
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"reflect"
"regexp"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/smallnest/rpcx/log"
"github.com/smallnest/rpcx/protocol"
"github.com/smallnest/rpcx/share"
"os"
"os/signal"
"syscall"
)
// ErrServerClosed is returned by the Server's Serve, ListenAndServe after a call to Shutdown or Close.
var ErrServerClosed = errors.New("http: Server closed")
const (
// ReaderBuffsize is used for bufio reader.
ReaderBuffsize = 1024
// WriterBuffsize is used for bufio writer.
WriterBuffsize = 1024
)
// contextKey is a value for use with context.WithValue. It's used as
// a pointer so it fits in an interface{} without allocation.
type contextKey struct {
name string
}
func (k *contextKey) String() string { return "rpcx context value " + k.name }
var (
// RemoteConnContextKey is a context key. It can be used in
// services with context.WithValue to access the connection arrived on.
// The associated value will be of type net.Conn.
RemoteConnContextKey = &contextKey{"remote-conn"}
// StartRequestContextKey records the start time
StartRequestContextKey = &contextKey{"start-parse-request"}
// StartSendRequestContextKey records the start time
StartSendRequestContextKey = &contextKey{"start-send-request"}
)
// Server is rpcx server that use TCP or UDP.
type Server struct {
ln net.Listener
readTimeout time.Duration
writeTimeout time.Duration
serviceMapMu sync.RWMutex
serviceMap map[string]*service
mu sync.RWMutex
activeConn map[net.Conn]struct{}
doneChan chan struct{}
seq uint64
inShutdown int32
onShutdown []func()
// TLSConfig for creating tls tcp connection.
tlsConfig *tls.Config
// BlockCrypt for kcp.BlockCrypt
options map[string]interface{}
// // use for KCP
// KCPConfig KCPConfig
// // for QUIC
// QUICConfig QUICConfig
Plugins PluginContainer
// AuthFunc can be used to auth.
AuthFunc func(ctx context.Context, req *protocol.Message, token string) error
ShutdownFunc func(s *Server)
HandleMsgChan chan struct{}
}
// NewServer returns a server.
func NewServer(options ...OptionFn) *Server {
s := &Server{
Plugins: &pluginContainer{},
options: make(map[string]interface{}),
}
s.HandleMsgChan = make(chan struct{}, 100000)
for _, op := range options {
op(s)
}
return s
}
// Address returns listened address.
func (s *Server) Address() net.Addr {
s.mu.RLock()
defer s.mu.RUnlock()
if s.ln == nil {
return nil
}
return s.ln.Addr()
}
// SendMessage a request to the specified client.
// The client is designated by the conn.
// conn can be gotten from context in services:
//
// ctx.Value(RemoteConnContextKey)
//
// servicePath, serviceMethod, metadata can be set to zero values.
func (s *Server) SendMessage(conn net.Conn, servicePath, serviceMethod string, metadata map[string]string, data []byte) error {
ctx := context.WithValue(context.Background(), StartSendRequestContextKey, time.Now().UnixNano())
s.Plugins.DoPreWriteRequest(ctx)
req := protocol.GetPooledMsg()
req.SetMessageType(protocol.Request)
seq := atomic.AddUint64(&s.seq, 1)
req.SetSeq(seq)
req.SetOneway(true)
req.SetSerializeType(protocol.SerializeNone)
req.ServicePath = servicePath
req.ServiceMethod = serviceMethod
req.Metadata = metadata
req.Payload = data
reqData := req.Encode()
_, err := conn.Write(reqData)
s.Plugins.DoPostWriteRequest(ctx, req, err)
protocol.FreeMsg(req)
return err
}
func (s *Server) getDoneChan() <-chan struct{} {
s.mu.Lock()
defer s.mu.Unlock()
if s.doneChan == nil {
s.doneChan = make(chan struct{})
}
return s.doneChan
}
func (s *Server)startShutdownListener() {
go func(s *Server) {
log.Info("server pid:", os.Getpid())
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGTERM)
si := <-c
if si.String() == "terminated" {
if nil != s.ShutdownFunc {
s.ShutdownFunc(s)
}
os.Exit(0)
}
}(s)
}
// Serve starts and listens RPC requests.
// It is blocked until receiving connectings from clients.
func (s *Server) Serve(network, address string) (err error) {
s.startShutdownListener()
var ln net.Listener
ln, err = s.makeListener(network, address)
if err != nil {
return
}
if network == "http" {
s.serveByHTTP(ln, "")
return nil
}
// try to start gateway
ln = s.startGateway(network, ln)
return s.serveListener(ln)
}
// serveListener accepts incoming connections on the Listener ln,
// creating a new service goroutine for each.
// The service goroutines read requests and then call services to reply to them.
func (s *Server) serveListener(ln net.Listener) error {
if s.Plugins == nil {
s.Plugins = &pluginContainer{}
}
var tempDelay time.Duration
s.mu.Lock()
s.ln = ln
if s.activeConn == nil {
s.activeConn = make(map[net.Conn]struct{})
}
s.mu.Unlock()
for {
conn, e := ln.Accept()
if e != nil {
select {
case <-s.getDoneChan():
return ErrServerClosed
default:
}
if ne, ok := e.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
} else {
tempDelay *= 2
}
if max := 1 * time.Second; tempDelay > max {
tempDelay = max
}
log.Errorf("rpcx: Accept error: %v; retrying in %v", e, tempDelay)
time.Sleep(tempDelay)
continue
}
return e
}
tempDelay = 0
if tc, ok := conn.(*net.TCPConn); ok {
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute)
tc.SetLinger(10)
}
s.mu.Lock()
s.activeConn[conn] = struct{}{}
s.mu.Unlock()
conn, ok := s.Plugins.DoPostConnAccept(conn)
if !ok {
continue
}
go s.serveConn(conn)
}
}
// serveByHTTP serves by HTTP.
// if rpcPath is an empty string, use share.DefaultRPCPath.
func (s *Server) serveByHTTP(ln net.Listener, rpcPath string) {
s.ln = ln
if s.Plugins == nil {
s.Plugins = &pluginContainer{}
}
if rpcPath == "" {
rpcPath = share.DefaultRPCPath
}
http.Handle(rpcPath, s)
srv := &http.Server{Handler: nil}
s.mu.Lock()
if s.activeConn == nil {
s.activeConn = make(map[net.Conn]struct{})
}
s.mu.Unlock()
srv.Serve(ln)
}
func (s *Server) serveConn(conn net.Conn) {
defer func() {
if err := recover(); err != nil {
const size = 64 << 10
buf := make([]byte, size)
ss := runtime.Stack(buf, false)
if ss > size {
ss = size
}
buf = buf[:ss]
log.Errorf("serving %s panic error: %s, stack:\n %s", conn.RemoteAddr(), err, buf)
}
s.mu.Lock()
delete(s.activeConn, conn)
s.mu.Unlock()
conn.Close()
if s.Plugins == nil {
s.Plugins = &pluginContainer{}
}
s.Plugins.DoPostConnClose(conn)
}()
if isShutdown(s) {
closeChannel(s,conn)
return
}
if tlsConn, ok := conn.(*tls.Conn); ok {
if d := s.readTimeout; d != 0 {
conn.SetReadDeadline(time.Now().Add(d))
}
if d := s.writeTimeout; d != 0 {
conn.SetWriteDeadline(time.Now().Add(d))
}
if err := tlsConn.Handshake(); err != nil {
log.Errorf("rpcx: TLS handshake error from %s: %v", conn.RemoteAddr(), err)
return
}
}
r := bufio.NewReaderSize(conn, ReaderBuffsize)
for {
if isShutdown(s) {
closeChannel(s,conn)
return
}
t0 := time.Now()
if s.readTimeout != 0 {
conn.SetReadDeadline(t0.Add(s.readTimeout))
}
ctx := context.WithValue(context.Background(), RemoteConnContextKey, conn)
req, err := s.readRequest(ctx, r)
if err != nil {
if err == io.EOF {
log.Infof("client has closed this connection: %s", conn.RemoteAddr().String())
} else if strings.Contains(err.Error(), "use of closed network connection") {
log.Infof("rpcx: connection %s is closed", conn.RemoteAddr().String())
} else {
log.Warnf("rpcx: failed to read request: %v", err)
}
return
}
s.HandleMsgChan <- struct{}{}
if s.writeTimeout != 0 {
conn.SetWriteDeadline(t0.Add(s.writeTimeout))
}
ctx = context.WithValue(ctx, StartRequestContextKey, time.Now().UnixNano())
err = s.auth(ctx, req)
if err != nil {
if !req.IsOneway() {
res := req.Clone()
res.SetMessageType(protocol.Response)
if len(res.Payload) > 1024 && req.CompressType() != protocol.None {
res.SetCompressType(req.CompressType())
}
handleError(res, err)
data := res.Encode()
s.Plugins.DoPreWriteResponse(ctx, req, res)
conn.Write(data)
s.Plugins.DoPostWriteResponse(ctx, req, res, err)
protocol.FreeMsg(res)
} else {
s.Plugins.DoPreWriteResponse(ctx, req, nil)
}
<-s.HandleMsgChan
protocol.FreeMsg(req)
continue
}
go func() {
defer func(){
<-s.HandleMsgChan
}()
if req.IsHeartbeat() {
req.SetMessageType(protocol.Response)
data := req.Encode()
conn.Write(data)
return
}
resMetadata := make(map[string]string)
newCtx := context.WithValue(context.WithValue(ctx, share.ReqMetaDataKey, req.Metadata),
share.ResMetaDataKey, resMetadata)
res, err := s.handleRequest(newCtx, req)
if err != nil {
log.Warnf("rpcx: failed to handle request: %v", err)
}
s.Plugins.DoPreWriteResponse(newCtx, req, res)
if !req.IsOneway() {
if len(resMetadata) > 0 { //copy meta in context to request
meta := res.Metadata
if meta == nil {
res.Metadata = resMetadata
} else {
for k, v := range resMetadata {
meta[k] = v
}
}
}
if len(res.Payload) > 1024 && req.CompressType() != protocol.None {
res.SetCompressType(req.CompressType())
}
data := res.Encode()
conn.Write(data)
//res.WriteTo(conn)
}
s.Plugins.DoPostWriteResponse(newCtx, req, res, err)
protocol.FreeMsg(req)
protocol.FreeMsg(res)
}()
}
}
func isShutdown(s *Server) (bool) {
return atomic.LoadInt32(&s.inShutdown) == 1
}
func closeChannel(s *Server,conn net.Conn) {
s.mu.Lock()
delete(s.activeConn, conn)
s.mu.Unlock()
conn.Close()
}
func (s *Server) readRequest(ctx context.Context, r io.Reader) (req *protocol.Message, err error) {
err = s.Plugins.DoPreReadRequest(ctx)
if err != nil {
return nil, err
}
// pool req?
req = protocol.GetPooledMsg()
err = req.Decode(r)
perr := s.Plugins.DoPostReadRequest(ctx, req, err)
if err == nil {
err = perr
}
return req, err
}
func (s *Server) auth(ctx context.Context, req *protocol.Message) error {
if s.AuthFunc != nil {
token := req.Metadata[share.AuthKey]
return s.AuthFunc(ctx, req, token)
}
return nil
}
func (s *Server) handleRequest(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
serviceName := req.ServicePath
methodName := req.ServiceMethod
res = req.Clone()
res.SetMessageType(protocol.Response)
s.serviceMapMu.RLock()
service := s.serviceMap[serviceName]
s.serviceMapMu.RUnlock()
if service == nil {
err = errors.New("rpcx: can't find service " + serviceName)
return handleError(res, err)
}
mtype := service.method[methodName]
if mtype == nil {
if service.function[methodName] != nil { //check raw functions
return s.handleRequestForFunction(ctx, req)
}
err = errors.New("rpcx: can't find method " + methodName)
return handleError(res, err)
}
var argv = argsReplyPools.Get(mtype.ArgType)
codec := share.Codecs[req.SerializeType()]
if codec == nil {
err = fmt.Errorf("can not find codec for %d", req.SerializeType())
return handleError(res, err)
}
err = codec.Decode(req.Payload, argv)
if err != nil {
return handleError(res, err)
}
replyv := argsReplyPools.Get(mtype.ReplyType)
if mtype.ArgType.Kind() != reflect.Ptr {
err = service.call(ctx, mtype, reflect.ValueOf(argv).Elem(), reflect.ValueOf(replyv))
} else {
err = service.call(ctx, mtype, reflect.ValueOf(argv), reflect.ValueOf(replyv))
}
argsReplyPools.Put(mtype.ArgType, argv)
if err != nil {
argsReplyPools.Put(mtype.ReplyType, replyv)
return handleError(res, err)
}
if !req.IsOneway() {
data, err := codec.Encode(replyv)
argsReplyPools.Put(mtype.ReplyType, replyv)
if err != nil {
return handleError(res, err)
}
res.Payload = data
}
return res, nil
}
func (s *Server) handleRequestForFunction(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
res = req.Clone()
res.SetMessageType(protocol.Response)
serviceName := req.ServicePath
methodName := req.ServiceMethod
s.serviceMapMu.RLock()
service := s.serviceMap[serviceName]
s.serviceMapMu.RUnlock()
if service == nil {
err = errors.New("rpcx: can't find service for func raw function")
return handleError(res, err)
}
mtype := service.function[methodName]
if mtype == nil {
err = errors.New("rpcx: can't find method " + methodName)
return handleError(res, err)
}
var argv = argsReplyPools.Get(mtype.ArgType)
codec := share.Codecs[req.SerializeType()]
if codec == nil {
err = fmt.Errorf("can not find codec for %d", req.SerializeType())
return handleError(res, err)
}
err = codec.Decode(req.Payload, argv)
if err != nil {
return handleError(res, err)
}
replyv := argsReplyPools.Get(mtype.ReplyType)
err = service.callForFunction(ctx, mtype, reflect.ValueOf(argv), reflect.ValueOf(replyv))
argsReplyPools.Put(mtype.ArgType, argv)
if err != nil {
argsReplyPools.Put(mtype.ReplyType, replyv)
return handleError(res, err)
}
if !req.IsOneway() {
data, err := codec.Encode(replyv)
argsReplyPools.Put(mtype.ReplyType, replyv)
if err != nil {
return handleError(res, err)
}
res.Payload = data
}
return res, nil
}
func handleError(res *protocol.Message, err error) (*protocol.Message, error) {
res.SetMessageStatusType(protocol.Error)
if res.Metadata == nil {
res.Metadata = make(map[string]string)
}
res.Metadata[protocol.ServiceError] = err.Error()
return res, err
}
// Can connect to RPC service using HTTP CONNECT to rpcPath.
var connected = "200 Connected to rpcx"
// ServeHTTP implements an http.Handler that answers RPC requests.
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != "CONNECT" {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusMethodNotAllowed)
io.WriteString(w, "405 must CONNECT\n")
return
}
conn, _, err := w.(http.Hijacker).Hijack()
if err != nil {
log.Info("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
return
}
io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
s.mu.Lock()
s.activeConn[conn] = struct{}{}
s.mu.Unlock()
s.serveConn(conn)
}
// Close immediately closes all active net.Listeners.
func (s *Server) Close() error {
s.closeDoneChanLocked()
var err error
if s.ln != nil {
err = s.ln.Close()
}
for c := range s.activeConn {
c.Close()
delete(s.activeConn, c)
s.Plugins.DoPostConnClose(c)
}
return err
}
// RegisterOnShutdown registers a function to call on Shutdown.
// This can be used to gracefully shutdown connections.
func (s *Server) RegisterOnShutdown(f func()) {
s.mu.Lock()
s.onShutdown = append(s.onShutdown, f)
s.mu.Unlock()
}
var shutdownPollInterval = 500 * time.Millisecond
// // Shutdown gracefully shuts down the server without interrupting any
// // active connections. Shutdown works by first closing the
// // listener, then closing all idle connections, and then waiting
// // indefinitely for connections to return to idle and then shut down.
// // If the provided context expires before the shutdown is complete,
// // Shutdown returns the context's error, otherwise it returns any
// // error returned from closing the Server's underlying Listener.
func (s *Server) Shutdown(ctx context.Context) error {
if atomic.CompareAndSwapInt32(&s.inShutdown, 0, 1) {
log.Info("shutdown begin")
ticker := time.NewTicker(shutdownPollInterval)
defer ticker.Stop()
for {
if s.checkProcessMsg() {
break
}
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
}
s.Close()
log.Info("shutdown end")
}
return nil
}
func (s *Server) checkProcessMsg() (bool) {
size := len(s.HandleMsgChan)
log.Info("need handle msg size:",size)
if size == 0 {
return true
}
return false
}
func (s *Server) closeDoneChanLocked() {
ch := s.getDoneChanLocked()
select {
case <-ch:
// Already closed. Don't close again.
default:
// Safe to close here. We're the only closer, guarded
// by s.mu.
close(ch)
}
}
func (s *Server) getDoneChanLocked() chan struct{} {
if s.doneChan == nil {
s.doneChan = make(chan struct{})
}
return s.doneChan
}
var ip4Reg = regexp.MustCompile(`^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$`)
func validIP4(ipAddress string) bool {
ipAddress = strings.Trim(ipAddress, " ")
i := strings.LastIndex(ipAddress, ":")
ipAddress = ipAddress[:i] //remove port
return ip4Reg.MatchString(ipAddress)
}