新增dubbo-getty

This commit is contained in:
zhangzeyi
2023-02-14 16:23:14 +08:00
parent 7ee59e8a45
commit a4fd712e7e
17 changed files with 3958 additions and 183 deletions

View File

@@ -4,7 +4,7 @@ certificate: # 证书根目录
client:
advertise_urls: # open api 服务的广播地址
- http://192.168.198.166:9400
- http://192.168.198.168:9400
certificate: # 征对 https 的证书配置
- cert:
key:
@@ -12,12 +12,12 @@ client:
- http://0.0.0.0:9400 # open api 服务的监听地址
gateway: # 网关服务配置
# advertise_urls: # 广播地址
# - tcp://192.168.198.166:8081
# - tcp://192.168.198.168:8081
listen_urls: # 监听地址
- tcp://0.0.0.0:8099
peer: # 节点通信配置
advertise_urls:
- tcp://192.168.198.166:9401
- tcp://192.168.198.168:9401
certificate:
- cert:
key:

View File

@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
hessian "github.com/apache/dubbo-go-hessian2"
dubbo_server "github.com/eolinker/apinto/app/dubbo-test/dubbo-server"
http_dubbo "github.com/eolinker/apinto/app/dubbo-test/http-dubbo"
"github.com/eolinker/apinto/utils"
"time"
@@ -14,7 +13,7 @@ import (
var errClientReadTimeout = errors.New("maybe the client read timeout or fail to decode tcp stream in Writer.Write")
func main() {
go dubbo_server.StartDubboServer()
//go dubbo_server.StartDubboServer()
//time.Sleep(time.Second)
//
@@ -29,7 +28,7 @@ func main() {
//cn.zzy.
addr := "192.168.198.167:8099"
addr = "192.168.198.167:20001"
addr = "127.0.0.1:4999"
addr = "127.0.0.1:20001"
resp, err := http_dubbo.ProxyToDubbo(addr, "api.UserService", "GetUser", time.Second*3, types, valuesList)
if err != nil {
fmt.Println(err)

View File

@@ -1,127 +0,0 @@
package manager
import (
"dubbo.apache.org/dubbo-go/v3/protocol/dubbo"
"dubbo.apache.org/dubbo-go/v3/remoting"
gxbytes "github.com/dubbogo/gost/bytes"
"github.com/eolinker/eosc/log"
"github.com/pkg/errors"
"io"
"net"
"sync"
"sync/atomic"
)
const maxReadBufLen = 4 * 1024
type dubbo2Server struct {
lock *sync.Mutex
conn []net.Conn
stop int32
}
func NewDubbo2Server() *dubbo2Server {
return &dubbo2Server{
lock: new(sync.Mutex),
}
}
func (d *dubbo2Server) Handler(port int, conn net.Conn) {
d.lock.Lock()
defer d.lock.Unlock()
if atomic.LoadInt32(&d.stop) == 1 {
conn.Close()
return
}
d.conn = append(d.conn, conn)
d.task(port, conn)
}
func (d *dubbo2Server) ShutDown() {
d.lock.Lock()
defer d.lock.Unlock()
atomic.StoreInt32(&d.stop, 1)
defer atomic.StoreInt32(&d.stop, 0)
for _, conn := range d.conn {
conn.Close()
}
d.conn = nil
}
func (d *dubbo2Server) task(port int, conn net.Conn) {
go func() {
var (
pkgLen int
buf []byte
err error
ok bool
netError net.Error
)
pktBuf := gxbytes.NewBuffer(nil)
for {
var bufLen = 0
for {
reader := io.Reader(conn)
buf = pktBuf.WriteNextBegin(maxReadBufLen)
bufLen, err = reader.Read(buf)
if err != nil {
if netError, ok = errors.Cause(err).(net.Error); ok && netError.Timeout() {
break
}
if errors.Cause(err) == io.EOF {
log.Infof("session.conn read EOF, client send over, session exit")
err = nil
if bufLen != 0 {
log.Infof("session.conn read EOF, while the bufLen(%d) is non-zero.", bufLen)
break
}
return
}
log.Errorf("[session.conn.read] = error:%+v", errors.WithStack(err))
return
}
break
}
if 0 != bufLen {
go func() {
pktBuf.WriteNextEnd(bufLen)
for {
if pktBuf.Len() <= 0 {
break
}
codec := dubbo.DubboCodec{}
var pkg *remoting.DecodeResult
pkg, pkgLen, err = codec.Decode(pktBuf.Bytes())
if err == nil && pkgLen > maxReadBufLen {
err = errors.Errorf("pkgLen %d > session max message len %d", pkgLen, maxReadBufLen)
}
pktBuf.Next(pkgLen)
if err != nil {
break
}
if pkg == nil {
break
}
manager.Handler(port, conn, pkg)
}
}()
}
}
}()
}

View File

@@ -2,6 +2,7 @@ package manager
import (
"github.com/eolinker/apinto/drivers/router"
getty "github.com/eolinker/apinto/dubbo-getty/server"
"github.com/eolinker/apinto/plugin"
"github.com/eolinker/eosc/common/bean"
"github.com/eolinker/eosc/eocontext"
@@ -18,14 +19,8 @@ func init() {
serverHandler := func(port int, listener net.Listener) {
for {
conn, err := listener.Accept()
if err != nil {
log.Errorf("dubbo-manger listener.Accept err=%v", err)
}
go manager.connHandler.Handler(port, conn)
}
server := getty.NewServer(manager.Handler, getty.WithListenerServer(listener))
server.Start()
}
router.Register(router.Dubbo2, serverHandler)

View File

@@ -1,15 +1,11 @@
package manager
import (
"dubbo.apache.org/dubbo-go/v3/common/constant"
"dubbo.apache.org/dubbo-go/v3/protocol"
"dubbo.apache.org/dubbo-go/v3/protocol/dubbo/impl"
"dubbo.apache.org/dubbo-go/v3/remoting"
dubbo2_context "github.com/eolinker/apinto/node/dubbo2-context"
"dubbo.apache.org/dubbo-go/v3/protocol/invocation"
"github.com/eolinker/apinto/router"
eoscContext "github.com/eolinker/eosc/eocontext"
"github.com/eolinker/eosc/log"
"net"
"sync"
"sync/atomic"
)
@@ -19,7 +15,6 @@ var _ IManger = (*dubboManger)(nil)
type IManger interface {
Set(id string, port int, serviceName, methodName string, rule []AppendRule, handler router.IRouterHandler) error
Delete(id string)
Handler(port int, conn net.Conn, result *remoting.DecodeResult)
}
func (d *dubboManger) SetGlobalFilters(globalFilters *eoscContext.IChainPro) {
@@ -30,7 +25,6 @@ func NewManager() *dubboManger {
return &dubboManger{
matcher: nil,
routersData: new(RouterData),
connHandler: NewDubbo2Server(),
globalFilters: atomic.Pointer[eoscContext.IChainPro]{},
}
}
@@ -39,7 +33,6 @@ type dubboManger struct {
lock sync.RWMutex
matcher router.IMatcher
routersData IRouterData
connHandler *dubbo2Server
globalFilters atomic.Pointer[eoscContext.IChainPro]
}
@@ -72,39 +65,18 @@ func (d *dubboManger) Delete(id string) {
return
}
func (d *dubboManger) Handler(port int, conn net.Conn, result *remoting.DecodeResult) {
if result.IsRequest {
req := result.Result.(*remoting.Request)
dubboPackage := impl.NewDubboPackage(nil)
dubboPackage.Header = impl.DubboHeader{
SerialID: req.SerialID,
Type: impl.PackageRequest,
ID: req.ID,
}
if invoc, ok := req.Data.(*protocol.Invocation); ok {
invocation := *invoc
dubboPackage.Service.Path = invocation.GetAttachmentWithDefaultValue(constant.PathKey, "")
dubboPackage.Service.Interface = invocation.GetAttachmentWithDefaultValue(constant.InterfaceKey, "")
dubboPackage.Service.Version = invocation.GetAttachmentWithDefaultValue(constant.VersionKey, "")
dubboPackage.Service.Group = invocation.GetAttachmentWithDefaultValue(constant.GroupKey, "")
dubboPackage.Service.Method = invocation.MethodName()
}
context := dubbo2_context.NewContext(dubboPackage, port, conn)
match, has := d.matcher.Match(port, context.HeaderReader())
if !has {
//todo 怎样处理 conn.Write() ???
} else {
log.Debug("match has:", port)
match.ServeHTTP(context)
}
}
func (d *dubboManger) Handler(req *invocation.RPCInvocation) protocol.RPCResult {
// context := dubbo2_context.NewContext(dubboPackage, port, conn)
//
// match, has := d.matcher.Match(port, context.HeaderReader())
// if !has {
// //todo 怎样处理 conn.Write() ???
// } else {
// log.Debug("match has:", port)
// match.ServeHTTP(context)
// }
//
//}
return protocol.RPCResult{}
}

482
dubbo-getty/client.go Normal file
View File

@@ -0,0 +1,482 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package getty
import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"github.com/eolinker/eosc/log"
"net"
"os"
"strings"
"sync"
"sync/atomic"
"time"
)
import (
"github.com/dubbogo/gost/bytes"
"github.com/dubbogo/gost/net"
gxsync "github.com/dubbogo/gost/sync"
gxtime "github.com/dubbogo/gost/time"
"github.com/gorilla/websocket"
perrors "github.com/pkg/errors"
)
const (
reconnectInterval = 3e8 // 300ms
connectInterval = 5e8 // 500ms
connectTimeout = 3e9
maxTimes = 10
)
var (
sessionClientKey = "session-client-owner"
connectPingPackage = []byte("connect-ping")
clientID = EndPointID(0)
)
type Client interface {
EndPoint
}
type client struct {
ClientOptions
// endpoint ID
endPointID EndPointID
// net
sync.Mutex
endPointType EndPointType
newSession NewSessionCallback
ssMap map[Session]struct{}
sync.Once
done chan struct{}
wg sync.WaitGroup
}
func (c *client) init(opts ...ClientOption) {
for _, opt := range opts {
opt(&(c.ClientOptions))
}
}
func newClient(t EndPointType, opts ...ClientOption) *client {
c := &client{
endPointID: atomic.AddInt32(&clientID, 1),
endPointType: t,
done: make(chan struct{}),
}
c.init(opts...)
if c.number <= 0 || c.addr == "" {
panic(fmt.Sprintf("client type:%s, @connNum:%d, @serverAddr:%s", t, c.number, c.addr))
}
c.ssMap = make(map[Session]struct{}, c.number)
return c
}
// NewTCPClient builds a tcp client.
func NewTCPClient(opts ...ClientOption) Client {
return newClient(TCP_CLIENT, opts...)
}
// NewUDPClient builds a connected udp client
func NewUDPClient(opts ...ClientOption) Client {
return newClient(UDP_CLIENT, opts...)
}
// NewWSClient builds a ws client.
func NewWSClient(opts ...ClientOption) Client {
c := newClient(WS_CLIENT, opts...)
if !strings.HasPrefix(c.addr, "ws://") {
panic(fmt.Sprintf("the prefix @serverAddr:%s is not ws://", c.addr))
}
return c
}
// NewWSSClient function builds a wss client.
func NewWSSClient(opts ...ClientOption) Client {
c := newClient(WSS_CLIENT, opts...)
if c.cert == "" {
panic(fmt.Sprintf("@cert:%s", c.cert))
}
if !strings.HasPrefix(c.addr, "wss://") {
panic(fmt.Sprintf("the prefix @serverAddr:%s is not wss://", c.addr))
}
return c
}
func (c *client) ID() EndPointID {
return c.endPointID
}
func (c *client) EndPointType() EndPointType {
return c.endPointType
}
func (c *client) dialTCP() Session {
var (
err error
conn net.Conn
)
for {
if c.IsClosed() {
return nil
}
if c.sslEnabled {
if sslConfig, buildTlsConfErr := c.tlsConfigBuilder.BuildTlsConfig(); buildTlsConfErr == nil && sslConfig != nil {
d := &net.Dialer{Timeout: connectTimeout}
conn, err = tls.DialWithDialer(d, "tcp", c.addr, sslConfig)
}
} else {
conn, err = net.DialTimeout("tcp", c.addr, connectTimeout)
}
if err == nil && gxnet.IsSameAddr(conn.RemoteAddr(), conn.LocalAddr()) {
conn.Close()
err = errSelfConnect
}
if err == nil {
return newTCPSession(conn, c)
}
log.Infof("net.DialTimeout(addr:%s, timeout:%v) = error:%+v", c.addr, connectTimeout, perrors.WithStack(err))
<-gxtime.After(connectInterval)
}
}
func (c *client) dialUDP() Session {
var (
err error
conn *net.UDPConn
localAddr *net.UDPAddr
peerAddr *net.UDPAddr
length int
bufp *[]byte
buf []byte
)
bufp = gxbytes.GetBytes(128)
defer gxbytes.PutBytes(bufp)
buf = *bufp
localAddr = &net.UDPAddr{IP: net.IPv4zero, Port: 0}
peerAddr, _ = net.ResolveUDPAddr("udp", c.addr)
for {
if c.IsClosed() {
return nil
}
conn, err = net.DialUDP("udp", localAddr, peerAddr)
if err == nil && gxnet.IsSameAddr(conn.RemoteAddr(), conn.LocalAddr()) {
conn.Close()
err = errSelfConnect
}
if err != nil {
log.Warnf("net.DialTimeout(addr:%s, timeout:%v) = error:%+v", c.addr, perrors.WithStack(err))
<-gxtime.After(connectInterval)
continue
}
// check connection alive by write/read action
conn.SetWriteDeadline(time.Now().Add(1e9))
if length, err = conn.Write(connectPingPackage[:]); err != nil {
conn.Close()
log.Warnf("conn.Write(%s) = {length:%d, err:%+v}", string(connectPingPackage), length, perrors.WithStack(err))
<-gxtime.After(connectInterval)
continue
}
conn.SetReadDeadline(time.Now().Add(1e9))
length, err = conn.Read(buf)
if netErr, ok := perrors.Cause(err).(net.Error); ok && netErr.Timeout() {
err = nil
}
if err != nil {
log.Infof("conn{%#v}.Read() = {length:%d, err:%+v}", conn, length, perrors.WithStack(err))
conn.Close()
<-gxtime.After(connectInterval)
continue
}
return newUDPSession(conn, c)
}
}
func (c *client) dialWS() Session {
var (
err error
dialer websocket.Dialer
conn *websocket.Conn
ss Session
)
dialer.EnableCompression = true
for {
if c.IsClosed() {
return nil
}
conn, _, err = dialer.Dial(c.addr, nil)
log.Infof("websocket.dialer.Dial(addr:%s) = error:%+v", c.addr, perrors.WithStack(err))
if err == nil && gxnet.IsSameAddr(conn.RemoteAddr(), conn.LocalAddr()) {
conn.Close()
err = errSelfConnect
}
if err == nil {
ss = newWSSession(conn, c)
if ss.(*session).maxMsgLen > 0 {
conn.SetReadLimit(int64(ss.(*session).maxMsgLen))
}
return ss
}
log.Infof("websocket.dialer.Dial(addr:%s) = error:%+v", c.addr, perrors.WithStack(err))
<-gxtime.After(connectInterval)
}
}
func (c *client) dialWSS() Session {
var (
err error
root *x509.Certificate
roots []*x509.Certificate
certPool *x509.CertPool
config *tls.Config
dialer websocket.Dialer
conn *websocket.Conn
ss Session
)
dialer.EnableCompression = true
config = &tls.Config{
InsecureSkipVerify: true,
}
if c.cert != "" {
certPEMBlock, err := os.ReadFile(c.cert)
if err != nil {
panic(fmt.Sprintf("os.ReadFile(cert:%s) = error:%+v", c.cert, perrors.WithStack(err)))
}
var cert tls.Certificate
for {
var certDERBlock *pem.Block
certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
if certDERBlock == nil {
break
}
if certDERBlock.Type == "CERTIFICATE" {
cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
}
}
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0] = cert
}
certPool = x509.NewCertPool()
for _, c := range config.Certificates {
roots, err = x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
if err != nil {
panic(fmt.Sprintf("error parsing server's root cert: %+v\n", perrors.WithStack(err)))
}
for _, root = range roots {
certPool.AddCert(root)
}
}
config.InsecureSkipVerify = true
config.RootCAs = certPool
// dialer.EnableCompression = true
dialer.TLSClientConfig = config
for {
if c.IsClosed() {
return nil
}
conn, _, err = dialer.Dial(c.addr, nil)
if err == nil && gxnet.IsSameAddr(conn.RemoteAddr(), conn.LocalAddr()) {
conn.Close()
err = errSelfConnect
}
if err == nil {
ss = newWSSession(conn, c)
if ss.(*session).maxMsgLen > 0 {
conn.SetReadLimit(int64(ss.(*session).maxMsgLen))
}
ss.SetName(defaultWSSSessionName)
return ss
}
log.Infof("websocket.dialer.Dial(addr:%s) = error:%+v", c.addr, perrors.WithStack(err))
<-gxtime.After(connectInterval)
}
}
func (c *client) dial() Session {
switch c.endPointType {
case TCP_CLIENT:
return c.dialTCP()
case UDP_CLIENT:
return c.dialUDP()
case WS_CLIENT:
return c.dialWS()
case WSS_CLIENT:
return c.dialWSS()
}
return nil
}
func (c *client) GetTaskPool() gxsync.GenericTaskPool {
return c.tPool
}
func (c *client) sessionNum() int {
var num int
c.Lock()
for s := range c.ssMap {
if s.IsClosed() {
delete(c.ssMap, s)
}
}
num = len(c.ssMap)
c.Unlock()
return num
}
func (c *client) connect() {
var (
err error
ss Session
)
for {
ss = c.dial()
if ss == nil {
// client has been closed
break
}
err = c.newSession(ss)
if err == nil {
ss.(*session).run()
c.Lock()
if c.ssMap == nil {
c.Unlock()
break
}
c.ssMap[ss] = struct{}{}
c.Unlock()
ss.SetAttribute(sessionClientKey, c)
break
}
// don't distinguish between tcp connection and websocket connection. Because
// gorilla/websocket/conn.go:(Conn)Close also invoke net.Conn.Close()
ss.Conn().Close()
}
}
// there are two methods to keep connection pool. the first approach is like
// redigo's lazy connection pool(https://github.com/gomodule/redigo/blob/master/redis/pool.go:),
// in which you should apply testOnBorrow to check alive of the connection.
// the second way is as follows. @RunEventLoop detects the aliveness of the connection
// in regular time interval.
// the active method maybe overburden the cpu slightly.
// however, you can get a active tcp connection very quickly.
func (c *client) RunEventLoop(newSession NewSessionCallback) {
c.Lock()
c.newSession = newSession
c.Unlock()
c.reConnect()
}
// a for-loop connect to make sure the connection pool is valid
func (c *client) reConnect() {
var num, max, times, interval int
max = c.number
interval = c.reconnectInterval
if interval == 0 {
interval = reconnectInterval
}
for {
if c.IsClosed() {
log.Warnf("client{peer:%s} goroutine exit now.", c.addr)
break
}
num = c.sessionNum()
if max <= num {
break
}
c.connect()
times++
if maxTimes < times {
times = maxTimes
}
<-gxtime.After(time.Duration(int64(times) * int64(interval)))
}
}
func (c *client) stop() {
select {
case <-c.done:
return
default:
c.Once.Do(func() {
close(c.done)
c.Lock()
for s := range c.ssMap {
s.RemoveAttribute(sessionClientKey)
s.Close()
}
c.ssMap = nil
c.Unlock()
})
}
}
func (c *client) IsClosed() bool {
select {
case <-c.done:
return true
default:
return false
}
}
func (c *client) Close() {
c.stop()
c.wg.Wait()
}

647
dubbo-getty/connection.go Normal file
View File

@@ -0,0 +1,647 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package getty
import (
"compress/flate"
"crypto/tls"
"fmt"
"github.com/eolinker/eosc/log"
"io"
"net"
"sync"
"time"
)
import (
"github.com/golang/snappy"
"github.com/gorilla/websocket"
perrors "github.com/pkg/errors"
uatomic "go.uber.org/atomic"
)
var (
launchTime = time.Now()
connID uatomic.Uint32
)
// Connection wrap some connection params and operations
type Connection interface {
ID() uint32
SetCompressType(CompressType)
LocalAddr() string
RemoteAddr() string
incReadPkgNum()
incWritePkgNum()
// UpdateActive update session's active time
UpdateActive()
// GetActive get session's active time
GetActive() time.Time
readTimeout() time.Duration
// SetReadTimeout sets deadline for the future read calls.
SetReadTimeout(time.Duration)
writeTimeout() time.Duration
// SetWriteTimeout sets deadline for the future read calls.
SetWriteTimeout(time.Duration)
send(interface{}) (int, error)
// don't distinguish between tcp connection and websocket connection. Because
// gorilla/websocket/conn.go:(Conn)Close also invoke net.Conn.Close
close(int)
// set related session
setSession(Session)
}
// ///////////////////////////////////////
// getty connection
// ///////////////////////////////////////
type gettyConn struct {
id uint32
compress CompressType
padding1 uint8
padding2 uint16
readBytes uatomic.Uint32 // read bytes
writeBytes uatomic.Uint32 // write bytes
readPkgNum uatomic.Uint32 // send pkg number
writePkgNum uatomic.Uint32 // recv pkg number
active uatomic.Int64 // last active, in milliseconds
rTimeout uatomic.Duration // network current limiting
wTimeout uatomic.Duration
rLastDeadline uatomic.Time // last network read time
wLastDeadline uatomic.Time // last network write time
local string // local address
peer string // peer address
ss Session
}
func (c *gettyConn) ID() uint32 {
return c.id
}
func (c *gettyConn) LocalAddr() string {
return c.local
}
func (c *gettyConn) RemoteAddr() string {
return c.peer
}
func (c *gettyConn) incReadPkgNum() {
c.readPkgNum.Add(1)
}
func (c *gettyConn) incWritePkgNum() {
c.writePkgNum.Add(1)
}
func (c *gettyConn) UpdateActive() {
c.active.Store(int64(time.Since(launchTime)))
}
func (c *gettyConn) GetActive() time.Time {
return launchTime.Add(time.Duration(c.active.Load()))
}
func (c *gettyConn) send(interface{}) (int, error) {
return 0, nil
}
func (c *gettyConn) close(int) {}
func (c gettyConn) readTimeout() time.Duration {
return c.rTimeout.Load()
}
func (c *gettyConn) setSession(ss Session) {
c.ss = ss
}
// SetReadTimeout Pls do not set read deadline for websocket connection. AlexStocks 20180310
// gorilla/websocket/conn.go:NextReader will always fail when got a timeout error.
//
// Pls do not set read deadline when using compression. AlexStocks 20180314.
func (c *gettyConn) SetReadTimeout(rTimeout time.Duration) {
if rTimeout < 1 {
panic("@rTimeout < 1")
}
c.rTimeout.Store(rTimeout)
if c.wTimeout.Load() == 0 {
c.wTimeout.Store(rTimeout)
}
}
func (c gettyConn) writeTimeout() time.Duration {
return c.wTimeout.Load()
}
// SetWriteTimeout Pls do not set write deadline for websocket connection. AlexStocks 20180310
// gorilla/websocket/conn.go:NextWriter will always fail when got a timeout error.
//
// Pls do not set write deadline when using compression. AlexStocks 20180314.
func (c *gettyConn) SetWriteTimeout(wTimeout time.Duration) {
if wTimeout < 1 {
panic("@wTimeout < 1")
}
c.wTimeout.Store(wTimeout)
if c.rTimeout.Load() == 0 {
c.rTimeout.Store(wTimeout)
}
}
/////////////////////////////////////////
// getty tcp connection
/////////////////////////////////////////
type gettyTCPConn struct {
gettyConn
reader io.Reader
writer io.Writer
conn net.Conn
}
// create gettyTCPConn
func newGettyTCPConn(conn net.Conn) *gettyTCPConn {
if conn == nil {
panic("newGettyTCPConn(conn):@conn is nil")
}
var localAddr, peerAddr string
// check conn.LocalAddr or conn.RemoteAddr is nil to defeat panic on 2016/09/27
if conn.LocalAddr() != nil {
localAddr = conn.LocalAddr().String()
}
if conn.RemoteAddr() != nil {
peerAddr = conn.RemoteAddr().String()
}
return &gettyTCPConn{
conn: conn,
reader: io.Reader(conn),
writer: io.Writer(conn),
gettyConn: gettyConn{
id: connID.Add(1),
rTimeout: *uatomic.NewDuration(netIOTimeout),
wTimeout: *uatomic.NewDuration(netIOTimeout),
local: localAddr,
peer: peerAddr,
compress: CompressNone,
},
}
}
// for zip compress
type writeFlusher struct {
flusher *flate.Writer
lock sync.Mutex
}
func (t *writeFlusher) Write(p []byte) (int, error) {
var (
n int
err error
)
t.lock.Lock()
defer t.lock.Unlock()
n, err = t.flusher.Write(p)
if err != nil {
return n, perrors.WithStack(err)
}
if err := t.flusher.Flush(); err != nil {
return 0, perrors.WithStack(err)
}
return n, nil
}
// SetCompressType set compress type(tcp: zip/snappy, websocket:zip)
func (t *gettyTCPConn) SetCompressType(c CompressType) {
switch c {
case CompressNone, CompressZip, CompressBestSpeed, CompressBestCompression, CompressHuffman:
ioReader := io.Reader(t.conn)
t.reader = flate.NewReader(ioReader)
ioWriter := io.Writer(t.conn)
w, err := flate.NewWriter(ioWriter, int(c))
if err != nil {
panic(fmt.Sprintf("flate.NewReader(flate.DefaultCompress) = err(%s)", err))
}
t.writer = &writeFlusher{flusher: w}
case CompressSnappy:
ioReader := io.Reader(t.conn)
t.reader = snappy.NewReader(ioReader)
ioWriter := io.Writer(t.conn)
t.writer = snappy.NewBufferedWriter(ioWriter)
default:
panic(fmt.Sprintf("illegal comparess type %d", c))
}
t.compress = c
}
// tcp connection read
func (t *gettyTCPConn) recv(p []byte) (int, error) {
var (
err error
currentTime time.Time
length int
)
// set read timeout deadline
if t.compress == CompressNone && t.rTimeout.Load() > 0 {
// Optimization: update read deadline only if more than 25%
// of the last read deadline exceeded.
// See https://github.com/golang/go/issues/15133 for details.
currentTime = time.Now()
if currentTime.Sub(t.rLastDeadline.Load()) > t.rTimeout.Load()>>2 {
if err = t.conn.SetReadDeadline(currentTime.Add(t.rTimeout.Load())); err != nil {
// just a timeout error
return 0, perrors.WithStack(err)
}
t.rLastDeadline.Store(currentTime)
}
}
length, err = t.reader.Read(p)
t.readBytes.Add(uint32(length))
return length, perrors.WithStack(err)
}
// tcp connection write
func (t *gettyTCPConn) send(pkg interface{}) (int, error) {
var (
err error
currentTime time.Time
ok bool
p []byte
length int
lg int64
)
if t.compress == CompressNone && t.wTimeout.Load() > 0 {
// Optimization: update write deadline only if more than 25%
// of the last write deadline exceeded.
// See https://github.com/golang/go/issues/15133 for details.
currentTime = time.Now()
if currentTime.Sub(t.wLastDeadline.Load()) > t.wTimeout.Load()>>2 {
if err = t.conn.SetWriteDeadline(currentTime.Add(t.wTimeout.Load())); err != nil {
return 0, perrors.WithStack(err)
}
t.wLastDeadline.Store(currentTime)
}
}
if buffers, ok := pkg.([][]byte); ok {
netBuf := net.Buffers(buffers)
lg, err = netBuf.WriteTo(t.conn)
if err == nil {
t.writeBytes.Add((uint32)(lg))
t.writePkgNum.Add((uint32)(len(buffers)))
}
log.Debug("localAddr: %s, remoteAddr:%s, now:%s, length:%d, err:%s",
t.conn.LocalAddr(), t.conn.RemoteAddr(), currentTime, length, err)
return int(lg), perrors.WithStack(err)
}
if p, ok = pkg.([]byte); ok {
length, err = t.writer.Write(p)
if err == nil {
t.writeBytes.Add((uint32)(len(p)))
t.writePkgNum.Add(1)
}
log.Debug("localAddr: %s, remoteAddr:%s, now:%s, length:%d, err:%v",
t.conn.LocalAddr(), t.conn.RemoteAddr(), currentTime, length, err)
return length, perrors.WithStack(err)
}
return 0, perrors.Errorf("illegal @pkg{%#v} type", pkg)
}
// close tcp connection
func (t *gettyTCPConn) close(waitSec int) {
// if tcpConn, ok := t.conn.(*net.TCPConn); ok {
// tcpConn.SetLinger(0)
// }
if t.conn != nil {
if writer, ok := t.writer.(*snappy.Writer); ok {
if err := writer.Close(); err != nil {
log.Errorf("snappy.Writer.Close() = error:%+v", err)
}
}
if conn, ok := t.conn.(*net.TCPConn); ok {
_ = conn.SetLinger(waitSec)
_ = conn.Close()
} else {
_ = t.conn.(*tls.Conn).Close()
}
t.conn = nil
}
}
// ///////////////////////////////////////
// getty udp connection
// ///////////////////////////////////////
type UDPContext struct {
Pkg interface{}
PeerAddr *net.UDPAddr
}
func (c UDPContext) String() string {
return fmt.Sprintf("{pkg:%#v, peer addr:%s}", c.Pkg, c.PeerAddr)
}
type gettyUDPConn struct {
gettyConn
compressType CompressType
conn *net.UDPConn // for server
}
// create gettyUDPConn
func newGettyUDPConn(conn *net.UDPConn) *gettyUDPConn {
if conn == nil {
panic("newGettyUDPConn(conn):@conn is nil")
}
var localAddr, peerAddr string
if conn.LocalAddr() != nil {
localAddr = conn.LocalAddr().String()
}
if conn.RemoteAddr() != nil {
// connected udp
peerAddr = conn.RemoteAddr().String()
}
return &gettyUDPConn{
conn: conn,
gettyConn: gettyConn{
id: connID.Add(1),
rTimeout: *uatomic.NewDuration(netIOTimeout),
wTimeout: *uatomic.NewDuration(netIOTimeout),
local: localAddr,
peer: peerAddr,
compress: CompressNone,
},
}
}
func (u *gettyUDPConn) SetCompressType(c CompressType) {
switch c {
case CompressNone, CompressZip, CompressBestSpeed, CompressBestCompression, CompressHuffman, CompressSnappy:
u.compressType = c
default:
panic(fmt.Sprintf("illegal comparess type %d", c))
}
}
// udp connection read
func (u *gettyUDPConn) recv(p []byte) (int, *net.UDPAddr, error) {
if u.rTimeout.Load() > 0 {
// Optimization: update read deadline only if more than 25%
// of the last read deadline exceeded.
// See https://github.com/golang/go/issues/15133 for details.
currentTime := time.Now()
if currentTime.Sub(u.rLastDeadline.Load()) > u.rTimeout.Load()>>2 {
if err := u.conn.SetReadDeadline(currentTime.Add(u.rTimeout.Load())); err != nil {
return 0, nil, perrors.WithStack(err)
}
u.rLastDeadline.Store(currentTime)
}
}
length, addr, err := u.conn.ReadFromUDP(p) // connected udp also can get return @addr
log.Debug("ReadFromUDP(p:%d) = {length:%d, peerAddr:%s, error:%v}", len(p), length, addr, err)
if err == nil {
u.readBytes.Add(uint32(length))
}
return length, addr, perrors.WithStack(err)
}
// write udp packet, @ctx should be of type UDPContext
func (u *gettyUDPConn) send(udpCtx interface{}) (int, error) {
var (
err error
currentTime time.Time
length int
ok bool
ctx UDPContext
buf []byte
peerAddr *net.UDPAddr
)
if ctx, ok = udpCtx.(UDPContext); !ok {
return 0, perrors.Errorf("illegal @udpCtx{%s} type, @udpCtx type:%T", udpCtx, udpCtx)
}
if buf, ok = ctx.Pkg.([]byte); !ok {
return 0, perrors.Errorf("illegal @udpCtx.Pkg{%#v} type", udpCtx)
}
if u.ss.EndPoint().EndPointType() == UDP_ENDPOINT {
peerAddr = ctx.PeerAddr
if peerAddr == nil {
return 0, ErrNullPeerAddr
}
}
if u.wTimeout.Load() > 0 {
// Optimization: update write deadline only if more than 25%
// of the last write deadline exceeded.
// See https://github.com/golang/go/issues/15133 for details.
currentTime = time.Now()
if currentTime.Sub(u.wLastDeadline.Load()) > u.wTimeout.Load()>>2 {
if err = u.conn.SetWriteDeadline(currentTime.Add(u.wTimeout.Load())); err != nil {
return 0, perrors.WithStack(err)
}
u.wLastDeadline.Store(currentTime)
}
}
if length, _, err = u.conn.WriteMsgUDP(buf, nil, peerAddr); err == nil {
u.writeBytes.Add((uint32)(len(buf)))
u.writePkgNum.Add(1)
}
log.Debug("WriteMsgUDP(peerAddr:%s) = {length:%d, error:%v}", peerAddr, length, err)
return length, perrors.WithStack(err)
}
// close udp connection
func (u *gettyUDPConn) close(_ int) {
if u.conn != nil {
u.conn.Close()
u.conn = nil
}
}
// ///////////////////////////////////////
// getty websocket connection
// ///////////////////////////////////////
type gettyWSConn struct {
gettyConn
conn *websocket.Conn
}
// create websocket connection
func newGettyWSConn(conn *websocket.Conn) *gettyWSConn {
if conn == nil {
panic("newGettyWSConn(conn):@conn is nil")
}
var localAddr, peerAddr string
// check conn.LocalAddr or conn.RemoetAddr is nil to defeat panic on 2016/09/27
if conn.LocalAddr() != nil {
localAddr = conn.LocalAddr().String()
}
if conn.RemoteAddr() != nil {
peerAddr = conn.RemoteAddr().String()
}
gettyWSConn := &gettyWSConn{
conn: conn,
gettyConn: gettyConn{
id: connID.Add(1),
rTimeout: *uatomic.NewDuration(netIOTimeout),
wTimeout: *uatomic.NewDuration(netIOTimeout),
local: localAddr,
peer: peerAddr,
compress: CompressNone,
},
}
conn.EnableWriteCompression(false)
conn.SetPingHandler(gettyWSConn.handlePing)
conn.SetPongHandler(gettyWSConn.handlePong)
return gettyWSConn
}
// SetCompressType set compress type
func (w *gettyWSConn) SetCompressType(c CompressType) {
switch c {
case CompressNone, CompressZip, CompressBestSpeed, CompressBestCompression, CompressHuffman:
w.conn.EnableWriteCompression(true)
w.conn.SetCompressionLevel(int(c))
default:
panic(fmt.Sprintf("illegal comparess type %d", c))
}
w.compress = c
}
func (w *gettyWSConn) handlePing(message string) error {
err := w.writePong([]byte(message))
if err == websocket.ErrCloseSent {
err = nil
} else if e, ok := err.(net.Error); ok && e.Temporary() {
err = nil
}
if err == nil {
w.UpdateActive()
}
return perrors.WithStack(err)
}
func (w *gettyWSConn) handlePong(string) error {
w.UpdateActive()
return nil
}
// websocket connection read
func (w *gettyWSConn) recv() ([]byte, error) {
// Pls do not set read deadline when using ReadMessage. AlexStocks 20180310
// gorilla/websocket/conn.go:NextReader will always fail when got a timeout error.
_, b, e := w.conn.ReadMessage() // the first return value is message type.
if e == nil {
w.readBytes.Add((uint32)(len(b)))
} else {
if websocket.IsUnexpectedCloseError(e, websocket.CloseGoingAway) {
log.Warnf("websocket unexpected close error: %v", e)
}
}
return b, perrors.WithStack(e)
}
func (w *gettyWSConn) updateWriteDeadline() error {
var (
err error
currentTime time.Time
)
if w.wTimeout.Load() > 0 {
// Optimization: update write deadline only if more than 25%
// of the last write deadline exceeded.
// See https://github.com/golang/go/issues/15133 for details.
currentTime = time.Now()
if currentTime.Sub(w.wLastDeadline.Load()) > w.wTimeout.Load()>>2 {
if err = w.conn.SetWriteDeadline(currentTime.Add(w.wTimeout.Load())); err != nil {
return perrors.WithStack(err)
}
w.wLastDeadline.Store(currentTime)
}
}
return nil
}
// websocket connection write
func (w *gettyWSConn) send(pkg interface{}) (int, error) {
var (
err error
ok bool
p []byte
)
if p, ok = pkg.([]byte); !ok {
return 0, perrors.Errorf("illegal @pkg{%#v} type", pkg)
}
w.updateWriteDeadline()
if err = w.conn.WriteMessage(websocket.BinaryMessage, p); err == nil {
w.writeBytes.Add((uint32)(len(p)))
w.writePkgNum.Add(1)
}
return len(p), perrors.WithStack(err)
}
func (w *gettyWSConn) writePing() error {
w.updateWriteDeadline()
return perrors.WithStack(w.conn.WriteMessage(websocket.PingMessage, []byte{}))
}
func (w *gettyWSConn) writePong(message []byte) error {
w.updateWriteDeadline()
return perrors.WithStack(w.conn.WriteMessage(websocket.PongMessage, message))
}
// close websocket connection
func (w *gettyWSConn) close(waitSec int) {
w.updateWriteDeadline()
w.conn.WriteMessage(websocket.CloseMessage, []byte("bye-bye!!!"))
conn := w.conn.UnderlyingConn()
if tcpConn, ok := conn.(*net.TCPConn); ok {
tcpConn.SetLinger(waitSec)
} else if wsConn, ok := conn.(*tls.Conn); ok {
wsConn.CloseWrite()
}
w.conn.Close()
}

81
dubbo-getty/const.go Normal file
View File

@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package getty
import (
"compress/flate"
"strconv"
)
type (
EndPointID = int32
EndPointType int32
)
const (
UDP_ENDPOINT EndPointType = 0
UDP_CLIENT EndPointType = 1
TCP_CLIENT EndPointType = 2
WS_CLIENT EndPointType = 3
WSS_CLIENT EndPointType = 4
TCP_SERVER EndPointType = 7
WS_SERVER EndPointType = 8
WSS_SERVER EndPointType = 9
)
var EndPointType_name = map[int32]string{
0: "UDP_ENDPOINT",
1: "UDP_CLIENT",
2: "TCP_CLIENT",
3: "WS_CLIENT",
4: "WSS_CLIENT",
7: "TCP_SERVER",
8: "WS_SERVER",
9: "WSS_SERVER",
}
var EndPointType_value = map[string]int32{
"UDP_ENDPOINT": 0,
"UDP_CLIENT": 1,
"TCP_CLIENT": 2,
"WS_CLIENT": 3,
"WSS_CLIENT": 4,
"TCP_SERVER": 7,
"WS_SERVER": 8,
"WSS_SERVER": 9,
}
func (x EndPointType) String() string {
s, ok := EndPointType_name[int32(x)]
if ok {
return s
}
return strconv.Itoa(int(x))
}
type CompressType int
const (
CompressNone CompressType = flate.NoCompression // 0
CompressZip = flate.DefaultCompression // -1
CompressBestSpeed = flate.BestSpeed // 1
CompressBestCompression = flate.BestCompression // 9
CompressHuffman = flate.HuffmanOnly // -2
CompressSnappy = 10
)

110
dubbo-getty/getty.go Normal file
View File

@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package getty
import (
gxsync "github.com/dubbogo/gost/sync"
perrors "github.com/pkg/errors"
)
var (
ErrSessionClosed = perrors.New("session Already Closed")
ErrSessionBlocked = perrors.New("session Full Blocked")
ErrNullPeerAddr = perrors.New("peer address is nil")
)
// NewSessionCallback will be invoked when server accepts a new client connection or client connects to server successfully.
// If there are too many client connections or u do not want to connect a server again, u can return non-nil error. And
// then getty will close the new session.
type NewSessionCallback func(Session) error
// Reader is used to unmarshal a complete pkg from buffer
type Reader interface {
// Read Parse tcp/udp/websocket pkg from buffer and if possible return a complete pkg.
// When receiving a tcp network streaming segment, there are 4 cases as following:
// case 1: a error found in the streaming segment;
// case 2: can not unmarshal a pkg header from the streaming segment;
// case 3: unmarshal a pkg header but can not unmarshal a pkg from the streaming segment;
// case 4: just unmarshal a pkg from the streaming segment;
// case 5: unmarshal more than one pkg from the streaming segment;
//
// The return value is (nil, 0, error) as case 1.
// The return value is (nil, 0, nil) as case 2.
// The return value is (nil, pkgLen, nil) as case 3.
// The return value is (pkg, pkgLen, nil) as case 4.
// The handleTcpPackage may invoke func Read many times as case 5.
Read(Session, []byte) (interface{}, int, error)
}
// Writer is used to marshal pkg and write to session
type Writer interface {
// Write if @Session is udpGettySession, the second parameter is UDPContext.
Write(Session, interface{}) ([]byte, error)
}
// ReadWriter interface use for handle application packages
type ReadWriter interface {
Reader
Writer
}
// EventListener is used to process pkg that received from remote session
type EventListener interface {
// OnOpen invoked when session opened
// If the return error is not nil, @Session will be closed.
OnOpen(Session) error
// OnClose invoked when session closed.
OnClose(Session)
// OnError invoked when got error.
OnError(Session, error)
// OnCron invoked periodically, its period can be set by (Session)SetCronPeriod
OnCron(Session)
// OnMessage invoked when getty received a package. Pls attention that do not handle long time
// logic processing in this func. You'd better set the package's maximum length.
// If the message's length is greater than it, u should should return err in
// Reader{Read} and getty will close this connection soon.
//
// If ur logic processing in this func will take a long time, u should start a goroutine
// pool(like working thread pool in cpp) to handle the processing asynchronously. Or u
// can do the logic processing in other asynchronous way.
// !!!In short, ur OnMessage callback func should return asap.
//
// If this is a udp event listener, the second parameter type is UDPContext.
OnMessage(Session, interface{})
}
// EndPoint represents the identity of the client/server
type EndPoint interface {
// ID get EndPoint ID
ID() EndPointID
// EndPointType get endpoint type
EndPointType() EndPointType
// RunEventLoop run event loop and serves client request.
RunEventLoop(newSession NewSessionCallback)
// IsClosed check the endpoint has been closed
IsClosed() bool
// Close close the endpoint and free its resource
Close()
// GetTaskPool get task pool implemented by dubbogo/gost
GetTaskPool() gxsync.GenericTaskPool
}

178
dubbo-getty/options.go Normal file
View File

@@ -0,0 +1,178 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package getty
import (
gxsync "github.com/dubbogo/gost/sync"
"net"
)
type ServerOption func(*ServerOptions)
type ServerOptions struct {
addr string
// tls
sslEnabled bool
tlsConfigBuilder TlsConfigBuilder
// websocket
path string
cert string
privateKey string
caCert string
// task queue
tPool gxsync.GenericTaskPool
listener net.Listener
}
// WithLocalAddress @addr server listen address.
func WithLocalAddress(addr string) ServerOption {
return func(o *ServerOptions) {
o.addr = addr
}
}
// WithWebsocketServerPath @path: websocket request url path
func WithWebsocketServerPath(path string) ServerOption {
return func(o *ServerOptions) {
o.path = path
}
}
// WithWebsocketServerCert @cert: server certificate file
func WithWebsocketServerCert(cert string) ServerOption {
return func(o *ServerOptions) {
o.cert = cert
}
}
func WithListenerServerCert(listener net.Listener) ServerOption {
return func(o *ServerOptions) {
o.listener = listener
}
}
// WithWebsocketServerPrivateKey @key: server private key(contains its public key)
func WithWebsocketServerPrivateKey(key string) ServerOption {
return func(o *ServerOptions) {
o.privateKey = key
}
}
// WithWebsocketServerRootCert @cert is the root certificate file to verify the legitimacy of server
func WithWebsocketServerRootCert(cert string) ServerOption {
return func(o *ServerOptions) {
o.caCert = cert
}
}
// WithServerTaskPool @pool server task pool.
func WithServerTaskPool(pool gxsync.GenericTaskPool) ServerOption {
return func(o *ServerOptions) {
o.tPool = pool
}
}
// WithServerSslEnabled enable use tls
func WithServerSslEnabled(sslEnabled bool) ServerOption {
return func(o *ServerOptions) {
o.sslEnabled = sslEnabled
}
}
// WithServerTlsConfigBuilder sslConfig is tls config
func WithServerTlsConfigBuilder(tlsConfigBuilder TlsConfigBuilder) ServerOption {
return func(o *ServerOptions) {
o.tlsConfigBuilder = tlsConfigBuilder
}
}
/////////////////////////////////////////
// Client Options
/////////////////////////////////////////
type ClientOption func(*ClientOptions)
type ClientOptions struct {
addr string
number int
reconnectInterval int // reConnect Interval
// tls
sslEnabled bool
tlsConfigBuilder TlsConfigBuilder
// the cert file of wss server which may contain server domain, server ip, the starting effective date, effective
// duration, the hash alg, the len of the private key.
// wss client will use it.
cert string
// task queue
tPool gxsync.GenericTaskPool
}
// WithServerAddress @addr is server address.
func WithServerAddress(addr string) ClientOption {
return func(o *ClientOptions) {
o.addr = addr
}
}
// WithReconnectInterval @reconnectInterval is server address.
func WithReconnectInterval(reconnectInterval int) ClientOption {
return func(o *ClientOptions) {
if 0 < reconnectInterval {
o.reconnectInterval = reconnectInterval
}
}
}
// WithClientTaskPool @pool client task pool.
func WithClientTaskPool(pool gxsync.GenericTaskPool) ClientOption {
return func(o *ClientOptions) {
o.tPool = pool
}
}
// WithConnectionNumber @num is connection number.
func WithConnectionNumber(num int) ClientOption {
return func(o *ClientOptions) {
if 0 < num {
o.number = num
}
}
}
// WithRootCertificateFile @certs is client certificate file. it can be empty.
func WithRootCertificateFile(cert string) ClientOption {
return func(o *ClientOptions) {
o.cert = cert
}
}
// WithClientSslEnabled enable use tls
func WithClientSslEnabled(sslEnabled bool) ClientOption {
return func(o *ClientOptions) {
o.sslEnabled = sslEnabled
}
}
// WithClientTlsConfigBuilder sslConfig is tls config
func WithClientTlsConfigBuilder(tlsConfigBuilder TlsConfigBuilder) ClientOption {
return func(o *ClientOptions) {
o.tlsConfigBuilder = tlsConfigBuilder
}
}

524
dubbo-getty/server.go Normal file
View File

@@ -0,0 +1,524 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package getty
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"github.com/eolinker/eosc/log"
"io/ioutil"
"net"
"net/http"
"strings"
"sync"
"time"
)
import (
gxnet "github.com/dubbogo/gost/net"
gxsync "github.com/dubbogo/gost/sync"
gxtime "github.com/dubbogo/gost/time"
"github.com/gorilla/websocket"
perrors "github.com/pkg/errors"
uatomic "go.uber.org/atomic"
)
var (
errSelfConnect = perrors.New("connect self!")
serverFastFailTimeout = time.Second * 1
serverID uatomic.Int32
)
// Server interface
type Server interface {
EndPoint
}
// StreamServer is like tcp/websocket/wss server
type StreamServer interface {
Server
// Listener get the network listener
Listener() net.Listener
}
// PacketServer is like udp listen endpoint
type PacketServer interface {
Server
// PacketConn get the network listener
PacketConn() net.PacketConn
}
type server struct {
ServerOptions
// endpoint ID
endPointID EndPointID
// net
pktListener net.PacketConn
streamListener net.Listener
lock sync.Mutex // for server
endPointType EndPointType
server *http.Server // for ws or wss server
sync.Once
done chan struct{}
wg sync.WaitGroup
}
func (s *server) init(opts ...ServerOption) {
for _, opt := range opts {
opt(&(s.ServerOptions))
}
}
func newServer(t EndPointType, opts ...ServerOption) *server {
s := &server{
endPointID: serverID.Add(1),
endPointType: t,
done: make(chan struct{}),
}
s.init(opts...)
return s
}
// NewTCPServer builds a tcp server.
func NewTCPServer(opts ...ServerOption) Server {
return newServer(TCP_SERVER, opts...)
}
// NewUDPEndPoint builds a unconnected udp server.
func NewUDPEndPoint(opts ...ServerOption) Server {
return newServer(UDP_ENDPOINT, opts...)
}
// NewWSServer builds a websocket server.
func NewWSServer(opts ...ServerOption) Server {
return newServer(WS_SERVER, opts...)
}
// NewWSSServer builds a secure websocket server.
func NewWSSServer(opts ...ServerOption) Server {
s := newServer(WSS_SERVER, opts...)
if s.addr == "" || s.cert == "" || s.privateKey == "" {
panic(fmt.Sprintf("@addr:%s, @cert:%s, @privateKey:%s, @caCert:%s",
s.addr, s.cert, s.privateKey, s.caCert))
}
return s
}
func (s *server) ID() int32 {
return s.endPointID
}
func (s *server) EndPointType() EndPointType {
return s.endPointType
}
func (s *server) stop() {
select {
case <-s.done:
return
default:
s.Once.Do(func() {
close(s.done)
s.lock.Lock()
if s.server != nil {
ctx, cancel := context.WithTimeout(context.Background(), serverFastFailTimeout)
if err := s.server.Shutdown(ctx); err != nil {
// if the log output is "shutdown ctx: context deadline exceeded" it means that
// there are still some active connections.
log.Errorf("server shutdown ctx:%s error:%v", ctx, err)
}
cancel()
}
s.server = nil
s.lock.Unlock()
if s.streamListener != nil {
// let the server exit asap when got error from RunEventLoop.
s.streamListener.Close()
s.streamListener = nil
}
if s.pktListener != nil {
s.pktListener.Close()
s.pktListener = nil
}
})
}
}
func (s *server) GetTaskPool() gxsync.GenericTaskPool {
return s.tPool
}
func (s *server) IsClosed() bool {
select {
case <-s.done:
return true
default:
return false
}
}
// net.ipv4.tcp_max_syn_backlog
// net.ipv4.tcp_timestamps
// net.ipv4.tcp_tw_recycle
func (s *server) listenTCP() error {
var (
err error
streamListener net.Listener
)
if s.listener == nil {
if len(s.addr) == 0 || !strings.Contains(s.addr, ":") {
streamListener, err = gxnet.ListenOnTCPRandomPort(s.addr)
if err != nil {
return perrors.Wrapf(err, "gxnet.ListenOnTCPRandomPort(addr:%s)", s.addr)
}
} else {
if s.sslEnabled {
if sslConfig, buildTlsConfErr := s.tlsConfigBuilder.BuildTlsConfig(); buildTlsConfErr == nil && sslConfig != nil {
streamListener, err = tls.Listen("tcp", s.addr, sslConfig)
}
} else {
streamListener, err = net.Listen("tcp", s.addr)
}
if err != nil {
return perrors.Wrapf(err, "net.Listen(tcp, addr:%s)", s.addr)
}
}
} else {
streamListener = s.listener
}
s.streamListener = streamListener
s.addr = s.streamListener.Addr().String()
return nil
}
func (s *server) listenUDP() error {
var (
err error
localAddr *net.UDPAddr
pktListener *net.UDPConn
)
if len(s.addr) == 0 || !strings.Contains(s.addr, ":") {
pktListener, err = gxnet.ListenOnUDPRandomPort(s.addr)
if err != nil {
return perrors.Wrapf(err, "gxnet.ListenOnUDPRandomPort(addr:%s)", s.addr)
}
} else {
localAddr, err = net.ResolveUDPAddr("udp", s.addr)
if err != nil {
return perrors.Wrapf(err, "net.ResolveUDPAddr(udp, addr:%s)", s.addr)
}
pktListener, err = net.ListenUDP("udp", localAddr)
if err != nil {
return perrors.Wrapf(err, "net.ListenUDP((udp, localAddr:%#v)", localAddr)
}
}
s.pktListener = pktListener
s.addr = s.pktListener.LocalAddr().String()
return nil
}
// Listen announces on the local network address.
func (s *server) listen() error {
switch s.endPointType {
case TCP_SERVER, WS_SERVER, WSS_SERVER:
return perrors.WithStack(s.listenTCP())
case UDP_ENDPOINT:
return perrors.WithStack(s.listenUDP())
}
return nil
}
func (s *server) accept(newSession NewSessionCallback) (Session, error) {
conn, err := s.streamListener.Accept()
if err != nil {
return nil, perrors.WithStack(err)
}
if gxnet.IsSameAddr(conn.RemoteAddr(), conn.LocalAddr()) {
log.Warnf("conn.localAddr{%s} == conn.RemoteAddr", conn.LocalAddr().String(), conn.RemoteAddr().String())
return nil, perrors.WithStack(errSelfConnect)
}
ss := newTCPSession(conn, s)
err = newSession(ss)
if err != nil {
conn.Close()
return nil, perrors.WithStack(err)
}
return ss, nil
}
func (s *server) runTCPEventLoop(newSession NewSessionCallback) {
s.wg.Add(1)
go func() {
defer s.wg.Done()
var (
err error
client Session
delay time.Duration
)
for {
if s.IsClosed() {
log.Infof("server{%s} stop accepting client connect request.", s.addr)
return
}
if delay != 0 {
<-gxtime.After(delay)
}
client, err = s.accept(newSession)
if err != nil {
if netErr, ok := perrors.Cause(err).(net.Error); ok && netErr.Temporary() {
if delay == 0 {
delay = 5 * time.Millisecond
} else {
delay *= 2
}
if max := 1 * time.Second; delay > max {
delay = max
}
continue
}
log.Warnf("server{%s}.Accept() = err {%+v}", s.addr, perrors.WithStack(err))
continue
}
delay = 0
client.(*session).run()
}
}()
}
func (s *server) runUDPEventLoop(newSession NewSessionCallback) {
s.wg.Add(1)
go func() {
defer s.wg.Done()
var (
err error
conn *net.UDPConn
ss Session
)
conn = s.pktListener.(*net.UDPConn)
ss = newUDPSession(conn, s)
if err = newSession(ss); err != nil {
conn.Close()
panic(err.Error())
}
ss.(*session).run()
}()
}
type wsHandler struct {
http.ServeMux
server *server
newSession NewSessionCallback
upgrader websocket.Upgrader
}
func newWSHandler(server *server, newSession NewSessionCallback) *wsHandler {
return &wsHandler{
server: server,
newSession: newSession,
upgrader: websocket.Upgrader{
// in default, ReadBufferSize & WriteBufferSize is 4k
// HandshakeTimeout: server.HTTPTimeout,
CheckOrigin: func(_ *http.Request) bool { return true }, // allow connections from any origin
EnableCompression: true,
},
}
}
func (s *wsHandler) serveWSRequest(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
// w.WriteHeader(http.StatusMethodNotAllowed)
http.Error(w, "Method not allowed", 405)
return
}
if s.server.IsClosed() {
http.Error(w, "HTTP server is closed(code:500-11).", 500)
log.Warnf("server{%s} stop acceptting client connect request.", s.server.addr)
return
}
conn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
log.Warnf("upgrader.Upgrader(http.Request{%#v}) = error:%+v", r, err)
return
}
if conn.RemoteAddr().String() == conn.LocalAddr().String() {
log.Warnf("conn.localAddr{%s} == conn.RemoteAddr", conn.LocalAddr().String(), conn.RemoteAddr().String())
return
}
// conn.SetReadLimit(int64(handler.maxMsgLen))
ss := newWSSession(conn, s.server)
err = s.newSession(ss)
if err != nil {
conn.Close()
log.Warnf("server{%s}.newSession(ss{%#v}) = err {%s}", s.server.addr, ss, err)
return
}
if ss.(*session).maxMsgLen > 0 {
conn.SetReadLimit(int64(ss.(*session).maxMsgLen))
}
ss.(*session).run()
}
// runWSEventLoop serve websocket client request
// @newSession: new websocket connection callback
func (s *server) runWSEventLoop(newSession NewSessionCallback) {
s.wg.Add(1)
go func() {
defer s.wg.Done()
var (
err error
handler *wsHandler
server *http.Server
)
handler = newWSHandler(s, newSession)
handler.HandleFunc(s.path, handler.serveWSRequest)
server = &http.Server{
Addr: s.addr,
Handler: handler,
// ReadTimeout: server.HTTPTimeout,
// WriteTimeout: server.HTTPTimeout,
}
s.lock.Lock()
s.server = server
s.lock.Unlock()
err = server.Serve(s.streamListener)
if err != nil {
log.Errorf("http.server.Serve(addr{%s}) = err:%+v", s.addr, perrors.WithStack(err))
}
}()
}
// serve websocket client request
// RunWSSEventLoop serve websocket client request
func (s *server) runWSSEventLoop(newSession NewSessionCallback) {
s.wg.Add(1)
go func() {
var (
err error
certPem []byte
certificate tls.Certificate
certPool *x509.CertPool
config *tls.Config
handler *wsHandler
server *http.Server
)
defer s.wg.Done()
if certificate, err = tls.LoadX509KeyPair(s.cert, s.privateKey); err != nil {
panic(fmt.Sprintf("tls.LoadX509KeyPair(certs{%s}, privateKey{%s}) = err:%+v",
s.cert, s.privateKey, perrors.WithStack(err)))
}
config = &tls.Config{
InsecureSkipVerify: true, // do not verify peer certs
ClientAuth: tls.NoClientCert,
NextProtos: []string{"http/1.1"},
Certificates: []tls.Certificate{certificate},
}
if s.caCert != "" {
certPem, err = ioutil.ReadFile(s.caCert)
if err != nil {
panic(fmt.Errorf("ioutil.ReadFile(certFile{%s}) = err:%+v", s.caCert, perrors.WithStack(err)))
}
certPool = x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM(certPem); !ok {
panic("failed to parse root certificate file")
}
config.ClientCAs = certPool
config.ClientAuth = tls.RequireAndVerifyClientCert
config.InsecureSkipVerify = false
}
handler = newWSHandler(s, newSession)
handler.HandleFunc(s.path, handler.serveWSRequest)
server = &http.Server{
Addr: s.addr,
Handler: handler,
// ReadTimeout: server.HTTPTimeout,
// WriteTimeout: server.HTTPTimeout,
}
server.SetKeepAlivesEnabled(true)
s.lock.Lock()
s.server = server
s.lock.Unlock()
err = server.Serve(tls.NewListener(s.streamListener, config))
if err != nil {
log.Errorf("http.server.Serve(addr{%s}) = err:%+v", s.addr, perrors.WithStack(err))
panic(err)
}
}()
}
// RunEventLoop serves client request.
// @newSession: new connection callback
func (s *server) RunEventLoop(newSession NewSessionCallback) {
if err := s.listen(); err != nil {
panic(fmt.Errorf("server.listen() = error:%+v", perrors.WithStack(err)))
}
switch s.endPointType {
case TCP_SERVER:
s.runTCPEventLoop(newSession)
case UDP_ENDPOINT:
s.runUDPEventLoop(newSession)
case WS_SERVER:
s.runWSEventLoop(newSession)
case WSS_SERVER:
s.runWSSEventLoop(newSession)
default:
panic(fmt.Sprintf("illegal server type %s", s.endPointType.String()))
}
}
func (s *server) Listener() net.Listener {
return s.streamListener
}
func (s *server) PacketConn() net.PacketConn {
return s.pktListener
}
func (s *server) Close() {
s.stop()
s.wg.Wait()
}

View File

@@ -0,0 +1,258 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package getty
import (
"time"
)
import (
perrors "github.com/pkg/errors"
)
import (
"dubbo.apache.org/dubbo-go/v3/config"
)
const (
TCPReadWriteTimeoutMinValue = time.Second * 1
)
type (
// GettySessionParam is session configuration for getty
GettySessionParam struct {
CompressEncoding bool `default:"false" yaml:"compress-encoding" json:"compress-encoding,omitempty"`
TcpNoDelay bool `default:"true" yaml:"tcp-no-delay" json:"tcp-no-delay,omitempty"`
TcpKeepAlive bool `default:"true" yaml:"tcp-keep-alive" json:"tcp-keep-alive,omitempty"`
KeepAlivePeriod string `default:"180s" yaml:"keep-alive-period" json:"keep-alive-period,omitempty"`
keepAlivePeriod time.Duration
TcpRBufSize int `default:"262144" yaml:"tcp-r-buf-size" json:"tcp-r-buf-size,omitempty"`
TcpWBufSize int `default:"65536" yaml:"tcp-w-buf-size" json:"tcp-w-buf-size,omitempty"`
TcpReadTimeout string `default:"1s" yaml:"tcp-read-timeout" json:"tcp-read-timeout,omitempty"`
tcpReadTimeout time.Duration
TcpWriteTimeout string `default:"5s" yaml:"tcp-write-timeout" json:"tcp-write-timeout,omitempty"`
tcpWriteTimeout time.Duration
WaitTimeout string `default:"7s" yaml:"wait-timeout" json:"wait-timeout,omitempty"`
waitTimeout time.Duration
MaxMsgLen int `default:"1024" yaml:"max-msg-len" json:"max-msg-len,omitempty"`
SessionName string `default:"rpc" yaml:"session-name" json:"session-name,omitempty"`
}
// ServerConfig holds supported types by the multiconfig package
ServerConfig struct {
SSLEnabled bool
// heartbeat
HeartbeatPeriod string `default:"60s" yaml:"heartbeat-period" json:"heartbeat-period,omitempty"`
heartbeatPeriod time.Duration
// heartbeat timeout
HeartbeatTimeout string `default:"5s" yaml:"heartbeat-timeout" json:"heartbeat-timeout,omitempty"`
heartbeatTimeout time.Duration
// session
SessionTimeout string `default:"60s" yaml:"session-timeout" json:"session-timeout,omitempty"`
sessionTimeout time.Duration
SessionNumber int `default:"1000" yaml:"session-number" json:"session-number,omitempty"`
// gr pool
GrPoolSize int `default:"0" yaml:"gr-pool-size" json:"gr-pool-size,omitempty"`
QueueLen int `default:"0" yaml:"queue-len" json:"queue-len,omitempty"`
QueueNumber int `default:"0" yaml:"queue-number" json:"queue-number,omitempty"`
// session tcp parameters
GettySessionParam GettySessionParam `required:"true" yaml:"getty-session-param" json:"getty-session-param,omitempty"`
}
// ClientConfig holds supported types by the multi config package
ClientConfig struct {
ReconnectInterval int `default:"0" yaml:"reconnect-interval" json:"reconnect-interval,omitempty"`
// session pool
ConnectionNum int `default:"16" yaml:"connection-number" json:"connection-number,omitempty"`
// heartbeat
HeartbeatPeriod string `default:"60s" yaml:"heartbeat-period" json:"heartbeat-period,omitempty"`
heartbeatPeriod time.Duration
// heartbeat timeout
HeartbeatTimeout string `default:"5s" yaml:"heartbeat-timeout" json:"heartbeat-timeout,omitempty"`
heartbeatTimeout time.Duration
// session
SessionTimeout string `default:"60s" yaml:"session-timeout" json:"session-timeout,omitempty"`
sessionTimeout time.Duration
// gr pool
GrPoolSize int `default:"0" yaml:"gr-pool-size" json:"gr-pool-size,omitempty"`
QueueLen int `default:"0" yaml:"queue-len" json:"queue-len,omitempty"`
QueueNumber int `default:"0" yaml:"queue-number" json:"queue-number,omitempty"`
// session tcp parameters
GettySessionParam GettySessionParam `required:"true" yaml:"getty-session-param" json:"getty-session-param,omitempty"`
}
)
// GetDefaultClientConfig gets client default configuration
func GetDefaultClientConfig() *ClientConfig {
defaultClientConfig := &ClientConfig{
ReconnectInterval: 0,
ConnectionNum: 16,
HeartbeatPeriod: "30s",
SessionTimeout: "180s",
GrPoolSize: 200,
QueueLen: 64,
QueueNumber: 10,
GettySessionParam: GettySessionParam{
CompressEncoding: false,
TcpNoDelay: true,
TcpKeepAlive: true,
KeepAlivePeriod: "180s",
TcpRBufSize: 262144,
TcpWBufSize: 65536,
TcpReadTimeout: "1s",
TcpWriteTimeout: "5s",
WaitTimeout: "1s",
MaxMsgLen: 102400,
SessionName: "client",
},
}
_ = defaultClientConfig.CheckValidity()
return defaultClientConfig
}
// GetDefaultServerConfig gets server default configuration
func GetDefaultServerConfig() *ServerConfig {
defaultServerConfig := &ServerConfig{
SessionTimeout: "180s",
SessionNumber: 700,
GrPoolSize: 120,
QueueNumber: 6,
QueueLen: 64,
GettySessionParam: GettySessionParam{
CompressEncoding: false,
TcpNoDelay: true,
TcpKeepAlive: true,
KeepAlivePeriod: "180s",
TcpRBufSize: 262144,
TcpWBufSize: 65536,
TcpReadTimeout: "1s",
TcpWriteTimeout: "5s",
WaitTimeout: "1s",
MaxMsgLen: 102400,
SessionName: "server",
},
}
_ = defaultServerConfig.CheckValidity()
return defaultServerConfig
}
// CheckValidity confirm getty session params
func (c *GettySessionParam) CheckValidity() error {
var err error
if c.keepAlivePeriod, err = time.ParseDuration(c.KeepAlivePeriod); err != nil {
return perrors.WithMessagef(err, "time.ParseDuration(KeepAlivePeriod{%#v})", c.KeepAlivePeriod)
}
if c.tcpReadTimeout, err = parseTcpTimeoutDuration(c.TcpReadTimeout); err != nil {
return perrors.WithMessagef(err, "time.ParseDuration(TcpReadTimeout{%#v})", c.TcpReadTimeout)
}
if c.tcpWriteTimeout, err = parseTcpTimeoutDuration(c.TcpWriteTimeout); err != nil {
return perrors.WithMessagef(err, "time.ParseDuration(TcpWriteTimeout{%#v})", c.TcpWriteTimeout)
}
if c.waitTimeout, err = time.ParseDuration(c.WaitTimeout); err != nil {
return perrors.WithMessagef(err, "time.ParseDuration(WaitTimeout{%#v})", c.WaitTimeout)
}
return nil
}
func parseTcpTimeoutDuration(timeStr string) (time.Duration, error) {
result, err := time.ParseDuration(timeStr)
if err != nil {
return 0, err
}
if result < TCPReadWriteTimeoutMinValue {
return TCPReadWriteTimeoutMinValue, nil
}
return result, nil
}
// CheckValidity confirm client params.
func (c *ClientConfig) CheckValidity() error {
var err error
c.ReconnectInterval = c.ReconnectInterval * 1e6
if c.heartbeatPeriod, err = time.ParseDuration(c.HeartbeatPeriod); err != nil {
return perrors.WithMessagef(err, "time.ParseDuration(HeartbeatPeroid{%#v})", c.HeartbeatPeriod)
}
if c.heartbeatPeriod >= time.Duration(config.MaxWheelTimeSpan) {
return perrors.WithMessagef(err, "heartbeat-period %s should be less than %s",
c.HeartbeatPeriod, time.Duration(config.MaxWheelTimeSpan))
}
if len(c.HeartbeatTimeout) == 0 {
c.heartbeatTimeout = 60 * time.Second
} else if c.heartbeatTimeout, err = time.ParseDuration(c.HeartbeatTimeout); err != nil {
return perrors.WithMessagef(err, "time.ParseDuration(HeartbeatTimeout{%#v})", c.HeartbeatTimeout)
}
if c.sessionTimeout, err = time.ParseDuration(c.SessionTimeout); err != nil {
return perrors.WithMessagef(err, "time.ParseDuration(SessionTimeout{%#v})", c.SessionTimeout)
}
return perrors.WithStack(c.GettySessionParam.CheckValidity())
}
// CheckValidity confirm server params
func (c *ServerConfig) CheckValidity() error {
var err error
if len(c.HeartbeatPeriod) == 0 {
c.heartbeatPeriod = 60 * time.Second
} else if c.heartbeatPeriod, err = time.ParseDuration(c.HeartbeatPeriod); err != nil {
return perrors.WithMessagef(err, "time.ParseDuration(HeartbeatPeroid{%#v})", c.HeartbeatPeriod)
}
if c.heartbeatPeriod >= time.Duration(config.MaxWheelTimeSpan) {
return perrors.WithMessagef(err, "heartbeat-period %s should be less than %s",
c.HeartbeatPeriod, time.Duration(config.MaxWheelTimeSpan))
}
if len(c.HeartbeatTimeout) == 0 {
c.heartbeatTimeout = 60 * time.Second
} else if c.heartbeatTimeout, err = time.ParseDuration(c.HeartbeatTimeout); err != nil {
return perrors.WithMessagef(err, "time.ParseDuration(HeartbeatTimeout{%#v})", c.HeartbeatTimeout)
}
if c.sessionTimeout, err = time.ParseDuration(c.SessionTimeout); err != nil {
return perrors.WithMessagef(err, "time.ParseDuration(SessionTimeout{%#v})", c.SessionTimeout)
}
if c.sessionTimeout >= time.Duration(config.MaxWheelTimeSpan) {
return perrors.WithMessagef(err, "session-timeout %s should be less than %s",
c.SessionTimeout, time.Duration(config.MaxWheelTimeSpan))
}
return perrors.WithStack(c.GettySessionParam.CheckValidity())
}

View File

@@ -0,0 +1,197 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package getty
import (
"crypto/tls"
"dubbo.apache.org/dubbo-go/v3/protocol/dubbo"
"fmt"
"net"
)
import (
"github.com/eolinker/apinto/dubbo-getty"
gxsync "github.com/dubbogo/gost/sync"
perrors "github.com/pkg/errors"
)
import (
"dubbo.apache.org/dubbo-go/v3/common/logger"
"dubbo.apache.org/dubbo-go/v3/config"
"dubbo.apache.org/dubbo-go/v3/protocol"
"dubbo.apache.org/dubbo-go/v3/protocol/invocation"
"dubbo.apache.org/dubbo-go/v3/remoting"
)
type ServerOption func(*Server)
// Server define getty server
type Server struct {
conf ServerConfig
addr string
codec remoting.Codec
tcpServer getty.Server
rpcHandler *RpcServerHandler
requestHandler func(*invocation.RPCInvocation) protocol.RPCResult
listener net.Listener
}
func WithAddrServer(addr string) ServerOption {
return func(server *Server) {
server.addr = addr
}
}
func WithListenerServer(listener net.Listener) ServerOption {
return func(server *Server) {
server.listener = listener
}
}
func WithConfigServer(conf ServerConfig) ServerOption {
return func(server *Server) {
server.conf = conf
}
}
const (
CronPeriod = 20e9
)
func init() {
codec := &dubbo.DubboCodec{}
remoting.RegistryCodec("dubbo", codec)
}
// NewServer create a new Server
func NewServer(handlers func(*invocation.RPCInvocation) protocol.RPCResult, serverOption ...ServerOption) *Server {
serverConfig := GetDefaultServerConfig()
s := &Server{
conf: *serverConfig,
codec: remoting.GetCodec("dubbo"),
requestHandler: handlers,
}
for _, f := range serverOption {
f(s)
}
s.rpcHandler = NewRpcServerHandler(s.conf.SessionNumber, s.conf.sessionTimeout, s)
return s
}
func (s *Server) newSession(session getty.Session) error {
var (
ok bool
tcpConn *net.TCPConn
err error
)
conf := s.conf
if conf.GettySessionParam.CompressEncoding {
session.SetCompressType(getty.CompressZip)
}
if _, ok = session.Conn().(*tls.Conn); ok {
session.SetName(conf.GettySessionParam.SessionName)
session.SetMaxMsgLen(conf.GettySessionParam.MaxMsgLen)
session.SetPkgHandler(NewRpcServerPackageHandler(s))
session.SetEventListener(s.rpcHandler)
session.SetReadTimeout(conf.GettySessionParam.tcpReadTimeout)
session.SetWriteTimeout(conf.GettySessionParam.tcpWriteTimeout)
session.SetCronPeriod((int)(conf.heartbeatPeriod.Nanoseconds() / 1e6))
session.SetWaitTime(conf.GettySessionParam.waitTimeout)
logger.Debugf("server accepts new session:%s\n", session.Stat())
return nil
}
if _, ok = session.Conn().(*net.TCPConn); !ok {
panic(fmt.Sprintf("%s, session.conn{%#v} is not tcp connection\n", session.Stat(), session.Conn()))
}
if _, ok = session.Conn().(*tls.Conn); !ok {
if tcpConn, ok = session.Conn().(*net.TCPConn); !ok {
return perrors.New(fmt.Sprintf("%s, session.conn{%#v} is not tcp connection", session.Stat(), session.Conn()))
}
if err = tcpConn.SetNoDelay(conf.GettySessionParam.TcpNoDelay); err != nil {
return err
}
if err = tcpConn.SetKeepAlive(conf.GettySessionParam.TcpKeepAlive); err != nil {
return err
}
if conf.GettySessionParam.TcpKeepAlive {
if err = tcpConn.SetKeepAlivePeriod(conf.GettySessionParam.keepAlivePeriod); err != nil {
return err
}
}
if err = tcpConn.SetReadBuffer(conf.GettySessionParam.TcpRBufSize); err != nil {
return err
}
if err = tcpConn.SetWriteBuffer(conf.GettySessionParam.TcpWBufSize); err != nil {
return err
}
}
conf.GettySessionParam.MaxMsgLen = 128 * 1024
session.SetMaxMsgLen(conf.GettySessionParam.MaxMsgLen)
session.SetPkgHandler(NewRpcServerPackageHandler(s))
session.SetEventListener(s.rpcHandler)
session.SetReadTimeout(conf.GettySessionParam.tcpReadTimeout)
session.SetWriteTimeout(conf.GettySessionParam.tcpWriteTimeout)
session.SetCronPeriod(CronPeriod)
session.SetWaitTime(conf.GettySessionParam.waitTimeout)
logger.Debugf("server accepts new session: %s", session.Stat())
return nil
}
// Start dubbo server.
func (s *Server) Start() {
var (
addr string
tcpServer getty.Server
)
var serverOpts []getty.ServerOption
addr = s.addr
if addr != "" {
serverOpts = append(serverOpts, getty.WithLocalAddress(addr))
}
if s.listener != nil {
serverOpts = append(serverOpts, getty.WithListenerServerCert(s.listener))
}
if s.conf.SSLEnabled {
serverOpts = append(serverOpts, getty.WithServerSslEnabled(s.conf.SSLEnabled),
getty.WithServerTlsConfigBuilder(config.GetServerTlsConfigBuilder()))
}
serverOpts = append(serverOpts, getty.WithServerTaskPool(gxsync.NewTaskPoolSimple(s.conf.GrPoolSize)))
tcpServer = getty.NewTCPServer(serverOpts...)
tcpServer.RunEventLoop(s.newSession)
logger.Debugf("s bind addr{%s} ok!", s.addr)
s.tcpServer = tcpServer
}
// Stop dubbo server
func (s *Server) Stop() {
s.tcpServer.Close()
}

View File

@@ -0,0 +1,257 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package getty
import (
"sync"
"sync/atomic"
"time"
)
import (
"github.com/eolinker/apinto/dubbo-getty"
hessian "github.com/apache/dubbo-go-hessian2"
gxtime "github.com/dubbogo/gost/time"
perrors "github.com/pkg/errors"
)
import (
"dubbo.apache.org/dubbo-go/v3/common/constant"
"dubbo.apache.org/dubbo-go/v3/common/logger"
"dubbo.apache.org/dubbo-go/v3/protocol/invocation"
"dubbo.apache.org/dubbo-go/v3/remoting"
)
const (
WritePkg_Timeout = 5 * time.Second // TODO: WritePkg_Timeout will entry *.yml
)
var (
errTooManySessions = perrors.New("too many sessions")
errHeartbeatReadTimeout = perrors.New("heartbeat read timeout")
)
type rpcSession struct {
session getty.Session
reqNum int32
}
func (s *rpcSession) AddReqNum(num int32) {
atomic.AddInt32(&s.reqNum, num)
}
func (s *rpcSession) GetReqNum() int32 {
return atomic.LoadInt32(&s.reqNum)
}
type RpcServerHandler struct {
maxSessionNum int
sessionTimeout time.Duration
sessionMap map[getty.Session]*rpcSession
rwlock sync.RWMutex
server *Server
timeoutTimes int
}
func NewRpcServerHandler(maxSessionNum int, sessionTimeout time.Duration, serverP *Server) *RpcServerHandler {
return &RpcServerHandler{
maxSessionNum: maxSessionNum,
sessionTimeout: sessionTimeout,
sessionMap: make(map[getty.Session]*rpcSession),
server: serverP,
}
}
// OnOpen call server session opened, add the session to getty server session list. also onOpen
// will check the max getty server session number
func (h *RpcServerHandler) OnOpen(session getty.Session) error {
var err error
h.rwlock.RLock()
if h.maxSessionNum <= len(h.sessionMap) {
err = errTooManySessions
}
h.rwlock.RUnlock()
if err != nil {
return perrors.WithStack(err)
}
logger.Infof("got session:%s", session.Stat())
h.rwlock.Lock()
h.sessionMap[session] = &rpcSession{session: session}
h.rwlock.Unlock()
return nil
}
// OnError the getty server session has errored, so remove the session from the getty server session list
func (h *RpcServerHandler) OnError(session getty.Session, err error) {
logger.Infof("session{%s} got error{%v}, will be closed.", session.Stat(), err)
h.rwlock.Lock()
delete(h.sessionMap, session)
h.rwlock.Unlock()
}
// OnClose close the session, remove it from the getty server list
func (h *RpcServerHandler) OnClose(session getty.Session) {
logger.Infof("session{%s} is closing......", session.Stat())
h.rwlock.Lock()
delete(h.sessionMap, session)
h.rwlock.Unlock()
}
// OnMessage get request from getty client, update the session reqNum and reply response to client
func (h *RpcServerHandler) OnMessage(session getty.Session, pkg interface{}) {
h.rwlock.Lock()
if _, ok := h.sessionMap[session]; ok {
h.sessionMap[session].reqNum++
}
h.rwlock.Unlock()
decodeResult, drOK := pkg.(*remoting.DecodeResult)
if !drOK || decodeResult == ((*remoting.DecodeResult)(nil)) {
logger.Errorf("illegal package{%#v}", pkg)
return
}
if !decodeResult.IsRequest {
res := decodeResult.Result.(*remoting.Response)
if res.Event {
logger.Debugf("get rpc heartbeat response{%#v}", res)
if res.Error != nil {
logger.Errorf("rpc heartbeat response{error: %#v}", res.Error)
}
res.Handle()
return
}
logger.Errorf("illegal package but not heartbeat. {%#v}", pkg)
return
}
req := decodeResult.Result.(*remoting.Request)
resp := remoting.NewResponse(req.ID, req.Version)
resp.Status = hessian.Response_OK
resp.Event = req.Event
resp.SerialID = req.SerialID
resp.Version = "2.0.2"
// heartbeat
if req.Event {
logger.Debugf("get rpc heartbeat request{%#v}", resp)
reply(session, resp)
return
}
invoc, ok := req.Data.(*invocation.RPCInvocation)
if !ok {
panic("create invocation occur some exception for the type is not suitable one.")
}
attachments := invoc.Attachments()
attachments[constant.LocalAddr] = session.LocalAddr()
attachments[constant.RemoteAddr] = session.RemoteAddr()
result := h.server.requestHandler(invoc)
if !req.TwoWay {
return
}
resp.Result = result
reply(session, resp)
}
// OnCron check the session health periodic. if the session's sessionTimeout has reached, just close the session
func (h *RpcServerHandler) OnCron(session getty.Session) {
var (
flag bool
active time.Time
)
h.rwlock.RLock()
if _, ok := h.sessionMap[session]; ok {
active = session.GetActive()
if h.sessionTimeout.Nanoseconds() < time.Since(active).Nanoseconds() {
flag = true
logger.Warnf("session{%s} timeout{%s}, reqNum{%d}",
session.Stat(), time.Since(active).String(), h.sessionMap[session].reqNum)
}
}
h.rwlock.RUnlock()
if flag {
h.rwlock.Lock()
delete(h.sessionMap, session)
h.rwlock.Unlock()
session.Close()
}
heartbeatCallBack := func(err error) {
if err != nil {
logger.Warnf("failed to send heartbeat, error{%v}", err)
if h.timeoutTimes >= 3 {
h.rwlock.Lock()
delete(h.sessionMap, session)
h.rwlock.Unlock()
session.Close()
return
}
h.timeoutTimes++
return
}
h.timeoutTimes = 0
}
if err := heartbeat(session, h.server.conf.heartbeatTimeout, heartbeatCallBack); err != nil {
logger.Warnf("failed to send heartbeat, error{%v}", err)
}
}
func reply(session getty.Session, resp *remoting.Response) {
if totalLen, sendLen, err := session.WritePkg(resp, WritePkg_Timeout); err != nil {
if sendLen != 0 && totalLen != sendLen {
logger.Warnf("start to close the session at replying because %d of %d bytes data is sent success. err:%+v", sendLen, totalLen, err)
go session.Close()
}
logger.Errorf("WritePkg error: %#v, %#v", perrors.WithStack(err), resp)
}
}
func heartbeat(session getty.Session, timeout time.Duration, callBack func(err error)) error {
req := remoting.NewRequest("2.0.2")
req.TwoWay = true
req.Event = true
resp := remoting.NewPendingResponse(req.ID)
remoting.AddPendingResponse(resp)
totalLen, sendLen, err := session.WritePkg(req, -1)
if sendLen != 0 && totalLen != sendLen {
logger.Warnf("start to close the session at heartbeat because %d of %d bytes data is sent success. err:%+v", sendLen, totalLen, err)
go session.Close()
}
go func() {
var err1 error
select {
case <-gxtime.After(timeout):
err1 = errHeartbeatReadTimeout
case <-resp.Done:
err1 = resp.Err
}
callBack(err1)
}()
return perrors.WithStack(err)
}

View File

@@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package getty
import (
"reflect"
)
import (
"github.com/eolinker/apinto/dubbo-getty"
perrors "github.com/pkg/errors"
)
import (
"dubbo.apache.org/dubbo-go/v3/common/logger"
"dubbo.apache.org/dubbo-go/v3/protocol/dubbo/impl"
"dubbo.apache.org/dubbo-go/v3/remoting"
)
// RpcServerPackageHandler Read data from client and Write data to client
type RpcServerPackageHandler struct {
server *Server
}
func NewRpcServerPackageHandler(server *Server) *RpcServerPackageHandler {
return &RpcServerPackageHandler{server: server}
}
// Read data from client. if the package size from client is larger than 4096 byte, client will read 4096 byte
// and send to client each time. the Read can assemble it.
func (p *RpcServerPackageHandler) Read(ss getty.Session, data []byte) (interface{}, int, error) {
req, length, err := (p.server.codec).Decode(data)
if err != nil {
err = perrors.WithStack(err)
}
if req == ((*remoting.DecodeResult)(nil)) {
return nil, length, err
}
if req.Result == ((*remoting.Request)(nil)) || req.Result == ((*remoting.Response)(nil)) {
return nil, length, err // as getty rule
}
return req, length, err
}
// Write send the data to client
func (p *RpcServerPackageHandler) Write(ss getty.Session, pkg interface{}) ([]byte, error) {
res, ok := pkg.(*remoting.Response)
maxBufLength := p.server.conf.GettySessionParam.MaxMsgLen + impl.HEADER_LENGTH
if ok {
buf, err := (p.server.codec).EncodeResponse(res)
bufLength := buf.Len()
if bufLength > maxBufLength {
logger.Errorf("Data length %d too large, max payload %d", bufLength-impl.HEADER_LENGTH, p.server.conf.GettySessionParam.MaxMsgLen)
return nil, perrors.Errorf("Data length %d too large, max payload %d", bufLength-impl.HEADER_LENGTH, p.server.conf.GettySessionParam.MaxMsgLen)
}
if err != nil {
logger.Warnf("binary.Write(res{%#v}) = err{%#v}", res, perrors.WithStack(err))
return nil, perrors.WithStack(err)
}
return buf.Bytes(), nil
}
req, ok := pkg.(*remoting.Request)
if ok {
buf, err := (p.server.codec).EncodeRequest(req)
bufLength := buf.Len()
if bufLength > maxBufLength {
logger.Errorf("Data length %d too large, max payload %d", bufLength-impl.HEADER_LENGTH, p.server.conf.GettySessionParam.MaxMsgLen)
return nil, perrors.Errorf("Data length %d too large, max payload %d", bufLength-impl.HEADER_LENGTH, p.server.conf.GettySessionParam.MaxMsgLen)
}
if err != nil {
logger.Warnf("binary.Write(req{%#v}) = err{%#v}", res, perrors.WithStack(err))
return nil, perrors.WithStack(err)
}
return buf.Bytes(), nil
}
logger.Errorf("illegal pkg:%+v\n, it is %+v", pkg, reflect.TypeOf(pkg))
return nil, perrors.New("invalid rpc response")
}

992
dubbo-getty/session.go Normal file
View File

@@ -0,0 +1,992 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package getty
import (
"bytes"
"context"
"fmt"
"github.com/eolinker/eosc/log"
"io"
"net"
"runtime"
"sync"
"time"
)
import (
gxbytes "github.com/dubbogo/gost/bytes"
gxcontext "github.com/dubbogo/gost/context"
gxtime "github.com/dubbogo/gost/time"
"github.com/gorilla/websocket"
perrors "github.com/pkg/errors"
uatomic "go.uber.org/atomic"
)
const (
maxReadBufLen = 4 * 1024
netIOTimeout = 1e9 // 1s
period = 60 * 1e9 // 1 minute
pendingDuration = 3e9
// MaxWheelTimeSpan 900s, 15 minute
MaxWheelTimeSpan = 900e9
maxPacketLen = 16 * 1024
defaultSessionName = "session"
defaultTCPSessionName = "tcp-session"
defaultUDPSessionName = "udp-session"
defaultWSSessionName = "ws-session"
defaultWSSSessionName = "wss-session"
outputFormat = "session %s, Read Bytes: %d, Write Bytes: %d, Read Pkgs: %d, Write Pkgs: %d"
)
var defaultTimerWheel *gxtime.TimerWheel
func init() {
gxtime.InitDefaultTimerWheel()
defaultTimerWheel = gxtime.GetDefaultTimerWheel()
}
// Session wrap connection between the server and the client
type Session interface {
Connection
Reset()
Conn() net.Conn
Stat() string
IsClosed() bool
// EndPoint get endpoint type
EndPoint() EndPoint
SetMaxMsgLen(int)
SetName(string)
SetEventListener(EventListener)
SetPkgHandler(ReadWriter)
SetReader(Reader)
SetWriter(Writer)
SetCronPeriod(int)
SetWaitTime(time.Duration)
GetAttribute(interface{}) interface{}
SetAttribute(interface{}, interface{})
RemoveAttribute(interface{})
// WritePkg the Writer will invoke this function. Pls attention that if timeout is less than 0, WritePkg will send @pkg asap.
// for udp session, the first parameter should be UDPContext.
// totalBytesLength: @pkg stream bytes length after encoding @pkg.
// sendBytesLength: stream bytes length that sent out successfully.
// err: maybe it has illegal data, encoding error, or write out system error.
WritePkg(pkg interface{}, timeout time.Duration) (totalBytesLength int, sendBytesLength int, err error)
WriteBytes([]byte) (int, error)
WriteBytesArray(...[]byte) (int, error)
Close()
}
// getty base session
type session struct {
name string
endPoint EndPoint
// net read Write
Connection
listener EventListener
// codec
reader Reader // @reader should be nil when @conn is a gettyWSConn object.
writer Writer
// handle logic
maxMsgLen int32
// heartbeat
period time.Duration
// done
wait time.Duration
once *sync.Once
done chan struct{}
// attribute
attrs *gxcontext.ValuesContext
// goroutines sync
grNum uatomic.Int32
lock sync.RWMutex
packetLock sync.RWMutex
}
func newSession(endPoint EndPoint, conn Connection) *session {
ss := &session{
name: defaultSessionName,
endPoint: endPoint,
Connection: conn,
maxMsgLen: maxReadBufLen,
period: period,
once: &sync.Once{},
done: make(chan struct{}),
wait: pendingDuration,
attrs: gxcontext.NewValuesContext(context.Background()),
}
ss.Connection.setSession(ss)
ss.SetWriteTimeout(netIOTimeout)
ss.SetReadTimeout(netIOTimeout)
return ss
}
func newTCPSession(conn net.Conn, endPoint EndPoint) Session {
c := newGettyTCPConn(conn)
session := newSession(endPoint, c)
session.name = defaultTCPSessionName
return session
}
func newUDPSession(conn *net.UDPConn, endPoint EndPoint) Session {
c := newGettyUDPConn(conn)
session := newSession(endPoint, c)
session.name = defaultUDPSessionName
return session
}
func newWSSession(conn *websocket.Conn, endPoint EndPoint) Session {
c := newGettyWSConn(conn)
session := newSession(endPoint, c)
session.name = defaultWSSessionName
return session
}
func (s *session) Reset() {
*s = session{
name: defaultSessionName,
once: &sync.Once{},
done: make(chan struct{}),
period: period,
wait: pendingDuration,
attrs: gxcontext.NewValuesContext(context.Background()),
}
}
func (s *session) Conn() net.Conn {
if tc, ok := s.Connection.(*gettyTCPConn); ok {
return tc.conn
}
if uc, ok := s.Connection.(*gettyUDPConn); ok {
return uc.conn
}
if wc, ok := s.Connection.(*gettyWSConn); ok {
return wc.conn.UnderlyingConn()
}
return nil
}
func (s *session) EndPoint() EndPoint {
return s.endPoint
}
func (s *session) gettyConn() *gettyConn {
if tc, ok := s.Connection.(*gettyTCPConn); ok {
return &(tc.gettyConn)
}
if uc, ok := s.Connection.(*gettyUDPConn); ok {
return &(uc.gettyConn)
}
if wc, ok := s.Connection.(*gettyWSConn); ok {
return &(wc.gettyConn)
}
return nil
}
// Stat get the connect statistic data
func (s *session) Stat() string {
var conn *gettyConn
if conn = s.gettyConn(); conn == nil {
return ""
}
return fmt.Sprintf(
outputFormat,
s.sessionToken(),
conn.readBytes.Load(),
conn.writeBytes.Load(),
conn.readPkgNum.Load(),
conn.writePkgNum.Load(),
)
}
// IsClosed check whether the session has been closed.
func (s *session) IsClosed() bool {
select {
case <-s.done:
return true
default:
return false
}
}
// SetMaxMsgLen set maximum package length of every package in (EventListener)OnMessage(@pkgs)
func (s *session) SetMaxMsgLen(length int) {
s.lock.Lock()
defer s.lock.Unlock()
s.maxMsgLen = int32(length)
}
// SetName set session name
func (s *session) SetName(name string) {
s.lock.Lock()
defer s.lock.Unlock()
s.name = name
}
// SetEventListener set event listener
func (s *session) SetEventListener(listener EventListener) {
s.lock.Lock()
defer s.lock.Unlock()
s.listener = listener
}
// SetPkgHandler set package handler
func (s *session) SetPkgHandler(handler ReadWriter) {
s.lock.Lock()
defer s.lock.Unlock()
s.reader = handler
s.writer = handler
}
func (s *session) SetReader(reader Reader) {
s.lock.Lock()
defer s.lock.Unlock()
s.reader = reader
}
func (s *session) SetWriter(writer Writer) {
s.lock.Lock()
defer s.lock.Unlock()
s.writer = writer
}
// SetCronPeriod period is in millisecond. Websocket session will send ping frame automatically every peroid.
func (s *session) SetCronPeriod(period int) {
if period < 1 {
panic("@period < 1")
}
s.lock.Lock()
defer s.lock.Unlock()
s.period = time.Duration(period) * time.Millisecond
}
// SetWaitTime set maximum wait time when session got error or got exit signal
func (s *session) SetWaitTime(waitTime time.Duration) {
if waitTime < 1 {
panic("@wait < 1")
}
s.lock.Lock()
defer s.lock.Unlock()
s.wait = waitTime
}
// GetAttribute get attribute of key @session:key
func (s *session) GetAttribute(key interface{}) interface{} {
s.lock.RLock()
if s.attrs == nil {
s.lock.RUnlock()
return nil
}
ret, flag := s.attrs.Get(key)
s.lock.RUnlock()
if !flag {
return nil
}
return ret
}
// SetAttribute set attribute of key @session:key
func (s *session) SetAttribute(key interface{}, value interface{}) {
s.lock.Lock()
if s.attrs != nil {
s.attrs.Set(key, value)
}
s.lock.Unlock()
}
// RemoveAttribute remove attribute of key @session:key
func (s *session) RemoveAttribute(key interface{}) {
s.lock.Lock()
if s.attrs != nil {
s.attrs.Delete(key)
}
s.lock.Unlock()
}
func (s *session) sessionToken() string {
if s.IsClosed() || s.Connection == nil {
return "session-closed"
}
return fmt.Sprintf("{%s:%s:%d:%s<->%s}",
s.name, s.EndPoint().EndPointType(), s.ID(), s.LocalAddr(), s.RemoteAddr())
}
func (s *session) WritePkg(pkg interface{}, timeout time.Duration) (int, int, error) {
if pkg == nil {
return 0, 0, fmt.Errorf("@pkg is nil")
}
if s.IsClosed() {
return 0, 0, ErrSessionClosed
}
defer func() {
if r := recover(); r != nil {
const size = 64 << 10
rBuf := make([]byte, size)
rBuf = rBuf[:runtime.Stack(rBuf, false)]
log.Errorf("[session.WritePkg] panic session %s: err=%s\n%s", s.sessionToken(), r, rBuf)
}
}()
pkgBytes, err := s.writer.Write(s, pkg)
if err != nil {
log.Warnf("%s, [session.WritePkg] session.writer.Write(@pkg:%#v) = error:%+v", s.Stat(), pkg, err)
return len(pkgBytes), 0, perrors.WithStack(err)
}
var udpCtxPtr *UDPContext
if udpCtx, ok := pkg.(UDPContext); ok {
udpCtxPtr = &udpCtx
} else if udpCtxP, ok := pkg.(*UDPContext); ok {
udpCtxPtr = udpCtxP
}
if udpCtxPtr != nil {
udpCtxPtr.Pkg = pkgBytes
pkg = *udpCtxPtr
} else {
pkg = pkgBytes
}
s.packetLock.RLock()
defer s.packetLock.RUnlock()
if 0 < timeout {
s.Connection.SetWriteTimeout(timeout)
}
var succssCount int
succssCount, err = s.Connection.send(pkg)
if err != nil {
log.Warnf("%s, [session.WritePkg] @s.Connection.Write(pkg:%#v) = err:%+v", s.Stat(), pkg, err)
return len(pkgBytes), succssCount, perrors.WithStack(err)
}
return len(pkgBytes), succssCount, nil
}
// WriteBytes for codecs
func (s *session) WriteBytes(pkg []byte) (int, error) {
if s.IsClosed() {
return 0, ErrSessionClosed
}
leftPackageSize, totalSize, writeSize := len(pkg), len(pkg), 0
if leftPackageSize > maxPacketLen {
s.packetLock.Lock()
defer s.packetLock.Unlock()
} else {
s.packetLock.RLock()
defer s.packetLock.RUnlock()
}
for leftPackageSize > maxPacketLen {
_, err := s.Connection.send(pkg[writeSize:(writeSize + maxPacketLen)])
if err != nil {
return writeSize, perrors.Wrapf(err, "s.Connection.Write(pkg len:%d)", len(pkg))
}
leftPackageSize -= maxPacketLen
writeSize += maxPacketLen
}
if leftPackageSize == 0 {
return writeSize, nil
}
_, err := s.Connection.send(pkg[writeSize:])
if err != nil {
return writeSize, perrors.Wrapf(err, "s.Connection.Write(pkg len:%d)", len(pkg))
}
return totalSize, nil
}
// WriteBytesArray Write multiple packages at once. so we invoke write sys.call just one time.
func (s *session) WriteBytesArray(pkgs ...[]byte) (int, error) {
if s.IsClosed() {
return 0, ErrSessionClosed
}
if len(pkgs) == 1 {
return s.WriteBytes(pkgs[0])
}
// reduce syscall and memcopy for multiple packages
if _, ok := s.Connection.(*gettyTCPConn); ok {
s.packetLock.RLock()
defer s.packetLock.RUnlock()
lg, err := s.Connection.send(pkgs)
if err != nil {
return 0, perrors.Wrapf(err, "s.Connection.Write(pkgs num:%d)", len(pkgs))
}
return lg, nil
}
// get len
var (
l int
wlg int
err error
length int
arrp *[]byte
arr []byte
)
length = 0
for i := 0; i < len(pkgs); i++ {
length += len(pkgs[i])
}
// merge the pkgs
arrp = gxbytes.AcquireBytes(length)
defer gxbytes.ReleaseBytes(arrp)
arr = *arrp
l = 0
for i := 0; i < len(pkgs); i++ {
copy(arr[l:], pkgs[i])
l += len(pkgs[i])
}
wlg, err = s.WriteBytes(arr)
if err != nil {
return 0, perrors.WithStack(err)
}
num := len(pkgs) - 1
for i := 0; i < num; i++ {
s.incWritePkgNum()
}
return wlg, nil
}
func heartbeat(_ gxtime.TimerID, _ time.Time, arg interface{}) error {
ss, _ := arg.(*session)
if ss == nil || ss.IsClosed() {
return ErrSessionClosed
}
f := func() {
wsConn, wsFlag := ss.Connection.(*gettyWSConn)
if wsFlag {
err := wsConn.writePing()
if err != nil {
log.Warnf("wsConn.writePing() = error:%+v", perrors.WithStack(err))
}
}
ss.listener.OnCron(ss)
}
// if enable task pool, run @f asynchronously.
if taskPool := ss.EndPoint().GetTaskPool(); taskPool != nil {
taskPool.AddTaskAlways(f)
return nil
}
f()
return nil
}
// func (s *session) RunEventLoop() {
func (s *session) run() {
if s.Connection == nil || s.listener == nil || s.writer == nil {
errStr := fmt.Sprintf("session{name:%s, conn:%#v, listener:%#v, writer:%#v}",
s.name, s.Connection, s.listener, s.writer)
log.Error(errStr)
panic(errStr)
}
// call session opened
s.UpdateActive()
if err := s.listener.OnOpen(s); err != nil {
log.Errorf("[OnOpen] session %s, error: %#v", s.Stat(), err)
s.Close()
return
}
if _, err := defaultTimerWheel.AddTimer(heartbeat, gxtime.TimerLoop, s.period, s); err != nil {
panic(fmt.Sprintf("failed to add session %s to defaultTimerWheel err:%v", s.Stat(), err))
}
s.grNum.Add(1)
// start read gr
go s.handlePackage()
}
func (s *session) addTask(pkg interface{}) {
f := func() {
s.listener.OnMessage(s, pkg)
s.incReadPkgNum()
}
if taskPool := s.EndPoint().GetTaskPool(); taskPool != nil {
taskPool.AddTaskAlways(f)
return
}
f()
}
func (s *session) handlePackage() {
var err error
defer func() {
if r := recover(); r != nil {
const size = 64 << 10
rBuf := make([]byte, size)
rBuf = rBuf[:runtime.Stack(rBuf, false)]
log.Errorf("[session.handlePackage] panic session %s: err=%s\n%s", s.sessionToken(), r, rBuf)
}
grNum := s.grNum.Add(-1)
log.Infof("%s, [session.handlePackage] gr will exit now, left gr num %d", s.sessionToken(), grNum)
s.stop()
if err != nil {
log.Errorf("%s, [session.handlePackage] error:%+v", s.sessionToken(), perrors.WithStack(err))
if s != nil || s.listener != nil {
s.listener.OnError(s, err)
}
}
s.listener.OnClose(s)
s.gc()
}()
if _, ok := s.Connection.(*gettyTCPConn); ok {
if s.reader == nil {
errStr := fmt.Sprintf("session{name:%s, conn:%#v, reader:%#v}", s.name, s.Connection, s.reader)
log.Error(errStr)
panic(errStr)
}
err = s.handleTCPPackage()
} else if _, ok := s.Connection.(*gettyWSConn); ok {
err = s.handleWSPackage()
} else if _, ok := s.Connection.(*gettyUDPConn); ok {
err = s.handleUDPPackage()
} else {
panic(fmt.Sprintf("unknown type session{%#v}", s))
}
}
// get package from tcp stream(packet)
func (s *session) handleTCPPackage() error {
var (
ok bool
err error
netError net.Error
conn *gettyTCPConn
exit bool
bufLen int
pkgLen int
buf []byte
pktBuf *gxbytes.Buffer
pkg interface{}
)
pktBuf = gxbytes.NewBuffer(nil)
conn = s.Connection.(*gettyTCPConn)
for {
if s.IsClosed() {
err = nil
// do not handle the left stream in pktBuf and exit asap.
// it is impossible packing a package by the left stream.
break
}
bufLen = 0
for {
// for clause for the network timeout condition check
// s.conn.SetReadTimeout(time.Now().Add(s.rTimeout))
buf = pktBuf.WriteNextBegin(maxReadBufLen)
bufLen, err = conn.recv(buf)
if err != nil {
if netError, ok = perrors.Cause(err).(net.Error); ok && netError.Timeout() {
break
}
if perrors.Cause(err) == io.EOF {
log.Infof("%s, session.conn read EOF, client send over, session exit", s.sessionToken())
err = nil
exit = true
if bufLen != 0 {
// as https://github.com/apache/dubbo-getty/issues/77#issuecomment-939652203
// this branch is impossible. Even if it happens, the bufLen will be zero and the error
// is io.EOF when getty continues to read the socket.
exit = false
log.Infof("%s, session.conn read EOF, while the bufLen(%d) is non-zero.", s.sessionToken())
}
break
}
log.Errorf("%s, [session.conn.read] = error:%+v", s.sessionToken(), perrors.WithStack(err))
exit = true
}
break
}
if 0 != bufLen {
pktBuf.WriteNextEnd(bufLen)
for {
if pktBuf.Len() <= 0 {
break
}
pkg, pkgLen, err = s.reader.Read(s, pktBuf.Bytes())
// for case 3/case 4
if err == nil && s.maxMsgLen > 0 && pkgLen > int(s.maxMsgLen) {
err = perrors.Errorf("pkgLen %d > session max message len %d", pkgLen, s.maxMsgLen)
}
// handle case 1
if err != nil {
log.Warnf("%s, [session.handleTCPPackage] = len{%d}, error:%+v",
s.sessionToken(), pkgLen, perrors.WithStack(err))
exit = true
break
}
// handle case 2/case 3
if pkg == nil {
break
}
// handle case 4
s.UpdateActive()
s.addTask(pkg)
pktBuf.Next(pkgLen)
// continue to handle case 5
}
}
if exit {
break
}
}
return perrors.WithStack(err)
}
// get package from udp packet
func (s *session) handleUDPPackage() error {
var (
ok bool
err error
netError net.Error
conn *gettyUDPConn
bufLen int
maxBufLen int
bufp *[]byte
buf []byte
addr *net.UDPAddr
pkgLen int
pkg interface{}
)
conn = s.Connection.(*gettyUDPConn)
maxBufLen = int(s.maxMsgLen + maxReadBufLen)
if int(s.maxMsgLen<<1) < bufLen {
maxBufLen = int(s.maxMsgLen << 1)
}
bufp = gxbytes.AcquireBytes(maxBufLen)
defer gxbytes.ReleaseBytes(bufp)
buf = *bufp
for {
if s.IsClosed() {
break
}
bufLen, addr, err = conn.recv(buf)
log.Debug("conn.read() = bufLen:%d, addr:%#v, err:%+v", bufLen, addr, perrors.WithStack(err))
if netError, ok = perrors.Cause(err).(net.Error); ok && netError.Timeout() {
continue
}
if err != nil {
log.Errorf("%s, [session.handleUDPPackage] = len:%d, error:%+v",
s.sessionToken(), bufLen, perrors.WithStack(err))
err = perrors.Wrapf(err, "conn.read()")
break
}
if bufLen == 0 {
log.Errorf("conn.read() = bufLen:%d, addr:%s, err:%+v", bufLen, addr, perrors.WithStack(err))
continue
}
if bufLen == len(connectPingPackage) && bytes.Equal(connectPingPackage, buf[:bufLen]) {
log.Infof("got %s connectPingPackage", addr)
continue
}
pkg, pkgLen, err = s.reader.Read(s, buf[:bufLen])
log.Debug("s.reader.Read() = pkg:%#v, pkgLen:%d, err:%+v", pkg, pkgLen, perrors.WithStack(err))
if err == nil && s.maxMsgLen > 0 && bufLen > int(s.maxMsgLen) {
err = perrors.Errorf("Message Too Long, bufLen %d, session max message len %d", bufLen, s.maxMsgLen)
}
if err != nil {
log.Warnf("%s, [session.handleUDPPackage] = len:%d, error:%+v",
s.sessionToken(), pkgLen, perrors.WithStack(err))
continue
}
if pkgLen == 0 {
log.Errorf("s.reader.Read() = pkg:%#v, pkgLen:%d, err:%+v", pkg, pkgLen, perrors.WithStack(err))
continue
}
s.UpdateActive()
s.addTask(UDPContext{Pkg: pkg, PeerAddr: addr})
}
return perrors.WithStack(err)
}
// get package from websocket stream
func (s *session) handleWSPackage() error {
var (
ok bool
err error
netError net.Error
length int
conn *gettyWSConn
pkg []byte
unmarshalPkg interface{}
)
conn = s.Connection.(*gettyWSConn)
for {
if s.IsClosed() {
break
}
pkg, err = conn.recv()
if netError, ok = perrors.Cause(err).(net.Error); ok && netError.Timeout() {
continue
}
if err != nil {
log.Warnf("%s, [session.handleWSPackage] = error:%+v",
s.sessionToken(), perrors.WithStack(err))
return perrors.WithStack(err)
}
s.UpdateActive()
if s.reader != nil {
unmarshalPkg, length, err = s.reader.Read(s, pkg)
if err == nil && s.maxMsgLen > 0 && length > int(s.maxMsgLen) {
err = perrors.Errorf("Message Too Long, length %d, session max message len %d", length, s.maxMsgLen)
}
if err != nil {
log.Warnf("%s, [session.handleWSPackage] = len:%d, error:%+v",
s.sessionToken(), length, perrors.WithStack(err))
continue
}
s.addTask(unmarshalPkg)
} else {
s.addTask(pkg)
}
}
return nil
}
func (s *session) stop() {
select {
case <-s.done: // s.done is a blocked channel. if it has not been closed, the default branch will be invoked.
return
default:
s.once.Do(func() {
// let read/Write timeout asap
now := time.Now()
if conn := s.Conn(); conn != nil {
conn.SetReadDeadline(now.Add(s.readTimeout()))
conn.SetWriteDeadline(now.Add(s.writeTimeout()))
}
close(s.done)
c := s.GetAttribute(sessionClientKey)
if clt, ok := c.(*client); ok {
clt.reConnect()
}
})
}
}
func (s *session) gc() {
var conn Connection
s.lock.Lock()
if s.attrs != nil {
s.attrs = nil
conn = s.Connection
s.Connection = nil
}
s.lock.Unlock()
go func() {
if conn != nil {
conn.close(int(s.wait))
}
}()
}
// Close will be invoked by NewSessionCallback(if return error is not nil)
// or (session)handleLoop automatically. It's thread safe.
func (s *session) Close() {
s.stop()
log.Infof("%s closed now. its current gr num is %d", s.sessionToken(), s.grNum.Load())
}
// GetActive return connection's time
func (s *session) GetActive() time.Time {
if s == nil {
return launchTime
}
s.lock.RLock()
defer s.lock.RUnlock()
if s.Connection != nil {
return s.Connection.GetActive()
}
return launchTime
}
// UpdateActive update connection's active time
func (s *session) UpdateActive() {
if s == nil {
return
}
s.lock.RLock()
defer s.lock.RUnlock()
if s.Connection != nil {
s.Connection.UpdateActive()
}
}
func (s *session) ID() uint32 {
if s == nil {
return 0
}
s.lock.RLock()
defer s.lock.RUnlock()
if s.Connection != nil {
return s.Connection.ID()
}
return 0
}
func (s *session) LocalAddr() string {
if s == nil {
return ""
}
s.lock.RLock()
defer s.lock.RUnlock()
if s.Connection != nil {
return s.Connection.LocalAddr()
}
return ""
}
func (s *session) RemoteAddr() string {
if s == nil {
return ""
}
s.lock.RLock()
defer s.lock.RUnlock()
if s.Connection != nil {
return s.Connection.RemoteAddr()
}
return ""
}
func (s *session) incReadPkgNum() {
if s == nil {
return
}
s.lock.RLock()
defer s.lock.RUnlock()
if s.Connection != nil {
s.Connection.incReadPkgNum()
}
}
func (s *session) incWritePkgNum() {
if s == nil {
return
}
s.lock.RLock()
defer s.lock.RUnlock()
if s.Connection != nil {
s.Connection.incWritePkgNum()
}
}
func (s *session) send(pkg interface{}) (int, error) {
if s == nil {
return 0, nil
}
s.lock.RLock()
defer s.lock.RUnlock()
if s.Connection != nil {
return s.Connection.send(pkg)
}
return 0, nil
}
func (s *session) readTimeout() time.Duration {
if s == nil {
return time.Duration(0)
}
s.lock.RLock()
defer s.lock.RUnlock()
if s.Connection != nil {
return s.Connection.readTimeout()
}
return time.Duration(0)
}
func (s *session) setSession(ss Session) {
if s == nil {
return
}
s.lock.RLock()
if s.Connection != nil {
s.Connection.setSession(ss)
}
s.lock.RUnlock()
}

114
dubbo-getty/tls.go Normal file
View File

@@ -0,0 +1,114 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package getty
import (
"crypto/tls"
"crypto/x509"
"fmt"
"github.com/eolinker/eosc/log"
"io/ioutil"
)
import (
perrors "github.com/pkg/errors"
)
// TlsConfigBuilder tls config builder interface
type TlsConfigBuilder interface {
BuildTlsConfig() (*tls.Config, error)
}
// ServerTlsConfigBuilder impl TlsConfigBuilder for server
type ServerTlsConfigBuilder struct {
ServerKeyCertChainPath string
ServerPrivateKeyPath string
ServerKeyPassword string
ServerTrustCertCollectionPath string
}
// BuildTlsConfig impl TlsConfigBuilder method
func (s *ServerTlsConfigBuilder) BuildTlsConfig() (*tls.Config, error) {
var (
err error
certPem []byte
certificate tls.Certificate
certPool *x509.CertPool
config *tls.Config
)
if certificate, err = tls.LoadX509KeyPair(s.ServerKeyCertChainPath, s.ServerPrivateKeyPath); err != nil {
log.Error(fmt.Sprintf("tls.LoadX509KeyPair(certs{%s}, privateKey{%s}) = err:%+v",
s.ServerKeyCertChainPath, s.ServerPrivateKeyPath, perrors.WithStack(err)))
return nil, err
}
config = &tls.Config{
InsecureSkipVerify: true, // do not verify peer certs
ClientAuth: tls.RequireAnyClientCert,
Certificates: []tls.Certificate{certificate},
}
if s.ServerTrustCertCollectionPath != "" {
certPem, err = ioutil.ReadFile(s.ServerTrustCertCollectionPath)
if err != nil {
log.Error(fmt.Errorf("ioutil.ReadFile(certFile{%s}) = err:%+v", s.ServerTrustCertCollectionPath, perrors.WithStack(err)))
return nil, err
}
certPool = x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM(certPem); !ok {
log.Error("failed to parse root certificate file")
return nil, err
}
config.ClientCAs = certPool
config.ClientAuth = tls.RequireAnyClientCert
config.InsecureSkipVerify = false
}
return config, nil
}
// ClientTlsConfigBuilder impl TlsConfigBuilder for client
type ClientTlsConfigBuilder struct {
ClientKeyCertChainPath string
ClientPrivateKeyPath string
ClientKeyPassword string
ClientTrustCertCollectionPath string
}
// BuildTlsConfig impl TlsConfigBuilder method
func (c *ClientTlsConfigBuilder) BuildTlsConfig() (*tls.Config, error) {
cert, err := tls.LoadX509KeyPair(c.ClientTrustCertCollectionPath, c.ClientPrivateKeyPath)
if err != nil {
log.Error(fmt.Sprintf("Unable to load X509 Key Pair %v", err))
return nil, err
}
certBytes, err := ioutil.ReadFile(c.ClientTrustCertCollectionPath)
if err != nil {
log.Error(fmt.Sprintf("Unable to read pem file: %s", c.ClientTrustCertCollectionPath))
return nil, err
}
clientCertPool := x509.NewCertPool()
ok := clientCertPool.AppendCertsFromPEM(certBytes)
if !ok {
log.Error("failed to parse root certificate")
return nil, err
}
return &tls.Config{
RootCAs: clientCertPool,
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: true,
}, nil
}