feat(p2p): add more link to tcp

This commit is contained in:
源文雨
2024-08-06 20:30:33 +08:00
parent ea768f88f9
commit b71a0541bd
8 changed files with 198 additions and 47 deletions

View File

@@ -199,7 +199,7 @@ func (p *Packet) FillHash() {
h := blake2b.New256() h := blake2b.New256()
_, err := h.Write(p.Body()) _, err := h.Write(p.Body())
if err != nil { if err != nil {
logrus.Error("[packet] err when fill hash:", err) logrus.Errorln("[packet] err when fill hash:", err)
return return
} }
hsh := h.Sum(p.Hash[:0]) hsh := h.Sum(p.Hash[:0])
@@ -213,7 +213,7 @@ func (p *Packet) IsVaildHash() bool {
h := blake2b.New256() h := blake2b.New256()
_, err := h.Write(p.Body()) _, err := h.Write(p.Body())
if err != nil { if err != nil {
logrus.Error("[packet] err when check hash:", err) logrus.Errorln("[packet] err when check hash:", err)
return false return false
} }
var sum [32]byte var sum [32]byte

View File

@@ -47,7 +47,9 @@ func (m *Me) wait(data []byte) *head.Packet {
logrus.Debugf("[recv] packet crc %016x, seq %08x, xored crc %016x", crclog, seq, crc) logrus.Debugf("[recv] packet crc %016x, seq %08x, xored crc %016x", crclog, seq, crc)
} }
if m.recved.Get(crc) { if m.recved.Get(crc) {
logrus.Warnln("[recv] ignore duplicated crc packet", strconv.FormatUint(crc, 16)) if config.ShowDebugLog {
logrus.Debugln("[recv] ignore duplicated crc packet", strconv.FormatUint(crc, 16))
}
return nil return nil
} }
if config.ShowDebugLog { if config.ShowDebugLog {

View File

@@ -12,6 +12,7 @@ import (
type Config struct { type Config struct {
DialTimeout time.Duration DialTimeout time.Duration
PeersTimeout time.Duration PeersTimeout time.Duration
KeepInterval time.Duration
ReceiveChannelSize int ReceiveChannelSize int
} }
@@ -34,6 +35,7 @@ func newEndpoint(endpoint string, configs ...any) (*EndPoint, error) {
addr: net.TCPAddrFromAddrPort(addr), addr: net.TCPAddrFromAddrPort(addr),
dialtimeout: cfg.DialTimeout, dialtimeout: cfg.DialTimeout,
peerstimeout: cfg.PeersTimeout, peerstimeout: cfg.PeersTimeout,
keepinterval: cfg.KeepInterval,
recvchansize: cfg.ReceiveChannelSize, recvchansize: cfg.ReceiveChannelSize,
}, nil }, nil
} }

View File

@@ -21,6 +21,7 @@ type packetType uint8
const ( const (
packetTypeKeepAlive packetType = iota packetTypeKeepAlive packetType = iota
packetTypeNormal packetTypeNormal
packetTypeSubKeepAlive
packetTypeTop packetTypeTop
) )
@@ -87,7 +88,7 @@ func (p *packet) WriteTo(w io.Writer) (n int64, err error) {
return io.Copy(w, &buf) return io.Copy(w, &buf)
} }
func isvalid(tcpconn *net.TCPConn) bool { func isvalid(tcpconn *net.TCPConn) (issub, ok bool) {
pckt := packet{} pckt := packet{}
stopch := make(chan struct{}) stopch := make(chan struct{})
@@ -107,7 +108,7 @@ func isvalid(tcpconn *net.TCPConn) bool {
if config.ShowDebugLog { if config.ShowDebugLog {
logrus.Debugln("[tcp] validate recv from", tcpconn.RemoteAddr(), "timeout") logrus.Debugln("[tcp] validate recv from", tcpconn.RemoteAddr(), "timeout")
} }
return false return
case <-copych: case <-copych:
t.Stop() t.Stop()
} }
@@ -116,17 +117,17 @@ func isvalid(tcpconn *net.TCPConn) bool {
if config.ShowDebugLog { if config.ShowDebugLog {
logrus.Debugln("[tcp] validate recv from", tcpconn.RemoteAddr(), "err:", err) logrus.Debugln("[tcp] validate recv from", tcpconn.RemoteAddr(), "err:", err)
} }
return false return
} }
if pckt.typ != packetTypeKeepAlive { if pckt.typ != packetTypeKeepAlive && pckt.typ != packetTypeSubKeepAlive {
if config.ShowDebugLog { if config.ShowDebugLog {
logrus.Debugln("[tcp] validate got invalid typ", pckt.typ, "from", tcpconn.RemoteAddr()) logrus.Debugln("[tcp] validate got invalid typ", pckt.typ, "from", tcpconn.RemoteAddr())
} }
return false return
} }
if config.ShowDebugLog { if config.ShowDebugLog {
logrus.Debugln("[tcp] passed validate recv from", tcpconn.RemoteAddr()) logrus.Debugln("[tcp] passed validate recv from", tcpconn.RemoteAddr())
} }
return true return pckt.typ == packetTypeSubKeepAlive, true
} }

View File

@@ -21,6 +21,7 @@ type EndPoint struct {
addr *net.TCPAddr addr *net.TCPAddr
dialtimeout time.Duration dialtimeout time.Duration
peerstimeout time.Duration peerstimeout time.Duration
keepinterval time.Duration
recvchansize int recvchansize int
} }
@@ -80,6 +81,7 @@ func (ep *EndPoint) Listen() (p2p.Conn, error) {
}), }),
recv: make(chan *connrecv, chansz), recv: make(chan *connrecv, chansz),
cplk: &sync.Mutex{}, cplk: &sync.Mutex{},
sblk: &sync.RWMutex{},
} }
go conn.accept() go conn.accept()
return conn, nil return conn, nil
@@ -91,6 +93,11 @@ type connrecv struct {
pckt packet pckt packet
} }
type subconn struct {
cplk sync.Mutex
conn *net.TCPConn
}
// Conn 伪装成无状态的有状态连接 // Conn 伪装成无状态的有状态连接
type Conn struct { type Conn struct {
addr *EndPoint addr *EndPoint
@@ -98,6 +105,8 @@ type Conn struct {
peers *ttl.Cache[string, *net.TCPConn] peers *ttl.Cache[string, *net.TCPConn]
recv chan *connrecv recv chan *connrecv
cplk *sync.Mutex cplk *sync.Mutex
sblk *sync.RWMutex
subs []*subconn
} }
func (conn *Conn) accept() { func (conn *Conn) accept() {
@@ -115,32 +124,54 @@ func (conn *Conn) accept() {
_ = conn.Close() _ = conn.Close()
newc, err := conn.addr.Listen() newc, err := conn.addr.Listen()
if err != nil { if err != nil {
logrus.Warn("[tcp] re-listen on", conn.addr, "err:", err) logrus.Warnln("[tcp] re-listen on", conn.addr, "err:", err)
return return
} }
*conn = *newc.(*Conn) *conn = *newc.(*Conn)
logrus.Info("[tcp] re-listen on", conn.addr) logrus.Infoln("[tcp] re-listen on", conn.addr)
continue continue
} }
go conn.receive(tcpconn, false) go conn.receive(tcpconn, false)
} }
} }
func delsubs(i int, subs []*subconn) []*subconn {
switch i {
case 0:
subs = subs[1:]
case len(subs) - 1:
subs = subs[:len(subs)-1]
default:
subs = append(subs[:i], subs[i+1:]...)
}
return subs
}
func (conn *Conn) receive(tcpconn *net.TCPConn, hasvalidated bool) { func (conn *Conn) receive(tcpconn *net.TCPConn, hasvalidated bool) {
ep, _ := newEndpoint(tcpconn.RemoteAddr().String(), &Config{ ep, _ := newEndpoint(tcpconn.RemoteAddr().String(), &Config{
DialTimeout: conn.addr.dialtimeout, DialTimeout: conn.addr.dialtimeout,
PeersTimeout: conn.addr.peerstimeout, PeersTimeout: conn.addr.peerstimeout,
KeepInterval: conn.addr.keepinterval,
ReceiveChannelSize: conn.addr.recvchansize, ReceiveChannelSize: conn.addr.recvchansize,
}) })
issub, ok := false, false
if !hasvalidated { if !hasvalidated {
if !isvalid(tcpconn) { issub, ok = isvalid(tcpconn)
if !ok {
return return
} }
if config.ShowDebugLog { if config.ShowDebugLog {
logrus.Debugln("[tcp] accept from", ep) logrus.Debugln("[tcp] accept from", ep, "issub:", issub)
}
if issub {
conn.sblk.Lock()
conn.subs = append(conn.subs, &subconn{conn: tcpconn})
conn.sblk.Unlock()
} else {
conn.peers.Set(ep.String(), tcpconn)
} }
conn.peers.Set(ep.String(), tcpconn)
} }
peerstimeout := conn.addr.peerstimeout peerstimeout := conn.addr.peerstimeout
@@ -148,15 +179,33 @@ func (conn *Conn) receive(tcpconn *net.TCPConn, hasvalidated bool) {
peerstimeout = time.Second * 30 peerstimeout = time.Second * 30
} }
peerstimeout *= 2 peerstimeout *= 2
defer conn.peers.Delete(ep.String()) if issub {
defer conn.peers.Delete(ep.String())
} else {
defer func() {
conn.sblk.Lock()
for i, sub := range conn.subs {
if sub.conn == tcpconn {
conn.subs = delsubs(i, conn.subs)
break
}
}
conn.sblk.Unlock()
}()
}
go conn.keep(ep)
for { for {
r := &connrecv{addr: ep} r := &connrecv{addr: ep}
if conn.addr == nil || conn.lstn == nil || conn.peers == nil || conn.recv == nil { if conn.addr == nil || conn.lstn == nil || conn.peers == nil || conn.recv == nil {
return return
} }
tcpconn := conn.peers.Get(ep.String()) if !issub {
if tcpconn == nil { tcpconn = conn.peers.Get(ep.String())
return if tcpconn == nil {
return
}
} }
r.conn = tcpconn r.conn = tcpconn
@@ -204,6 +253,46 @@ func (conn *Conn) receive(tcpconn *net.TCPConn, hasvalidated bool) {
} }
} }
func (conn *Conn) keep(ep *EndPoint) {
keepinterval := ep.keepinterval
if keepinterval < time.Second*4 {
keepinterval = time.Second * 4
}
t := time.NewTicker(keepinterval)
defer t.Stop()
for range t.C {
if conn.addr == nil {
return
}
tcpconn := conn.peers.Get(ep.String())
if tcpconn != nil {
_, err := io.Copy(tcpconn, &packet{typ: packetTypeKeepAlive})
if conn.addr == nil {
return
}
if err != nil {
logrus.Warnln("[tcp] keep main conn alive to", ep, "err:", err)
conn.peers.Delete(ep.String())
} else if config.ShowDebugLog {
logrus.Debugln("[tcp] keep main conn alive to", ep)
}
}
conn.sblk.RLock()
for i, sub := range conn.subs {
_, err := io.Copy(sub.conn, &packet{typ: packetTypeSubKeepAlive})
if conn.addr == nil {
return
}
if err != nil {
logrus.Warnln("[tcp] keep sub conn alive to", sub.conn.RemoteAddr(), "err:", err)
conn.subs = delsubs(i, conn.subs) // del 1 link at once
break
}
}
conn.sblk.RUnlock()
}
}
func (conn *Conn) Close() error { func (conn *Conn) Close() error {
if conn.lstn != nil { if conn.lstn != nil {
_ = conn.lstn.Close() _ = conn.lstn.Close()
@@ -246,20 +335,28 @@ func (conn *Conn) ReadFromPeer(b []byte) (int, p2p.EndPoint, error) {
return n, p.addr, nil return n, p.addr, nil
} }
func (conn *Conn) WriteToPeer(b []byte, ep p2p.EndPoint) (n int, err error) { // writeToPeer after acquiring lock
tcpep, ok := ep.(*EndPoint) func (conn *Conn) writeToPeer(b []byte, tcpep *EndPoint, issub bool) (n int, err error) {
if !ok {
return 0, p2p.ErrEndpointTypeMistatch
}
blen := len(b)
if blen >= 65536 {
return 0, errors.New("data size " + strconv.Itoa(blen) + " is too large")
}
retried := false retried := false
conn.cplk.Lock() ok := false
defer conn.cplk.Unlock() var (
tcpconn := conn.peers.Get(tcpep.String()) tcpconn *net.TCPConn
subc *subconn
)
RECONNECT: RECONNECT:
if issub {
conn.sblk.RLock()
for _, sub := range conn.subs {
if sub.cplk.TryLock() {
tcpconn = sub.conn
subc = sub
break
}
}
conn.sblk.RUnlock()
} else {
tcpconn = conn.peers.Get(tcpep.String())
}
if tcpconn == nil { if tcpconn == nil {
dialtimeout := tcpep.dialtimeout dialtimeout := tcpep.dialtimeout
if dialtimeout < time.Second { if dialtimeout < time.Second {
@@ -278,9 +375,13 @@ RECONNECT:
if !ok { if !ok {
return 0, errors.New("expect *net.TCPConn but got " + reflect.ValueOf(cn).Type().String()) return 0, errors.New("expect *net.TCPConn but got " + reflect.ValueOf(cn).Type().String())
} }
_, err = io.Copy(tcpconn, &packet{ pkt := &packet{}
typ: packetTypeKeepAlive, if issub {
}) pkt.typ = packetTypeSubKeepAlive
} else {
pkt.typ = packetTypeKeepAlive
}
_, err = io.Copy(tcpconn, pkt)
if err != nil { if err != nil {
if config.ShowDebugLog { if config.ShowDebugLog {
logrus.Debugln("[tcp] dial to", tcpep.addr, "success, but write err:", err) logrus.Debugln("[tcp] dial to", tcpep.addr, "success, but write err:", err)
@@ -290,23 +391,58 @@ RECONNECT:
if config.ShowDebugLog { if config.ShowDebugLog {
logrus.Debugln("[tcp] dial to", tcpep.addr, "success, local:", tcpconn.LocalAddr()) logrus.Debugln("[tcp] dial to", tcpep.addr, "success, local:", tcpconn.LocalAddr())
} }
conn.peers.Set(tcpep.String(), tcpconn) if !issub {
go conn.receive(tcpconn, true) conn.peers.Set(tcpep.String(), tcpconn)
} else {
conn.sblk.Lock()
conn.subs = append(conn.subs, &subconn{conn: tcpconn})
conn.sblk.Unlock()
go conn.receive(tcpconn, true)
}
} else if config.ShowDebugLog { } else if config.ShowDebugLog {
logrus.Debugln("[tcp] reuse tcpconn from", tcpconn.LocalAddr(), "to", tcpconn.RemoteAddr()) logrus.Debugln("[tcp] reuse tcpconn from", tcpconn.LocalAddr(), "to", tcpconn.RemoteAddr())
} }
cnt, err := io.Copy(tcpconn, &packet{ cnt, err := io.Copy(tcpconn, &packet{
typ: packetTypeNormal, typ: packetTypeNormal,
len: uint16(blen), len: uint16(len(b)),
dat: b, dat: b,
}) })
if err != nil { if err != nil {
conn.peers.Delete(tcpep.String()) if subc == nil {
conn.peers.Delete(tcpep.String())
} else {
conn.sblk.Lock()
for i, sub := range conn.subs {
if sub == subc {
conn.subs = delsubs(i, conn.subs)
break
}
}
conn.sblk.Unlock()
}
if !retried { if !retried {
retried = true retried = true
tcpconn = nil tcpconn = nil
goto RECONNECT goto RECONNECT
} }
} }
if subc != nil {
subc.cplk.Unlock()
}
return int(cnt) - 3, err return int(cnt) - 3, err
} }
func (conn *Conn) WriteToPeer(b []byte, ep p2p.EndPoint) (n int, err error) {
tcpep, ok := ep.(*EndPoint)
if !ok {
return 0, p2p.ErrEndpointTypeMistatch
}
if len(b) >= 65536 {
return 0, errors.New("data size " + strconv.Itoa(len(b)) + " is too large")
}
if !conn.cplk.TryLock() {
return conn.writeToPeer(b, tcpep, true)
}
defer conn.cplk.Unlock()
return conn.writeToPeer(b, tcpep, false)
}

View File

@@ -27,7 +27,7 @@ type NICIO struct {
func NewNIC(ip net.IP, subnet *net.IPNet, mtu string, cidrs ...string) *NICIO { func NewNIC(ip net.IP, subnet *net.IPNet, mtu string, cidrs ...string) *NICIO {
ifce, err := water.New(water.Config{DeviceType: water.TUN}) ifce, err := water.New(water.Config{DeviceType: water.TUN})
if err != nil { if err != nil {
logrus.Error(err) logrus.Errorln(err)
os.Exit(1) os.Exit(1)
} }
subn, bitsn := subnet.Mask.Size() subn, bitsn := subnet.Mask.Size()

View File

@@ -107,7 +107,7 @@ func testTunnel(t *testing.T, nw string, isplain bool, pshk *[32]byte, mtu uint1
time.Sleep(time.Second) // wait link up time.Sleep(time.Second) // wait link up
sendb := ([]byte)("1234") sendb := ([]byte)("1234")
tunnme.Write(sendb) go tunnme.Write(sendb)
buf := make([]byte, 4) buf := make([]byte, 4)
tunnpeer.Read(buf) tunnpeer.Read(buf)
if string(sendb) != string(buf) { if string(sendb) != string(buf) {
@@ -117,7 +117,7 @@ func testTunnel(t *testing.T, nw string, isplain bool, pshk *[32]byte, mtu uint1
sendb = make([]byte, 4096) sendb = make([]byte, 4096)
rand.Read(sendb) rand.Read(sendb)
tunnme.Write(sendb) go tunnme.Write(sendb)
buf = make([]byte, 4096) buf = make([]byte, 4096)
_, err = io.ReadFull(&tunnpeer, buf) _, err = io.ReadFull(&tunnpeer, buf)
if err != nil { if err != nil {
@@ -127,13 +127,22 @@ func testTunnel(t *testing.T, nw string, isplain bool, pshk *[32]byte, mtu uint1
t.Fatal("error: recv 4096 bytes data") t.Fatal("error: recv 4096 bytes data")
} }
sendb = make([]byte, 65535) sendbufs := make(chan []byte, 32)
go func() {
for i := 0; i < 32; i++ {
sendb := make([]byte, 65535)
rand.Read(sendb)
n, _ := tunnme.Write(sendb)
sendbufs <- sendb
t.Log("loop", i, "write", n, "bytes")
}
close(sendbufs)
}()
buf = make([]byte, 65535) buf = make([]byte, 65535)
for i := 0; i < 32; i++ { i := 0
rand.Read(sendb) for sendb := range sendbufs {
n, _ := tunnme.Write(sendb) n, err := io.ReadFull(&tunnpeer, buf)
t.Log("loop", i, "write", n, "bytes")
n, err = io.ReadFull(&tunnpeer, buf)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -141,6 +150,7 @@ func testTunnel(t *testing.T, nw string, isplain bool, pshk *[32]byte, mtu uint1
if string(sendb) != string(buf) { if string(sendb) != string(buf) {
t.Fatal("loop", i, "error: recv 65535 bytes data") t.Fatal("loop", i, "error: recv 65535 bytes data")
} }
i++
} }
rand.Read(sendb) rand.Read(sendb)

View File

@@ -59,7 +59,7 @@ func (wg *WG) Start(srcport, destport uint16) {
func (wg *WG) Run(srcport, destport uint16) { func (wg *WG) Run(srcport, destport uint16) {
wg.init(srcport, destport) wg.init(srcport, destport)
_, _ = wg.me.ListenNIC() _, _ = wg.me.ListenNIC()
logrus.Info("[wg] stopped") logrus.Infoln("[wg] stopped")
} }
func (wg *WG) Stop() { func (wg *WG) Stop() {