修改数据包格式

This commit is contained in:
lwch
2021-08-26 14:23:30 +08:00
parent b5649e1a33
commit ecfbc0d14d
10 changed files with 225 additions and 149 deletions

View File

@@ -5,7 +5,6 @@ import (
"natpass/code/network"
"natpass/code/server/global"
"net"
"strings"
"sync"
"time"
@@ -16,7 +15,7 @@ import (
type Handler struct {
cfg *global.Configure
lockClients sync.RWMutex
clients map[string]*client // client id => client
clients map[string]*clients // client id => client
lockLinks sync.RWMutex
links map[string][2]*client // link id => endpoints
idx int
@@ -26,7 +25,7 @@ type Handler struct {
func New(cfg *global.Configure) *Handler {
return &Handler{
cfg: cfg,
clients: make(map[string]*client),
clients: make(map[string]*clients),
links: make(map[string][2]*client),
idx: 0,
}
@@ -36,6 +35,7 @@ func New(cfg *global.Configure) *Handler {
func (h *Handler) Handle(conn net.Conn) {
c := network.NewConn(conn)
var id string
var idx uint32
defer func() {
if len(id) > 0 {
logging.Info("%s disconnected", id)
@@ -44,7 +44,7 @@ func (h *Handler) Handle(conn net.Conn) {
}()
var err error
for i := 0; i < 10; i++ {
id, err = h.readHandshake(c)
id, idx, err = h.readHandshake(c)
if err != nil {
if err == errInvalidHandshake {
logging.Error("invalid handshake from %s", c.RemoteAddr().String())
@@ -60,75 +60,67 @@ func (h *Handler) Handle(conn net.Conn) {
}
logging.Info("%s connected", id)
// split id and index
trimID := id
n := strings.LastIndex(id, "-")
if n != -1 {
trimID = id[:n]
}
clients := h.tryGetClients(id)
cli := clients.new(idx, c)
cli := newClient(h, id, trimID, c)
h.lockClients.Lock()
h.clients[cli.id] = cli
h.lockClients.Unlock()
defer h.closeAll(cli)
defer h.closeClient(cli)
cli.run()
}
func (h *Handler) connsByTrimID(id string) []*client {
ret := make([]*client, 0, 10)
h.lockClients.RLock()
for _, cli := range h.clients {
if cli.trimID == id {
ret = append(ret, cli)
}
func (h *Handler) tryGetClients(id string) *clients {
h.lockClients.Lock()
defer h.lockClients.Unlock()
clients := h.clients[id]
if clients != nil {
return clients
}
h.lockClients.RUnlock()
return ret
clients = newClients(h, id)
h.clients[id] = clients
return clients
}
func (h *Handler) getClient(linkID, targetID string) *client {
// readHandshake read handshake message and compare secret encoded from md5
func (h *Handler) readHandshake(c *network.Conn) (string, uint32, error) {
msg, err := c.ReadMessage(5 * time.Second)
if err != nil {
return "", 0, err
}
if msg.GetXType() != network.Msg_handshake {
return "", 0, errNotHandshake
}
n := bytes.Compare(msg.GetHsp().GetEnc(), h.cfg.Enc[:])
if n != 0 {
return "", 0, errInvalidHandshake
}
return msg.GetFrom(), msg.GetFromIdx(), nil
}
func (h *Handler) getClient(linkID, to string, toIdx uint32) *client {
h.lockLinks.RLock()
pair := h.links[linkID]
h.lockLinks.RUnlock()
if pair[0] != nil && pair[0].trimID == targetID {
if pair[0] != nil && pair[0].idx == toIdx {
return pair[0]
}
if pair[1] != nil && pair[1].trimID == targetID {
if pair[1] != nil && pair[1].idx == toIdx {
return pair[1]
}
conns := h.connsByTrimID(targetID)
if len(conns) == 0 {
h.lockClients.RLock()
clients := h.clients[to]
h.lockClients.RUnlock()
if clients == nil {
return nil
}
conn := conns[h.idx%len(conns)]
h.idx++
return conn
return clients.next()
}
// readHandshake read handshake message and compare secret encoded from md5
func (h *Handler) readHandshake(c *network.Conn) (string, error) {
msg, err := c.ReadMessage(5 * time.Second)
if err != nil {
return "", err
}
if msg.GetXType() != network.Msg_handshake {
return "", errNotHandshake
}
n := bytes.Compare(msg.GetHsp().GetEnc(), h.cfg.Enc[:])
if n != 0 {
return "", errInvalidHandshake
}
return msg.GetFrom(), nil
}
// onMessage forward message
func (h *Handler) onMessage(from *client, conn *network.Conn, msg *network.Msg) {
to := msg.GetTo()
toIdx := msg.GetToIdx()
var linkID string
switch msg.GetXType() {
case network.Msg_connect_req:
@@ -142,9 +134,9 @@ func (h *Handler) onMessage(from *client, conn *network.Conn, msg *network.Msg)
default:
return
}
cli := h.getClient(linkID, to)
cli := h.getClient(linkID, to, toIdx)
if cli == nil {
logging.Error("client %s not found", to)
logging.Error("client %s-%d not found", to, toIdx)
return
}
h.msgHook(msg, from, cli)
@@ -180,11 +172,13 @@ func (h *Handler) msgHook(msg *network.Msg, from, to *client) {
delete(h.links, id)
h.lockLinks.Unlock()
}
msg.From = from.trimID
msg.From = from.parent.id
msg.FromIdx = from.idx
msg.To = to.parent.id
msg.ToIdx = to.idx
}
// closeAll close all links from client
func (h *Handler) closeAll(cli *client) {
func (h *Handler) closeClient(cli *client) {
links := cli.getLinks()
for _, t := range links {
h.lockLinks.RLock()
@@ -200,7 +194,13 @@ func (h *Handler) closeAll(cli *client) {
delete(h.links, t)
h.lockLinks.Unlock()
}
h.lockClients.Lock()
delete(h.clients, cli.id)
h.lockClients.Unlock()
h.lockClients.RLock()
clients := h.clients[cli.parent.id]
h.lockClients.RUnlock()
clients.close(cli.idx)
if len(clients.data) == 0 {
h.lockClients.Lock()
delete(h.clients, clients.id)
h.lockClients.Unlock()
}
}