mirror of
https://github.com/eolinker/apinto
synced 2025-09-26 21:01:19 +08:00
528 lines
12 KiB
Go
528 lines
12 KiB
Go
/*
|
||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||
* contributor license agreements. See the NOTICE file distributed with
|
||
* this work for additional information regarding copyright ownership.
|
||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||
* (the "License"); you may not use this file except in compliance with
|
||
* the License. You may obtain a copy of the License at
|
||
*
|
||
* http://www.apache.org/licenses/LICENSE-2.0
|
||
*
|
||
* Unless required by applicable law or agreed to in writing, software
|
||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
* See the License for the specific language governing permissions and
|
||
* limitations under the License.
|
||
*/
|
||
|
||
package getty
|
||
|
||
import (
|
||
"context"
|
||
"crypto/tls"
|
||
"crypto/x509"
|
||
"errors"
|
||
"fmt"
|
||
"github.com/eolinker/eosc/log"
|
||
"net"
|
||
"net/http"
|
||
"os"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
import (
|
||
gxnet "github.com/dubbogo/gost/net"
|
||
gxsync "github.com/dubbogo/gost/sync"
|
||
gxtime "github.com/dubbogo/gost/time"
|
||
|
||
"github.com/gorilla/websocket"
|
||
|
||
perrors "github.com/pkg/errors"
|
||
|
||
uatomic "go.uber.org/atomic"
|
||
)
|
||
|
||
var (
|
||
errSelfConnect = perrors.New("connect self!")
|
||
serverFastFailTimeout = time.Second * 1
|
||
|
||
serverID uatomic.Int32
|
||
)
|
||
|
||
// Server interface
|
||
type Server interface {
|
||
EndPoint
|
||
}
|
||
|
||
// StreamServer is like tcp/websocket/wss server
|
||
type StreamServer interface {
|
||
Server
|
||
// Listener get the network listener
|
||
Listener() net.Listener
|
||
}
|
||
|
||
// PacketServer is like udp listen endpoint
|
||
type PacketServer interface {
|
||
Server
|
||
// PacketConn get the network listener
|
||
PacketConn() net.PacketConn
|
||
}
|
||
|
||
type server struct {
|
||
ServerOptions
|
||
|
||
// endpoint ID
|
||
endPointID EndPointID
|
||
|
||
// net
|
||
pktListener net.PacketConn
|
||
streamListener net.Listener
|
||
lock sync.Mutex // for server
|
||
endPointType EndPointType
|
||
server *http.Server // for ws or wss server
|
||
sync.Once
|
||
done chan struct{}
|
||
wg sync.WaitGroup
|
||
}
|
||
|
||
func (s *server) init(opts ...ServerOption) {
|
||
for _, opt := range opts {
|
||
opt(&(s.ServerOptions))
|
||
}
|
||
}
|
||
|
||
func newServer(t EndPointType, opts ...ServerOption) *server {
|
||
s := &server{
|
||
endPointID: serverID.Add(1),
|
||
endPointType: t,
|
||
done: make(chan struct{}),
|
||
}
|
||
|
||
s.init(opts...)
|
||
|
||
return s
|
||
}
|
||
|
||
// NewTCPServer builds a tcp server.
|
||
func NewTCPServer(opts ...ServerOption) Server {
|
||
return newServer(TCP_SERVER, opts...)
|
||
}
|
||
|
||
// NewUDPEndPoint builds a unconnected udp server.
|
||
func NewUDPEndPoint(opts ...ServerOption) Server {
|
||
return newServer(UDP_ENDPOINT, opts...)
|
||
}
|
||
|
||
// NewWSServer builds a websocket server.
|
||
func NewWSServer(opts ...ServerOption) Server {
|
||
return newServer(WS_SERVER, opts...)
|
||
}
|
||
|
||
// NewWSSServer builds a secure websocket server.
|
||
func NewWSSServer(opts ...ServerOption) Server {
|
||
s := newServer(WSS_SERVER, opts...)
|
||
|
||
if s.addr == "" || s.cert == "" || s.privateKey == "" {
|
||
panic(fmt.Sprintf("@addr:%s, @cert:%s, @privateKey:%s, @caCert:%s",
|
||
s.addr, s.cert, s.privateKey, s.caCert))
|
||
}
|
||
|
||
return s
|
||
}
|
||
|
||
func (s *server) ID() int32 {
|
||
return s.endPointID
|
||
}
|
||
|
||
func (s *server) EndPointType() EndPointType {
|
||
return s.endPointType
|
||
}
|
||
|
||
func (s *server) stop() {
|
||
select {
|
||
case <-s.done:
|
||
return
|
||
default:
|
||
s.Once.Do(func() {
|
||
close(s.done)
|
||
s.lock.Lock()
|
||
if s.server != nil {
|
||
ctx, cancel := context.WithTimeout(context.Background(), serverFastFailTimeout)
|
||
if err := s.server.Shutdown(ctx); err != nil {
|
||
// if the log output is "shutdown ctx: context deadline exceeded", it means that
|
||
// there are still some active connections.
|
||
log.Errorf("server shutdown ctx:%s error:%v", ctx, err)
|
||
}
|
||
cancel()
|
||
}
|
||
s.server = nil
|
||
s.lock.Unlock()
|
||
if s.streamListener != nil {
|
||
// let the server exit asap when got error from RunEventLoop.
|
||
s.streamListener.Close()
|
||
s.streamListener = nil
|
||
}
|
||
if s.pktListener != nil {
|
||
s.pktListener.Close()
|
||
s.pktListener = nil
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func (s *server) GetTaskPool() gxsync.GenericTaskPool {
|
||
return s.tPool
|
||
}
|
||
|
||
func (s *server) IsClosed() bool {
|
||
select {
|
||
case <-s.done:
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
// net.ipv4.tcp_max_syn_backlog
|
||
// net.ipv4.tcp_timestamps
|
||
// net.ipv4.tcp_tw_recycle
|
||
func (s *server) listenTCP() error {
|
||
var (
|
||
err error
|
||
streamListener net.Listener
|
||
)
|
||
|
||
if s.listener == nil {
|
||
|
||
if len(s.addr) == 0 || !strings.Contains(s.addr, ":") {
|
||
streamListener, err = gxnet.ListenOnTCPRandomPort(s.addr)
|
||
if err != nil {
|
||
return perrors.Wrapf(err, "gxnet.ListenOnTCPRandomPort(addr:%s)", s.addr)
|
||
}
|
||
} else {
|
||
if s.sslEnabled {
|
||
if sslConfig, buildTlsConfErr := s.tlsConfigBuilder.BuildTlsConfig(); buildTlsConfErr == nil && sslConfig != nil {
|
||
streamListener, err = tls.Listen("tcp", s.addr, sslConfig)
|
||
}
|
||
} else {
|
||
streamListener, err = net.Listen("tcp", s.addr)
|
||
}
|
||
if err != nil {
|
||
return perrors.Wrapf(err, "net.Listen(tcp, addr:%s)", s.addr)
|
||
}
|
||
}
|
||
|
||
} else {
|
||
streamListener = s.listener
|
||
}
|
||
|
||
s.streamListener = streamListener
|
||
s.addr = s.streamListener.Addr().String()
|
||
|
||
return nil
|
||
}
|
||
|
||
func (s *server) listenUDP() error {
|
||
var (
|
||
err error
|
||
localAddr *net.UDPAddr
|
||
pktListener *net.UDPConn
|
||
)
|
||
|
||
if len(s.addr) == 0 || !strings.Contains(s.addr, ":") {
|
||
pktListener, err = gxnet.ListenOnUDPRandomPort(s.addr)
|
||
if err != nil {
|
||
return perrors.Wrapf(err, "gxnet.ListenOnUDPRandomPort(addr:%s)", s.addr)
|
||
}
|
||
} else {
|
||
localAddr, err = net.ResolveUDPAddr("udp", s.addr)
|
||
if err != nil {
|
||
return perrors.Wrapf(err, "net.ResolveUDPAddr(udp, addr:%s)", s.addr)
|
||
}
|
||
pktListener, err = net.ListenUDP("udp", localAddr)
|
||
if err != nil {
|
||
return perrors.Wrapf(err, "net.ListenUDP((udp, localAddr:%#v)", localAddr)
|
||
}
|
||
}
|
||
|
||
s.pktListener = pktListener
|
||
s.addr = s.pktListener.LocalAddr().String()
|
||
|
||
return nil
|
||
}
|
||
|
||
// Listen announces on the local network address.
|
||
func (s *server) listen() error {
|
||
switch s.endPointType {
|
||
case TCP_SERVER, WS_SERVER, WSS_SERVER:
|
||
return perrors.WithStack(s.listenTCP())
|
||
case UDP_ENDPOINT:
|
||
return perrors.WithStack(s.listenUDP())
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (s *server) accept(newSession NewSessionCallback) (Session, error) {
|
||
conn, err := s.streamListener.Accept()
|
||
if err != nil {
|
||
return nil, perrors.WithStack(err)
|
||
}
|
||
if gxnet.IsSameAddr(conn.RemoteAddr(), conn.LocalAddr()) {
|
||
log.Warnf("conn.localAddr{%s} == conn.RemoteAddr{%s}", conn.LocalAddr().String(), conn.RemoteAddr().String())
|
||
return nil, perrors.WithStack(errSelfConnect)
|
||
}
|
||
|
||
ss := newTCPSession(conn, s)
|
||
err = newSession(ss)
|
||
if err != nil {
|
||
conn.Close()
|
||
return nil, perrors.WithStack(err)
|
||
}
|
||
|
||
return ss, nil
|
||
}
|
||
|
||
func (s *server) runTCPEventLoop(newSession NewSessionCallback) {
|
||
s.wg.Add(1)
|
||
go func() {
|
||
defer s.wg.Done()
|
||
var (
|
||
err error
|
||
client Session
|
||
delay time.Duration
|
||
)
|
||
for {
|
||
if s.IsClosed() {
|
||
log.Infof("server{%s} stop accepting client connect request.", s.addr)
|
||
return
|
||
}
|
||
if delay != 0 {
|
||
<-gxtime.After(delay)
|
||
}
|
||
client, err = s.accept(newSession)
|
||
if err != nil {
|
||
var netErr net.Error
|
||
if errors.As(perrors.Cause(err), &netErr) && netErr.Timeout() {
|
||
if delay == 0 {
|
||
delay = 5 * time.Millisecond
|
||
} else {
|
||
delay *= 2
|
||
}
|
||
if max := 1 * time.Second; delay > max {
|
||
delay = max
|
||
}
|
||
continue
|
||
}
|
||
log.Warnf("server{%s}.Accept() = err {%+v}", s.addr, perrors.Cause(err))
|
||
break
|
||
}
|
||
delay = 0
|
||
client.(*session).run()
|
||
}
|
||
}()
|
||
}
|
||
|
||
func (s *server) runUDPEventLoop(newSession NewSessionCallback) {
|
||
s.wg.Add(1)
|
||
go func() {
|
||
defer s.wg.Done()
|
||
var (
|
||
err error
|
||
conn *net.UDPConn
|
||
ss Session
|
||
)
|
||
|
||
conn = s.pktListener.(*net.UDPConn)
|
||
ss = newUDPSession(conn, s)
|
||
if err = newSession(ss); err != nil {
|
||
conn.Close()
|
||
return
|
||
}
|
||
|
||
ss.(*session).run()
|
||
}()
|
||
}
|
||
|
||
type wsHandler struct {
|
||
http.ServeMux
|
||
server *server
|
||
newSession NewSessionCallback
|
||
upgrader websocket.Upgrader
|
||
}
|
||
|
||
func newWSHandler(server *server, newSession NewSessionCallback) *wsHandler {
|
||
return &wsHandler{
|
||
server: server,
|
||
newSession: newSession,
|
||
upgrader: websocket.Upgrader{
|
||
// in default, ReadBufferSize & WriteBufferSize is 4k
|
||
// HandshakeTimeout: server.HTTPTimeout,
|
||
CheckOrigin: func(_ *http.Request) bool { return true }, // allow connections from any origin
|
||
EnableCompression: true,
|
||
},
|
||
}
|
||
}
|
||
|
||
func (s *wsHandler) serveWSRequest(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != "GET" {
|
||
// w.WriteHeader(http.StatusMethodNotAllowed)
|
||
http.Error(w, "Method not allowed", 405)
|
||
return
|
||
}
|
||
|
||
if s.server.IsClosed() {
|
||
http.Error(w, "HTTP server is closed(code:500-11).", 500)
|
||
log.Warnf("server{%s} stop acceptting client connect request.", s.server.addr)
|
||
return
|
||
}
|
||
|
||
conn, err := s.upgrader.Upgrade(w, r, nil)
|
||
if err != nil {
|
||
log.Warnf("upgrader.Upgrader(http.Request{%#v}) = error:%+v", r, err)
|
||
return
|
||
}
|
||
if conn.RemoteAddr().String() == conn.LocalAddr().String() {
|
||
log.Warnf("conn.localAddr{%s} == conn.RemoteAddr", conn.LocalAddr().String(), conn.RemoteAddr().String())
|
||
return
|
||
}
|
||
// conn.SetReadLimit(int64(handler.maxMsgLen))
|
||
ss := newWSSession(conn, s.server)
|
||
err = s.newSession(ss)
|
||
if err != nil {
|
||
conn.Close()
|
||
log.Warnf("server{%s}.newSession(ss{%#v}) = err {%s}", s.server.addr, ss, err)
|
||
return
|
||
}
|
||
if ss.(*session).maxMsgLen > 0 {
|
||
conn.SetReadLimit(int64(ss.(*session).maxMsgLen))
|
||
}
|
||
ss.(*session).run()
|
||
}
|
||
|
||
// runWSEventLoop serve websocket client request
|
||
// @newSession: new websocket connection callback
|
||
func (s *server) runWSEventLoop(newSession NewSessionCallback) {
|
||
s.wg.Add(1)
|
||
go func() {
|
||
defer s.wg.Done()
|
||
var (
|
||
err error
|
||
handler *wsHandler
|
||
server *http.Server
|
||
)
|
||
handler = newWSHandler(s, newSession)
|
||
handler.HandleFunc(s.path, handler.serveWSRequest)
|
||
server = &http.Server{
|
||
Addr: s.addr,
|
||
Handler: handler,
|
||
// ReadTimeout: server.HTTPTimeout,
|
||
// WriteTimeout: server.HTTPTimeout,
|
||
}
|
||
s.lock.Lock()
|
||
s.server = server
|
||
s.lock.Unlock()
|
||
err = server.Serve(s.streamListener)
|
||
if err != nil {
|
||
log.Errorf("http.server.Serve(addr{%s}) = err:%+v", s.addr, perrors.WithStack(err))
|
||
}
|
||
}()
|
||
}
|
||
|
||
// serve websocket client request
|
||
// RunWSSEventLoop serve websocket client request
|
||
func (s *server) runWSSEventLoop(newSession NewSessionCallback) {
|
||
s.wg.Add(1)
|
||
go func() {
|
||
var (
|
||
err error
|
||
certPem []byte
|
||
certificate tls.Certificate
|
||
certPool *x509.CertPool
|
||
config *tls.Config
|
||
handler *wsHandler
|
||
server *http.Server
|
||
)
|
||
defer s.wg.Done()
|
||
|
||
if certificate, err = tls.LoadX509KeyPair(s.cert, s.privateKey); err != nil {
|
||
panic(fmt.Sprintf("tls.LoadX509KeyPair(certs{%s}, privateKey{%s}) = err:%+v",
|
||
s.cert, s.privateKey, perrors.WithStack(err)))
|
||
}
|
||
config = &tls.Config{
|
||
InsecureSkipVerify: true, // do not verify peer certs
|
||
ClientAuth: tls.NoClientCert,
|
||
NextProtos: []string{"http/1.1"},
|
||
Certificates: []tls.Certificate{certificate},
|
||
}
|
||
|
||
if s.caCert != "" {
|
||
certPem, err = os.ReadFile(s.caCert)
|
||
if err != nil {
|
||
panic(fmt.Errorf("os.ReadFile(certFile{%s}) = err:%+v", s.caCert, perrors.WithStack(err)))
|
||
}
|
||
certPool = x509.NewCertPool()
|
||
if ok := certPool.AppendCertsFromPEM(certPem); !ok {
|
||
panic("failed to parse root certificate file")
|
||
}
|
||
config.ClientCAs = certPool
|
||
config.ClientAuth = tls.RequireAndVerifyClientCert
|
||
config.InsecureSkipVerify = false
|
||
}
|
||
|
||
handler = newWSHandler(s, newSession)
|
||
handler.HandleFunc(s.path, handler.serveWSRequest)
|
||
server = &http.Server{
|
||
Addr: s.addr,
|
||
Handler: handler,
|
||
// ReadTimeout: server.HTTPTimeout,
|
||
// WriteTimeout: server.HTTPTimeout,
|
||
}
|
||
server.SetKeepAlivesEnabled(true)
|
||
s.lock.Lock()
|
||
s.server = server
|
||
s.lock.Unlock()
|
||
err = server.Serve(tls.NewListener(s.streamListener, config))
|
||
if err != nil {
|
||
log.Errorf("http.server.Serve(addr{%s}) = err:%+v", s.addr, perrors.WithStack(err))
|
||
panic(err)
|
||
}
|
||
}()
|
||
}
|
||
|
||
// RunEventLoop serves client request.
|
||
// @newSession: new connection callback
|
||
func (s *server) RunEventLoop(newSession NewSessionCallback) {
|
||
if err := s.listen(); err != nil {
|
||
panic(fmt.Errorf("server.listen() = error:%+v", perrors.WithStack(err)))
|
||
}
|
||
|
||
switch s.endPointType {
|
||
case TCP_SERVER:
|
||
s.runTCPEventLoop(newSession)
|
||
case UDP_ENDPOINT:
|
||
s.runUDPEventLoop(newSession)
|
||
case WS_SERVER:
|
||
s.runWSEventLoop(newSession)
|
||
case WSS_SERVER:
|
||
s.runWSSEventLoop(newSession)
|
||
default:
|
||
panic(fmt.Sprintf("illegal server type %s", s.endPointType.String()))
|
||
}
|
||
}
|
||
|
||
func (s *server) Listener() net.Listener {
|
||
return s.streamListener
|
||
}
|
||
|
||
func (s *server) PacketConn() net.PacketConn {
|
||
return s.pktListener
|
||
}
|
||
|
||
func (s *server) Close() {
|
||
s.stop()
|
||
s.wg.Wait()
|
||
}
|