mirror of
https://github.com/lwch/natpass
synced 2025-10-04 21:12:41 +08:00
去除重连逻辑
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"os"
|
||||
rt "runtime"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
@@ -105,9 +106,14 @@ func (a *App) run() {
|
||||
}()
|
||||
|
||||
if a.cfg.DashboardEnabled {
|
||||
go func() {
|
||||
a.conn.Wait()
|
||||
logging.Flush()
|
||||
os.Exit(1)
|
||||
}()
|
||||
db := dashboard.New(a.cfg, a.conn, mgr, a.version)
|
||||
runtime.Assert(db.ListenAndServe(a.cfg.DashboardListen, a.cfg.DashboardPort))
|
||||
} else {
|
||||
select {}
|
||||
a.conn.Wait()
|
||||
}
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package conn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"strings"
|
||||
@@ -24,6 +25,9 @@ type Conn struct {
|
||||
write chan *network.Msg
|
||||
lockDrop sync.RWMutex
|
||||
drop map[string]time.Time
|
||||
// runtime
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// New new connection
|
||||
@@ -35,9 +39,8 @@ func New(cfg *global.Configure) *Conn {
|
||||
write: make(chan *network.Msg, 1024),
|
||||
drop: make(map[string]time.Time),
|
||||
}
|
||||
var err error
|
||||
conn.conn, err = conn.tryConnect()
|
||||
runtime.Assert(err)
|
||||
runtime.Assert(conn.connect())
|
||||
conn.ctx, conn.cancel = context.WithCancel(context.Background())
|
||||
go conn.loopRead()
|
||||
go conn.loopWrite()
|
||||
go conn.keepalive()
|
||||
@@ -45,7 +48,7 @@ func New(cfg *global.Configure) *Conn {
|
||||
return conn
|
||||
}
|
||||
|
||||
func (conn *Conn) connect() (*network.Conn, error) {
|
||||
func (conn *Conn) connect() error {
|
||||
var dial net.Conn
|
||||
var err error
|
||||
if conn.cfg.UseSSL {
|
||||
@@ -55,30 +58,23 @@ func (conn *Conn) connect() (*network.Conn, error) {
|
||||
}
|
||||
if err != nil {
|
||||
logging.Error("dial: %v", err)
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
cn := network.NewConn(dial)
|
||||
err = writeHandshake(cn, conn.cfg)
|
||||
if err != nil {
|
||||
logging.Error("write handshake: %v", err)
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
logging.Info("%s connected", conn.cfg.Server)
|
||||
return cn, nil
|
||||
conn.conn = cn
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *Conn) tryConnect() (*network.Conn, error) {
|
||||
var ret *network.Conn
|
||||
var err error
|
||||
for i := 0; i < 10; i++ {
|
||||
ret, err = conn.connect()
|
||||
if err == nil {
|
||||
return ret, nil
|
||||
}
|
||||
logging.Error("connect error on %d times: %v", i+1, err)
|
||||
time.Sleep(time.Second)
|
||||
func (conn *Conn) close() {
|
||||
if conn.conn != nil {
|
||||
conn.conn.Close()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func writeHandshake(conn *network.Conn, cfg *global.Configure) error {
|
||||
@@ -96,6 +92,8 @@ func writeHandshake(conn *network.Conn, cfg *global.Configure) error {
|
||||
|
||||
func (conn *Conn) loopRead() {
|
||||
defer utils.Recover("loopRead")
|
||||
defer conn.close()
|
||||
defer conn.cancel()
|
||||
var timeout int
|
||||
for {
|
||||
msg, _, err := conn.conn.ReadMessage(conn.cfg.ReadTimeout)
|
||||
@@ -104,16 +102,11 @@ func (conn *Conn) loopRead() {
|
||||
timeout++
|
||||
if timeout >= 60 {
|
||||
logging.Error("too many timeout times")
|
||||
conn.conn, err = conn.tryConnect()
|
||||
runtime.Assert(err)
|
||||
timeout = 0
|
||||
continue
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
logging.Error("read message: %v", err)
|
||||
conn.conn, err = conn.tryConnect()
|
||||
runtime.Assert(err)
|
||||
continue
|
||||
}
|
||||
timeout = 0
|
||||
@@ -142,21 +135,28 @@ func (conn *Conn) loopRead() {
|
||||
conn.lockDrop.Lock()
|
||||
conn.drop[msg.GetLinkId()] = time.Now().Add(time.Minute)
|
||||
conn.lockDrop.Unlock()
|
||||
case <-conn.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) loopWrite() {
|
||||
defer utils.Recover("loopWrite")
|
||||
defer conn.close()
|
||||
defer conn.cancel()
|
||||
for {
|
||||
msg := <-conn.write
|
||||
var msg *network.Msg
|
||||
select {
|
||||
case msg = <-conn.write:
|
||||
case <-conn.ctx.Done():
|
||||
return
|
||||
}
|
||||
msg.From = conn.cfg.ID
|
||||
err := conn.conn.WriteMessage(msg, conn.cfg.WriteTimeout)
|
||||
if err != nil {
|
||||
logging.Error("write message error on %s: %v",
|
||||
conn.cfg.ID, err)
|
||||
conn.conn, err = conn.connect()
|
||||
runtime.Assert(err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -164,9 +164,16 @@ func (conn *Conn) loopWrite() {
|
||||
|
||||
func (conn *Conn) keepalive() {
|
||||
defer utils.Recover("keepalive")
|
||||
defer conn.close()
|
||||
defer conn.cancel()
|
||||
tk := time.NewTicker(10 * time.Second)
|
||||
for {
|
||||
time.Sleep(10 * time.Second)
|
||||
conn.SendKeepalive()
|
||||
select {
|
||||
case <-tk.C:
|
||||
conn.SendKeepalive()
|
||||
case <-conn.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,3 +227,8 @@ func (conn *Conn) checkDrop() {
|
||||
conn.lockDrop.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Wait wait for connection closed
|
||||
func (conn *Conn) Wait() {
|
||||
<-conn.ctx.Done()
|
||||
}
|
||||
|
Reference in New Issue
Block a user