mirror of
https://github.com/lwch/natpass
synced 2025-11-01 07:42:32 +08:00
实现链接断开时的disconnect消息发送逻辑
This commit is contained in:
@@ -3,15 +3,18 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"natpass/code/network"
|
"natpass/code/network"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lwch/logging"
|
"github.com/lwch/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
type client struct {
|
type client struct {
|
||||||
|
sync.RWMutex
|
||||||
parent *Handler
|
parent *Handler
|
||||||
id string
|
id string
|
||||||
c *network.Conn
|
c *network.Conn
|
||||||
|
tunnels map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClient(parent *Handler, id string, conn *network.Conn) *client {
|
func newClient(parent *Handler, id string, conn *network.Conn) *client {
|
||||||
@@ -19,6 +22,7 @@ func newClient(parent *Handler, id string, conn *network.Conn) *client {
|
|||||||
parent: parent,
|
parent: parent,
|
||||||
id: id,
|
id: id,
|
||||||
c: conn,
|
c: conn,
|
||||||
|
tunnels: make(map[string]struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -39,3 +43,35 @@ func (c *client) run() {
|
|||||||
func (c *client) writeMessage(msg *network.Msg) error {
|
func (c *client) writeMessage(msg *network.Msg) error {
|
||||||
return c.c.WriteMessage(msg, time.Second)
|
return c.c.WriteMessage(msg, time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *client) addTunnel(id string) {
|
||||||
|
c.Lock()
|
||||||
|
c.tunnels[id] = struct{}{}
|
||||||
|
c.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) getTunnels() []string {
|
||||||
|
ret := make([]string, 0, len(c.tunnels))
|
||||||
|
c.RLock()
|
||||||
|
for tn := range c.tunnels {
|
||||||
|
ret = append(ret, tn)
|
||||||
|
}
|
||||||
|
c.RUnlock()
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) close(id string) {
|
||||||
|
var msg network.Msg
|
||||||
|
msg.From = "server"
|
||||||
|
msg.To = c.id
|
||||||
|
msg.XType = network.Msg_disconnect
|
||||||
|
msg.Payload = &network.Msg_XDisconnect{
|
||||||
|
XDisconnect: &network.Disconnect{
|
||||||
|
Id: id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
c.c.WriteMessage(&msg, time.Second)
|
||||||
|
c.Lock()
|
||||||
|
delete(c.tunnels, id)
|
||||||
|
c.Unlock()
|
||||||
|
}
|
||||||
|
|||||||
@@ -14,13 +14,15 @@ import (
|
|||||||
type Handler struct {
|
type Handler struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
cfg *global.Configure
|
cfg *global.Configure
|
||||||
clients map[string]*client
|
clients map[string]*client // client id => client
|
||||||
|
tunnels map[string][2]*client // tunnel id => endpoints
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(cfg *global.Configure) *Handler {
|
func New(cfg *global.Configure) *Handler {
|
||||||
return &Handler{
|
return &Handler{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
clients: make(map[string]*client),
|
clients: make(map[string]*client),
|
||||||
|
tunnels: make(map[string][2]*client),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -56,6 +58,8 @@ func (h *Handler) Handle(conn net.Conn) {
|
|||||||
h.clients[cli.id] = cli
|
h.clients[cli.id] = cli
|
||||||
h.Unlock()
|
h.Unlock()
|
||||||
|
|
||||||
|
defer h.closeAll(cli)
|
||||||
|
|
||||||
cli.run()
|
cli.run()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,5 +87,56 @@ func (h *Handler) onMessage(msg *network.Msg) {
|
|||||||
logging.Error("client %s not found", to)
|
logging.Error("client %s not found", to)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
h.msgFilter(msg)
|
||||||
cli.writeMessage(msg)
|
cli.writeMessage(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Handler) msgFilter(msg *network.Msg) {
|
||||||
|
from := msg.GetFrom()
|
||||||
|
to := msg.GetTo()
|
||||||
|
switch msg.GetXType() {
|
||||||
|
case network.Msg_connect_rep:
|
||||||
|
if msg.GetCrep().GetOk() {
|
||||||
|
h.RLock()
|
||||||
|
fromCli := h.clients[from]
|
||||||
|
toCli := h.clients[to]
|
||||||
|
h.RUnlock()
|
||||||
|
id := msg.GetCrep().GetId()
|
||||||
|
var pair [2]*client
|
||||||
|
if fromCli != nil {
|
||||||
|
fromCli.addTunnel(id)
|
||||||
|
pair[0] = fromCli
|
||||||
|
}
|
||||||
|
if toCli != nil {
|
||||||
|
toCli.addTunnel(id)
|
||||||
|
pair[1] = toCli
|
||||||
|
}
|
||||||
|
h.Lock()
|
||||||
|
h.tunnels[id] = pair
|
||||||
|
h.Unlock()
|
||||||
|
}
|
||||||
|
case network.Msg_disconnect:
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) closeAll(cli *client) {
|
||||||
|
tunnels := cli.getTunnels()
|
||||||
|
for _, t := range tunnels {
|
||||||
|
h.RLock()
|
||||||
|
pair := h.tunnels[t]
|
||||||
|
h.RUnlock()
|
||||||
|
if pair[0] != nil {
|
||||||
|
pair[0].close(t)
|
||||||
|
}
|
||||||
|
if pair[1] != nil {
|
||||||
|
pair[1].close(t)
|
||||||
|
}
|
||||||
|
h.Lock()
|
||||||
|
delete(h.tunnels, t)
|
||||||
|
h.Unlock()
|
||||||
|
}
|
||||||
|
h.Lock()
|
||||||
|
delete(h.clients, cli.id)
|
||||||
|
h.Unlock()
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user