Drop reference to detached datachannels

This allows users of detached datachannels to garbage collect
resources associated with the datachannel and the sctp stream.
There is no functional change here.
This commit is contained in:
sukun
2024-03-04 14:13:03 +05:30
committed by Sean DuBois
parent a8c02b0879
commit 835ac3b08e
5 changed files with 102 additions and 18 deletions

View File

@@ -420,7 +420,6 @@ func (d *DataChannel) ensureOpen() error {
// resulting DataChannel object.
func (d *DataChannel) Detach() (datachannel.ReadWriteCloser, error) {
d.mu.Lock()
defer d.mu.Unlock()
if !d.api.settingEngine.detach.DataChannels {
return nil, errDetachNotEnabled
@@ -432,7 +431,28 @@ func (d *DataChannel) Detach() (datachannel.ReadWriteCloser, error) {
d.detachCalled = true
return d.dataChannel, nil
dataChannel := d.dataChannel
d.mu.Unlock()
// Remove the reference from SCTPTransport so that the datachannel
// can be garbage collected on close
d.sctpTransport.lock.Lock()
n := len(d.sctpTransport.dataChannels)
j := 0
for i := 0; i < n; i++ {
if d == d.sctpTransport.dataChannels[i] {
continue
}
d.sctpTransport.dataChannels[j] = d.sctpTransport.dataChannels[i]
j++
}
for i := j; i < n; i++ {
d.sctpTransport.dataChannels[i] = nil
}
d.sctpTransport.dataChannels = d.sctpTransport.dataChannels[:j]
d.sctpTransport.lock.Unlock()
return dataChannel, nil
}
// Close Closes the DataChannel. It may be called regardless of whether

View File

@@ -692,3 +692,57 @@ func TestDataChannel_Dial(t *testing.T) {
closePair(t, offerPC, answerPC, done)
})
}
func TestDetachRemovesDatachannelReference(t *testing.T) {
// Use Detach data channels mode
s := SettingEngine{}
s.DetachDataChannels()
api := NewAPI(WithSettingEngine(s))
// Set up two peer connections.
config := Configuration{}
pca, err := api.NewPeerConnection(config)
if err != nil {
t.Fatal(err)
}
pcb, err := api.NewPeerConnection(config)
if err != nil {
t.Fatal(err)
}
defer closePairNow(t, pca, pcb)
dcChan := make(chan *DataChannel, 1)
pcb.OnDataChannel(func(d *DataChannel) {
d.OnOpen(func() {
if _, detachErr := d.Detach(); detachErr != nil {
t.Error(detachErr)
}
dcChan <- d
})
})
if err = signalPair(pca, pcb); err != nil {
t.Fatal(err)
}
attached, err := pca.CreateDataChannel("", nil)
if err != nil {
t.Fatal(err)
}
open := make(chan struct{}, 1)
attached.OnOpen(func() {
open <- struct{}{}
})
<-open
d := <-dcChan
d.sctpTransport.lock.RLock()
defer d.sctpTransport.lock.RUnlock()
for _, dc := range d.sctpTransport.dataChannels[:cap(d.sctpTransport.dataChannels)] {
if dc == d {
t.Errorf("expected sctpTransport to drop reference to datachannel")
}
}
}

View File

@@ -2018,6 +2018,9 @@ func (pc *PeerConnection) CreateDataChannel(label string, options *DataChannelIn
pc.sctpTransport.lock.Lock()
pc.sctpTransport.dataChannels = append(pc.sctpTransport.dataChannels, d)
if d.ID() != nil {
pc.sctpTransport.dataChannelIDsUsed[*d.ID()] = struct{}{}
}
pc.sctpTransport.dataChannelsRequested++
pc.sctpTransport.lock.Unlock()

View File

@@ -52,6 +52,7 @@ type SCTPTransport struct {
// DataChannels
dataChannels []*DataChannel
dataChannelIDsUsed map[uint16]struct{}
dataChannelsOpened uint32
dataChannelsRequested uint32
dataChannelsAccepted uint32
@@ -69,6 +70,7 @@ func (api *API) NewSCTPTransport(dtls *DTLSTransport) *SCTPTransport {
state: SCTPTransportStateConnecting,
api: api,
log: api.settingEngine.LoggerFactory.NewLogger("ortc"),
dataChannelIDsUsed: make(map[uint16]struct{}),
}
res.updateMessageSize()
@@ -287,6 +289,13 @@ func (r *SCTPTransport) onDataChannel(dc *DataChannel) (done chan struct{}) {
r.lock.Lock()
r.dataChannels = append(r.dataChannels, dc)
r.dataChannelsAccepted++
if dc.ID() != nil {
r.dataChannelIDsUsed[*dc.ID()] = struct{}{}
} else {
// This cannot happen, the constructor for this datachannel in the caller
// takes a pointer to the id.
r.log.Errorf("accepted data channel with no ID")
}
handler := r.onDataChannelHandler
r.lock.Unlock()
@@ -393,21 +402,12 @@ func (r *SCTPTransport) generateAndSetDataChannelID(dtlsRole DTLSRole, idOut **u
r.lock.Lock()
defer r.lock.Unlock()
// Create map of ids so we can compare without double-looping each time.
idsMap := make(map[uint16]struct{}, len(r.dataChannels))
for _, dc := range r.dataChannels {
if dc.ID() == nil {
continue
}
idsMap[*dc.ID()] = struct{}{}
}
for ; id < max-1; id += 2 {
if _, ok := idsMap[id]; ok {
if _, ok := r.dataChannelIDsUsed[id]; ok {
continue
}
*idOut = &id
r.dataChannelIDsUsed[id] = struct{}{}
return nil
}

View File

@@ -10,11 +10,15 @@ import "testing"
func TestGenerateDataChannelID(t *testing.T) {
sctpTransportWithChannels := func(ids []uint16) *SCTPTransport {
ret := &SCTPTransport{dataChannels: []*DataChannel{}}
ret := &SCTPTransport{
dataChannels: []*DataChannel{},
dataChannelIDsUsed: make(map[uint16]struct{}),
}
for i := range ids {
id := ids[i]
ret.dataChannels = append(ret.dataChannels, &DataChannel{id: &id})
ret.dataChannelIDsUsed[id] = struct{}{}
}
return ret
@@ -46,5 +50,8 @@ func TestGenerateDataChannelID(t *testing.T) {
if *idPtr != testCase.result {
t.Errorf("Wrong id: %d expected %d", *idPtr, testCase.result)
}
if _, ok := testCase.s.dataChannelIDsUsed[*idPtr]; !ok {
t.Errorf("expected new id to be added to the map: %d", *idPtr)
}
}
}