mirror of
https://github.com/lwch/natpass
synced 2025-10-05 05:16:50 +08:00
235 lines
4.8 KiB
Go
235 lines
4.8 KiB
Go
package conn
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/lwch/logging"
|
|
"github.com/lwch/natpass/code/client/global"
|
|
"github.com/lwch/natpass/code/network"
|
|
"github.com/lwch/natpass/code/utils"
|
|
"github.com/lwch/runtime"
|
|
)
|
|
|
|
// Conn connection
|
|
type Conn struct {
|
|
sync.RWMutex
|
|
cfg *global.Configure
|
|
conn *network.Conn
|
|
read map[string]chan *network.Msg // link id => channel
|
|
unknownRead chan *network.Msg // read message without link
|
|
write chan *network.Msg
|
|
lockDrop sync.RWMutex
|
|
drop map[string]time.Time
|
|
// runtime
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
// New new connection
|
|
func New(cfg *global.Configure) *Conn {
|
|
conn := &Conn{
|
|
cfg: cfg,
|
|
read: make(map[string]chan *network.Msg),
|
|
unknownRead: make(chan *network.Msg, 1024),
|
|
write: make(chan *network.Msg, 1024),
|
|
drop: make(map[string]time.Time),
|
|
}
|
|
runtime.Assert(conn.connect())
|
|
conn.ctx, conn.cancel = context.WithCancel(context.Background())
|
|
go conn.loopRead()
|
|
go conn.loopWrite()
|
|
go conn.keepalive()
|
|
go conn.checkDrop()
|
|
return conn
|
|
}
|
|
|
|
func (conn *Conn) connect() error {
|
|
var dial net.Conn
|
|
var err error
|
|
if conn.cfg.UseSSL {
|
|
dial, err = tls.Dial("tcp", conn.cfg.Server, nil)
|
|
} else {
|
|
dial, err = net.Dial("tcp", conn.cfg.Server)
|
|
}
|
|
if err != nil {
|
|
logging.Error("dial: %v", err)
|
|
return err
|
|
}
|
|
cn := network.NewConn(dial)
|
|
err = writeHandshake(cn, conn.cfg)
|
|
if err != nil {
|
|
logging.Error("write handshake: %v", err)
|
|
return err
|
|
}
|
|
logging.Info("%s connected", conn.cfg.Server)
|
|
conn.conn = cn
|
|
return nil
|
|
}
|
|
|
|
func (conn *Conn) close() {
|
|
if conn.conn != nil {
|
|
conn.conn.Close()
|
|
}
|
|
}
|
|
|
|
func writeHandshake(conn *network.Conn, cfg *global.Configure) error {
|
|
var msg network.Msg
|
|
msg.XType = network.Msg_handshake
|
|
msg.From = cfg.ID
|
|
msg.To = "server"
|
|
msg.Payload = &network.Msg_Hsp{
|
|
Hsp: &network.HandshakePayload{
|
|
Enc: cfg.Enc[:],
|
|
},
|
|
}
|
|
return conn.WriteMessage(&msg, 5*time.Second)
|
|
}
|
|
|
|
func (conn *Conn) loopRead() {
|
|
defer utils.Recover("loopRead")
|
|
defer conn.close()
|
|
defer conn.cancel()
|
|
var timeout int
|
|
for {
|
|
msg, _, err := conn.conn.ReadMessage(conn.cfg.ReadTimeout)
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), "i/o timeout") {
|
|
timeout++
|
|
if timeout >= 60 {
|
|
logging.Error("too many timeout times")
|
|
return
|
|
}
|
|
continue
|
|
}
|
|
logging.Error("read message: %v", err)
|
|
continue
|
|
}
|
|
timeout = 0
|
|
if msg.GetXType() == network.Msg_keepalive {
|
|
continue
|
|
}
|
|
logging.Debug("read message %s(%s) from %s",
|
|
msg.GetXType().String(), msg.GetLinkId(), msg.GetFrom())
|
|
linkID := msg.GetLinkId()
|
|
conn.lockDrop.RLock()
|
|
_, drop := conn.drop[linkID]
|
|
conn.lockDrop.RUnlock()
|
|
if drop {
|
|
continue
|
|
}
|
|
conn.RLock()
|
|
ch := conn.read[linkID]
|
|
conn.RUnlock()
|
|
if ch == nil {
|
|
ch = conn.unknownRead
|
|
}
|
|
select {
|
|
case ch <- msg:
|
|
case <-time.After(conn.cfg.ReadTimeout):
|
|
logging.Error("drop message: %s", msg.GetXType().String())
|
|
conn.lockDrop.Lock()
|
|
conn.drop[msg.GetLinkId()] = time.Now().Add(time.Minute)
|
|
conn.lockDrop.Unlock()
|
|
case <-conn.ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (conn *Conn) loopWrite() {
|
|
defer utils.Recover("loopWrite")
|
|
defer conn.close()
|
|
defer conn.cancel()
|
|
for {
|
|
var msg *network.Msg
|
|
select {
|
|
case msg = <-conn.write:
|
|
case <-conn.ctx.Done():
|
|
return
|
|
}
|
|
msg.From = conn.cfg.ID
|
|
err := conn.conn.WriteMessage(msg, conn.cfg.WriteTimeout)
|
|
if err != nil {
|
|
logging.Error("write message error on %s: %v",
|
|
conn.cfg.ID, err)
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
func (conn *Conn) keepalive() {
|
|
defer utils.Recover("keepalive")
|
|
defer conn.close()
|
|
defer conn.cancel()
|
|
tk := time.NewTicker(10 * time.Second)
|
|
for {
|
|
select {
|
|
case <-tk.C:
|
|
conn.SendKeepalive()
|
|
case <-conn.ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// AddLink attach read message
|
|
func (conn *Conn) AddLink(id string) {
|
|
logging.Info("add link %s", id)
|
|
conn.Lock()
|
|
if _, ok := conn.read[id]; !ok {
|
|
conn.read[id] = make(chan *network.Msg, 10)
|
|
}
|
|
conn.Unlock()
|
|
}
|
|
|
|
// Reset reset message next read
|
|
func (conn *Conn) Reset(id string, msg *network.Msg) {
|
|
conn.RLock()
|
|
ch := conn.read[id]
|
|
conn.RUnlock()
|
|
ch <- msg
|
|
}
|
|
|
|
// ChanRead get read channel from link id
|
|
func (conn *Conn) ChanRead(id string) <-chan *network.Msg {
|
|
conn.RLock()
|
|
defer conn.RUnlock()
|
|
return conn.read[id]
|
|
}
|
|
|
|
// ChanUnknown get channel of unknown link id
|
|
func (conn *Conn) ChanUnknown() <-chan *network.Msg {
|
|
return conn.unknownRead
|
|
}
|
|
|
|
func (conn *Conn) checkDrop() {
|
|
for {
|
|
time.Sleep(time.Second)
|
|
|
|
drops := make([]string, 0, len(conn.drop))
|
|
conn.lockDrop.RLock()
|
|
for k, t := range conn.drop {
|
|
if time.Now().After(t) {
|
|
drops = append(drops, k)
|
|
}
|
|
}
|
|
conn.lockDrop.RUnlock()
|
|
|
|
conn.lockDrop.Lock()
|
|
for _, id := range drops {
|
|
delete(conn.drop, id)
|
|
}
|
|
conn.lockDrop.Unlock()
|
|
}
|
|
}
|
|
|
|
// Wait wait for connection closed
|
|
func (conn *Conn) Wait() {
|
|
<-conn.ctx.Done()
|
|
}
|