Update UDPMux to dispatch inbound packets on ufrag

If we get a packet for an address we don't know dispatch it
by ufrag still.

Resolves #357
This commit is contained in:
Sean DuBois
2021-04-20 22:15:01 -07:00
parent af8539d47e
commit e6e49f59b0
3 changed files with 61 additions and 18 deletions

View File

@@ -5,9 +5,11 @@ import (
"io"
"net"
"os"
"strings"
"sync"
"github.com/pion/logging"
"github.com/pion/stun"
)
// UDPMux allows multiple connections to go over a single UDP port
@@ -27,8 +29,8 @@ type UDPMuxDefault struct {
// conns is a map of all udpMuxedConn indexed by ufrag|network|candidateType
conns map[string]*udpMuxedConn
// map of udpAddr -> udpMuxedConn
addressMap sync.Map
addressMapMu sync.RWMutex
addressMap map[string]*udpMuxedConn
// buffer pool to recycle buffers for net.UDPAddr encodes/decodes
pool *sync.Pool
@@ -47,6 +49,7 @@ type UDPMuxParams struct {
// NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
m := &UDPMuxDefault{
addressMap: map[string]*udpMuxedConn{},
params: params,
conns: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
@@ -109,10 +112,13 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
// keep lock section small to avoid deadlock with conn lock
m.mu.Unlock()
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
for _, c := range removedConns {
addresses := c.getAddresses()
for _, addr := range addresses {
m.addressMap.Delete(addr)
delete(m.addressMap, addr)
}
}
}
@@ -154,9 +160,12 @@ func (m *UDPMuxDefault) removeConn(key string) {
return
}
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
addresses := c.getAddresses()
for _, addr := range addresses {
m.addressMap.Delete(addr)
delete(m.addressMap, addr)
}
}
@@ -168,11 +177,16 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
if m.IsClosed() {
return
}
existing, ok := m.addressMap.Load(addr)
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
existing, ok := m.addressMap[addr]
if ok {
existing.(*udpMuxedConn).removeAddress(addr)
existing.removeAddress(addr)
}
m.addressMap.Store(addr, conn)
m.addressMap[addr] = conn
}
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
@@ -192,6 +206,7 @@ func (m *UDPMuxDefault) connWorker() {
defer func() {
_ = m.Close()
}()
buf := make([]byte, receiveMTU)
for {
n, addr, err := m.params.UDPConn.ReadFrom(buf)
@@ -207,21 +222,46 @@ func (m *UDPMuxDefault) connWorker() {
return
}
// look up forward destination
v, ok := m.addressMap.Load(addr.String())
if !ok {
// ignore packets that we don't know where to route to
continue
}
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
logger.Errorf("underlying PacketConn did not return a UDPAddr")
return
}
c := v.(*udpMuxedConn)
err = c.writePacket(buf[:n], udpAddr)
if err != nil {
// If we have already seen this address dispatch to the appropriate destination
m.addressMapMu.Lock()
destinationConn := m.addressMap[addr.String()]
m.addressMapMu.Unlock()
// If we haven't seen this address before but is a STUN packet lookup by ufrag
if destinationConn == nil && stun.IsMessage(buf[:n]) {
msg := &stun.Message{
Raw: append([]byte{}, buf[:n]...),
}
if err = msg.Decode(); err != nil {
m.params.Logger.Warnf("Failed to handle decode ICE from %s: %v\n", addr.String(), err)
continue
}
attr, stunAttrErr := msg.Get(stun.AttrUsername)
if stunAttrErr != nil {
m.params.Logger.Warnf("No Username attribute in STUN message from %s\n", addr.String())
continue
}
ufrag := strings.Split(string(attr), ":")[0]
m.mu.Lock()
destinationConn = m.conns[ufrag]
m.mu.Unlock()
}
if destinationConn == nil {
continue
}
if err = destinationConn.writePacket(buf[:n], udpAddr); err != nil {
logger.Errorf("could not write packet: %v", err)
}
}