mirror of
https://github.com/lwch/natpass
synced 2025-10-05 05:16:50 +08:00
去除重连逻辑
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package app
|
package app
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"os"
|
||||||
rt "runtime"
|
rt "runtime"
|
||||||
|
|
||||||
"github.com/kardianos/service"
|
"github.com/kardianos/service"
|
||||||
@@ -105,9 +106,14 @@ func (a *App) run() {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
if a.cfg.DashboardEnabled {
|
if a.cfg.DashboardEnabled {
|
||||||
|
go func() {
|
||||||
|
a.conn.Wait()
|
||||||
|
logging.Flush()
|
||||||
|
os.Exit(1)
|
||||||
|
}()
|
||||||
db := dashboard.New(a.cfg, a.conn, mgr, a.version)
|
db := dashboard.New(a.cfg, a.conn, mgr, a.version)
|
||||||
runtime.Assert(db.ListenAndServe(a.cfg.DashboardListen, a.cfg.DashboardPort))
|
runtime.Assert(db.ListenAndServe(a.cfg.DashboardListen, a.cfg.DashboardPort))
|
||||||
} else {
|
} else {
|
||||||
select {}
|
a.conn.Wait()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
package conn
|
package conn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -24,6 +25,9 @@ type Conn struct {
|
|||||||
write chan *network.Msg
|
write chan *network.Msg
|
||||||
lockDrop sync.RWMutex
|
lockDrop sync.RWMutex
|
||||||
drop map[string]time.Time
|
drop map[string]time.Time
|
||||||
|
// runtime
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// New new connection
|
// New new connection
|
||||||
@@ -35,9 +39,8 @@ func New(cfg *global.Configure) *Conn {
|
|||||||
write: make(chan *network.Msg, 1024),
|
write: make(chan *network.Msg, 1024),
|
||||||
drop: make(map[string]time.Time),
|
drop: make(map[string]time.Time),
|
||||||
}
|
}
|
||||||
var err error
|
runtime.Assert(conn.connect())
|
||||||
conn.conn, err = conn.tryConnect()
|
conn.ctx, conn.cancel = context.WithCancel(context.Background())
|
||||||
runtime.Assert(err)
|
|
||||||
go conn.loopRead()
|
go conn.loopRead()
|
||||||
go conn.loopWrite()
|
go conn.loopWrite()
|
||||||
go conn.keepalive()
|
go conn.keepalive()
|
||||||
@@ -45,7 +48,7 @@ func New(cfg *global.Configure) *Conn {
|
|||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) connect() (*network.Conn, error) {
|
func (conn *Conn) connect() error {
|
||||||
var dial net.Conn
|
var dial net.Conn
|
||||||
var err error
|
var err error
|
||||||
if conn.cfg.UseSSL {
|
if conn.cfg.UseSSL {
|
||||||
@@ -55,30 +58,23 @@ func (conn *Conn) connect() (*network.Conn, error) {
|
|||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Error("dial: %v", err)
|
logging.Error("dial: %v", err)
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
cn := network.NewConn(dial)
|
cn := network.NewConn(dial)
|
||||||
err = writeHandshake(cn, conn.cfg)
|
err = writeHandshake(cn, conn.cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Error("write handshake: %v", err)
|
logging.Error("write handshake: %v", err)
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
logging.Info("%s connected", conn.cfg.Server)
|
logging.Info("%s connected", conn.cfg.Server)
|
||||||
return cn, nil
|
conn.conn = cn
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) tryConnect() (*network.Conn, error) {
|
func (conn *Conn) close() {
|
||||||
var ret *network.Conn
|
if conn.conn != nil {
|
||||||
var err error
|
conn.conn.Close()
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
ret, err = conn.connect()
|
|
||||||
if err == nil {
|
|
||||||
return ret, nil
|
|
||||||
}
|
}
|
||||||
logging.Error("connect error on %d times: %v", i+1, err)
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeHandshake(conn *network.Conn, cfg *global.Configure) error {
|
func writeHandshake(conn *network.Conn, cfg *global.Configure) error {
|
||||||
@@ -96,6 +92,8 @@ func writeHandshake(conn *network.Conn, cfg *global.Configure) error {
|
|||||||
|
|
||||||
func (conn *Conn) loopRead() {
|
func (conn *Conn) loopRead() {
|
||||||
defer utils.Recover("loopRead")
|
defer utils.Recover("loopRead")
|
||||||
|
defer conn.close()
|
||||||
|
defer conn.cancel()
|
||||||
var timeout int
|
var timeout int
|
||||||
for {
|
for {
|
||||||
msg, _, err := conn.conn.ReadMessage(conn.cfg.ReadTimeout)
|
msg, _, err := conn.conn.ReadMessage(conn.cfg.ReadTimeout)
|
||||||
@@ -104,16 +102,11 @@ func (conn *Conn) loopRead() {
|
|||||||
timeout++
|
timeout++
|
||||||
if timeout >= 60 {
|
if timeout >= 60 {
|
||||||
logging.Error("too many timeout times")
|
logging.Error("too many timeout times")
|
||||||
conn.conn, err = conn.tryConnect()
|
return
|
||||||
runtime.Assert(err)
|
|
||||||
timeout = 0
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
logging.Error("read message: %v", err)
|
logging.Error("read message: %v", err)
|
||||||
conn.conn, err = conn.tryConnect()
|
|
||||||
runtime.Assert(err)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
timeout = 0
|
timeout = 0
|
||||||
@@ -142,21 +135,28 @@ func (conn *Conn) loopRead() {
|
|||||||
conn.lockDrop.Lock()
|
conn.lockDrop.Lock()
|
||||||
conn.drop[msg.GetLinkId()] = time.Now().Add(time.Minute)
|
conn.drop[msg.GetLinkId()] = time.Now().Add(time.Minute)
|
||||||
conn.lockDrop.Unlock()
|
conn.lockDrop.Unlock()
|
||||||
|
case <-conn.ctx.Done():
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) loopWrite() {
|
func (conn *Conn) loopWrite() {
|
||||||
defer utils.Recover("loopWrite")
|
defer utils.Recover("loopWrite")
|
||||||
|
defer conn.close()
|
||||||
|
defer conn.cancel()
|
||||||
for {
|
for {
|
||||||
msg := <-conn.write
|
var msg *network.Msg
|
||||||
|
select {
|
||||||
|
case msg = <-conn.write:
|
||||||
|
case <-conn.ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
msg.From = conn.cfg.ID
|
msg.From = conn.cfg.ID
|
||||||
err := conn.conn.WriteMessage(msg, conn.cfg.WriteTimeout)
|
err := conn.conn.WriteMessage(msg, conn.cfg.WriteTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Error("write message error on %s: %v",
|
logging.Error("write message error on %s: %v",
|
||||||
conn.cfg.ID, err)
|
conn.cfg.ID, err)
|
||||||
conn.conn, err = conn.connect()
|
|
||||||
runtime.Assert(err)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -164,9 +164,16 @@ func (conn *Conn) loopWrite() {
|
|||||||
|
|
||||||
func (conn *Conn) keepalive() {
|
func (conn *Conn) keepalive() {
|
||||||
defer utils.Recover("keepalive")
|
defer utils.Recover("keepalive")
|
||||||
|
defer conn.close()
|
||||||
|
defer conn.cancel()
|
||||||
|
tk := time.NewTicker(10 * time.Second)
|
||||||
for {
|
for {
|
||||||
time.Sleep(10 * time.Second)
|
select {
|
||||||
|
case <-tk.C:
|
||||||
conn.SendKeepalive()
|
conn.SendKeepalive()
|
||||||
|
case <-conn.ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -220,3 +227,8 @@ func (conn *Conn) checkDrop() {
|
|||||||
conn.lockDrop.Unlock()
|
conn.lockDrop.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wait wait for connection closed
|
||||||
|
func (conn *Conn) Wait() {
|
||||||
|
<-conn.ctx.Done()
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user