mirror of
https://github.com/lwch/natpass
synced 2025-10-09 23:10:07 +08:00
1. 增加代码注释
2. hook了disconnect消息
This commit is contained in:
@@ -18,6 +18,7 @@ type Handler struct {
|
||||
tunnels map[string][2]*client // tunnel id => endpoints
|
||||
}
|
||||
|
||||
// New create handler
|
||||
func New(cfg *global.Configure) *Handler {
|
||||
return &Handler{
|
||||
cfg: cfg,
|
||||
@@ -26,6 +27,7 @@ func New(cfg *global.Configure) *Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// Handle main loop
|
||||
func (h *Handler) Handle(conn net.Conn) {
|
||||
c := network.NewConn(conn)
|
||||
var id string
|
||||
@@ -63,6 +65,7 @@ func (h *Handler) Handle(conn net.Conn) {
|
||||
cli.run()
|
||||
}
|
||||
|
||||
// readHandshake read handshake message and compare secret encoded from md5
|
||||
func (h *Handler) readHandshake(c *network.Conn) (string, error) {
|
||||
msg, err := c.ReadMessage(5 * time.Second)
|
||||
if err != nil {
|
||||
@@ -78,6 +81,7 @@ func (h *Handler) readHandshake(c *network.Conn) (string, error) {
|
||||
return msg.GetFrom(), nil
|
||||
}
|
||||
|
||||
// onMessage forward message
|
||||
func (h *Handler) onMessage(msg *network.Msg) {
|
||||
to := msg.GetTo()
|
||||
h.RLock()
|
||||
@@ -87,20 +91,21 @@ func (h *Handler) onMessage(msg *network.Msg) {
|
||||
logging.Error("client %s not found", to)
|
||||
return
|
||||
}
|
||||
h.msgFilter(msg)
|
||||
h.msgHook(msg)
|
||||
cli.writeMessage(msg)
|
||||
}
|
||||
|
||||
func (h *Handler) msgFilter(msg *network.Msg) {
|
||||
// msgHook hook from on message
|
||||
func (h *Handler) msgHook(msg *network.Msg) {
|
||||
from := msg.GetFrom()
|
||||
to := msg.GetTo()
|
||||
h.RLock()
|
||||
fromCli := h.clients[from]
|
||||
toCli := h.clients[to]
|
||||
h.RUnlock()
|
||||
switch msg.GetXType() {
|
||||
case network.Msg_connect_rep:
|
||||
if msg.GetCrep().GetOk() {
|
||||
h.RLock()
|
||||
fromCli := h.clients[from]
|
||||
toCli := h.clients[to]
|
||||
h.RUnlock()
|
||||
id := msg.GetCrep().GetId()
|
||||
var pair [2]*client
|
||||
if fromCli != nil {
|
||||
@@ -116,10 +121,16 @@ func (h *Handler) msgFilter(msg *network.Msg) {
|
||||
h.Unlock()
|
||||
}
|
||||
case network.Msg_disconnect:
|
||||
|
||||
if fromCli != nil {
|
||||
fromCli.removeTunnel(msg.GetXDisconnect().GetId())
|
||||
}
|
||||
if toCli != nil {
|
||||
toCli.removeTunnel(msg.GetXDisconnect().GetId())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// closeAll close all tunnel from client
|
||||
func (h *Handler) closeAll(cli *client) {
|
||||
tunnels := cli.getTunnels()
|
||||
for _, t := range tunnels {
|
||||
|
Reference in New Issue
Block a user