diff --git a/p2p/host/basic/natmgr.go b/p2p/host/basic/natmgr.go index 4bbec10a4..522d3e827 100644 --- a/p2p/host/basic/natmgr.go +++ b/p2p/host/basic/natmgr.go @@ -30,6 +30,11 @@ func NewNATManager(net network.Network) NATManager { return newNatManager(net) } +type entry struct { + protocol string + port int +} + // natManager takes care of adding + removing port mappings to the nat. // Initialized with the host if it has a NATPortMap option enabled. // natManager receives signals from the network, and check on nat mappings: @@ -42,7 +47,9 @@ type natManager struct { nat *inat.NAT ready chan struct{} // closed once the nat is ready to process port mappings - syncFlag chan struct{} + syncFlag chan struct{} // cap: 1 + + tracked map[entry]bool // the bool is only used in doSync and has no meaning outside of that function refCount sync.WaitGroup ctxCancel context.CancelFunc @@ -55,6 +62,7 @@ func newNatManager(net network.Network) *natManager { ready: make(chan struct{}), syncFlag: make(chan struct{}, 1), ctxCancel: cancel, + tracked: make(map[entry]bool), } nmgr.refCount.Add(1) go nmgr.background(ctx) @@ -127,10 +135,10 @@ func (nmgr *natManager) sync() { // doSync syncs the current NAT mappings, removing any outdated mappings and adding any // new mappings. func (nmgr *natManager) doSync() { - ports := map[string]map[int]bool{ - "tcp": {}, - "udp": {}, + for e := range nmgr.tracked { + nmgr.tracked[e] = false } + var newAddresses []entry for _, maddr := range nmgr.net.ListenAddresses() { // Strip the IP maIP, rest := ma.SplitFirst(maddr) @@ -166,48 +174,36 @@ func (nmgr *natManager) doSync() { default: continue } - port, err := strconv.ParseUint(proto.Value(), 10, 16) if err != nil { // bug in multiaddr panic(err) } - ports[protocol][int(port)] = false + e := entry{protocol: protocol, port: int(port)} + if _, ok := nmgr.tracked[e]; ok { + nmgr.tracked[e] = true + } else { + newAddresses = append(newAddresses, e) + } } var wg sync.WaitGroup defer wg.Wait() // Close old mappings - for _, m := range nmgr.nat.Mappings() { - mappedPort := m.InternalPort() - if _, ok := ports[m.Protocol()][mappedPort]; !ok { - // No longer need this mapping. - wg.Add(1) - go func(m inat.Mapping) { - defer wg.Done() - m.Close() - }(m) - } else { - // already mapped - ports[m.Protocol()][mappedPort] = true + for e, v := range nmgr.tracked { + if !v { + nmgr.nat.RemoveMapping(e.protocol, e.port) + delete(nmgr.tracked, e) } } // Create new mappings. - for proto, pports := range ports { - for port, mapped := range pports { - if mapped { - continue - } - wg.Add(1) - go func(proto string, port int) { - defer wg.Done() - if err := nmgr.nat.AddMapping(proto, port); err != nil { - log.Errorf("failed to port-map %s port %d: %s", proto, port, err) - } - }(proto, port) + for _, e := range newAddresses { + if err := nmgr.nat.AddMapping(e.protocol, e.port); err != nil { + log.Errorf("failed to port-map %s port %d: %s", e.protocol, e.port, err) } + nmgr.tracked[e] = false } } diff --git a/p2p/net/nat/mapping.go b/p2p/net/nat/mapping.go index 0641e8fc4..4ba507e04 100644 --- a/p2p/net/nat/mapping.go +++ b/p2p/net/nat/mapping.go @@ -24,9 +24,6 @@ type Mapping interface { // ExternalAddr returns the external facing address. If the mapping is not // established, addr will be nil, and and ErrNoMapping will be returned. ExternalAddr() (addr net.Addr, err error) - - // Close closes the port mapping - Close() error } // keeps republishing @@ -103,8 +100,3 @@ func (m *mapping) ExternalAddr() (net.Addr, error) { panic(fmt.Sprintf("invalid protocol %q", m.Protocol())) } } - -func (m *mapping) Close() error { - m.nat.removeMapping(m) - return nil -} diff --git a/p2p/net/nat/nat.go b/p2p/net/nat/nat.go index 201c887f8..ca07082bb 100644 --- a/p2p/net/nat/nat.go +++ b/p2p/net/nat/nat.go @@ -24,6 +24,11 @@ const MappingDuration = time.Second * 60 // CacheTime is the time a mapping will cache an external address for const CacheTime = time.Second * 15 +type entry struct { + protocol string + port int +} + // DiscoverNAT looks for a NAT device in the network and // returns an object that can manage port mappings. func DiscoverNAT(ctx context.Context) (*NAT, error) { @@ -40,7 +45,19 @@ func DiscoverNAT(ctx context.Context) (*NAT, error) { log.Debug("DiscoverGateway address:", addr) } - return newNAT(natInstance), nil + ctx, cancel := context.WithCancel(context.Background()) + nat := &NAT{ + nat: natInstance, + mappings: make(map[entry]int), + ctx: ctx, + ctxCancel: cancel, + } + nat.refCount.Add(1) + go func() { + defer nat.refCount.Done() + nat.background() + }() + return nat, nil } // NAT is an object that manages address port mappings in @@ -57,17 +74,7 @@ type NAT struct { mappingmu sync.RWMutex // guards mappings closed bool - mappings map[*mapping]struct{} -} - -func newNAT(realNAT nat.NAT) *NAT { - ctx, cancel := context.WithCancel(context.Background()) - return &NAT{ - nat: realNAT, - mappings: make(map[*mapping]struct{}), - ctx: ctx, - ctxCancel: cancel, - } + mappings map[entry]int } // Close shuts down all port mappings. NAT can no longer be used. @@ -84,94 +91,114 @@ func (nat *NAT) Close() error { // Mappings returns a slice of all NAT mappings func (nat *NAT) Mappings() []Mapping { nat.mappingmu.Lock() + defer nat.mappingmu.Unlock() maps2 := make([]Mapping, 0, len(nat.mappings)) - for m := range nat.mappings { - maps2 = append(maps2, m) + for e, extPort := range nat.mappings { + maps2 = append(maps2, &mapping{ + nat: nat, + proto: e.protocol, + intport: e.port, + extport: extPort, + }) } - nat.mappingmu.Unlock() return maps2 } // AddMapping attempts to construct a mapping on protocol and internal port -// It will also periodically renew the mapping until the returned Mapping -// -- or its parent NAT -- is Closed. +// It will also periodically renew the mapping. // // May not succeed, and mappings may change over time; // NAT devices may not respect our port requests, and even lie. func (nat *NAT) AddMapping(protocol string, port int) error { - if nat == nil { - return fmt.Errorf("no nat available") - } - switch protocol { case "tcp", "udp": default: return fmt.Errorf("invalid protocol: %s", protocol) } - m := &mapping{ - intport: port, - nat: nat, - proto: protocol, - } - nat.mappingmu.Lock() if nat.closed { nat.mappingmu.Unlock() return errors.New("closed") } - nat.mappings[m] = struct{}{} - nat.refCount.Add(1) - nat.mappingmu.Unlock() - go nat.refreshMappings(m) // do it once synchronously, so first mapping is done right away, and before exiting, // allowing users -- in the optimistic case -- to use results right after. - nat.establishMapping(m) + extPort := nat.establishMapping(protocol, port) + nat.mappings[entry{protocol: protocol, port: port}] = extPort + nat.mappingmu.Unlock() + return nil } -func (nat *NAT) removeMapping(m *mapping) { +func (nat *NAT) RemoveMapping(protocol string, port int) error { nat.mappingmu.Lock() - delete(nat.mappings, m) - nat.mappingmu.Unlock() - nat.natmu.Lock() - nat.nat.DeletePortMapping(m.Protocol(), m.InternalPort()) - nat.natmu.Unlock() + defer nat.mappingmu.Unlock() + switch protocol { + case "tcp", "udp": + delete(nat.mappings, entry{protocol: protocol, port: port}) + default: + return fmt.Errorf("invalid protocol: %s", protocol) + } + return nil } -func (nat *NAT) refreshMappings(m *mapping) { - defer nat.refCount.Done() - t := time.NewTicker(MappingDuration / 3) +func (nat *NAT) background() { + const tick = MappingDuration / 3 + t := time.NewTimer(tick) // don't use a ticker here. We don't know how long establishing the mappings takes. defer t.Stop() + var in []entry + var out []int // port numbers for { select { case <-t.C: - nat.establishMapping(m) + in = in[:0] + out = out[:0] + nat.mappingmu.Lock() + for e := range nat.mappings { + in = append(in, e) + } + nat.mappingmu.Unlock() + // Establishing the mapping involves network requests. + // Don't hold the mutex, just save the ports. + for _, e := range in { + out = append(out, nat.establishMapping(e.protocol, e.port)) + } + nat.mappingmu.Lock() + for i, p := range in { + if _, ok := nat.mappings[p]; !ok { + continue // entry might have been deleted + } + nat.mappings[p] = out[i] + } + nat.mappingmu.Unlock() + t.Reset(tick) case <-nat.ctx.Done(): - m.Close() + nat.mappingmu.Lock() + for e := range nat.mappings { + delete(nat.mappings, e) + } + nat.mappingmu.Unlock() return } } } -func (nat *NAT) establishMapping(m *mapping) { - oldport := m.ExternalPort() - - log.Debugf("Attempting port map: %s/%d", m.Protocol(), m.InternalPort()) +func (nat *NAT) establishMapping(protocol string, internalPort int) (externalPort int) { + log.Debugf("Attempting port map: %s/%d", protocol, internalPort) const comment = "libp2p" nat.natmu.Lock() - newport, err := nat.nat.AddPortMapping(m.Protocol(), m.InternalPort(), comment, MappingDuration) + var err error + externalPort, err = nat.nat.AddPortMapping(protocol, internalPort, comment, MappingDuration) if err != nil { // Some hardware does not support mappings with timeout, so try that - newport, err = nat.nat.AddPortMapping(m.Protocol(), m.InternalPort(), comment, 0) + externalPort, err = nat.nat.AddPortMapping(protocol, internalPort, comment, 0) } nat.natmu.Unlock() - if err != nil || newport == 0 { - m.setExternalPort(0) // clear mapping + if err != nil || externalPort == 0 { // TODO: log.Event if err != nil { log.Warnf("failed to establish port mapping: %s", err) @@ -180,12 +207,9 @@ func (nat *NAT) establishMapping(m *mapping) { } // we do not close if the mapping failed, // because it may work again next time. - return + return 0 } - m.setExternalPort(newport) - log.Debugf("NAT Mapping: %d --> %d (%s)", m.ExternalPort(), m.InternalPort(), m.Protocol()) - if oldport != 0 && newport != oldport { - log.Debugf("failed to renew same port mapping: ch %d -> %d", oldport, newport) - } + log.Debugf("NAT Mapping: %d --> %d (%s)", externalPort, internalPort, protocol) + return externalPort }