Update UDPMux to dispatch inbound packets on ufrag

If we get a packet for an address we don't know dispatch it
by ufrag still.

Resolves #357
This commit is contained in:
Sean DuBois
2021-04-20 22:15:01 -07:00
parent af8539d47e
commit e6e49f59b0
3 changed files with 61 additions and 18 deletions

View File

@@ -41,7 +41,9 @@ func TestMuxAgent(t *testing.T) {
muxedA, err := NewAgent(&AgentConfig{ muxedA, err := NewAgent(&AgentConfig{
UDPMux: udpMux, UDPMux: udpMux,
CandidateTypes: []CandidateType{CandidateTypeHost}, CandidateTypes: []CandidateType{CandidateTypeHost},
NetworkTypes: supportedNetworkTypes(), NetworkTypes: []NetworkType{
NetworkTypeUDP4,
},
}) })
require.NoError(t, err) require.NoError(t, err)

View File

@@ -5,9 +5,11 @@ import (
"io" "io"
"net" "net"
"os" "os"
"strings"
"sync" "sync"
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/stun"
) )
// UDPMux allows multiple connections to go over a single UDP port // UDPMux allows multiple connections to go over a single UDP port
@@ -27,8 +29,8 @@ type UDPMuxDefault struct {
// conns is a map of all udpMuxedConn indexed by ufrag|network|candidateType // conns is a map of all udpMuxedConn indexed by ufrag|network|candidateType
conns map[string]*udpMuxedConn conns map[string]*udpMuxedConn
// map of udpAddr -> udpMuxedConn addressMapMu sync.RWMutex
addressMap sync.Map addressMap map[string]*udpMuxedConn
// buffer pool to recycle buffers for net.UDPAddr encodes/decodes // buffer pool to recycle buffers for net.UDPAddr encodes/decodes
pool *sync.Pool pool *sync.Pool
@@ -47,6 +49,7 @@ type UDPMuxParams struct {
// NewUDPMuxDefault creates an implementation of UDPMux // NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
m := &UDPMuxDefault{ m := &UDPMuxDefault{
addressMap: map[string]*udpMuxedConn{},
params: params, params: params,
conns: make(map[string]*udpMuxedConn), conns: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1), closedChan: make(chan struct{}, 1),
@@ -109,10 +112,13 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
// keep lock section small to avoid deadlock with conn lock // keep lock section small to avoid deadlock with conn lock
m.mu.Unlock() m.mu.Unlock()
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
for _, c := range removedConns { for _, c := range removedConns {
addresses := c.getAddresses() addresses := c.getAddresses()
for _, addr := range addresses { for _, addr := range addresses {
m.addressMap.Delete(addr) delete(m.addressMap, addr)
} }
} }
} }
@@ -154,9 +160,12 @@ func (m *UDPMuxDefault) removeConn(key string) {
return return
} }
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
addresses := c.getAddresses() addresses := c.getAddresses()
for _, addr := range addresses { for _, addr := range addresses {
m.addressMap.Delete(addr) delete(m.addressMap, addr)
} }
} }
@@ -168,11 +177,16 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
if m.IsClosed() { if m.IsClosed() {
return return
} }
existing, ok := m.addressMap.Load(addr)
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
existing, ok := m.addressMap[addr]
if ok { if ok {
existing.(*udpMuxedConn).removeAddress(addr) existing.removeAddress(addr)
} }
m.addressMap.Store(addr, conn)
m.addressMap[addr] = conn
} }
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn { func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
@@ -192,6 +206,7 @@ func (m *UDPMuxDefault) connWorker() {
defer func() { defer func() {
_ = m.Close() _ = m.Close()
}() }()
buf := make([]byte, receiveMTU) buf := make([]byte, receiveMTU)
for { for {
n, addr, err := m.params.UDPConn.ReadFrom(buf) n, addr, err := m.params.UDPConn.ReadFrom(buf)
@@ -207,21 +222,46 @@ func (m *UDPMuxDefault) connWorker() {
return return
} }
// look up forward destination
v, ok := m.addressMap.Load(addr.String())
if !ok {
// ignore packets that we don't know where to route to
continue
}
udpAddr, ok := addr.(*net.UDPAddr) udpAddr, ok := addr.(*net.UDPAddr)
if !ok { if !ok {
logger.Errorf("underlying PacketConn did not return a UDPAddr") logger.Errorf("underlying PacketConn did not return a UDPAddr")
return return
} }
c := v.(*udpMuxedConn)
err = c.writePacket(buf[:n], udpAddr) // If we have already seen this address dispatch to the appropriate destination
if err != nil { m.addressMapMu.Lock()
destinationConn := m.addressMap[addr.String()]
m.addressMapMu.Unlock()
// If we haven't seen this address before but is a STUN packet lookup by ufrag
if destinationConn == nil && stun.IsMessage(buf[:n]) {
msg := &stun.Message{
Raw: append([]byte{}, buf[:n]...),
}
if err = msg.Decode(); err != nil {
m.params.Logger.Warnf("Failed to handle decode ICE from %s: %v\n", addr.String(), err)
continue
}
attr, stunAttrErr := msg.Get(stun.AttrUsername)
if stunAttrErr != nil {
m.params.Logger.Warnf("No Username attribute in STUN message from %s\n", addr.String())
continue
}
ufrag := strings.Split(string(attr), ":")[0]
m.mu.Lock()
destinationConn = m.conns[ufrag]
m.mu.Unlock()
}
if destinationConn == nil {
continue
}
if err = destinationConn.writePacket(buf[:n], udpAddr); err != nil {
logger.Errorf("could not write packet: %v", err) logger.Errorf("could not write packet: %v", err)
} }
} }

View File

@@ -147,6 +147,7 @@ func (c *udpMuxedConn) addAddress(addr string) {
func (c *udpMuxedConn) removeAddress(addr string) { func (c *udpMuxedConn) removeAddress(addr string) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
newAddresses := make([]string, 0, len(c.addresses)) newAddresses := make([]string, 0, len(c.addresses))
for _, a := range c.addresses { for _, a := range c.addresses {
if a != addr { if a != addr {