mirror of
https://github.com/libp2p/go-libp2p.git
synced 2025-10-21 23:30:28 +08:00
nat: use a single Go routine to renew NAT mappings
This commit is contained in:
@@ -30,6 +30,11 @@ func NewNATManager(net network.Network) NATManager {
|
|||||||
return newNatManager(net)
|
return newNatManager(net)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type entry struct {
|
||||||
|
protocol string
|
||||||
|
port int
|
||||||
|
}
|
||||||
|
|
||||||
// natManager takes care of adding + removing port mappings to the nat.
|
// natManager takes care of adding + removing port mappings to the nat.
|
||||||
// Initialized with the host if it has a NATPortMap option enabled.
|
// Initialized with the host if it has a NATPortMap option enabled.
|
||||||
// natManager receives signals from the network, and check on nat mappings:
|
// natManager receives signals from the network, and check on nat mappings:
|
||||||
@@ -42,7 +47,9 @@ type natManager struct {
|
|||||||
nat *inat.NAT
|
nat *inat.NAT
|
||||||
|
|
||||||
ready chan struct{} // closed once the nat is ready to process port mappings
|
ready chan struct{} // closed once the nat is ready to process port mappings
|
||||||
syncFlag chan struct{}
|
syncFlag chan struct{} // cap: 1
|
||||||
|
|
||||||
|
tracked map[entry]bool // the bool is only used in doSync and has no meaning outside of that function
|
||||||
|
|
||||||
refCount sync.WaitGroup
|
refCount sync.WaitGroup
|
||||||
ctxCancel context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
@@ -55,6 +62,7 @@ func newNatManager(net network.Network) *natManager {
|
|||||||
ready: make(chan struct{}),
|
ready: make(chan struct{}),
|
||||||
syncFlag: make(chan struct{}, 1),
|
syncFlag: make(chan struct{}, 1),
|
||||||
ctxCancel: cancel,
|
ctxCancel: cancel,
|
||||||
|
tracked: make(map[entry]bool),
|
||||||
}
|
}
|
||||||
nmgr.refCount.Add(1)
|
nmgr.refCount.Add(1)
|
||||||
go nmgr.background(ctx)
|
go nmgr.background(ctx)
|
||||||
@@ -127,10 +135,10 @@ func (nmgr *natManager) sync() {
|
|||||||
// doSync syncs the current NAT mappings, removing any outdated mappings and adding any
|
// doSync syncs the current NAT mappings, removing any outdated mappings and adding any
|
||||||
// new mappings.
|
// new mappings.
|
||||||
func (nmgr *natManager) doSync() {
|
func (nmgr *natManager) doSync() {
|
||||||
ports := map[string]map[int]bool{
|
for e := range nmgr.tracked {
|
||||||
"tcp": {},
|
nmgr.tracked[e] = false
|
||||||
"udp": {},
|
|
||||||
}
|
}
|
||||||
|
var newAddresses []entry
|
||||||
for _, maddr := range nmgr.net.ListenAddresses() {
|
for _, maddr := range nmgr.net.ListenAddresses() {
|
||||||
// Strip the IP
|
// Strip the IP
|
||||||
maIP, rest := ma.SplitFirst(maddr)
|
maIP, rest := ma.SplitFirst(maddr)
|
||||||
@@ -166,48 +174,36 @@ func (nmgr *natManager) doSync() {
|
|||||||
default:
|
default:
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
port, err := strconv.ParseUint(proto.Value(), 10, 16)
|
port, err := strconv.ParseUint(proto.Value(), 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// bug in multiaddr
|
// bug in multiaddr
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
ports[protocol][int(port)] = false
|
e := entry{protocol: protocol, port: int(port)}
|
||||||
|
if _, ok := nmgr.tracked[e]; ok {
|
||||||
|
nmgr.tracked[e] = true
|
||||||
|
} else {
|
||||||
|
newAddresses = append(newAddresses, e)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
defer wg.Wait()
|
defer wg.Wait()
|
||||||
|
|
||||||
// Close old mappings
|
// Close old mappings
|
||||||
for _, m := range nmgr.nat.Mappings() {
|
for e, v := range nmgr.tracked {
|
||||||
mappedPort := m.InternalPort()
|
if !v {
|
||||||
if _, ok := ports[m.Protocol()][mappedPort]; !ok {
|
nmgr.nat.RemoveMapping(e.protocol, e.port)
|
||||||
// No longer need this mapping.
|
delete(nmgr.tracked, e)
|
||||||
wg.Add(1)
|
|
||||||
go func(m inat.Mapping) {
|
|
||||||
defer wg.Done()
|
|
||||||
m.Close()
|
|
||||||
}(m)
|
|
||||||
} else {
|
|
||||||
// already mapped
|
|
||||||
ports[m.Protocol()][mappedPort] = true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create new mappings.
|
// Create new mappings.
|
||||||
for proto, pports := range ports {
|
for _, e := range newAddresses {
|
||||||
for port, mapped := range pports {
|
if err := nmgr.nat.AddMapping(e.protocol, e.port); err != nil {
|
||||||
if mapped {
|
log.Errorf("failed to port-map %s port %d: %s", e.protocol, e.port, err)
|
||||||
continue
|
|
||||||
}
|
|
||||||
wg.Add(1)
|
|
||||||
go func(proto string, port int) {
|
|
||||||
defer wg.Done()
|
|
||||||
if err := nmgr.nat.AddMapping(proto, port); err != nil {
|
|
||||||
log.Errorf("failed to port-map %s port %d: %s", proto, port, err)
|
|
||||||
}
|
|
||||||
}(proto, port)
|
|
||||||
}
|
}
|
||||||
|
nmgr.tracked[e] = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -24,9 +24,6 @@ type Mapping interface {
|
|||||||
// ExternalAddr returns the external facing address. If the mapping is not
|
// ExternalAddr returns the external facing address. If the mapping is not
|
||||||
// established, addr will be nil, and and ErrNoMapping will be returned.
|
// established, addr will be nil, and and ErrNoMapping will be returned.
|
||||||
ExternalAddr() (addr net.Addr, err error)
|
ExternalAddr() (addr net.Addr, err error)
|
||||||
|
|
||||||
// Close closes the port mapping
|
|
||||||
Close() error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// keeps republishing
|
// keeps republishing
|
||||||
@@ -103,8 +100,3 @@ func (m *mapping) ExternalAddr() (net.Addr, error) {
|
|||||||
panic(fmt.Sprintf("invalid protocol %q", m.Protocol()))
|
panic(fmt.Sprintf("invalid protocol %q", m.Protocol()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mapping) Close() error {
|
|
||||||
m.nat.removeMapping(m)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
@@ -24,6 +24,11 @@ const MappingDuration = time.Second * 60
|
|||||||
// CacheTime is the time a mapping will cache an external address for
|
// CacheTime is the time a mapping will cache an external address for
|
||||||
const CacheTime = time.Second * 15
|
const CacheTime = time.Second * 15
|
||||||
|
|
||||||
|
type entry struct {
|
||||||
|
protocol string
|
||||||
|
port int
|
||||||
|
}
|
||||||
|
|
||||||
// DiscoverNAT looks for a NAT device in the network and
|
// DiscoverNAT looks for a NAT device in the network and
|
||||||
// returns an object that can manage port mappings.
|
// returns an object that can manage port mappings.
|
||||||
func DiscoverNAT(ctx context.Context) (*NAT, error) {
|
func DiscoverNAT(ctx context.Context) (*NAT, error) {
|
||||||
@@ -40,7 +45,19 @@ func DiscoverNAT(ctx context.Context) (*NAT, error) {
|
|||||||
log.Debug("DiscoverGateway address:", addr)
|
log.Debug("DiscoverGateway address:", addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
return newNAT(natInstance), nil
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
nat := &NAT{
|
||||||
|
nat: natInstance,
|
||||||
|
mappings: make(map[entry]int),
|
||||||
|
ctx: ctx,
|
||||||
|
ctxCancel: cancel,
|
||||||
|
}
|
||||||
|
nat.refCount.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer nat.refCount.Done()
|
||||||
|
nat.background()
|
||||||
|
}()
|
||||||
|
return nat, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NAT is an object that manages address port mappings in
|
// NAT is an object that manages address port mappings in
|
||||||
@@ -57,17 +74,7 @@ type NAT struct {
|
|||||||
|
|
||||||
mappingmu sync.RWMutex // guards mappings
|
mappingmu sync.RWMutex // guards mappings
|
||||||
closed bool
|
closed bool
|
||||||
mappings map[*mapping]struct{}
|
mappings map[entry]int
|
||||||
}
|
|
||||||
|
|
||||||
func newNAT(realNAT nat.NAT) *NAT {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
return &NAT{
|
|
||||||
nat: realNAT,
|
|
||||||
mappings: make(map[*mapping]struct{}),
|
|
||||||
ctx: ctx,
|
|
||||||
ctxCancel: cancel,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close shuts down all port mappings. NAT can no longer be used.
|
// Close shuts down all port mappings. NAT can no longer be used.
|
||||||
@@ -84,94 +91,114 @@ func (nat *NAT) Close() error {
|
|||||||
// Mappings returns a slice of all NAT mappings
|
// Mappings returns a slice of all NAT mappings
|
||||||
func (nat *NAT) Mappings() []Mapping {
|
func (nat *NAT) Mappings() []Mapping {
|
||||||
nat.mappingmu.Lock()
|
nat.mappingmu.Lock()
|
||||||
|
defer nat.mappingmu.Unlock()
|
||||||
maps2 := make([]Mapping, 0, len(nat.mappings))
|
maps2 := make([]Mapping, 0, len(nat.mappings))
|
||||||
for m := range nat.mappings {
|
for e, extPort := range nat.mappings {
|
||||||
maps2 = append(maps2, m)
|
maps2 = append(maps2, &mapping{
|
||||||
|
nat: nat,
|
||||||
|
proto: e.protocol,
|
||||||
|
intport: e.port,
|
||||||
|
extport: extPort,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
nat.mappingmu.Unlock()
|
|
||||||
return maps2
|
return maps2
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddMapping attempts to construct a mapping on protocol and internal port
|
// AddMapping attempts to construct a mapping on protocol and internal port
|
||||||
// It will also periodically renew the mapping until the returned Mapping
|
// It will also periodically renew the mapping.
|
||||||
// -- or its parent NAT -- is Closed.
|
|
||||||
//
|
//
|
||||||
// May not succeed, and mappings may change over time;
|
// May not succeed, and mappings may change over time;
|
||||||
// NAT devices may not respect our port requests, and even lie.
|
// NAT devices may not respect our port requests, and even lie.
|
||||||
func (nat *NAT) AddMapping(protocol string, port int) error {
|
func (nat *NAT) AddMapping(protocol string, port int) error {
|
||||||
if nat == nil {
|
|
||||||
return fmt.Errorf("no nat available")
|
|
||||||
}
|
|
||||||
|
|
||||||
switch protocol {
|
switch protocol {
|
||||||
case "tcp", "udp":
|
case "tcp", "udp":
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("invalid protocol: %s", protocol)
|
return fmt.Errorf("invalid protocol: %s", protocol)
|
||||||
}
|
}
|
||||||
|
|
||||||
m := &mapping{
|
|
||||||
intport: port,
|
|
||||||
nat: nat,
|
|
||||||
proto: protocol,
|
|
||||||
}
|
|
||||||
|
|
||||||
nat.mappingmu.Lock()
|
nat.mappingmu.Lock()
|
||||||
if nat.closed {
|
if nat.closed {
|
||||||
nat.mappingmu.Unlock()
|
nat.mappingmu.Unlock()
|
||||||
return errors.New("closed")
|
return errors.New("closed")
|
||||||
}
|
}
|
||||||
nat.mappings[m] = struct{}{}
|
|
||||||
nat.refCount.Add(1)
|
|
||||||
nat.mappingmu.Unlock()
|
|
||||||
go nat.refreshMappings(m)
|
|
||||||
|
|
||||||
// do it once synchronously, so first mapping is done right away, and before exiting,
|
// do it once synchronously, so first mapping is done right away, and before exiting,
|
||||||
// allowing users -- in the optimistic case -- to use results right after.
|
// allowing users -- in the optimistic case -- to use results right after.
|
||||||
nat.establishMapping(m)
|
extPort := nat.establishMapping(protocol, port)
|
||||||
|
nat.mappings[entry{protocol: protocol, port: port}] = extPort
|
||||||
|
nat.mappingmu.Unlock()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nat *NAT) removeMapping(m *mapping) {
|
func (nat *NAT) RemoveMapping(protocol string, port int) error {
|
||||||
nat.mappingmu.Lock()
|
nat.mappingmu.Lock()
|
||||||
delete(nat.mappings, m)
|
defer nat.mappingmu.Unlock()
|
||||||
nat.mappingmu.Unlock()
|
switch protocol {
|
||||||
nat.natmu.Lock()
|
case "tcp", "udp":
|
||||||
nat.nat.DeletePortMapping(m.Protocol(), m.InternalPort())
|
delete(nat.mappings, entry{protocol: protocol, port: port})
|
||||||
nat.natmu.Unlock()
|
default:
|
||||||
|
return fmt.Errorf("invalid protocol: %s", protocol)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nat *NAT) refreshMappings(m *mapping) {
|
func (nat *NAT) background() {
|
||||||
defer nat.refCount.Done()
|
const tick = MappingDuration / 3
|
||||||
t := time.NewTicker(MappingDuration / 3)
|
t := time.NewTimer(tick) // don't use a ticker here. We don't know how long establishing the mappings takes.
|
||||||
defer t.Stop()
|
defer t.Stop()
|
||||||
|
|
||||||
|
var in []entry
|
||||||
|
var out []int // port numbers
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-t.C:
|
case <-t.C:
|
||||||
nat.establishMapping(m)
|
in = in[:0]
|
||||||
|
out = out[:0]
|
||||||
|
nat.mappingmu.Lock()
|
||||||
|
for e := range nat.mappings {
|
||||||
|
in = append(in, e)
|
||||||
|
}
|
||||||
|
nat.mappingmu.Unlock()
|
||||||
|
// Establishing the mapping involves network requests.
|
||||||
|
// Don't hold the mutex, just save the ports.
|
||||||
|
for _, e := range in {
|
||||||
|
out = append(out, nat.establishMapping(e.protocol, e.port))
|
||||||
|
}
|
||||||
|
nat.mappingmu.Lock()
|
||||||
|
for i, p := range in {
|
||||||
|
if _, ok := nat.mappings[p]; !ok {
|
||||||
|
continue // entry might have been deleted
|
||||||
|
}
|
||||||
|
nat.mappings[p] = out[i]
|
||||||
|
}
|
||||||
|
nat.mappingmu.Unlock()
|
||||||
|
t.Reset(tick)
|
||||||
case <-nat.ctx.Done():
|
case <-nat.ctx.Done():
|
||||||
m.Close()
|
nat.mappingmu.Lock()
|
||||||
|
for e := range nat.mappings {
|
||||||
|
delete(nat.mappings, e)
|
||||||
|
}
|
||||||
|
nat.mappingmu.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nat *NAT) establishMapping(m *mapping) {
|
func (nat *NAT) establishMapping(protocol string, internalPort int) (externalPort int) {
|
||||||
oldport := m.ExternalPort()
|
log.Debugf("Attempting port map: %s/%d", protocol, internalPort)
|
||||||
|
|
||||||
log.Debugf("Attempting port map: %s/%d", m.Protocol(), m.InternalPort())
|
|
||||||
const comment = "libp2p"
|
const comment = "libp2p"
|
||||||
|
|
||||||
nat.natmu.Lock()
|
nat.natmu.Lock()
|
||||||
newport, err := nat.nat.AddPortMapping(m.Protocol(), m.InternalPort(), comment, MappingDuration)
|
var err error
|
||||||
|
externalPort, err = nat.nat.AddPortMapping(protocol, internalPort, comment, MappingDuration)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Some hardware does not support mappings with timeout, so try that
|
// Some hardware does not support mappings with timeout, so try that
|
||||||
newport, err = nat.nat.AddPortMapping(m.Protocol(), m.InternalPort(), comment, 0)
|
externalPort, err = nat.nat.AddPortMapping(protocol, internalPort, comment, 0)
|
||||||
}
|
}
|
||||||
nat.natmu.Unlock()
|
nat.natmu.Unlock()
|
||||||
|
|
||||||
if err != nil || newport == 0 {
|
if err != nil || externalPort == 0 {
|
||||||
m.setExternalPort(0) // clear mapping
|
|
||||||
// TODO: log.Event
|
// TODO: log.Event
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to establish port mapping: %s", err)
|
log.Warnf("failed to establish port mapping: %s", err)
|
||||||
@@ -180,12 +207,9 @@ func (nat *NAT) establishMapping(m *mapping) {
|
|||||||
}
|
}
|
||||||
// we do not close if the mapping failed,
|
// we do not close if the mapping failed,
|
||||||
// because it may work again next time.
|
// because it may work again next time.
|
||||||
return
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
m.setExternalPort(newport)
|
log.Debugf("NAT Mapping: %d --> %d (%s)", externalPort, internalPort, protocol)
|
||||||
log.Debugf("NAT Mapping: %d --> %d (%s)", m.ExternalPort(), m.InternalPort(), m.Protocol())
|
return externalPort
|
||||||
if oldport != 0 && newport != oldport {
|
|
||||||
log.Debugf("failed to renew same port mapping: ch %d -> %d", oldport, newport)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user