Fix race codes

This commit is contained in:
世界
2025-09-12 18:02:37 +08:00
parent e2503223dc
commit e229d7041e

View File

@@ -11,6 +11,7 @@ import (
) )
type TCPNat struct { type TCPNat struct {
timeout time.Duration
portIndex uint16 portIndex uint16
portAccess sync.RWMutex portAccess sync.RWMutex
addrAccess sync.RWMutex addrAccess sync.RWMutex
@@ -19,6 +20,7 @@ type TCPNat struct {
} }
type TCPSession struct { type TCPSession struct {
sync.Mutex
Source netip.AddrPort Source netip.AddrPort
Destination netip.AddrPort Destination netip.AddrPort
LastActive time.Time LastActive time.Time
@@ -26,38 +28,41 @@ type TCPSession struct {
func NewNat(ctx context.Context, timeout time.Duration) *TCPNat { func NewNat(ctx context.Context, timeout time.Duration) *TCPNat {
natMap := &TCPNat{ natMap := &TCPNat{
timeout: timeout,
portIndex: 10000, portIndex: 10000,
addrMap: make(map[netip.AddrPort]uint16), addrMap: make(map[netip.AddrPort]uint16),
portMap: make(map[uint16]*TCPSession), portMap: make(map[uint16]*TCPSession),
} }
go natMap.loopCheckTimeout(ctx, timeout) go natMap.loopCheckTimeout(ctx)
return natMap return natMap
} }
func (n *TCPNat) loopCheckTimeout(ctx context.Context, timeout time.Duration) { func (n *TCPNat) loopCheckTimeout(ctx context.Context) {
ticker := time.NewTicker(timeout) ticker := time.NewTicker(n.timeout)
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
n.checkTimeout(timeout) n.checkTimeout()
case <-ctx.Done(): case <-ctx.Done():
return return
} }
} }
} }
func (n *TCPNat) checkTimeout(timeout time.Duration) { func (n *TCPNat) checkTimeout() {
now := time.Now() now := time.Now()
n.portAccess.Lock() n.portAccess.Lock()
defer n.portAccess.Unlock() defer n.portAccess.Unlock()
n.addrAccess.Lock() n.addrAccess.Lock()
defer n.addrAccess.Unlock() defer n.addrAccess.Unlock()
for natPort, session := range n.portMap { for natPort, session := range n.portMap {
if now.Sub(session.LastActive) > timeout { session.Lock()
if now.Sub(session.LastActive) > n.timeout {
delete(n.addrMap, session.Source) delete(n.addrMap, session.Source)
delete(n.portMap, natPort) delete(n.portMap, natPort)
} }
session.Unlock()
} }
} }
@@ -66,7 +71,11 @@ func (n *TCPNat) LookupBack(port uint16) *TCPSession {
session := n.portMap[port] session := n.portMap[port]
n.portAccess.RUnlock() n.portAccess.RUnlock()
if session != nil { if session != nil {
session.LastActive = time.Now() session.Lock()
if time.Since(session.LastActive) > time.Second {
session.LastActive = time.Now()
}
session.Unlock()
} }
return session return session
} }