From 2b85befb35dfaa924783a317b2f008cd7b7a97b1 Mon Sep 17 00:00:00 2001 From: Joe Turki Date: Sun, 14 Dec 2025 15:20:38 +0200 Subject: [PATCH] Wire and test renomination --- icegatherer.go | 27 ++ icegatherer_test.go | 751 ++++++++++++++++++++++++++++++++++++++++++ settingengine.go | 66 +++- settingengine_test.go | 51 +++ 4 files changed, 894 insertions(+), 1 deletion(-) diff --git a/icegatherer.go b/icegatherer.go index b459c714..236d9c7f 100644 --- a/icegatherer.go +++ b/icegatherer.go @@ -11,6 +11,7 @@ import ( "strings" "sync" "sync/atomic" + "time" "github.com/pion/ice/v4" "github.com/pion/logging" @@ -105,6 +106,7 @@ func (g *ICEGatherer) buildAgentOptions() []ice.AgentOption { options = append(options, g.natRewriteOptions(nat1To1CandiTyp)...) options = append(options, g.timeoutOptions()...) options = append(options, g.miscOptions()...) + options = append(options, g.renominationOptions()...) requestedNetworkTypes := g.api.settingEngine.candidates.ICENetworkTypes if len(requestedNetworkTypes) == 0 { @@ -247,6 +249,31 @@ func (g *ICEGatherer) miscOptions() []ice.AgentOption { return opts } +func (g *ICEGatherer) renominationOptions() []ice.AgentOption { + renom := g.api.settingEngine.renomination + if !renom.enabled && !renom.automatic { + return nil + } + + generator := renom.generator + opts := []ice.AgentOption{ + ice.WithRenomination(func() uint32 { + return generator() + }), + } + + if renom.automatic { + interval := time.Duration(0) + if renom.automaticInterval != nil { + interval = *renom.automaticInterval + } + + opts = append(opts, ice.WithAutomaticRenomination(interval)) + } + + return opts +} + func legacyNAT1To1AddressRewriteRules(ips []string, candidateType ice.CandidateType) []ice.AddressRewriteRule { catchAll := make([]string, 0, len(ips)) rules := make([]ice.AddressRewriteRule, 0, len(ips)+1) diff --git a/icegatherer_test.go b/icegatherer_test.go index 2d7b7b29..2d8f987d 100644 --- a/icegatherer_test.go +++ b/icegatherer_test.go @@ -11,6 +11,7 @@ import ( "fmt" "net" "strings" + "sync" "sync/atomic" "testing" "time" @@ -1061,3 +1062,753 @@ func TestNewICEGathererSetMediaStreamIdentification(t *testing.T) { //nolint:cyc assert.NoError(t, gatherer.Close()) } + +func TestICEGatherer_RenominationOptions(t *testing.T) { + se := SettingEngine{} + assert.NoError(t, se.SetICERenomination()) + assert.True(t, se.renomination.enabled) + assert.True(t, se.renomination.automatic) + assert.Nil(t, se.renomination.automaticInterval) + assert.NotNil(t, se.renomination.generator) +} + +func TestICEGatherer_RenominationOptionsDisabled(t *testing.T) { + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + offerPC, answerPC, cleanup := buildRenominationVNetPair(t, false, false, nil) + defer cleanup() + + connectAndWaitForICE(t, offerPC, answerPC) + + agent := getAgent(t, offerPC) + + selectedPair, err := agent.GetSelectedCandidatePair() + assert.NoError(t, err) + assert.NotNil(t, selectedPair) + + err = agent.RenominateCandidate(selectedPair.Local, selectedPair.Remote) + assert.Error(t, err) + assert.ErrorIs(t, err, ice.ErrRenominationNotEnabled) +} + +func TestICEGatherer_RenominationSendsNomination(t *testing.T) { //nolint:cyclop + lim := test.TimeOut(time.Second * 35) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + nominationCh := make(chan uint32, 2) + handler := func(m *stun.Message, _, _ ice.Candidate, _ *ice.CandidatePair) bool { + var attr ice.NominationAttribute + if err := attr.GetFrom(m); err == nil { + select { + case nominationCh <- attr.Value: + default: + } + } + + return false + } + + offerPC, answerPC, offerSender, answerSender, cleanup := buildStagedRenominationPair(t, handler) + defer cleanup() + + recvCh := make(chan string, 4) + negotiated := true + id := uint16(0) + offerDC, err := offerPC.CreateDataChannel("renomination-dc", &DataChannelInit{ + Negotiated: &negotiated, + ID: &id, + }) + assert.NoError(t, err) + answerDC, err := answerPC.CreateDataChannel("renomination-dc", &DataChannelInit{ + Negotiated: &negotiated, + ID: &id, + }) + assert.NoError(t, err) + answerDC.OnMessage(func(msg DataChannelMessage) { + select { + case recvCh <- string(msg.Data): + default: + } + }) + + connected := make(chan struct{}) + var once sync.Once + offerPC.OnICEConnectionStateChange(func(state ICEConnectionState) { + if state == ICEConnectionStateConnected { + once.Do(func() { + close(connected) + }) + } + }) + + startTrickleRenomination(t, offerPC, answerPC, offerSender, answerSender) + assert.NoError(t, offerSender.errValue()) + assert.NoError(t, answerSender.errValue()) + + select { + case <-connected: + case <-time.After(15 * time.Second): + assert.Fail(t, "timed out waiting for ICE to connect") + } + + pair := selectedCandidatePair(t, offerPC) + assert.NotNil(t, pair) + if pair.Remote.Type() != ice.CandidateTypeServerReflexive { + t.Logf("initial remote candidate type %s (expected srflx), continuing", pair.Remote.Type()) + } + initialStat, initialStatOK := getAgent(t, offerPC).GetSelectedCandidatePairStats() + assert.True(t, initialStatOK) + assert.NoError(t, offerSender.flushHost()) + assert.NoError(t, answerSender.flushHost()) + + waitDataChannelOpen(t, offerDC) + waitDataChannelOpen(t, answerDC) + sendAndExpect(t, offerDC, recvCh, "before-renom") + + waitForTwoRemoteCandidates(t, offerPC) + waitForTwoRemoteCandidates(t, answerPC) + + var switchLocal ice.Candidate + var switchRemote ice.Candidate + agent := getAgent(t, offerPC) + assert.Eventuallyf(t, func() bool { + switchLocal, switchRemote = findSwitchTarget(t, offerPC, initialStat.RemoteCandidateID) + + return switchLocal != nil && switchRemote != nil + }, 10*time.Second, 50*time.Millisecond, "no alternate succeeded pair found; pairs: %s", candidatePairSummary(t, agent)) + assert.NoError(t, agent.RenominateCandidate(switchLocal, switchRemote)) + + sendAndExpect(t, offerDC, recvCh, "after-renom") + + select { + case v := <-nominationCh: + assert.Greater(t, v, uint32(0)) + case <-time.After(20 * time.Second): + assert.Fail(t, "did not observe nomination attribute on binding request") + } +} + +func TestICEGatherer_RenominationSwitchesPair(t *testing.T) { //nolint:cyclop + lim := test.TimeOut(time.Second * 45) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + offerPC, answerPC, offerSender, answerSender, cleanup := buildStagedRenominationPair(t, nil) + defer cleanup() + + recvCh := make(chan string, 4) + negotiated := true + id := uint16(0) + offerDC, err := offerPC.CreateDataChannel("renomination-dc", &DataChannelInit{ + Negotiated: &negotiated, + ID: &id, + }) + assert.NoError(t, err) + answerDC, err := answerPC.CreateDataChannel("renomination-dc", &DataChannelInit{ + Negotiated: &negotiated, + ID: &id, + }) + assert.NoError(t, err) + answerDC.OnMessage(func(msg DataChannelMessage) { + select { + case recvCh <- string(msg.Data): + default: + } + }) + + connected := make(chan struct{}) + offerPC.OnICEConnectionStateChange(func(state ICEConnectionState) { + if state == ICEConnectionStateConnected { + select { + case <-connected: + default: + close(connected) + } + } + }) + + var flushHostOnce sync.Once + flushHosts := func() { + flushHostOnce.Do(func() { + assert.NoError(t, offerSender.flushHost()) + assert.NoError(t, answerSender.flushHost()) + }) + } + + startTrickleRenomination(t, offerPC, answerPC, offerSender, answerSender) + assert.NoError(t, offerSender.errValue()) + assert.NoError(t, answerSender.errValue()) + + // Fallback: release host candidates even if the initial selection check stalls. + go func() { + time.Sleep(time.Second) + flushHosts() + }() + + select { + case <-connected: + case <-time.After(15 * time.Second): + agent := getAgent(t, offerPC) + assert.Fail(t, "timed out waiting for initial connection; pairs: %s", candidatePairSummary(t, agent)) + } + + var initialRemoteType ice.CandidateType + if !assert.Eventuallyf( + t, func() bool { + if pair := selectedCandidatePair(t, offerPC); pair == nil { + return false + } else { + initialRemoteType = pair.Remote.Type() + + return initialRemoteType == ice.CandidateTypeServerReflexive || + initialRemoteType == ice.CandidateTypePeerReflexive + } + }, + 12*time.Second, 30*time.Millisecond, + "expected to start on a srflx/prflx remote candidate (got %s)", initialRemoteType, + ) { + flushHosts() + assert.Fail(t, "expected to start on a srflx/prflx remote candidate") + } + + flushHosts() + + waitDataChannelOpen(t, offerDC) + waitDataChannelOpen(t, answerDC) + sendAndExpect(t, offerDC, recvCh, "before-switch") + + initialPair := selectedCandidatePair(t, offerPC) + initialStat, initialStatOK := getAgent(t, offerPC).GetSelectedCandidatePairStats() + t.Logf("initial selected pair: %s<->%s (%s/%s)", + initialPair.Local.Address(), initialPair.Remote.Address(), initialPair.Local.Type(), initialPair.Remote.Type()) + + waitForTwoRemoteCandidates(t, offerPC) + waitForTwoRemoteCandidates(t, answerPC) + + assert.True(t, initialStatOK, "missing initial selected pair stats") + + switchLocal, switchRemote := findSwitchTarget(t, offerPC, initialStat.RemoteCandidateID) + assert.NotNil(t, switchLocal) + assert.NotNil(t, switchRemote) + assert.NotNil(t, switchLocal.Type()) + assert.NotNil(t, switchRemote.Type()) + assert.False(t, switchLocal.Equal(switchRemote), "switch local and remote candidates should be different") + + t.Logf( + "renomination target: %s/%s -> %s/%s", + switchLocal.Address(), switchLocal.Type(), switchRemote.Address(), switchRemote.Type(), + ) + + agent := getAgent(t, offerPC) + if !assert.Eventually(t, func() bool { + pair := selectedCandidatePair(t, offerPC) + if pair != nil && pair.Local.Equal(switchLocal) && pair.Remote.Equal(switchRemote) { + return true + } + + if err := agent.RenominateCandidate(switchLocal, switchRemote); err != nil { + t.Logf("renomination attempt: %v", err) + } + + return false + }, 10*time.Second, 50*time.Millisecond, "selected pair should change after renomination") { + assert.Fail(t, "selected pair did not switch; pairs: %s", candidatePairSummary(t, agent)) + } + + finalStat, ok := agent.GetSelectedCandidatePairStats() + assert.True(t, ok) + assert.NotEqual( + t, initialStat.RemoteCandidateID, finalStat.RemoteCandidateID, "selected pair should change after renomination", + ) + + finalLocal := findCandidateByID(t, agent, finalStat.LocalCandidateID, true) + finalRemote := findCandidateByID(t, agent, finalStat.RemoteCandidateID, false) + assert.NotNil(t, finalLocal) + assert.NotNil(t, finalRemote) + assert.Equal(t, ice.CandidateTypeHost, finalLocal.Type()) + assert.NotEqual(t, ice.CandidateTypeServerReflexive, finalRemote.Type()) + + finalPair := selectedCandidatePair(t, offerPC) + assert.NotNil(t, finalPair) + sendAndExpect(t, offerDC, recvCh, "after-switch") + assert.False(t, initialPair.Remote.Equal(finalPair.Remote), "expected remote candidate to change after renomination") +} + +func buildRenominationVNetPair( + t *testing.T, + enableRenomination bool, + automatic bool, + bindingHandler func(*stun.Message, ice.Candidate, ice.Candidate, *ice.CandidatePair) bool, +) (*PeerConnection, *PeerConnection, func()) { + t.Helper() + + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "1.2.3.0/24", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + assert.NoError(t, err) + + netStack, err := vnet.NewNet(&vnet.NetConfig{ + StaticIPs: []string{"1.2.3.4"}, + }) + assert.NoError(t, err) + assert.NoError(t, router.AddNet(netStack)) + + answerNet, err := vnet.NewNet(&vnet.NetConfig{ + StaticIPs: []string{"1.2.3.5"}, + }) + assert.NoError(t, err) + assert.NoError(t, router.AddNet(answerNet)) + + assert.NoError(t, router.Start()) + + offerSE := SettingEngine{} + offerSE.SetNet(netStack) + offerSE.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled) + offerSE.SetNetworkTypes([]NetworkType{NetworkTypeUDP4}) + if enableRenomination { + assert.NoError(t, offerSE.SetICERenomination()) + if automatic { + assert.NoError(t, offerSE.SetICERenomination()) + } + } + + answerSE := SettingEngine{} + answerSE.SetNet(answerNet) + answerSE.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled) + answerSE.SetNetworkTypes([]NetworkType{NetworkTypeUDP4}) + if enableRenomination { + assert.NoError(t, answerSE.SetICERenomination()) + if automatic { + assert.NoError(t, answerSE.SetICERenomination()) + } + } + if bindingHandler != nil { + answerSE.SetICEBindingRequestHandler(bindingHandler) + } + + offerPC, err := NewAPI(WithSettingEngine(offerSE)).NewPeerConnection(Configuration{}) + assert.NoError(t, err) + answerPC, err := NewAPI(WithSettingEngine(answerSE)).NewPeerConnection(Configuration{}) + assert.NoError(t, err) + + cleanup := func() { + closePairNow(t, offerPC, answerPC) + assert.NoError(t, router.Stop()) + } + + return offerPC, answerPC, cleanup +} + +func connectAndWaitForICE(t *testing.T, offerPC, answerPC *PeerConnection) { + t.Helper() + + connected := make(chan struct{}) + var once sync.Once + offerPC.OnICEConnectionStateChange(func(state ICEConnectionState) { + if state == ICEConnectionStateConnected { + once.Do(func() { + close(connected) + }) + } + }) + + assert.NoError(t, signalPair(offerPC, answerPC)) + + select { + case <-connected: + case <-time.After(5 * time.Second): + assert.Fail(t, "timed out waiting for ICE to connect") + } +} + +func selectedCandidatePair(t *testing.T, pc *PeerConnection) *ice.CandidatePair { + t.Helper() + + agent := getAgent(t, pc) + + pair, err := agent.GetSelectedCandidatePair() + assert.NoError(t, err) + + return pair +} + +func waitForTwoRemoteCandidates(t *testing.T, pc *PeerConnection) { + t.Helper() + + assert.Eventually(t, func() bool { + agent := getAgent(t, pc) + + remotes, err := agent.GetRemoteCandidates() + assert.NoError(t, err) + + return len(remotes) >= 2 + }, 5*time.Second, 20*time.Millisecond) +} + +func findCandidateByID(t *testing.T, agent *ice.Agent, id string, local bool) ice.Candidate { + t.Helper() + + var cands []ice.Candidate + var err error + if local { + cands, err = agent.GetLocalCandidates() + } else { + cands, err = agent.GetRemoteCandidates() + } + assert.NoError(t, err) + + for _, cand := range cands { + if cand.ID() == id { + return cand + } + } + + return nil +} + +//nolint:cyclop +func findSwitchTarget( + t *testing.T, pc *PeerConnection, excludeRemoteID string, +) (ice.Candidate, ice.Candidate) { + t.Helper() + + agent := getAgent(t, pc) + var targetLocal ice.Candidate + var targetRemote ice.Candidate + + for _, stat := range agent.GetCandidatePairsStats() { + if stat.State != ice.CandidatePairStateSucceeded || + stat.LocalCandidateID == "" || stat.RemoteCandidateID == "" || + stat.RemoteCandidateID == excludeRemoteID { + continue + } + + local := findCandidateByID(t, agent, stat.LocalCandidateID, true) + remote := findCandidateByID(t, agent, stat.RemoteCandidateID, false) + if local == nil || remote == nil { + continue + } + + if local.Type() != ice.CandidateTypeHost { + continue + } + + if remote.Type() == ice.CandidateTypeHost { + return local, remote + } + + if remote.Type() == ice.CandidateTypePeerReflexive { + targetLocal = local + targetRemote = remote + } + } + + return targetLocal, targetRemote +} + +func getAgent(t *testing.T, pc *PeerConnection) *ice.Agent { + t.Helper() + + pc.iceTransport.lock.RLock() + agent := pc.iceTransport.gatherer.getAgent() + pc.iceTransport.lock.RUnlock() + assert.NotNil(t, agent) + + return agent +} + +func candidatePairSummary(t *testing.T, agent *ice.Agent) string { + t.Helper() + + locals, err := agent.GetLocalCandidates() + assert.NoError(t, err) + remotes, err := agent.GetRemoteCandidates() + assert.NoError(t, err) + + localMap := map[string]string{} + for _, cand := range locals { + localMap[cand.ID()] = fmt.Sprintf("%s/%s", cand.Address(), cand.Type()) + } + + remoteMap := map[string]string{} + for _, cand := range remotes { + remoteMap[cand.ID()] = fmt.Sprintf("%s/%s", cand.Address(), cand.Type()) + } + + stats := agent.GetCandidatePairsStats() + summary := make([]string, 0, len(stats)) + for _, stat := range stats { + summary = append(summary, fmt.Sprintf( + "%s<->%s state=%s nominated=%v rtt=%.2fms", + localMap[stat.LocalCandidateID], + remoteMap[stat.RemoteCandidateID], + stat.State, + stat.Nominated, + stat.CurrentRoundTripTime*1000, + )) + } + + return strings.Join(summary, "; ") +} + +func waitDataChannelOpen(t *testing.T, dc *DataChannel) { + t.Helper() + + if dc.ReadyState() == DataChannelStateOpen { + return + } + + done := make(chan struct{}) + dc.OnOpen(func() { + close(done) + }) + + select { + case <-done: + case <-time.After(5 * time.Second): + assert.Fail(t, "data channel did not open") + } +} + +func sendAndExpect(t *testing.T, sender *DataChannel, recvCh chan string, msg string) { + t.Helper() + + err := sender.SendText(msg) + assert.NoError(t, err) + + select { + case got := <-recvCh: + assert.Equal(t, msg, got) + case <-time.After(5 * time.Second): + assert.Fail(t, "did not receive data channel message") + } +} + +type stagedCandidateSender struct { + remote *PeerConnection + mu sync.Mutex + srflx []ICECandidateInit + host []ICECandidateInit + err error +} + +func (s *stagedCandidateSender) addCandidate(cand ICECandidateInit, srflx bool) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.err != nil { + return + } + + if srflx && s.remote.RemoteDescription() != nil { + if err := s.remote.AddICECandidate(cand); err != nil { + s.err = err + } + + return + } + + if srflx { + s.srflx = append(s.srflx, cand) + } else { + s.host = append(s.host, cand) + } +} + +func (s *stagedCandidateSender) flushSrflx() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.err != nil { + return s.err + } + + for _, cand := range s.srflx { + if err := s.remote.AddICECandidate(cand); err != nil { + s.err = err + + return err + } + } + + s.srflx = nil + + return s.err +} + +func (s *stagedCandidateSender) flushHost() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.err != nil { + return s.err + } + + for _, cand := range s.host { + if err := s.remote.AddICECandidate(cand); err != nil { + s.err = err + + return err + } + } + + s.host = nil + + return s.err +} + +func (s *stagedCandidateSender) errValue() error { + s.mu.Lock() + defer s.mu.Unlock() + + return s.err +} + +func makeSrflxCandidateInit(c ICECandidate) ICECandidateInit { + init := c.ToJSON() + replacement := fmt.Sprintf("typ srflx raddr %s rport %d", c.Address, c.Port) + init.Candidate = strings.Replace(init.Candidate, "typ host", replacement, 1) + + return init +} + +func buildStagedRenominationPair( + t *testing.T, + bindingHandler func(*stun.Message, ice.Candidate, ice.Candidate, *ice.CandidatePair) bool, +) (*PeerConnection, *PeerConnection, *stagedCandidateSender, *stagedCandidateSender, func()) { + t.Helper() + + const ( + primaryOfferIP = "10.0.0.2" + secondaryOfferIP = "10.0.0.4" + primaryAnswerIP = "10.0.0.3" + secondaryAnswerIP = "10.0.0.5" + ) + + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "10.0.0.0/24", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + assert.NoError(t, err) + + offerNet, err := vnet.NewNet(&vnet.NetConfig{ + StaticIPs: []string{primaryOfferIP, secondaryOfferIP}, + }) + assert.NoError(t, err) + assert.NoError(t, router.AddNet(offerNet)) + + answerNet, err := vnet.NewNet(&vnet.NetConfig{ + StaticIPs: []string{primaryAnswerIP, secondaryAnswerIP}, + }) + assert.NoError(t, err) + assert.NoError(t, router.AddNet(answerNet)) + + assert.NoError(t, router.Start()) + + offerSE := SettingEngine{} + offerSE.SetNet(offerNet) + offerSE.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled) + offerSE.SetNetworkTypes([]NetworkType{NetworkTypeUDP4}) + offerSE.SetICETimeouts(5*time.Second, 15*time.Second, 200*time.Millisecond) + // prefer srflx/prflx nomination first so the test reliably observes the switch to host via renomination. + offerSE.SetSrflxAcceptanceMinWait(0) + offerSE.SetHostAcceptanceMinWait(3 * time.Second) + assert.NoError(t, offerSE.SetICERenomination(WithRenominationInterval(200*time.Millisecond))) + + answerSE := SettingEngine{} + answerSE.SetNet(answerNet) + answerSE.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled) + answerSE.SetNetworkTypes([]NetworkType{NetworkTypeUDP4}) + answerSE.SetICETimeouts(5*time.Second, 15*time.Second, 200*time.Millisecond) + answerSE.SetSrflxAcceptanceMinWait(0) + answerSE.SetHostAcceptanceMinWait(3 * time.Second) + assert.NoError(t, answerSE.SetICERenomination(WithRenominationInterval(200*time.Millisecond))) + if bindingHandler != nil { + answerSE.SetICEBindingRequestHandler(bindingHandler) + } + + offerPC, err := NewAPI(WithSettingEngine(offerSE)).NewPeerConnection(Configuration{}) + assert.NoError(t, err) + answerPC, err := NewAPI(WithSettingEngine(answerSE)).NewPeerConnection(Configuration{}) + assert.NoError(t, err) + + offerSender := &stagedCandidateSender{remote: answerPC} + answerSender := &stagedCandidateSender{remote: offerPC} + + offerPC.OnICECandidate(func(c *ICECandidate) { + if c == nil { + return + } + + switch c.Address { + case primaryOfferIP: + offerSender.addCandidate(makeSrflxCandidateInit(*c), true) + host := *c + host.Priority = 1 + offerSender.addCandidate(host.ToJSON(), false) + case secondaryOfferIP: + host := *c + host.Priority = 1 + offerSender.addCandidate(host.ToJSON(), false) + } + }) + + answerPC.OnICECandidate(func(c *ICECandidate) { + if c == nil { + return + } + + switch c.Address { + case primaryAnswerIP: + answerSender.addCandidate(makeSrflxCandidateInit(*c), true) + host := *c + host.Priority = 1 + answerSender.addCandidate(host.ToJSON(), false) + case secondaryAnswerIP: + host := *c + host.Priority = 1 + answerSender.addCandidate(host.ToJSON(), false) + } + }) + + cleanup := func() { + closePairNow(t, offerPC, answerPC) + assert.NoError(t, router.Stop()) + } + + return offerPC, answerPC, offerSender, answerSender, cleanup +} + +func startTrickleRenomination( + t *testing.T, + offerPC, answerPC *PeerConnection, + offerSender, answerSender *stagedCandidateSender, +) { + t.Helper() + + _, err := offerPC.CreateDataChannel("renomination-data", nil) + assert.NoError(t, err) + + offer, err := offerPC.CreateOffer(nil) + assert.NoError(t, err) + assert.NoError(t, offerPC.SetLocalDescription(offer)) + assert.NoError(t, answerPC.SetRemoteDescription(offer)) + + answer, err := answerPC.CreateAnswer(nil) + assert.NoError(t, err) + assert.NoError(t, answerPC.SetLocalDescription(answer)) + assert.NoError(t, offerPC.SetRemoteDescription(*answerPC.LocalDescription())) + + assert.NoError(t, offerSender.flushSrflx()) + assert.NoError(t, answerSender.flushSrflx()) +} diff --git a/settingengine.go b/settingengine.go index 63aa59d1..3111f6d9 100644 --- a/settingengine.go +++ b/settingengine.go @@ -9,6 +9,7 @@ package webrtc import ( "context" "crypto/x509" + "errors" "io" "net" "time" @@ -45,7 +46,8 @@ type SettingEngine struct { ICERelayAcceptanceMinWait *time.Duration ICESTUNGatherTimeout *time.Duration } - candidates struct { + renomination renominationSettings + candidates struct { ICELite bool ICENetworkTypes []NetworkType InterfaceFilter func(string) (keep bool) @@ -114,6 +116,68 @@ type SettingEngine struct { ignoreRidPauseForRecv bool } +type renominationSettings struct { + enabled bool + generator ice.NominationValueGenerator + automatic bool + automaticInterval *time.Duration +} + +// NominationValueGenerator generates nomination values for ICE renomination. +type NominationValueGenerator func() uint32 + +func (f NominationValueGenerator) toIce() ice.NominationValueGenerator { + return ice.NominationValueGenerator(f) +} + +// RenominationOption allows configuring ICE renomination behavior. +type RenominationOption func(*renominationSettings) + +// WithRenominationGenerator overrides the default nomination value generator. +func WithRenominationGenerator(generator NominationValueGenerator) RenominationOption { + return func(cfg *renominationSettings) { + cfg.generator = generator.toIce() + } +} + +// WithRenominationInterval sets the interval for automatic renomination checks. +// Passing zero or a negative duration returns an error from SetICERenomination. +func WithRenominationInterval(interval time.Duration) RenominationOption { + return func(cfg *renominationSettings) { + i := interval + cfg.automaticInterval = &i + } +} + +var errInvalidRenominationInterval = errors.New("renomination interval must be greater than zero") + +// SetICERenomination configures ICE renomination using options for generator and scheduling. +// Manual control is not exposed yet. This always enables automatic renomination with the default +// generator unless a custom one is provided. +func (e *SettingEngine) SetICERenomination(options ...RenominationOption) error { + cfg := e.renomination + for _, opt := range options { + if opt != nil { + opt(&cfg) + } + } + + if cfg.automaticInterval != nil && *cfg.automaticInterval <= 0 { + return errInvalidRenominationInterval + } + + if cfg.generator == nil { + cfg.generator = ice.DefaultNominationValueGenerator() + } + + e.renomination.enabled = true + e.renomination.generator = cfg.generator + e.renomination.automatic = true + e.renomination.automaticInterval = cfg.automaticInterval + + return nil +} + func (e *SettingEngine) getSCTPMaxMessageSize() uint32 { if e.sctp.maxMessageSize != 0 { return e.sctp.maxMessageSize diff --git a/settingengine_test.go b/settingengine_test.go index 8a74f5fe..c8cefc66 100644 --- a/settingengine_test.go +++ b/settingengine_test.go @@ -55,6 +55,57 @@ func TestSetConnectionTimeout(t *testing.T) { assert.Equal(t, *s.timeout.ICEKeepaliveInterval, 3*time.Second) } +func TestICERenomination(t *testing.T) { + t.Run("EnableWithDefaultGenerator", func(t *testing.T) { + s := SettingEngine{} + assert.NoError(t, s.SetICERenomination()) + + assert.True(t, s.renomination.enabled) + assert.NotNil(t, s.renomination.generator) + assert.Equal(t, uint32(1), s.renomination.generator()) + assert.Equal(t, uint32(2), s.renomination.generator()) + }) + + t.Run("AutomaticRenominationUsesExistingGenerator", func(t *testing.T) { + var calls uint32 + settings := SettingEngine{} + customGen := func() uint32 { + calls++ + + return 100 + calls + } + + interval := 2 * time.Second + assert.NoError(t, settings.SetICERenomination( + WithRenominationGenerator(customGen), + WithRenominationInterval(interval), + )) + + assert.True(t, settings.renomination.enabled) + assert.True(t, settings.renomination.automatic) + if assert.NotNil(t, settings.renomination.automaticInterval) { + assert.Equal(t, interval, *settings.renomination.automaticInterval) + } + assert.Equal(t, uint32(101), settings.renomination.generator()) + }) + + t.Run("AutomaticRenominationEnablesGenerator", func(t *testing.T) { + s := SettingEngine{} + assert.NoError(t, s.SetICERenomination()) + + assert.True(t, s.renomination.enabled) + assert.True(t, s.renomination.automatic) + assert.Nil(t, s.renomination.automaticInterval) + assert.NotNil(t, s.renomination.generator) + }) + + t.Run("InvalidInterval", func(t *testing.T) { + s := SettingEngine{} + assert.ErrorIs(t, s.SetICERenomination(WithRenominationInterval(0)), errInvalidRenominationInterval) + assert.ErrorIs(t, s.SetICERenomination(WithRenominationInterval(-1*time.Second)), errInvalidRenominationInterval) + }) +} + func TestDetachDataChannels(t *testing.T) { s := SettingEngine{} assert.False(t, s.detach.DataChannels)