nat: use a single Go routine to renew NAT mappings

This commit is contained in:
Marten Seemann
2023-04-03 22:27:04 +09:00
parent f5cbaf1721
commit c62be4081c
3 changed files with 107 additions and 95 deletions

View File

@@ -30,6 +30,11 @@ func NewNATManager(net network.Network) NATManager {
return newNatManager(net)
}
type entry struct {
protocol string
port int
}
// natManager takes care of adding + removing port mappings to the nat.
// Initialized with the host if it has a NATPortMap option enabled.
// natManager receives signals from the network, and check on nat mappings:
@@ -42,7 +47,9 @@ type natManager struct {
nat *inat.NAT
ready chan struct{} // closed once the nat is ready to process port mappings
syncFlag chan struct{}
syncFlag chan struct{} // cap: 1
tracked map[entry]bool // the bool is only used in doSync and has no meaning outside of that function
refCount sync.WaitGroup
ctxCancel context.CancelFunc
@@ -55,6 +62,7 @@ func newNatManager(net network.Network) *natManager {
ready: make(chan struct{}),
syncFlag: make(chan struct{}, 1),
ctxCancel: cancel,
tracked: make(map[entry]bool),
}
nmgr.refCount.Add(1)
go nmgr.background(ctx)
@@ -127,10 +135,10 @@ func (nmgr *natManager) sync() {
// doSync syncs the current NAT mappings, removing any outdated mappings and adding any
// new mappings.
func (nmgr *natManager) doSync() {
ports := map[string]map[int]bool{
"tcp": {},
"udp": {},
for e := range nmgr.tracked {
nmgr.tracked[e] = false
}
var newAddresses []entry
for _, maddr := range nmgr.net.ListenAddresses() {
// Strip the IP
maIP, rest := ma.SplitFirst(maddr)
@@ -166,48 +174,36 @@ func (nmgr *natManager) doSync() {
default:
continue
}
port, err := strconv.ParseUint(proto.Value(), 10, 16)
if err != nil {
// bug in multiaddr
panic(err)
}
ports[protocol][int(port)] = false
e := entry{protocol: protocol, port: int(port)}
if _, ok := nmgr.tracked[e]; ok {
nmgr.tracked[e] = true
} else {
newAddresses = append(newAddresses, e)
}
}
var wg sync.WaitGroup
defer wg.Wait()
// Close old mappings
for _, m := range nmgr.nat.Mappings() {
mappedPort := m.InternalPort()
if _, ok := ports[m.Protocol()][mappedPort]; !ok {
// No longer need this mapping.
wg.Add(1)
go func(m inat.Mapping) {
defer wg.Done()
m.Close()
}(m)
} else {
// already mapped
ports[m.Protocol()][mappedPort] = true
for e, v := range nmgr.tracked {
if !v {
nmgr.nat.RemoveMapping(e.protocol, e.port)
delete(nmgr.tracked, e)
}
}
// Create new mappings.
for proto, pports := range ports {
for port, mapped := range pports {
if mapped {
continue
}
wg.Add(1)
go func(proto string, port int) {
defer wg.Done()
if err := nmgr.nat.AddMapping(proto, port); err != nil {
log.Errorf("failed to port-map %s port %d: %s", proto, port, err)
}
}(proto, port)
for _, e := range newAddresses {
if err := nmgr.nat.AddMapping(e.protocol, e.port); err != nil {
log.Errorf("failed to port-map %s port %d: %s", e.protocol, e.port, err)
}
nmgr.tracked[e] = false
}
}