diff --git a/code/client/app/app.go b/code/client/app/app.go index 84f04f7..b940dcc 100644 --- a/code/client/app/app.go +++ b/code/client/app/app.go @@ -1,6 +1,7 @@ package app import ( + "os" rt "runtime" "github.com/kardianos/service" @@ -105,9 +106,14 @@ func (a *App) run() { }() if a.cfg.DashboardEnabled { + go func() { + a.conn.Wait() + logging.Flush() + os.Exit(1) + }() db := dashboard.New(a.cfg, a.conn, mgr, a.version) runtime.Assert(db.ListenAndServe(a.cfg.DashboardListen, a.cfg.DashboardPort)) } else { - select {} + a.conn.Wait() } } diff --git a/code/client/conn/conn.go b/code/client/conn/conn.go index 8c31b25..b493b6e 100644 --- a/code/client/conn/conn.go +++ b/code/client/conn/conn.go @@ -1,6 +1,7 @@ package conn import ( + "context" "crypto/tls" "net" "strings" @@ -24,6 +25,9 @@ type Conn struct { write chan *network.Msg lockDrop sync.RWMutex drop map[string]time.Time + // runtime + ctx context.Context + cancel context.CancelFunc } // New new connection @@ -35,9 +39,8 @@ func New(cfg *global.Configure) *Conn { write: make(chan *network.Msg, 1024), drop: make(map[string]time.Time), } - var err error - conn.conn, err = conn.tryConnect() - runtime.Assert(err) + runtime.Assert(conn.connect()) + conn.ctx, conn.cancel = context.WithCancel(context.Background()) go conn.loopRead() go conn.loopWrite() go conn.keepalive() @@ -45,7 +48,7 @@ func New(cfg *global.Configure) *Conn { return conn } -func (conn *Conn) connect() (*network.Conn, error) { +func (conn *Conn) connect() error { var dial net.Conn var err error if conn.cfg.UseSSL { @@ -55,30 +58,23 @@ func (conn *Conn) connect() (*network.Conn, error) { } if err != nil { logging.Error("dial: %v", err) - return nil, err + return err } cn := network.NewConn(dial) err = writeHandshake(cn, conn.cfg) if err != nil { logging.Error("write handshake: %v", err) - return nil, err + return err } logging.Info("%s connected", conn.cfg.Server) - return cn, nil + conn.conn = cn + return nil } -func (conn *Conn) tryConnect() (*network.Conn, error) { - var ret *network.Conn - var err error - for i := 0; i < 10; i++ { - ret, err = conn.connect() - if err == nil { - return ret, nil - } - logging.Error("connect error on %d times: %v", i+1, err) - time.Sleep(time.Second) +func (conn *Conn) close() { + if conn.conn != nil { + conn.conn.Close() } - return nil, err } func writeHandshake(conn *network.Conn, cfg *global.Configure) error { @@ -96,6 +92,8 @@ func writeHandshake(conn *network.Conn, cfg *global.Configure) error { func (conn *Conn) loopRead() { defer utils.Recover("loopRead") + defer conn.close() + defer conn.cancel() var timeout int for { msg, _, err := conn.conn.ReadMessage(conn.cfg.ReadTimeout) @@ -104,16 +102,11 @@ func (conn *Conn) loopRead() { timeout++ if timeout >= 60 { logging.Error("too many timeout times") - conn.conn, err = conn.tryConnect() - runtime.Assert(err) - timeout = 0 - continue + return } continue } logging.Error("read message: %v", err) - conn.conn, err = conn.tryConnect() - runtime.Assert(err) continue } timeout = 0 @@ -142,21 +135,28 @@ func (conn *Conn) loopRead() { conn.lockDrop.Lock() conn.drop[msg.GetLinkId()] = time.Now().Add(time.Minute) conn.lockDrop.Unlock() + case <-conn.ctx.Done(): + return } } } func (conn *Conn) loopWrite() { defer utils.Recover("loopWrite") + defer conn.close() + defer conn.cancel() for { - msg := <-conn.write + var msg *network.Msg + select { + case msg = <-conn.write: + case <-conn.ctx.Done(): + return + } msg.From = conn.cfg.ID err := conn.conn.WriteMessage(msg, conn.cfg.WriteTimeout) if err != nil { logging.Error("write message error on %s: %v", conn.cfg.ID, err) - conn.conn, err = conn.connect() - runtime.Assert(err) continue } } @@ -164,9 +164,16 @@ func (conn *Conn) loopWrite() { func (conn *Conn) keepalive() { defer utils.Recover("keepalive") + defer conn.close() + defer conn.cancel() + tk := time.NewTicker(10 * time.Second) for { - time.Sleep(10 * time.Second) - conn.SendKeepalive() + select { + case <-tk.C: + conn.SendKeepalive() + case <-conn.ctx.Done(): + return + } } } @@ -220,3 +227,8 @@ func (conn *Conn) checkDrop() { conn.lockDrop.Unlock() } } + +// Wait wait for connection closed +func (conn *Conn) Wait() { + <-conn.ctx.Done() +}