diff --git a/stack_system_nat.go b/stack_system_nat.go index 1d0216e..66240a6 100644 --- a/stack_system_nat.go +++ b/stack_system_nat.go @@ -11,6 +11,7 @@ import ( ) type TCPNat struct { + timeout time.Duration portIndex uint16 portAccess sync.RWMutex addrAccess sync.RWMutex @@ -19,6 +20,7 @@ type TCPNat struct { } type TCPSession struct { + sync.Mutex Source netip.AddrPort Destination netip.AddrPort LastActive time.Time @@ -26,38 +28,41 @@ type TCPSession struct { func NewNat(ctx context.Context, timeout time.Duration) *TCPNat { natMap := &TCPNat{ + timeout: timeout, portIndex: 10000, addrMap: make(map[netip.AddrPort]uint16), portMap: make(map[uint16]*TCPSession), } - go natMap.loopCheckTimeout(ctx, timeout) + go natMap.loopCheckTimeout(ctx) return natMap } -func (n *TCPNat) loopCheckTimeout(ctx context.Context, timeout time.Duration) { - ticker := time.NewTicker(timeout) +func (n *TCPNat) loopCheckTimeout(ctx context.Context) { + ticker := time.NewTicker(n.timeout) defer ticker.Stop() for { select { case <-ticker.C: - n.checkTimeout(timeout) + n.checkTimeout() case <-ctx.Done(): return } } } -func (n *TCPNat) checkTimeout(timeout time.Duration) { +func (n *TCPNat) checkTimeout() { now := time.Now() n.portAccess.Lock() defer n.portAccess.Unlock() n.addrAccess.Lock() defer n.addrAccess.Unlock() for natPort, session := range n.portMap { - if now.Sub(session.LastActive) > timeout { + session.Lock() + if now.Sub(session.LastActive) > n.timeout { delete(n.addrMap, session.Source) delete(n.portMap, natPort) } + session.Unlock() } } @@ -66,7 +71,11 @@ func (n *TCPNat) LookupBack(port uint16) *TCPSession { session := n.portMap[port] n.portAccess.RUnlock() if session != nil { - session.LastActive = time.Now() + session.Lock() + if time.Since(session.LastActive) > time.Second { + session.LastActive = time.Now() + } + session.Unlock() } return session }