nat: use a single Go routine to renew NAT mappings

This commit is contained in:
Marten Seemann
2023-04-03 22:27:04 +09:00
parent f5cbaf1721
commit c62be4081c
3 changed files with 107 additions and 95 deletions

View File

@@ -30,6 +30,11 @@ func NewNATManager(net network.Network) NATManager {
return newNatManager(net) return newNatManager(net)
} }
type entry struct {
protocol string
port int
}
// natManager takes care of adding + removing port mappings to the nat. // natManager takes care of adding + removing port mappings to the nat.
// Initialized with the host if it has a NATPortMap option enabled. // Initialized with the host if it has a NATPortMap option enabled.
// natManager receives signals from the network, and check on nat mappings: // natManager receives signals from the network, and check on nat mappings:
@@ -42,7 +47,9 @@ type natManager struct {
nat *inat.NAT nat *inat.NAT
ready chan struct{} // closed once the nat is ready to process port mappings ready chan struct{} // closed once the nat is ready to process port mappings
syncFlag chan struct{} syncFlag chan struct{} // cap: 1
tracked map[entry]bool // the bool is only used in doSync and has no meaning outside of that function
refCount sync.WaitGroup refCount sync.WaitGroup
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
@@ -55,6 +62,7 @@ func newNatManager(net network.Network) *natManager {
ready: make(chan struct{}), ready: make(chan struct{}),
syncFlag: make(chan struct{}, 1), syncFlag: make(chan struct{}, 1),
ctxCancel: cancel, ctxCancel: cancel,
tracked: make(map[entry]bool),
} }
nmgr.refCount.Add(1) nmgr.refCount.Add(1)
go nmgr.background(ctx) go nmgr.background(ctx)
@@ -127,10 +135,10 @@ func (nmgr *natManager) sync() {
// doSync syncs the current NAT mappings, removing any outdated mappings and adding any // doSync syncs the current NAT mappings, removing any outdated mappings and adding any
// new mappings. // new mappings.
func (nmgr *natManager) doSync() { func (nmgr *natManager) doSync() {
ports := map[string]map[int]bool{ for e := range nmgr.tracked {
"tcp": {}, nmgr.tracked[e] = false
"udp": {},
} }
var newAddresses []entry
for _, maddr := range nmgr.net.ListenAddresses() { for _, maddr := range nmgr.net.ListenAddresses() {
// Strip the IP // Strip the IP
maIP, rest := ma.SplitFirst(maddr) maIP, rest := ma.SplitFirst(maddr)
@@ -166,48 +174,36 @@ func (nmgr *natManager) doSync() {
default: default:
continue continue
} }
port, err := strconv.ParseUint(proto.Value(), 10, 16) port, err := strconv.ParseUint(proto.Value(), 10, 16)
if err != nil { if err != nil {
// bug in multiaddr // bug in multiaddr
panic(err) panic(err)
} }
ports[protocol][int(port)] = false e := entry{protocol: protocol, port: int(port)}
if _, ok := nmgr.tracked[e]; ok {
nmgr.tracked[e] = true
} else {
newAddresses = append(newAddresses, e)
}
} }
var wg sync.WaitGroup var wg sync.WaitGroup
defer wg.Wait() defer wg.Wait()
// Close old mappings // Close old mappings
for _, m := range nmgr.nat.Mappings() { for e, v := range nmgr.tracked {
mappedPort := m.InternalPort() if !v {
if _, ok := ports[m.Protocol()][mappedPort]; !ok { nmgr.nat.RemoveMapping(e.protocol, e.port)
// No longer need this mapping. delete(nmgr.tracked, e)
wg.Add(1)
go func(m inat.Mapping) {
defer wg.Done()
m.Close()
}(m)
} else {
// already mapped
ports[m.Protocol()][mappedPort] = true
} }
} }
// Create new mappings. // Create new mappings.
for proto, pports := range ports { for _, e := range newAddresses {
for port, mapped := range pports { if err := nmgr.nat.AddMapping(e.protocol, e.port); err != nil {
if mapped { log.Errorf("failed to port-map %s port %d: %s", e.protocol, e.port, err)
continue
}
wg.Add(1)
go func(proto string, port int) {
defer wg.Done()
if err := nmgr.nat.AddMapping(proto, port); err != nil {
log.Errorf("failed to port-map %s port %d: %s", proto, port, err)
}
}(proto, port)
} }
nmgr.tracked[e] = false
} }
} }

View File

@@ -24,9 +24,6 @@ type Mapping interface {
// ExternalAddr returns the external facing address. If the mapping is not // ExternalAddr returns the external facing address. If the mapping is not
// established, addr will be nil, and and ErrNoMapping will be returned. // established, addr will be nil, and and ErrNoMapping will be returned.
ExternalAddr() (addr net.Addr, err error) ExternalAddr() (addr net.Addr, err error)
// Close closes the port mapping
Close() error
} }
// keeps republishing // keeps republishing
@@ -103,8 +100,3 @@ func (m *mapping) ExternalAddr() (net.Addr, error) {
panic(fmt.Sprintf("invalid protocol %q", m.Protocol())) panic(fmt.Sprintf("invalid protocol %q", m.Protocol()))
} }
} }
func (m *mapping) Close() error {
m.nat.removeMapping(m)
return nil
}

View File

@@ -24,6 +24,11 @@ const MappingDuration = time.Second * 60
// CacheTime is the time a mapping will cache an external address for // CacheTime is the time a mapping will cache an external address for
const CacheTime = time.Second * 15 const CacheTime = time.Second * 15
type entry struct {
protocol string
port int
}
// DiscoverNAT looks for a NAT device in the network and // DiscoverNAT looks for a NAT device in the network and
// returns an object that can manage port mappings. // returns an object that can manage port mappings.
func DiscoverNAT(ctx context.Context) (*NAT, error) { func DiscoverNAT(ctx context.Context) (*NAT, error) {
@@ -40,7 +45,19 @@ func DiscoverNAT(ctx context.Context) (*NAT, error) {
log.Debug("DiscoverGateway address:", addr) log.Debug("DiscoverGateway address:", addr)
} }
return newNAT(natInstance), nil ctx, cancel := context.WithCancel(context.Background())
nat := &NAT{
nat: natInstance,
mappings: make(map[entry]int),
ctx: ctx,
ctxCancel: cancel,
}
nat.refCount.Add(1)
go func() {
defer nat.refCount.Done()
nat.background()
}()
return nat, nil
} }
// NAT is an object that manages address port mappings in // NAT is an object that manages address port mappings in
@@ -57,17 +74,7 @@ type NAT struct {
mappingmu sync.RWMutex // guards mappings mappingmu sync.RWMutex // guards mappings
closed bool closed bool
mappings map[*mapping]struct{} mappings map[entry]int
}
func newNAT(realNAT nat.NAT) *NAT {
ctx, cancel := context.WithCancel(context.Background())
return &NAT{
nat: realNAT,
mappings: make(map[*mapping]struct{}),
ctx: ctx,
ctxCancel: cancel,
}
} }
// Close shuts down all port mappings. NAT can no longer be used. // Close shuts down all port mappings. NAT can no longer be used.
@@ -84,94 +91,114 @@ func (nat *NAT) Close() error {
// Mappings returns a slice of all NAT mappings // Mappings returns a slice of all NAT mappings
func (nat *NAT) Mappings() []Mapping { func (nat *NAT) Mappings() []Mapping {
nat.mappingmu.Lock() nat.mappingmu.Lock()
defer nat.mappingmu.Unlock()
maps2 := make([]Mapping, 0, len(nat.mappings)) maps2 := make([]Mapping, 0, len(nat.mappings))
for m := range nat.mappings { for e, extPort := range nat.mappings {
maps2 = append(maps2, m) maps2 = append(maps2, &mapping{
nat: nat,
proto: e.protocol,
intport: e.port,
extport: extPort,
})
} }
nat.mappingmu.Unlock()
return maps2 return maps2
} }
// AddMapping attempts to construct a mapping on protocol and internal port // AddMapping attempts to construct a mapping on protocol and internal port
// It will also periodically renew the mapping until the returned Mapping // It will also periodically renew the mapping.
// -- or its parent NAT -- is Closed.
// //
// May not succeed, and mappings may change over time; // May not succeed, and mappings may change over time;
// NAT devices may not respect our port requests, and even lie. // NAT devices may not respect our port requests, and even lie.
func (nat *NAT) AddMapping(protocol string, port int) error { func (nat *NAT) AddMapping(protocol string, port int) error {
if nat == nil {
return fmt.Errorf("no nat available")
}
switch protocol { switch protocol {
case "tcp", "udp": case "tcp", "udp":
default: default:
return fmt.Errorf("invalid protocol: %s", protocol) return fmt.Errorf("invalid protocol: %s", protocol)
} }
m := &mapping{
intport: port,
nat: nat,
proto: protocol,
}
nat.mappingmu.Lock() nat.mappingmu.Lock()
if nat.closed { if nat.closed {
nat.mappingmu.Unlock() nat.mappingmu.Unlock()
return errors.New("closed") return errors.New("closed")
} }
nat.mappings[m] = struct{}{}
nat.refCount.Add(1)
nat.mappingmu.Unlock()
go nat.refreshMappings(m)
// do it once synchronously, so first mapping is done right away, and before exiting, // do it once synchronously, so first mapping is done right away, and before exiting,
// allowing users -- in the optimistic case -- to use results right after. // allowing users -- in the optimistic case -- to use results right after.
nat.establishMapping(m) extPort := nat.establishMapping(protocol, port)
nat.mappings[entry{protocol: protocol, port: port}] = extPort
nat.mappingmu.Unlock()
return nil return nil
} }
func (nat *NAT) removeMapping(m *mapping) { func (nat *NAT) RemoveMapping(protocol string, port int) error {
nat.mappingmu.Lock() nat.mappingmu.Lock()
delete(nat.mappings, m) defer nat.mappingmu.Unlock()
nat.mappingmu.Unlock() switch protocol {
nat.natmu.Lock() case "tcp", "udp":
nat.nat.DeletePortMapping(m.Protocol(), m.InternalPort()) delete(nat.mappings, entry{protocol: protocol, port: port})
nat.natmu.Unlock() default:
return fmt.Errorf("invalid protocol: %s", protocol)
}
return nil
} }
func (nat *NAT) refreshMappings(m *mapping) { func (nat *NAT) background() {
defer nat.refCount.Done() const tick = MappingDuration / 3
t := time.NewTicker(MappingDuration / 3) t := time.NewTimer(tick) // don't use a ticker here. We don't know how long establishing the mappings takes.
defer t.Stop() defer t.Stop()
var in []entry
var out []int // port numbers
for { for {
select { select {
case <-t.C: case <-t.C:
nat.establishMapping(m) in = in[:0]
out = out[:0]
nat.mappingmu.Lock()
for e := range nat.mappings {
in = append(in, e)
}
nat.mappingmu.Unlock()
// Establishing the mapping involves network requests.
// Don't hold the mutex, just save the ports.
for _, e := range in {
out = append(out, nat.establishMapping(e.protocol, e.port))
}
nat.mappingmu.Lock()
for i, p := range in {
if _, ok := nat.mappings[p]; !ok {
continue // entry might have been deleted
}
nat.mappings[p] = out[i]
}
nat.mappingmu.Unlock()
t.Reset(tick)
case <-nat.ctx.Done(): case <-nat.ctx.Done():
m.Close() nat.mappingmu.Lock()
for e := range nat.mappings {
delete(nat.mappings, e)
}
nat.mappingmu.Unlock()
return return
} }
} }
} }
func (nat *NAT) establishMapping(m *mapping) { func (nat *NAT) establishMapping(protocol string, internalPort int) (externalPort int) {
oldport := m.ExternalPort() log.Debugf("Attempting port map: %s/%d", protocol, internalPort)
log.Debugf("Attempting port map: %s/%d", m.Protocol(), m.InternalPort())
const comment = "libp2p" const comment = "libp2p"
nat.natmu.Lock() nat.natmu.Lock()
newport, err := nat.nat.AddPortMapping(m.Protocol(), m.InternalPort(), comment, MappingDuration) var err error
externalPort, err = nat.nat.AddPortMapping(protocol, internalPort, comment, MappingDuration)
if err != nil { if err != nil {
// Some hardware does not support mappings with timeout, so try that // Some hardware does not support mappings with timeout, so try that
newport, err = nat.nat.AddPortMapping(m.Protocol(), m.InternalPort(), comment, 0) externalPort, err = nat.nat.AddPortMapping(protocol, internalPort, comment, 0)
} }
nat.natmu.Unlock() nat.natmu.Unlock()
if err != nil || newport == 0 { if err != nil || externalPort == 0 {
m.setExternalPort(0) // clear mapping
// TODO: log.Event // TODO: log.Event
if err != nil { if err != nil {
log.Warnf("failed to establish port mapping: %s", err) log.Warnf("failed to establish port mapping: %s", err)
@@ -180,12 +207,9 @@ func (nat *NAT) establishMapping(m *mapping) {
} }
// we do not close if the mapping failed, // we do not close if the mapping failed,
// because it may work again next time. // because it may work again next time.
return return 0
} }
m.setExternalPort(newport) log.Debugf("NAT Mapping: %d --> %d (%s)", externalPort, internalPort, protocol)
log.Debugf("NAT Mapping: %d --> %d (%s)", m.ExternalPort(), m.InternalPort(), m.Protocol()) return externalPort
if oldport != 0 && newport != oldport {
log.Debugf("failed to renew same port mapping: ch %d -> %d", oldport, newport)
}
} }