Fix detached datachannels handling

https://github.com/pion/webrtc/pull/2696 introduced removing
datachannels from the sctptransport for better garbage collection.

That PR introduced a race condition for data channels created before
connection establishment. When an out of band negotiated data channel,
created before peerconnection establishment is detached, there's a race
between the data channel being removed from `r.dataChannels` and it
being copied in to the existing data channel slice in the
acceptDataChannels goroutine.

This PR fixes this race by copying the slice before any datachannels
could be detached.
This commit is contained in:
sukun
2024-12-23 22:37:51 +00:00
parent fbf79c12f0
commit b82306ab62
2 changed files with 128 additions and 7 deletions

View File

@@ -142,7 +142,7 @@ func (r *SCTPTransport) Start(_ SCTPCapabilities) error {
r.dataChannelsOpened += openedDCCount
r.lock.Unlock()
go r.acceptDataChannels(sctpAssociation)
go r.acceptDataChannels(sctpAssociation, dataChannels)
return nil
}
@@ -163,10 +163,9 @@ func (r *SCTPTransport) Stop() error {
return nil
}
func (r *SCTPTransport) acceptDataChannels(a *sctp.Association) {
r.lock.RLock()
dataChannels := make([]*datachannel.DataChannel, 0, len(r.dataChannels))
for _, dc := range r.dataChannels {
func (r *SCTPTransport) acceptDataChannels(a *sctp.Association, existingDataChannels []*DataChannel) {
dataChannels := make([]*datachannel.DataChannel, 0, len(existingDataChannels))
for _, dc := range existingDataChannels {
dc.mu.Lock()
isNil := dc.dataChannel == nil
dc.mu.Unlock()
@@ -175,8 +174,6 @@ func (r *SCTPTransport) acceptDataChannels(a *sctp.Association) {
}
dataChannels = append(dataChannels, dc.dataChannel)
}
r.lock.RUnlock()
ACCEPT:
for {
dc, err := datachannel.Accept(a, &datachannel.Config{

View File

@@ -8,6 +8,7 @@ package webrtc
import (
"bytes"
"sync"
"testing"
"time"
@@ -126,3 +127,126 @@ func TestSCTPTransportOnClose(t *testing.T) {
t.Fatal("timed out")
}
}
func TestSCTPTransportOutOfBandNegotiatedDataChannelDetach(t *testing.T) {
const N = 10
done := make(chan struct{}, N)
for i := 0; i < N; i++ {
go func() {
// Use Detach data channels mode
s := SettingEngine{}
s.DetachDataChannels()
api := NewAPI(WithSettingEngine(s))
// Set up two peer connections.
config := Configuration{}
offerPC, err := api.NewPeerConnection(config)
if err != nil {
t.Error(err)
return
}
answerPC, err := api.NewPeerConnection(config)
if err != nil {
t.Error(err)
return
}
defer closePairNow(t, offerPC, answerPC)
defer func() { done <- struct{}{} }()
negotiated := true
id := uint16(0)
readDetach := make(chan struct{})
dc1, err := offerPC.CreateDataChannel("", &DataChannelInit{
Negotiated: &negotiated,
ID: &id,
})
if err != nil {
t.Error(err)
return
}
dc1.OnOpen(func() {
_, _ = dc1.Detach()
close(readDetach)
})
writeDetach := make(chan struct{})
dc2, err := answerPC.CreateDataChannel("", &DataChannelInit{
Negotiated: &negotiated,
ID: &id,
})
if err != nil {
t.Error(err)
return
}
dc2.OnOpen(func() {
_, _ = dc2.Detach()
close(writeDetach)
})
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
connestd := make(chan struct{}, 1)
offerPC.OnConnectionStateChange(func(state PeerConnectionState) {
if state == PeerConnectionStateConnected {
connestd <- struct{}{}
}
})
select {
case <-connestd:
case <-time.After(10 * time.Second):
t.Error("conn establishment timed out")
return
}
<-readDetach
err1 := dc1.dataChannel.SetReadDeadline(time.Now().Add(10 * time.Second))
if err1 != nil {
t.Error(err)
return
}
buf := make([]byte, 10)
n, err1 := dc1.dataChannel.Read(buf)
if err1 != nil {
t.Error(err)
return
}
if string(buf[:n]) != "hello" {
t.Error("invalid read")
}
}()
go func() {
defer wg.Done()
connestd := make(chan struct{}, 1)
answerPC.OnConnectionStateChange(func(state PeerConnectionState) {
if state == PeerConnectionStateConnected {
connestd <- struct{}{}
}
})
select {
case <-connestd:
case <-time.After(10 * time.Second):
t.Error("connection establishment timed out")
return
}
<-writeDetach
n, err1 := dc2.dataChannel.Write([]byte("hello"))
if err1 != nil || n != len("hello") {
t.Error(err)
}
}()
err = signalPair(offerPC, answerPC)
require.NoError(t, err)
wg.Wait()
}()
}
for i := 0; i < N; i++ {
select {
case <-done:
case <-time.After(20 * time.Second):
t.Fatal("timed out")
}
}
}