去除重连逻辑

This commit is contained in:
lwch
2022-07-06 14:55:01 +08:00
parent 77e581cdd6
commit 9554ccd331
2 changed files with 48 additions and 30 deletions

View File

@@ -1,6 +1,7 @@
package app package app
import ( import (
"os"
rt "runtime" rt "runtime"
"github.com/kardianos/service" "github.com/kardianos/service"
@@ -105,9 +106,14 @@ func (a *App) run() {
}() }()
if a.cfg.DashboardEnabled { if a.cfg.DashboardEnabled {
go func() {
a.conn.Wait()
logging.Flush()
os.Exit(1)
}()
db := dashboard.New(a.cfg, a.conn, mgr, a.version) db := dashboard.New(a.cfg, a.conn, mgr, a.version)
runtime.Assert(db.ListenAndServe(a.cfg.DashboardListen, a.cfg.DashboardPort)) runtime.Assert(db.ListenAndServe(a.cfg.DashboardListen, a.cfg.DashboardPort))
} else { } else {
select {} a.conn.Wait()
} }
} }

View File

@@ -1,6 +1,7 @@
package conn package conn
import ( import (
"context"
"crypto/tls" "crypto/tls"
"net" "net"
"strings" "strings"
@@ -24,6 +25,9 @@ type Conn struct {
write chan *network.Msg write chan *network.Msg
lockDrop sync.RWMutex lockDrop sync.RWMutex
drop map[string]time.Time drop map[string]time.Time
// runtime
ctx context.Context
cancel context.CancelFunc
} }
// New new connection // New new connection
@@ -35,9 +39,8 @@ func New(cfg *global.Configure) *Conn {
write: make(chan *network.Msg, 1024), write: make(chan *network.Msg, 1024),
drop: make(map[string]time.Time), drop: make(map[string]time.Time),
} }
var err error runtime.Assert(conn.connect())
conn.conn, err = conn.tryConnect() conn.ctx, conn.cancel = context.WithCancel(context.Background())
runtime.Assert(err)
go conn.loopRead() go conn.loopRead()
go conn.loopWrite() go conn.loopWrite()
go conn.keepalive() go conn.keepalive()
@@ -45,7 +48,7 @@ func New(cfg *global.Configure) *Conn {
return conn return conn
} }
func (conn *Conn) connect() (*network.Conn, error) { func (conn *Conn) connect() error {
var dial net.Conn var dial net.Conn
var err error var err error
if conn.cfg.UseSSL { if conn.cfg.UseSSL {
@@ -55,30 +58,23 @@ func (conn *Conn) connect() (*network.Conn, error) {
} }
if err != nil { if err != nil {
logging.Error("dial: %v", err) logging.Error("dial: %v", err)
return nil, err return err
} }
cn := network.NewConn(dial) cn := network.NewConn(dial)
err = writeHandshake(cn, conn.cfg) err = writeHandshake(cn, conn.cfg)
if err != nil { if err != nil {
logging.Error("write handshake: %v", err) logging.Error("write handshake: %v", err)
return nil, err return err
} }
logging.Info("%s connected", conn.cfg.Server) logging.Info("%s connected", conn.cfg.Server)
return cn, nil conn.conn = cn
return nil
} }
func (conn *Conn) tryConnect() (*network.Conn, error) { func (conn *Conn) close() {
var ret *network.Conn if conn.conn != nil {
var err error conn.conn.Close()
for i := 0; i < 10; i++ {
ret, err = conn.connect()
if err == nil {
return ret, nil
}
logging.Error("connect error on %d times: %v", i+1, err)
time.Sleep(time.Second)
} }
return nil, err
} }
func writeHandshake(conn *network.Conn, cfg *global.Configure) error { func writeHandshake(conn *network.Conn, cfg *global.Configure) error {
@@ -96,6 +92,8 @@ func writeHandshake(conn *network.Conn, cfg *global.Configure) error {
func (conn *Conn) loopRead() { func (conn *Conn) loopRead() {
defer utils.Recover("loopRead") defer utils.Recover("loopRead")
defer conn.close()
defer conn.cancel()
var timeout int var timeout int
for { for {
msg, _, err := conn.conn.ReadMessage(conn.cfg.ReadTimeout) msg, _, err := conn.conn.ReadMessage(conn.cfg.ReadTimeout)
@@ -104,16 +102,11 @@ func (conn *Conn) loopRead() {
timeout++ timeout++
if timeout >= 60 { if timeout >= 60 {
logging.Error("too many timeout times") logging.Error("too many timeout times")
conn.conn, err = conn.tryConnect() return
runtime.Assert(err)
timeout = 0
continue
} }
continue continue
} }
logging.Error("read message: %v", err) logging.Error("read message: %v", err)
conn.conn, err = conn.tryConnect()
runtime.Assert(err)
continue continue
} }
timeout = 0 timeout = 0
@@ -142,21 +135,28 @@ func (conn *Conn) loopRead() {
conn.lockDrop.Lock() conn.lockDrop.Lock()
conn.drop[msg.GetLinkId()] = time.Now().Add(time.Minute) conn.drop[msg.GetLinkId()] = time.Now().Add(time.Minute)
conn.lockDrop.Unlock() conn.lockDrop.Unlock()
case <-conn.ctx.Done():
return
} }
} }
} }
func (conn *Conn) loopWrite() { func (conn *Conn) loopWrite() {
defer utils.Recover("loopWrite") defer utils.Recover("loopWrite")
defer conn.close()
defer conn.cancel()
for { for {
msg := <-conn.write var msg *network.Msg
select {
case msg = <-conn.write:
case <-conn.ctx.Done():
return
}
msg.From = conn.cfg.ID msg.From = conn.cfg.ID
err := conn.conn.WriteMessage(msg, conn.cfg.WriteTimeout) err := conn.conn.WriteMessage(msg, conn.cfg.WriteTimeout)
if err != nil { if err != nil {
logging.Error("write message error on %s: %v", logging.Error("write message error on %s: %v",
conn.cfg.ID, err) conn.cfg.ID, err)
conn.conn, err = conn.connect()
runtime.Assert(err)
continue continue
} }
} }
@@ -164,9 +164,16 @@ func (conn *Conn) loopWrite() {
func (conn *Conn) keepalive() { func (conn *Conn) keepalive() {
defer utils.Recover("keepalive") defer utils.Recover("keepalive")
defer conn.close()
defer conn.cancel()
tk := time.NewTicker(10 * time.Second)
for { for {
time.Sleep(10 * time.Second) select {
conn.SendKeepalive() case <-tk.C:
conn.SendKeepalive()
case <-conn.ctx.Done():
return
}
} }
} }
@@ -220,3 +227,8 @@ func (conn *Conn) checkDrop() {
conn.lockDrop.Unlock() conn.lockDrop.Unlock()
} }
} }
// Wait wait for connection closed
func (conn *Conn) Wait() {
<-conn.ctx.Done()
}