diff --git a/gather_test.go b/gather_test.go index 4deadb3..af075b6 100644 --- a/gather_test.go +++ b/gather_test.go @@ -1100,6 +1100,63 @@ func TestGatherCandidatesRelayCallsAddRelayCandidates(t *testing.T) { assert.True(t, locConn.closed) } +func TestGatherCandidatesRelayUsesTurnNet(t *testing.T) { + defer test.CheckRoutines(t)() + + stubClient := &stubTurnClient{} + turnNet := newRelayGatherNet(&net.UDPAddr{IP: net.IPv4(10, 0, 0, 2), Port: 50000}) + + agent, err := NewAgentWithOptions( + WithNet(turnNet), + WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), + WithCandidateTypes([]CandidateType{CandidateTypeRelay}), + WithUrls([]*stun.URI{ + { + Scheme: stun.SchemeTypeTURN, + Host: "example.com", + Port: 3478, + Username: "username", + Password: "password", + Proto: stun.ProtoTypeUDP, + }, + }), + WithMulticastDNSMode(MulticastDNSModeDisabled), + ) + require.NoError(t, err) + defer func() { + require.NoError(t, agent.Close()) + }() + + stubClient.relayConn = newStubPacketConn(&net.UDPAddr{IP: net.IP{203, 0, 113, 9}, Port: 6000}) + agent.turnClientFactory = func(cfg *turn.ClientConfig) (turnClient, error) { + stubClient.cfgConn = cfg.Conn + + return stubClient, nil + } + + candCh := make(chan Candidate, 1) + require.NoError(t, agent.OnCandidate(func(c Candidate) { + if c != nil && c.Type() == CandidateTypeRelay { + candCh <- c + } + })) + + agent.gatherCandidatesRelay(context.Background(), agent.urls) + + select { + case cand := <-candCh: + relay, ok := cand.(*CandidateRelay) + require.True(t, ok) + require.Equal(t, turnNet.addr.IP.String(), relay.RelatedAddress().Address) + + addr, ok := stubClient.cfgConn.LocalAddr().(*net.UDPAddr) + require.True(t, ok) + require.Equal(t, turnNet.addr.IP.String(), addr.IP.String()) + case <-time.After(time.Second): + assert.Fail(t, "expected relay candidate using turn network") + } +} + func TestGatherCandidatesRelayDefaultClientError(t *testing.T) { defer test.CheckRoutines(t)() diff --git a/gather_vnet_test.go b/gather_vnet_test.go index 3b201fd..383ec8b 100644 --- a/gather_vnet_test.go +++ b/gather_vnet_test.go @@ -11,11 +11,13 @@ import ( "fmt" "net" "testing" + "time" "github.com/pion/logging" "github.com/pion/stun/v3" "github.com/pion/transport/v3/test" "github.com/pion/transport/v3/vnet" + "github.com/pion/turn/v4" "github.com/stretchr/testify/require" ) @@ -395,6 +397,126 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) { }) } +func TestGatherRelayWithVNet(t *testing.T) { + defer test.CheckRoutines(t)() + + loggerFactory := logging.NewDefaultLoggerFactory() + + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "10.0.0.0/24", + LoggerFactory: loggerFactory, + }) + require.NoError(t, err) + + clientNet, err := vnet.NewNet(&vnet.NetConfig{ + StaticIPs: []string{"10.0.0.2"}, + }) + require.NoError(t, err) + + serverNet, err := vnet.NewNet(&vnet.NetConfig{ + StaticIPs: []string{"10.0.0.3"}, + }) + require.NoError(t, err) + + require.NoError(t, router.AddNet(clientNet)) + require.NoError(t, router.AddNet(serverNet)) + require.NoError(t, router.Start()) + defer func() { + require.NoError(t, router.Stop()) + }() + + turnAddr := &net.UDPAddr{ + IP: net.IPv4(10, 0, 0, 3), + Port: 3478, + } + serverConn, err := serverNet.ListenPacket("udp4", turnAddr.String()) + require.NoError(t, err) + + relayGenerator := &turn.RelayAddressGeneratorStatic{ + RelayAddress: turnAddr.IP, + Address: turnAddr.IP.String(), + Net: serverNet, + } + + const ( + turnRealm = "pion.ly" + turnUser = "user" + turnPass = "pass" + ) + + server, err := turn.NewServer(turn.ServerConfig{ + LoggerFactory: loggerFactory, + Realm: turnRealm, + PacketConnConfigs: []turn.PacketConnConfig{ + { + PacketConn: serverConn, + RelayAddressGenerator: relayGenerator, + }, + }, + AuthHandler: func(username, realm string, srcAddr net.Addr) ([]byte, bool) { + if username != turnUser { + return nil, false + } + + return turn.GenerateAuthKey(username, realm, turnPass), true + }, + }) + require.NoError(t, err) + defer func() { + require.NoError(t, server.Close()) + }() + + agent, err := NewAgentWithOptions( + WithNet(clientNet), + WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), + WithCandidateTypes([]CandidateType{CandidateTypeRelay}), + WithUrls([]*stun.URI{ + { + Scheme: stun.SchemeTypeTURN, + Host: turnAddr.IP.String(), + Port: turnAddr.Port, + Username: turnUser, + Password: turnPass, + Proto: stun.ProtoTypeUDP, + }, + }), + WithMulticastDNSMode(MulticastDNSModeDisabled), + ) + require.NoError(t, err) + defer func() { + require.NoError(t, agent.Close()) + }() + + relayCandidates := make(chan Candidate, 1) + done := make(chan struct{}) + require.NoError(t, agent.OnCandidate(func(c Candidate) { + if c == nil { + close(done) + + return + } + + if c.Type() == CandidateTypeRelay { + select { + case relayCandidates <- c: + default: + } + } + })) + + require.NoError(t, agent.GatherCandidates()) + + select { + case cand := <-relayCandidates: + require.Equal(t, CandidateTypeRelay, cand.Type()) + require.Equal(t, "10.0.0.3", cand.Address()) + case <-done: + require.Fail(t, "gathering finished without relay candidate") + case <-time.After(5 * time.Second): + require.Fail(t, "timeout waiting for relay candidate") + } +} + func TestVNetGather_TURNConnectionLeak(t *testing.T) { defer test.CheckRoutines(t)()