修改数据包格式

This commit is contained in:
lwch
2021-08-26 14:23:30 +08:00
parent b5649e1a33
commit ecfbc0d14d
10 changed files with 225 additions and 149 deletions

View File

@@ -89,7 +89,8 @@ func main() {
var linkID string var linkID string
switch msg.GetXType() { switch msg.GetXType() {
case network.Msg_connect_req: case network.Msg_connect_req:
connect(pl, conn, msg.GetFrom(), msg.GetTo(), msg.GetCreq()) connect(pl, conn, msg.GetFrom(), msg.GetTo(),
msg.GetFromIdx(), msg.GetToIdx(), msg.GetCreq())
case network.Msg_connect_rep: case network.Msg_connect_rep:
linkID = msg.GetCrep().GetId() linkID = msg.GetCrep().GetId()
case network.Msg_disconnect: case network.Msg_disconnect:
@@ -102,7 +103,7 @@ func main() {
continue continue
} }
} }
logging.Info("connection %s exited", conn.ID) logging.Info("connection %s-%d exited", cfg.ID, conn.Idx)
time.Sleep(time.Second) time.Sleep(time.Second)
} }
}() }()
@@ -111,14 +112,14 @@ func main() {
select {} select {}
} }
func connect(pool *pool.Pool, conn *pool.Conn, from, to string, req *network.ConnectRequest) { func connect(pool *pool.Pool, conn *pool.Conn, from, to string, fromIdx, toIdx uint32, req *network.ConnectRequest) {
dial := "tcp" dial := "tcp"
if req.GetXType() == network.ConnectRequest_udp { if req.GetXType() == network.ConnectRequest_udp {
dial = "udp" dial = "udp"
} }
link, err := net.Dial(dial, fmt.Sprintf("%s:%d", req.GetAddr(), req.GetPort())) link, err := net.Dial(dial, fmt.Sprintf("%s:%d", req.GetAddr(), req.GetPort()))
if err != nil { if err != nil {
conn.SendConnectError(from, req.GetId(), err.Error()) conn.SendConnectError(from, fromIdx, req.GetId(), err.Error())
return return
} }
host, pt, _ := net.SplitHostPort(link.LocalAddr().String()) host, pt, _ := net.SplitHostPort(link.LocalAddr().String())
@@ -133,7 +134,7 @@ func connect(pool *pool.Pool, conn *pool.Conn, from, to string, req *network.Con
RemotePort: uint16(req.GetPort()), RemotePort: uint16(req.GetPort()),
}) })
lk := tunnel.NewLink(tn, req.GetId(), from, link, conn) lk := tunnel.NewLink(tn, req.GetId(), from, link, conn)
conn.SendConnectOK(from, req.GetId()) conn.SendConnectOK(from, fromIdx, req.GetId())
lk.Forward() lk.Forward()
lk.OnWork <- struct{}{} lk.OnWork <- struct{}{}
} }

View File

@@ -12,7 +12,7 @@ import (
type Conn struct { type Conn struct {
sync.RWMutex sync.RWMutex
ID string Idx uint32
parent *Pool parent *Pool
conn *network.Conn conn *network.Conn
read map[string]chan *network.Msg // link id => channel read map[string]chan *network.Msg // link id => channel
@@ -20,16 +20,16 @@ type Conn struct {
write chan *network.Msg // link id => channel write chan *network.Msg // link id => channel
} }
func newConn(parent *Pool, conn *network.Conn, id string) *Conn { func newConn(parent *Pool, conn *network.Conn, idx uint32) *Conn {
ret := &Conn{ ret := &Conn{
ID: id, Idx: idx,
parent: parent, parent: parent,
conn: conn, conn: conn,
read: make(map[string]chan *network.Msg), read: make(map[string]chan *network.Msg),
unknownRead: make(chan *network.Msg), unknownRead: make(chan *network.Msg),
write: make(chan *network.Msg), write: make(chan *network.Msg),
} }
logging.Info("new connection: %s", ret.ID) logging.Info("new connection: %s-%d", ret.parent.cfg.ID, ret.Idx)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
go ret.loopRead(cancel) go ret.loopRead(cancel)
go ret.loopWrite(cancel) go ret.loopWrite(cancel)
@@ -81,8 +81,8 @@ func (conn *Conn) Close() {
close(conn.write) close(conn.write)
conn.write = nil conn.write = nil
} }
conn.parent.onClose(conn.ID) conn.parent.onClose(conn.Idx)
logging.Error("connection %s closed", conn.ID) logging.Error("connection %s-%d closed", conn.parent.cfg.ID, conn.Idx)
} }
func (conn *Conn) loopRead(cancel context.CancelFunc) { func (conn *Conn) loopRead(cancel context.CancelFunc) {
@@ -138,10 +138,12 @@ func (conn *Conn) loopWrite(cancel context.CancelFunc) {
if msg == nil { if msg == nil {
return return
} }
msg.From = conn.ID msg.From = conn.parent.cfg.ID
msg.FromIdx = conn.Idx
err := loopWrite(conn.conn, msg, conn.parent.cfg.WriteTimeout) err := loopWrite(conn.conn, msg, conn.parent.cfg.WriteTimeout)
if err != nil { if err != nil {
logging.Error("write message error on %s: %v", conn.ID, err) logging.Error("write message error on %s-%d: %v",
conn.parent.cfg.ID, conn.Idx, err)
return return
} }
} }

View File

@@ -2,7 +2,6 @@ package pool
import ( import (
"crypto/tls" "crypto/tls"
"fmt"
"natpass/code/client/global" "natpass/code/client/global"
"natpass/code/network" "natpass/code/network"
"sync" "sync"
@@ -16,16 +15,16 @@ import (
type Pool struct { type Pool struct {
sync.RWMutex sync.RWMutex
cfg *global.Configure cfg *global.Configure
conns map[string]*Conn conns map[uint32]*Conn
count int count int
idx int idx uint32
} }
// New create connection pool // New create connection pool
func New(cfg *global.Configure) *Pool { func New(cfg *global.Configure) *Pool {
return &Pool{ return &Pool{
cfg: cfg, cfg: cfg,
conns: make(map[string]*Conn, cfg.Links), conns: make(map[uint32]*Conn, cfg.Links),
count: cfg.Links, count: cfg.Links,
idx: 0, idx: 0,
} }
@@ -53,27 +52,26 @@ func (p *Pool) Get(id ...string) *Conn {
} }
if len(conns) >= p.count { if len(conns) >= p.count {
p.Lock() p.Lock()
conn := conns[p.idx%len(conns)] conn := conns[int(p.idx)%len(conns)]
p.idx++ p.idx++
p.Unlock() p.Unlock()
return conn return conn
} }
cid := fmt.Sprintf("%s-%d", p.cfg.ID, time.Now().UnixNano()) conn := p.connect(p.idx)
conn := p.connect(cid)
if conn == nil { if conn == nil {
return nil return nil
} }
c := newConn(p, conn, cid) c := newConn(p, conn, p.idx)
p.Lock() p.Lock()
p.conns[c.ID] = c p.conns[c.Idx] = c
p.idx++ p.idx++
p.Unlock() p.Unlock()
return c return c
} }
func (p *Pool) connect(id string) *network.Conn { func (p *Pool) connect(idx uint32) *network.Conn {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
logging.Error("connect error: %v", err) logging.Error("connect error: %v", err)
@@ -82,16 +80,17 @@ func (p *Pool) connect(id string) *network.Conn {
conn, err := tls.Dial("tcp", p.cfg.Server, nil) conn, err := tls.Dial("tcp", p.cfg.Server, nil)
runtime.Assert(err) runtime.Assert(err)
c := network.NewConn(conn) c := network.NewConn(conn)
err = p.writeHandshake(c, p.cfg, id) err = p.writeHandshake(c, p.cfg, idx)
runtime.Assert(err) runtime.Assert(err)
logging.Info("%s connected", p.cfg.Server) logging.Info("%s connected", p.cfg.Server)
return c return c
} }
func (p *Pool) writeHandshake(conn *network.Conn, cfg *global.Configure, id string) error { func (p *Pool) writeHandshake(conn *network.Conn, cfg *global.Configure, idx uint32) error {
var msg network.Msg var msg network.Msg
msg.XType = network.Msg_handshake msg.XType = network.Msg_handshake
msg.From = id msg.From = p.cfg.ID
msg.FromIdx = idx
msg.To = "server" msg.To = "server"
msg.Payload = &network.Msg_Hsp{ msg.Payload = &network.Msg_Hsp{
Hsp: &network.HandshakePayload{ Hsp: &network.HandshakePayload{
@@ -101,9 +100,9 @@ func (p *Pool) writeHandshake(conn *network.Conn, cfg *global.Configure, id stri
return conn.WriteMessage(&msg, 5*time.Second) return conn.WriteMessage(&msg, 5*time.Second)
} }
func (p *Pool) onClose(id string) { func (p *Pool) onClose(idx uint32) {
p.Lock() p.Lock()
delete(p.conns, id) delete(p.conns, idx)
p.Unlock() p.Unlock()
} }

View File

@@ -31,9 +31,10 @@ func (conn *Conn) SendConnectReq(id string, cfg global.Tunnel) {
} }
// SendConnectError send connect error response message // SendConnectError send connect error response message
func (conn *Conn) SendConnectError(to, id, info string) { func (conn *Conn) SendConnectError(to string, toIdx uint32, id, info string) {
var msg network.Msg var msg network.Msg
msg.To = to msg.To = to
msg.ToIdx = toIdx
msg.XType = network.Msg_connect_rep msg.XType = network.Msg_connect_rep
msg.Payload = &network.Msg_Crep{ msg.Payload = &network.Msg_Crep{
Crep: &network.ConnectResponse{ Crep: &network.ConnectResponse{
@@ -49,9 +50,10 @@ func (conn *Conn) SendConnectError(to, id, info string) {
} }
// SendConnectOK send connect success response message // SendConnectOK send connect success response message
func (conn *Conn) SendConnectOK(to, id string) { func (conn *Conn) SendConnectOK(to string, toIdx uint32, id string) {
var msg network.Msg var msg network.Msg
msg.To = to msg.To = to
msg.ToIdx = toIdx
msg.XType = network.Msg_connect_rep msg.XType = network.Msg_connect_rep
msg.Payload = &network.Msg_Crep{ msg.Payload = &network.Msg_Crep{
Crep: &network.ConnectResponse{ Crep: &network.ConnectResponse{
@@ -66,9 +68,10 @@ func (conn *Conn) SendConnectOK(to, id string) {
} }
// SendDisconnect send disconnect message // SendDisconnect send disconnect message
func (conn *Conn) SendDisconnect(to, id string) { func (conn *Conn) SendDisconnect(to string, toIdx uint32, id string) {
var msg network.Msg var msg network.Msg
msg.To = to msg.To = to
msg.ToIdx = toIdx
msg.XType = network.Msg_disconnect msg.XType = network.Msg_disconnect
msg.Payload = &network.Msg_XDisconnect{ msg.Payload = &network.Msg_XDisconnect{
XDisconnect: &network.Disconnect{ XDisconnect: &network.Disconnect{
@@ -82,7 +85,7 @@ func (conn *Conn) SendDisconnect(to, id string) {
} }
// SendData send forward data // SendData send forward data
func (conn *Conn) SendData(to, id string, data []byte) { func (conn *Conn) SendData(to string, toIdx uint32, id string, data []byte) {
dup := func(data []byte) []byte { dup := func(data []byte) []byte {
ret := make([]byte, len(data)) ret := make([]byte, len(data))
copy(ret, data) copy(ret, data)
@@ -90,6 +93,7 @@ func (conn *Conn) SendData(to, id string, data []byte) {
} }
var msg network.Msg var msg network.Msg
msg.To = to msg.To = to
msg.ToIdx = toIdx
msg.XType = network.Msg_forward msg.XType = network.Msg_forward
msg.Payload = &network.Msg_XData{ msg.Payload = &network.Msg_XData{
XData: &network.Data{ XData: &network.Data{

View File

@@ -14,6 +14,7 @@ type Link struct {
parent *Tunnel parent *Tunnel
id string // link id id string // link id
target string // target id target string // target id
targetIdx uint32 // target idx
local net.Conn local net.Conn
remote *pool.Conn remote *pool.Conn
OnWork chan struct{} OnWork chan struct{}
@@ -53,6 +54,7 @@ func (link *Link) remoteRead() {
if msg == nil { if msg == nil {
return return
} }
link.targetIdx = msg.GetFromIdx()
switch msg.GetXType() { switch msg.GetXType() {
case network.Msg_forward: case network.Msg_forward:
_, err := io.Copy(link.local, bytes.NewReader(msg.GetXData().GetData())) _, err := io.Copy(link.local, bytes.NewReader(msg.GetXData().GetData()))
@@ -85,7 +87,7 @@ func (link *Link) localRead() {
n, err := link.local.Read(buf) n, err := link.local.Read(buf)
if err != nil { if err != nil {
if !link.closeFromRemote { if !link.closeFromRemote {
link.remote.SendDisconnect(link.target, link.id) link.remote.SendDisconnect(link.target, link.targetIdx, link.id)
} }
// logging.Error("read data on tunnel %s link %s failed, err=%v", link.parent.Name, link.id, err) // logging.Error("read data on tunnel %s link %s failed, err=%v", link.parent.Name, link.id, err)
return return
@@ -94,6 +96,6 @@ func (link *Link) localRead() {
continue continue
} }
logging.Debug("link %s on tunnel %s read from local %d bytes", link.id, link.parent.Name, n) logging.Debug("link %s on tunnel %s read from local %d bytes", link.id, link.parent.Name, n)
link.remote.SendData(link.target, link.id, buf[:n]) link.remote.SendData(link.target, link.targetIdx, link.id, buf[:n])
} }
} }

View File

@@ -130,9 +130,11 @@ type Msg struct {
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
XType MsgType `protobuf:"varint,1,opt,name=_type,json=Type,proto3,enum=network.MsgType" json:"_type,omitempty"` XType MsgType `protobuf:"varint,1,opt,name=_type,json=Type,proto3,enum=network.MsgType" json:"_type,omitempty"`
From string `protobuf:"bytes,2,opt,name=from,proto3" json:"from,omitempty"` From string `protobuf:"bytes,2,opt,name=from,proto3" json:"from,omitempty"`
To string `protobuf:"bytes,3,opt,name=to,proto3" json:"to,omitempty"` FromIdx uint32 `protobuf:"varint,3,opt,name=from_idx,json=fromIdx,proto3" json:"from_idx,omitempty"`
To string `protobuf:"bytes,4,opt,name=to,proto3" json:"to,omitempty"`
ToIdx uint32 `protobuf:"varint,5,opt,name=to_idx,json=toIdx,proto3" json:"to_idx,omitempty"`
// Types that are assignable to Payload: // Types that are assignable to Payload:
// *Msg_Hsp // *Msg_Hsp
// *Msg_Creq // *Msg_Creq
@@ -188,6 +190,13 @@ func (x *Msg) GetFrom() string {
return "" return ""
} }
func (x *Msg) GetFromIdx() uint32 {
if x != nil {
return x.FromIdx
}
return 0
}
func (x *Msg) GetTo() string { func (x *Msg) GetTo() string {
if x != nil { if x != nil {
return x.To return x.To
@@ -195,6 +204,13 @@ func (x *Msg) GetTo() string {
return "" return ""
} }
func (x *Msg) GetToIdx() uint32 {
if x != nil {
return x.ToIdx
}
return 0
}
func (m *Msg) GetPayload() isMsg_Payload { func (m *Msg) GetPayload() isMsg_Payload {
if m != nil { if m != nil {
return m.Payload return m.Payload
@@ -279,36 +295,39 @@ var file_msg_proto_rawDesc = []byte{
0x6f, 0x74, 0x6f, 0x1a, 0x0d, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x2e, 0x70, 0x72, 0x6f, 0x6f, 0x74, 0x6f, 0x1a, 0x0d, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x2e, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x22, 0x25, 0x0a, 0x11, 0x68, 0x61, 0x6e, 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x5f, 0x74, 0x6f, 0x22, 0x25, 0x0a, 0x11, 0x68, 0x61, 0x6e, 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x5f,
0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x63, 0x18, 0x01, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x63, 0x18, 0x01,
0x20, 0x01, 0x28, 0x0c, 0x52, 0x03, 0x65, 0x6e, 0x63, 0x22, 0xb0, 0x03, 0x0a, 0x03, 0x6d, 0x73, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x03, 0x65, 0x6e, 0x63, 0x22, 0xe2, 0x03, 0x0a, 0x03, 0x6d, 0x73,
0x67, 0x12, 0x26, 0x0a, 0x05, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x67, 0x12, 0x26, 0x0a, 0x05, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e,
0x32, 0x11, 0x2e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x2e, 0x6d, 0x73, 0x67, 0x2e, 0x74, 0x32, 0x11, 0x2e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x2e, 0x6d, 0x73, 0x67, 0x2e, 0x74,
0x79, 0x70, 0x65, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x72, 0x6f, 0x79, 0x70, 0x65, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x72, 0x6f,
0x6d, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x72, 0x6f, 0x6d, 0x12, 0x0e, 0x0a, 0x6d, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x72, 0x6f, 0x6d, 0x12, 0x19, 0x0a,
0x02, 0x74, 0x6f, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x74, 0x6f, 0x12, 0x2e, 0x0a, 0x08, 0x66, 0x72, 0x6f, 0x6d, 0x5f, 0x69, 0x64, 0x78, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52,
0x03, 0x68, 0x73, 0x70, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6e, 0x65, 0x74, 0x07, 0x66, 0x72, 0x6f, 0x6d, 0x49, 0x64, 0x78, 0x12, 0x0e, 0x0a, 0x02, 0x74, 0x6f, 0x18, 0x04,
0x77, 0x6f, 0x72, 0x6b, 0x2e, 0x68, 0x61, 0x6e, 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x5f, 0x70, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x74, 0x6f, 0x12, 0x15, 0x0a, 0x06, 0x74, 0x6f, 0x5f, 0x69,
0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x48, 0x00, 0x52, 0x03, 0x68, 0x73, 0x70, 0x12, 0x2e, 0x0a, 0x64, 0x78, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x74, 0x6f, 0x49, 0x64, 0x78, 0x12,
0x04, 0x63, 0x72, 0x65, 0x71, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6e, 0x65, 0x2e, 0x0a, 0x03, 0x68, 0x73, 0x70, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6e,
0x74, 0x77, 0x6f, 0x72, 0x6b, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x5f, 0x72, 0x65, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x2e, 0x68, 0x61, 0x6e, 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x04, 0x63, 0x72, 0x65, 0x71, 0x12, 0x2f, 0x0a, 0x5f, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x48, 0x00, 0x52, 0x03, 0x68, 0x73, 0x70, 0x12,
0x04, 0x63, 0x72, 0x65, 0x70, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x6e, 0x65, 0x2e, 0x0a, 0x04, 0x63, 0x72, 0x65, 0x71, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x18, 0x2e,
0x74, 0x77, 0x6f, 0x72, 0x6b, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x5f, 0x72, 0x65, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x5f,
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x48, 0x00, 0x52, 0x04, 0x63, 0x72, 0x65, 0x70, 0x12, 0x36, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x04, 0x63, 0x72, 0x65, 0x71, 0x12,
0x0a, 0x0b, 0x5f, 0x64, 0x69, 0x73, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x18, 0x0d, 0x20, 0x2f, 0x0a, 0x04, 0x63, 0x72, 0x65, 0x70, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e,
0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x2e, 0x64, 0x69, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x5f,
0x73, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x48, 0x00, 0x52, 0x0a, 0x44, 0x69, 0x73, 0x63, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x48, 0x00, 0x52, 0x04, 0x63, 0x72, 0x65, 0x70,
0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x12, 0x24, 0x0a, 0x05, 0x5f, 0x64, 0x61, 0x74, 0x61, 0x18, 0x12, 0x36, 0x0a, 0x0b, 0x5f, 0x64, 0x69, 0x73, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x18,
0x0e, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x2e, 0x0d, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x2e,
0x64, 0x61, 0x74, 0x61, 0x48, 0x00, 0x52, 0x04, 0x44, 0x61, 0x74, 0x61, 0x22, 0x63, 0x0a, 0x04, 0x64, 0x69, 0x73, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x48, 0x00, 0x52, 0x0a, 0x44, 0x69,
0x74, 0x79, 0x70, 0x65, 0x12, 0x0d, 0x0a, 0x09, 0x68, 0x61, 0x6e, 0x64, 0x73, 0x68, 0x61, 0x6b, 0x73, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x12, 0x24, 0x0a, 0x05, 0x5f, 0x64, 0x61, 0x74,
0x65, 0x10, 0x00, 0x12, 0x0d, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x61, 0x6c, 0x69, 0x76, 0x65, 0x61, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72,
0x10, 0x01, 0x12, 0x0f, 0x0a, 0x0b, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x5f, 0x72, 0x65, 0x6b, 0x2e, 0x64, 0x61, 0x74, 0x61, 0x48, 0x00, 0x52, 0x04, 0x44, 0x61, 0x74, 0x61, 0x22, 0x63,
0x71, 0x10, 0x02, 0x12, 0x0f, 0x0a, 0x0b, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x5f, 0x72, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x0d, 0x0a, 0x09, 0x68, 0x61, 0x6e, 0x64, 0x73, 0x68,
0x65, 0x70, 0x10, 0x03, 0x12, 0x0e, 0x0a, 0x0a, 0x64, 0x69, 0x73, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x61, 0x6b, 0x65, 0x10, 0x00, 0x12, 0x0d, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x61, 0x6c, 0x69,
0x63, 0x74, 0x10, 0x04, 0x12, 0x0b, 0x0a, 0x07, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x10, 0x76, 0x65, 0x10, 0x01, 0x12, 0x0f, 0x0a, 0x0b, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x5f,
0x05, 0x42, 0x09, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x42, 0x0c, 0x5a, 0x0a, 0x72, 0x65, 0x71, 0x10, 0x02, 0x12, 0x0f, 0x0a, 0x0b, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74,
0x2e, 0x2f, 0x3b, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x5f, 0x72, 0x65, 0x70, 0x10, 0x03, 0x12, 0x0e, 0x0a, 0x0a, 0x64, 0x69, 0x73, 0x63, 0x6f, 0x6e,
0x6f, 0x33, 0x6e, 0x65, 0x63, 0x74, 0x10, 0x04, 0x12, 0x0b, 0x0a, 0x07, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72,
0x64, 0x10, 0x05, 0x42, 0x09, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x42, 0x0c,
0x5a, 0x0a, 0x2e, 0x2f, 0x3b, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x62, 0x06, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x33,
} }
var ( var (

View File

@@ -19,9 +19,11 @@ message msg {
disconnect = 4; disconnect = 4;
forward = 5; forward = 5;
} }
type _type = 1; type _type = 1;
string from = 2; string from = 2;
string to = 3; uint32 from_idx = 3;
string to = 4;
uint32 to_idx = 5;
oneof payload { oneof payload {
handshake_payload hsp = 10; handshake_payload hsp = 10;
connect_request creq = 11; connect_request creq = 11;

View File

@@ -11,23 +11,11 @@ import (
type client struct { type client struct {
sync.RWMutex sync.RWMutex
parent *Handler parent *clients
id string idx uint32
trimID string conn *network.Conn
c *network.Conn
links map[string]struct{} // link id => struct{}
updated time.Time updated time.Time
} links map[string]struct{} // link id => struct{}
func newClient(parent *Handler, id, trimID string, conn *network.Conn) *client {
return &client{
parent: parent,
id: id,
trimID: trimID,
c: conn,
links: make(map[string]struct{}),
updated: time.Now(),
}
} }
func (c *client) run() { func (c *client) run() {
@@ -39,25 +27,25 @@ func (c *client) run() {
links = append(links, id) links = append(links, id)
} }
c.RUnlock() c.RUnlock()
logging.Info("%s is not keepalived, links: %v", c.id, links) logging.Info("%s-%d is not keepalived, links: %v", c.parent.id, c.idx, links)
c.parent.closeAll(c) c.parent.parent.closeClient(c)
return return
} }
msg, err := c.c.ReadMessage(c.parent.cfg.ReadTimeout) msg, err := c.conn.ReadMessage(c.parent.parent.cfg.ReadTimeout)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "i/o timeout") { if strings.Contains(err.Error(), "i/o timeout") {
continue continue
} }
logging.Error("read message from %s: %v", c.id, err) logging.Error("read message from %s-%d: %v", c.parent.id, c.idx, err)
return return
} }
c.updated = time.Now() c.updated = time.Now()
c.parent.onMessage(c, c.c, msg) c.parent.parent.onMessage(c, c.conn, msg)
} }
} }
func (c *client) writeMessage(msg *network.Msg) error { func (c *client) writeMessage(msg *network.Msg) error {
return c.c.WriteMessage(msg, c.parent.cfg.WriteTimeout) return c.conn.WriteMessage(msg, c.parent.parent.cfg.WriteTimeout)
} }
func (c *client) addLink(id string) { func (c *client) addLink(id string) {
@@ -85,14 +73,15 @@ func (c *client) getLinks() []string {
func (c *client) close(id string) { func (c *client) close(id string) {
var msg network.Msg var msg network.Msg
msg.From = "server" msg.From = "server"
msg.To = c.id msg.To = c.parent.id
msg.ToIdx = c.idx
msg.XType = network.Msg_disconnect msg.XType = network.Msg_disconnect
msg.Payload = &network.Msg_XDisconnect{ msg.Payload = &network.Msg_XDisconnect{
XDisconnect: &network.Disconnect{ XDisconnect: &network.Disconnect{
Id: id, Id: id,
}, },
} }
c.c.WriteMessage(&msg, c.parent.cfg.WriteTimeout) c.conn.WriteMessage(&msg, c.parent.parent.cfg.WriteTimeout)
c.Lock() c.Lock()
delete(c.links, id) delete(c.links, id)
c.Unlock() c.Unlock()

View File

@@ -0,0 +1,58 @@
package handler
import (
"natpass/code/network"
"sync"
"time"
)
type clients struct {
sync.RWMutex
parent *Handler
id string
data map[uint32]*client // idx => client
idx int
}
func newClients(parent *Handler, id string) *clients {
return &clients{
parent: parent,
id: id,
data: make(map[uint32]*client),
}
}
func (cs *clients) new(idx uint32, conn *network.Conn) *client {
cli := &client{
parent: cs,
idx: idx,
conn: conn,
updated: time.Now(),
links: make(map[string]struct{}),
}
cs.Lock()
cs.data[idx] = cli
cs.Unlock()
return cli
}
func (cs *clients) next() *client {
list := make([]*client, 0, len(cs.data))
cs.RLock()
for _, cli := range cs.data {
list = append(list, cli)
}
cs.RUnlock()
if len(list) > 0 {
cli := list[cs.idx%len(list)]
cs.idx++
return cli
}
return nil
}
func (cs *clients) close(idx uint32) {
cs.Lock()
delete(cs.data, idx)
cs.Unlock()
}

View File

@@ -5,7 +5,6 @@ import (
"natpass/code/network" "natpass/code/network"
"natpass/code/server/global" "natpass/code/server/global"
"net" "net"
"strings"
"sync" "sync"
"time" "time"
@@ -16,7 +15,7 @@ import (
type Handler struct { type Handler struct {
cfg *global.Configure cfg *global.Configure
lockClients sync.RWMutex lockClients sync.RWMutex
clients map[string]*client // client id => client clients map[string]*clients // client id => client
lockLinks sync.RWMutex lockLinks sync.RWMutex
links map[string][2]*client // link id => endpoints links map[string][2]*client // link id => endpoints
idx int idx int
@@ -26,7 +25,7 @@ type Handler struct {
func New(cfg *global.Configure) *Handler { func New(cfg *global.Configure) *Handler {
return &Handler{ return &Handler{
cfg: cfg, cfg: cfg,
clients: make(map[string]*client), clients: make(map[string]*clients),
links: make(map[string][2]*client), links: make(map[string][2]*client),
idx: 0, idx: 0,
} }
@@ -36,6 +35,7 @@ func New(cfg *global.Configure) *Handler {
func (h *Handler) Handle(conn net.Conn) { func (h *Handler) Handle(conn net.Conn) {
c := network.NewConn(conn) c := network.NewConn(conn)
var id string var id string
var idx uint32
defer func() { defer func() {
if len(id) > 0 { if len(id) > 0 {
logging.Info("%s disconnected", id) logging.Info("%s disconnected", id)
@@ -44,7 +44,7 @@ func (h *Handler) Handle(conn net.Conn) {
}() }()
var err error var err error
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
id, err = h.readHandshake(c) id, idx, err = h.readHandshake(c)
if err != nil { if err != nil {
if err == errInvalidHandshake { if err == errInvalidHandshake {
logging.Error("invalid handshake from %s", c.RemoteAddr().String()) logging.Error("invalid handshake from %s", c.RemoteAddr().String())
@@ -60,75 +60,67 @@ func (h *Handler) Handle(conn net.Conn) {
} }
logging.Info("%s connected", id) logging.Info("%s connected", id)
// split id and index clients := h.tryGetClients(id)
trimID := id cli := clients.new(idx, c)
n := strings.LastIndex(id, "-")
if n != -1 {
trimID = id[:n]
}
cli := newClient(h, id, trimID, c) defer h.closeClient(cli)
h.lockClients.Lock()
h.clients[cli.id] = cli
h.lockClients.Unlock()
defer h.closeAll(cli)
cli.run() cli.run()
} }
func (h *Handler) connsByTrimID(id string) []*client { func (h *Handler) tryGetClients(id string) *clients {
ret := make([]*client, 0, 10) h.lockClients.Lock()
h.lockClients.RLock() defer h.lockClients.Unlock()
for _, cli := range h.clients { clients := h.clients[id]
if cli.trimID == id { if clients != nil {
ret = append(ret, cli) return clients
}
} }
h.lockClients.RUnlock() clients = newClients(h, id)
return ret h.clients[id] = clients
return clients
} }
func (h *Handler) getClient(linkID, targetID string) *client { // readHandshake read handshake message and compare secret encoded from md5
func (h *Handler) readHandshake(c *network.Conn) (string, uint32, error) {
msg, err := c.ReadMessage(5 * time.Second)
if err != nil {
return "", 0, err
}
if msg.GetXType() != network.Msg_handshake {
return "", 0, errNotHandshake
}
n := bytes.Compare(msg.GetHsp().GetEnc(), h.cfg.Enc[:])
if n != 0 {
return "", 0, errInvalidHandshake
}
return msg.GetFrom(), msg.GetFromIdx(), nil
}
func (h *Handler) getClient(linkID, to string, toIdx uint32) *client {
h.lockLinks.RLock() h.lockLinks.RLock()
pair := h.links[linkID] pair := h.links[linkID]
h.lockLinks.RUnlock() h.lockLinks.RUnlock()
if pair[0] != nil && pair[0].trimID == targetID { if pair[0] != nil && pair[0].idx == toIdx {
return pair[0] return pair[0]
} }
if pair[1] != nil && pair[1].trimID == targetID { if pair[1] != nil && pair[1].idx == toIdx {
return pair[1] return pair[1]
} }
conns := h.connsByTrimID(targetID) h.lockClients.RLock()
if len(conns) == 0 { clients := h.clients[to]
h.lockClients.RUnlock()
if clients == nil {
return nil return nil
} }
conn := conns[h.idx%len(conns)] return clients.next()
h.idx++
return conn
} }
// readHandshake read handshake message and compare secret encoded from md5
func (h *Handler) readHandshake(c *network.Conn) (string, error) {
msg, err := c.ReadMessage(5 * time.Second)
if err != nil {
return "", err
}
if msg.GetXType() != network.Msg_handshake {
return "", errNotHandshake
}
n := bytes.Compare(msg.GetHsp().GetEnc(), h.cfg.Enc[:])
if n != 0 {
return "", errInvalidHandshake
}
return msg.GetFrom(), nil
}
// onMessage forward message
func (h *Handler) onMessage(from *client, conn *network.Conn, msg *network.Msg) { func (h *Handler) onMessage(from *client, conn *network.Conn, msg *network.Msg) {
to := msg.GetTo() to := msg.GetTo()
toIdx := msg.GetToIdx()
var linkID string var linkID string
switch msg.GetXType() { switch msg.GetXType() {
case network.Msg_connect_req: case network.Msg_connect_req:
@@ -142,9 +134,9 @@ func (h *Handler) onMessage(from *client, conn *network.Conn, msg *network.Msg)
default: default:
return return
} }
cli := h.getClient(linkID, to) cli := h.getClient(linkID, to, toIdx)
if cli == nil { if cli == nil {
logging.Error("client %s not found", to) logging.Error("client %s-%d not found", to, toIdx)
return return
} }
h.msgHook(msg, from, cli) h.msgHook(msg, from, cli)
@@ -180,11 +172,13 @@ func (h *Handler) msgHook(msg *network.Msg, from, to *client) {
delete(h.links, id) delete(h.links, id)
h.lockLinks.Unlock() h.lockLinks.Unlock()
} }
msg.From = from.trimID msg.From = from.parent.id
msg.FromIdx = from.idx
msg.To = to.parent.id
msg.ToIdx = to.idx
} }
// closeAll close all links from client func (h *Handler) closeClient(cli *client) {
func (h *Handler) closeAll(cli *client) {
links := cli.getLinks() links := cli.getLinks()
for _, t := range links { for _, t := range links {
h.lockLinks.RLock() h.lockLinks.RLock()
@@ -200,7 +194,13 @@ func (h *Handler) closeAll(cli *client) {
delete(h.links, t) delete(h.links, t)
h.lockLinks.Unlock() h.lockLinks.Unlock()
} }
h.lockClients.Lock() h.lockClients.RLock()
delete(h.clients, cli.id) clients := h.clients[cli.parent.id]
h.lockClients.Unlock() h.lockClients.RUnlock()
clients.close(cli.idx)
if len(clients.data) == 0 {
h.lockClients.Lock()
delete(h.clients, clients.id)
h.lockClients.Unlock()
}
} }