mirror of
https://github.com/pion/ice.git
synced 2025-10-30 02:11:50 +08:00
Single port handling via UDPMux
Allows for ICE to handle connections on a single UDP port
This commit is contained in:
@@ -62,6 +62,7 @@ Check out the **[contributing wiki](https://github.com/pion/webrtc/wiki/Contribu
|
|||||||
* [Assad Obaid](https://github.com/assadobaid)
|
* [Assad Obaid](https://github.com/assadobaid)
|
||||||
* [Antoine Baché](https://github.com/Antonito)
|
* [Antoine Baché](https://github.com/Antonito)
|
||||||
* [Will Forcey](https://github.com/wawesomeNOGUI)
|
* [Will Forcey](https://github.com/wawesomeNOGUI)
|
||||||
|
* [David Zhao](https://github.com/davidzhao)
|
||||||
|
|
||||||
### License
|
### License
|
||||||
MIT License - see [LICENSE](LICENSE) for full text
|
MIT License - see [LICENSE](LICENSE) for full text
|
||||||
|
|||||||
5
agent.go
5
agent.go
@@ -123,6 +123,7 @@ type Agent struct {
|
|||||||
|
|
||||||
net *vnet.Net
|
net *vnet.Net
|
||||||
tcpMux TCPMux
|
tcpMux TCPMux
|
||||||
|
udpMux UDPMux
|
||||||
|
|
||||||
interfaceFilter func(string) bool
|
interfaceFilter func(string) bool
|
||||||
|
|
||||||
@@ -314,6 +315,7 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
|
|||||||
if a.tcpMux == nil {
|
if a.tcpMux == nil {
|
||||||
a.tcpMux = newInvalidTCPMux()
|
a.tcpMux = newInvalidTCPMux()
|
||||||
}
|
}
|
||||||
|
a.udpMux = config.UDPMux
|
||||||
|
|
||||||
if a.net == nil {
|
if a.net == nil {
|
||||||
a.net = vnet.NewNet(nil)
|
a.net = vnet.NewNet(nil)
|
||||||
@@ -897,6 +899,9 @@ func (a *Agent) Close() error {
|
|||||||
a.err.Store(ErrClosed)
|
a.err.Store(ErrClosed)
|
||||||
|
|
||||||
a.tcpMux.RemoveConnByUfrag(a.localUfrag)
|
a.tcpMux.RemoveConnByUfrag(a.localUfrag)
|
||||||
|
if a.udpMux != nil {
|
||||||
|
a.udpMux.RemoveConnByUfrag(a.localUfrag)
|
||||||
|
}
|
||||||
|
|
||||||
close(a.done)
|
close(a.done)
|
||||||
|
|
||||||
|
|||||||
@@ -145,6 +145,11 @@ type AgentConfig struct {
|
|||||||
// experimental and the API might change in the future.
|
// experimental and the API might change in the future.
|
||||||
TCPMux TCPMux
|
TCPMux TCPMux
|
||||||
|
|
||||||
|
// UDPMux is used for multiplexing multiple incoming UDP connections on a single port
|
||||||
|
// when this is set, the agent ignores PortMin and PortMax configurations and will
|
||||||
|
// defer to UDPMux for incoming connections
|
||||||
|
UDPMux UDPMux
|
||||||
|
|
||||||
// Proxy Dialer is a dialer that should be implemented by the user based on golang.org/x/net/proxy
|
// Proxy Dialer is a dialer that should be implemented by the user based on golang.org/x/net/proxy
|
||||||
// dial interface in order to support corporate proxies
|
// dial interface in order to support corporate proxies
|
||||||
ProxyDialer proxy.Dialer
|
ProxyDialer proxy.Dialer
|
||||||
|
|||||||
67
agent_udpmux_test.go
Normal file
67
agent_udpmux_test.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
// +build !js
|
||||||
|
|
||||||
|
package ice
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pion/logging"
|
||||||
|
"github.com/pion/transport/test"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestMuxAgent is an end to end test over UDP mux, ensuring two agents could connect over mux
|
||||||
|
func TestMuxAgent(t *testing.T) {
|
||||||
|
report := test.CheckRoutines(t)
|
||||||
|
defer report()
|
||||||
|
loggerFactory := logging.NewDefaultLoggerFactory()
|
||||||
|
|
||||||
|
udpMux := NewUDPMuxDefault(UDPMuxParams{
|
||||||
|
Logger: loggerFactory.NewLogger("ice"),
|
||||||
|
ReadBufferSize: 20,
|
||||||
|
})
|
||||||
|
muxPort := 7686
|
||||||
|
require.NoError(t, udpMux.Start(muxPort))
|
||||||
|
|
||||||
|
muxedA, err := NewAgent(&AgentConfig{
|
||||||
|
UDPMux: udpMux,
|
||||||
|
CandidateTypes: []CandidateType{CandidateTypeHost},
|
||||||
|
NetworkTypes: supportedNetworkTypes(),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
a, err := NewAgent(&AgentConfig{
|
||||||
|
CandidateTypes: []CandidateType{CandidateTypeHost},
|
||||||
|
NetworkTypes: supportedNetworkTypes(),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
conn, muxedConn := connect(a, muxedA)
|
||||||
|
|
||||||
|
pair := muxedA.getSelectedPair()
|
||||||
|
require.NotNil(t, pair)
|
||||||
|
require.Equal(t, muxPort, pair.Local.Port())
|
||||||
|
|
||||||
|
// send a packet to Mux
|
||||||
|
data := []byte("hello world")
|
||||||
|
_, err = conn.Write(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
buffer := make([]byte, 1024)
|
||||||
|
n, err := muxedConn.Read(buffer)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, data, buffer[:n])
|
||||||
|
|
||||||
|
// send a packet from Mux
|
||||||
|
_, err = muxedConn.Write(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
n, err = conn.Read(buffer)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, data, buffer[:n])
|
||||||
|
|
||||||
|
// close it down
|
||||||
|
require.NoError(t, conn.Close())
|
||||||
|
require.NoError(t, muxedConn.Close())
|
||||||
|
require.NoError(t, udpMux.Close())
|
||||||
|
}
|
||||||
@@ -109,6 +109,9 @@ var (
|
|||||||
// ErrTCPRemoteAddrAlreadyExists indicates we already have the connection with same remote addr.
|
// ErrTCPRemoteAddrAlreadyExists indicates we already have the connection with same remote addr.
|
||||||
ErrTCPRemoteAddrAlreadyExists = errors.New("conn with same remote addr already exists")
|
ErrTCPRemoteAddrAlreadyExists = errors.New("conn with same remote addr already exists")
|
||||||
|
|
||||||
|
// ErrMuxNotStarted indicates the Mux has not been started prior to use
|
||||||
|
ErrMuxNotStarted = errors.New("mux must be started first")
|
||||||
|
|
||||||
errSendPacket = errors.New("failed to send packet")
|
errSendPacket = errors.New("failed to send packet")
|
||||||
errAttributeTooShortICECandidate = errors.New("attribute not long enough to be ICE candidate")
|
errAttributeTooShortICECandidate = errors.New("attribute not long enough to be ICE candidate")
|
||||||
errParseComponent = errors.New("could not parse component")
|
errParseComponent = errors.New("could not parse component")
|
||||||
|
|||||||
@@ -164,7 +164,6 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
|
|||||||
switch network {
|
switch network {
|
||||||
case tcp:
|
case tcp:
|
||||||
// Handle ICE TCP passive mode
|
// Handle ICE TCP passive mode
|
||||||
|
|
||||||
a.log.Debugf("GetConn by ufrag: %s\n", a.localUfrag)
|
a.log.Debugf("GetConn by ufrag: %s\n", a.localUfrag)
|
||||||
conn, err = a.tcpMux.GetConnByUfrag(a.localUfrag)
|
conn, err = a.tcpMux.GetConnByUfrag(a.localUfrag)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -178,11 +177,19 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
|
|||||||
// is there a way to verify that the listen address is even
|
// is there a way to verify that the listen address is even
|
||||||
// accessible from the current interface.
|
// accessible from the current interface.
|
||||||
case udp:
|
case udp:
|
||||||
|
if a.udpMux != nil {
|
||||||
|
conn, err = a.udpMux.GetConnByUfrag(a.localUfrag)
|
||||||
|
if err != nil {
|
||||||
|
a.log.Warnf("could not get udp muxed connection: %v\n", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
} else {
|
||||||
conn, err = listenUDPInPortRange(a.net, a.log, int(a.portmax), int(a.portmin), network, &net.UDPAddr{IP: ip, Port: 0})
|
conn, err = listenUDPInPortRange(a.net, a.log, int(a.portmax), int(a.portmin), network, &net.UDPAddr{IP: ip, Port: 0})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.log.Warnf("could not listen %s %s\n", network, ip)
|
a.log.Warnf("could not listen %s %s\n", network, ip)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
port = conn.LocalAddr().(*net.UDPAddr).Port
|
port = conn.LocalAddr().(*net.UDPAddr).Port
|
||||||
}
|
}
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -4,6 +4,7 @@ go 1.13
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/google/uuid v1.2.0
|
github.com/google/uuid v1.2.0
|
||||||
|
github.com/kr/pretty v0.1.0 // indirect
|
||||||
github.com/pion/dtls/v2 v2.0.9
|
github.com/pion/dtls/v2 v2.0.9
|
||||||
github.com/pion/logging v0.2.2
|
github.com/pion/logging v0.2.2
|
||||||
github.com/pion/mdns v0.0.5
|
github.com/pion/mdns v0.0.5
|
||||||
@@ -13,4 +14,5 @@ require (
|
|||||||
github.com/pion/turn/v2 v2.0.5
|
github.com/pion/turn/v2 v2.0.5
|
||||||
github.com/stretchr/testify v1.7.0
|
github.com/stretchr/testify v1.7.0
|
||||||
golang.org/x/net v0.0.0-20210331212208-0fccb6fa2b5c
|
golang.org/x/net v0.0.0-20210331212208-0fccb6fa2b5c
|
||||||
|
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
8
go.sum
8
go.sum
@@ -2,6 +2,11 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
|
|||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/google/uuid v1.2.0 h1:qJYtXnJRWmpe7m/3XlyhrsLrEURqHRM2kxzoxXqyUDs=
|
github.com/google/uuid v1.2.0 h1:qJYtXnJRWmpe7m/3XlyhrsLrEURqHRM2kxzoxXqyUDs=
|
||||||
github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||||
|
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||||
|
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||||
|
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||||
|
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||||
github.com/pion/dtls/v2 v2.0.9 h1:7Ow+V++YSZQMYzggI0P9vLJz/hUFcffsfGMfT/Qy+u8=
|
github.com/pion/dtls/v2 v2.0.9 h1:7Ow+V++YSZQMYzggI0P9vLJz/hUFcffsfGMfT/Qy+u8=
|
||||||
github.com/pion/dtls/v2 v2.0.9/go.mod h1:O0Wr7si/Zj5/EBFlDzDd6UtVxx25CE1r7XM7BQKYQho=
|
github.com/pion/dtls/v2 v2.0.9/go.mod h1:O0Wr7si/Zj5/EBFlDzDd6UtVxx25CE1r7XM7BQKYQho=
|
||||||
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
|
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
|
||||||
@@ -50,7 +55,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
|||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
|
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
|
||||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
|
||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
|||||||
@@ -23,9 +23,8 @@ type TCPMux interface {
|
|||||||
RemoveConnByUfrag(ufrag string)
|
RemoveConnByUfrag(ufrag string)
|
||||||
}
|
}
|
||||||
|
|
||||||
// invalidTCPMux is an implementation of TCPMux that always returns ErroTCPMuxNotInitialized.
|
// invalidTCPMux is an implementation of TCPMux that always returns ErrTCPMuxNotInitialized.
|
||||||
type invalidTCPMux struct {
|
type invalidTCPMux struct{}
|
||||||
}
|
|
||||||
|
|
||||||
func newInvalidTCPMux() *invalidTCPMux {
|
func newInvalidTCPMux() *invalidTCPMux {
|
||||||
return &invalidTCPMux{}
|
return &invalidTCPMux{}
|
||||||
|
|||||||
251
udp_mux.go
Normal file
251
udp_mux.go
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
package ice
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/pion/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UDPMux allows multiple connections to go over a single UDP port
|
||||||
|
type UDPMux interface {
|
||||||
|
io.Closer
|
||||||
|
GetConnByUfrag(ufrag string) (net.PacketConn, error)
|
||||||
|
RemoveConnByUfrag(ufrag string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UDPMuxDefault is an implementation of the interface
|
||||||
|
type UDPMuxDefault struct {
|
||||||
|
params UDPMuxParams
|
||||||
|
listenAddr *net.UDPAddr
|
||||||
|
udpConn *net.UDPConn
|
||||||
|
|
||||||
|
mappingChan chan connMap
|
||||||
|
closedChan chan struct{}
|
||||||
|
closeOnce sync.Once
|
||||||
|
|
||||||
|
// conns is a map of all udpMuxedConn indexed by ufrag
|
||||||
|
conns map[string]*udpMuxedConn
|
||||||
|
|
||||||
|
// buffer pool to recycle buffers for incoming packets
|
||||||
|
pool *sync.Pool
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
type connMap struct {
|
||||||
|
address string
|
||||||
|
conn *udpMuxedConn
|
||||||
|
}
|
||||||
|
|
||||||
|
// UDPMuxParams are parameters for UDPMux.
|
||||||
|
type UDPMuxParams struct {
|
||||||
|
Logger logging.LeveledLogger
|
||||||
|
ReadBufferSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUDPMuxDefault creates an implementation of UDPMux
|
||||||
|
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||||
|
return &UDPMuxDefault{
|
||||||
|
params: params,
|
||||||
|
conns: make(map[string]*udpMuxedConn),
|
||||||
|
mappingChan: make(chan connMap, 10),
|
||||||
|
closedChan: make(chan struct{}, 1),
|
||||||
|
pool: &sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return make([]byte, receiveMTU)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the mux. Before the UDPMux is usable, it must be started
|
||||||
|
func (m *UDPMuxDefault) Start(port int) error {
|
||||||
|
if m.udpConn != nil {
|
||||||
|
return ErrMultipleStart
|
||||||
|
}
|
||||||
|
m.listenAddr = &net.UDPAddr{
|
||||||
|
Port: port,
|
||||||
|
}
|
||||||
|
uc, err := net.ListenUDP(udp, m.listenAddr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.udpConn = uc
|
||||||
|
|
||||||
|
go m.connWorker()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr returns the listening address of this UDPMuxDefault
|
||||||
|
func (m *UDPMuxDefault) LocalAddr() net.Addr {
|
||||||
|
return m.listenAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConnByUfrag returns a PacketConn given the connection's ufrag.
|
||||||
|
// creates the connection if an existing one can't be found
|
||||||
|
func (m *UDPMuxDefault) GetConnByUfrag(ufrag string) (net.PacketConn, error) {
|
||||||
|
if m.udpConn == nil {
|
||||||
|
return nil, ErrMuxNotStarted
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.IsClosed() {
|
||||||
|
return nil, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
|
||||||
|
if c, ok := m.conns[ufrag]; ok {
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c := m.createMuxedConn()
|
||||||
|
go func() {
|
||||||
|
<-c.CloseChannel()
|
||||||
|
m.RemoveConnByUfrag(ufrag)
|
||||||
|
}()
|
||||||
|
m.conns[ufrag] = c
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveConnByUfrag stops and removes the muxed packet connection
|
||||||
|
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
||||||
|
// get addresses to remove
|
||||||
|
m.mu.Lock()
|
||||||
|
c := m.conns[ufrag]
|
||||||
|
delete(m.conns, ufrag)
|
||||||
|
// keep lock section small to avoid deadlock with conn lock
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
addresses := c.getAddresses()
|
||||||
|
|
||||||
|
for _, addr := range addresses {
|
||||||
|
m.mappingChan <- connMap{
|
||||||
|
address: addr,
|
||||||
|
conn: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsClosed returns true if the mux had been closed
|
||||||
|
func (m *UDPMuxDefault) IsClosed() bool {
|
||||||
|
select {
|
||||||
|
case <-m.closedChan:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the mux, no further connections could be created
|
||||||
|
func (m *UDPMuxDefault) Close() error {
|
||||||
|
var err error
|
||||||
|
m.closeOnce.Do(func() {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
// close udp conn and prevent packets coming in
|
||||||
|
err = m.udpConn.Close()
|
||||||
|
|
||||||
|
for _, c := range m.conns {
|
||||||
|
_ = c.Close()
|
||||||
|
}
|
||||||
|
m.conns = make(map[string]*udpMuxedConn)
|
||||||
|
close(m.closedChan)
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *UDPMuxDefault) writeTo(buf []byte, raddr net.Addr) (n int, err error) {
|
||||||
|
return m.udpConn.WriteTo(buf, raddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *UDPMuxDefault) doneWithBuffer(buf []byte) {
|
||||||
|
//nolint
|
||||||
|
m.pool.Put(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
|
||||||
|
if m.IsClosed() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.mappingChan <- connMap{
|
||||||
|
address: addr,
|
||||||
|
conn: conn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *UDPMuxDefault) createMuxedConn() *udpMuxedConn {
|
||||||
|
c := newUDPMuxedConn(&udpMuxedConnParams{
|
||||||
|
Mux: m,
|
||||||
|
ReadBuffer: m.params.ReadBufferSize,
|
||||||
|
LocalAddr: m.LocalAddr(),
|
||||||
|
Logger: m.params.Logger,
|
||||||
|
})
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *UDPMuxDefault) connWorker() {
|
||||||
|
// map of remote addresses -> udpMuxedConn
|
||||||
|
// used to look up incoming packets
|
||||||
|
remoteMap := make(map[string]*udpMuxedConn)
|
||||||
|
logger := m.params.Logger
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
_ = m.Close()
|
||||||
|
}()
|
||||||
|
for {
|
||||||
|
buffer := m.pool.Get().([]byte)
|
||||||
|
n, addr, err := m.udpConn.ReadFrom(buffer)
|
||||||
|
if err == io.EOF {
|
||||||
|
return
|
||||||
|
} else if err != nil {
|
||||||
|
logger.Errorf("could not read udp packet: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// process any mapping changes
|
||||||
|
m.applyMappingChanges(remoteMap)
|
||||||
|
|
||||||
|
// look up forward destination
|
||||||
|
addrStr := addr.String()
|
||||||
|
c := remoteMap[addrStr]
|
||||||
|
|
||||||
|
if c == nil {
|
||||||
|
//nolint
|
||||||
|
m.pool.Put(buffer)
|
||||||
|
// ignore packets that we don't know where to route to
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.writePacket(muxedPacket{
|
||||||
|
Data: buffer,
|
||||||
|
Size: n,
|
||||||
|
RAddr: addr,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("could not write packet: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *UDPMuxDefault) applyMappingChanges(remoteMap map[string]*udpMuxedConn) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case cm := <-m.mappingChan:
|
||||||
|
// deregister previous addresses
|
||||||
|
existingConn := remoteMap[cm.address]
|
||||||
|
if existingConn != nil {
|
||||||
|
existingConn.removeAddress(cm.address)
|
||||||
|
}
|
||||||
|
remoteMap[cm.address] = cm.conn
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
107
udp_mux_test.go
Normal file
107
udp_mux_test.go
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
// +build !js
|
||||||
|
|
||||||
|
package ice
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/logging"
|
||||||
|
"github.com/pion/stun"
|
||||||
|
"github.com/pion/transport/test"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUDPMux(t *testing.T) {
|
||||||
|
report := test.CheckRoutines(t)
|
||||||
|
defer report()
|
||||||
|
|
||||||
|
loggerFactory := logging.NewDefaultLoggerFactory()
|
||||||
|
udpMux := NewUDPMuxDefault(UDPMuxParams{
|
||||||
|
Logger: loggerFactory.NewLogger("ice"),
|
||||||
|
ReadBufferSize: 20,
|
||||||
|
})
|
||||||
|
err := udpMux.Start(7686)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
_ = udpMux.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
require.NotNil(t, udpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
|
||||||
|
require.Equal(t, ":7686", udpMux.LocalAddr().String())
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
testMuxConnection(t, udpMux, "ufrag1")
|
||||||
|
}()
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
testMuxConnection(t, udpMux, "ufrag2")
|
||||||
|
}()
|
||||||
|
|
||||||
|
testMuxConnection(t, udpMux, "ufrag3")
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
require.NoError(t, udpMux.Close())
|
||||||
|
|
||||||
|
// can't create more connections
|
||||||
|
_, err = udpMux.GetConnByUfrag("failufrag")
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string) {
|
||||||
|
pktConn, err := udpMux.GetConnByUfrag(ufrag)
|
||||||
|
require.NoError(t, err, "error retrieving muxed connection for ufrag")
|
||||||
|
defer func() {
|
||||||
|
_ = pktConn.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
remoteConn, err := net.DialUDP(udp, nil, udpMux.LocalAddr().(*net.UDPAddr))
|
||||||
|
require.NoError(t, err, "error dialing test udp connection")
|
||||||
|
|
||||||
|
// initial messages are dropped
|
||||||
|
_, err = remoteConn.Write([]byte("dropped bytes"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
// wait for packet to be consumed
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
|
||||||
|
// write out to establish connection
|
||||||
|
msg := stun.New()
|
||||||
|
msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest}
|
||||||
|
msg.Add(stun.AttrUsername, []byte(ufrag+":otherufrag"))
|
||||||
|
msg.Encode()
|
||||||
|
_, err = pktConn.WriteTo(msg.Raw, remoteConn.LocalAddr())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// ensure received
|
||||||
|
buf := make([]byte, receiveMTU)
|
||||||
|
n, err := remoteConn.Read(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, msg.Raw, buf[:n])
|
||||||
|
|
||||||
|
// write a bunch of packets from remote to ensure proper receipt
|
||||||
|
dataToSend := [][]byte{
|
||||||
|
[]byte("hello world"),
|
||||||
|
[]byte("test text"),
|
||||||
|
msg.Raw,
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer := make([]byte, receiveMTU)
|
||||||
|
for _, data := range dataToSend {
|
||||||
|
_, err := remoteConn.Write(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
n, _, err := pktConn.ReadFrom(buffer)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, data, buffer[:n])
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
165
udp_muxed_conn.go
Normal file
165
udp_muxed_conn.go
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
package ice
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
type udpMuxedConnParams struct {
|
||||||
|
Mux *UDPMuxDefault
|
||||||
|
ReadBuffer int
|
||||||
|
LocalAddr net.Addr
|
||||||
|
Logger logging.LeveledLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
type muxedPacket struct {
|
||||||
|
Data []byte
|
||||||
|
RAddr net.Addr
|
||||||
|
Size int
|
||||||
|
}
|
||||||
|
|
||||||
|
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
|
||||||
|
type udpMuxedConn struct {
|
||||||
|
params *udpMuxedConnParams
|
||||||
|
// remote addresses that we have sent to on this conn
|
||||||
|
addresses []string
|
||||||
|
|
||||||
|
// channel holding incoming packets
|
||||||
|
recvChan chan muxedPacket
|
||||||
|
closedChan chan struct{}
|
||||||
|
closeOnce sync.Once
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn {
|
||||||
|
p := &udpMuxedConn{
|
||||||
|
params: params,
|
||||||
|
recvChan: make(chan muxedPacket, params.ReadBuffer),
|
||||||
|
closedChan: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpMuxedConn) ReadFrom(b []byte) (n int, raddr net.Addr, err error) {
|
||||||
|
pkt, ok := <-c.recvChan
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return 0, nil, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
|
||||||
|
if cap(b) < pkt.Size {
|
||||||
|
return 0, pkt.RAddr, io.ErrShortBuffer
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(b, pkt.Data[:pkt.Size])
|
||||||
|
c.params.Mux.doneWithBuffer(pkt.Data)
|
||||||
|
return pkt.Size, pkt.RAddr, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpMuxedConn) WriteTo(buf []byte, raddr net.Addr) (n int, err error) {
|
||||||
|
if c.isClosed() {
|
||||||
|
return 0, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
// each time we write to a new address, we'll register it with the mux
|
||||||
|
addr := raddr.String()
|
||||||
|
if !c.containsAddress(addr) {
|
||||||
|
c.addAddress(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.params.Mux.writeTo(buf, raddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpMuxedConn) LocalAddr() net.Addr {
|
||||||
|
return c.params.LocalAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpMuxedConn) SetDeadline(tm time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpMuxedConn) SetReadDeadline(tm time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpMuxedConn) SetWriteDeadline(tm time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpMuxedConn) CloseChannel() <-chan struct{} {
|
||||||
|
return c.closedChan
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpMuxedConn) Close() error {
|
||||||
|
c.closeOnce.Do(func() {
|
||||||
|
close(c.closedChan)
|
||||||
|
close(c.recvChan)
|
||||||
|
})
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.addresses = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpMuxedConn) isClosed() bool {
|
||||||
|
select {
|
||||||
|
case <-c.closedChan:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpMuxedConn) getAddresses() []string {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
addresses := make([]string, len(c.addresses))
|
||||||
|
copy(addresses, c.addresses)
|
||||||
|
return addresses
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpMuxedConn) addAddress(addr string) {
|
||||||
|
c.mu.Lock()
|
||||||
|
c.addresses = append(c.addresses, addr)
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
// map it on mux
|
||||||
|
c.params.Mux.registerConnForAddress(c, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpMuxedConn) removeAddress(addr string) {
|
||||||
|
newAddresses := make([]string, 0, len(c.addresses))
|
||||||
|
for _, a := range c.addresses {
|
||||||
|
if a != addr {
|
||||||
|
newAddresses = append(newAddresses, a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
c.addresses = newAddresses
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpMuxedConn) containsAddress(addr string) bool {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
for _, a := range c.addresses {
|
||||||
|
if addr == a {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpMuxedConn) writePacket(pkt muxedPacket) error {
|
||||||
|
select {
|
||||||
|
case c.recvChan <- pkt:
|
||||||
|
return nil
|
||||||
|
case <-c.closedChan:
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user