mirror of
https://github.com/pion/webrtc.git
synced 2025-09-26 19:21:12 +08:00
Drop reference to detached datachannels
This allows users of detached datachannels to garbage collect resources associated with the datachannel and the sctp stream. There is no functional change here.
This commit is contained in:
@@ -420,7 +420,6 @@ func (d *DataChannel) ensureOpen() error {
|
||||
// resulting DataChannel object.
|
||||
func (d *DataChannel) Detach() (datachannel.ReadWriteCloser, error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
if !d.api.settingEngine.detach.DataChannels {
|
||||
return nil, errDetachNotEnabled
|
||||
@@ -432,7 +431,28 @@ func (d *DataChannel) Detach() (datachannel.ReadWriteCloser, error) {
|
||||
|
||||
d.detachCalled = true
|
||||
|
||||
return d.dataChannel, nil
|
||||
dataChannel := d.dataChannel
|
||||
d.mu.Unlock()
|
||||
|
||||
// Remove the reference from SCTPTransport so that the datachannel
|
||||
// can be garbage collected on close
|
||||
d.sctpTransport.lock.Lock()
|
||||
n := len(d.sctpTransport.dataChannels)
|
||||
j := 0
|
||||
for i := 0; i < n; i++ {
|
||||
if d == d.sctpTransport.dataChannels[i] {
|
||||
continue
|
||||
}
|
||||
d.sctpTransport.dataChannels[j] = d.sctpTransport.dataChannels[i]
|
||||
j++
|
||||
}
|
||||
for i := j; i < n; i++ {
|
||||
d.sctpTransport.dataChannels[i] = nil
|
||||
}
|
||||
d.sctpTransport.dataChannels = d.sctpTransport.dataChannels[:j]
|
||||
d.sctpTransport.lock.Unlock()
|
||||
|
||||
return dataChannel, nil
|
||||
}
|
||||
|
||||
// Close Closes the DataChannel. It may be called regardless of whether
|
||||
|
@@ -692,3 +692,57 @@ func TestDataChannel_Dial(t *testing.T) {
|
||||
closePair(t, offerPC, answerPC, done)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDetachRemovesDatachannelReference(t *testing.T) {
|
||||
// Use Detach data channels mode
|
||||
s := SettingEngine{}
|
||||
s.DetachDataChannels()
|
||||
api := NewAPI(WithSettingEngine(s))
|
||||
|
||||
// Set up two peer connections.
|
||||
config := Configuration{}
|
||||
pca, err := api.NewPeerConnection(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pcb, err := api.NewPeerConnection(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer closePairNow(t, pca, pcb)
|
||||
|
||||
dcChan := make(chan *DataChannel, 1)
|
||||
pcb.OnDataChannel(func(d *DataChannel) {
|
||||
d.OnOpen(func() {
|
||||
if _, detachErr := d.Detach(); detachErr != nil {
|
||||
t.Error(detachErr)
|
||||
}
|
||||
|
||||
dcChan <- d
|
||||
})
|
||||
})
|
||||
|
||||
if err = signalPair(pca, pcb); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
attached, err := pca.CreateDataChannel("", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
open := make(chan struct{}, 1)
|
||||
attached.OnOpen(func() {
|
||||
open <- struct{}{}
|
||||
})
|
||||
<-open
|
||||
|
||||
d := <-dcChan
|
||||
d.sctpTransport.lock.RLock()
|
||||
defer d.sctpTransport.lock.RUnlock()
|
||||
for _, dc := range d.sctpTransport.dataChannels[:cap(d.sctpTransport.dataChannels)] {
|
||||
if dc == d {
|
||||
t.Errorf("expected sctpTransport to drop reference to datachannel")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -2018,6 +2018,9 @@ func (pc *PeerConnection) CreateDataChannel(label string, options *DataChannelIn
|
||||
|
||||
pc.sctpTransport.lock.Lock()
|
||||
pc.sctpTransport.dataChannels = append(pc.sctpTransport.dataChannels, d)
|
||||
if d.ID() != nil {
|
||||
pc.sctpTransport.dataChannelIDsUsed[*d.ID()] = struct{}{}
|
||||
}
|
||||
pc.sctpTransport.dataChannelsRequested++
|
||||
pc.sctpTransport.lock.Unlock()
|
||||
|
||||
|
@@ -52,6 +52,7 @@ type SCTPTransport struct {
|
||||
|
||||
// DataChannels
|
||||
dataChannels []*DataChannel
|
||||
dataChannelIDsUsed map[uint16]struct{}
|
||||
dataChannelsOpened uint32
|
||||
dataChannelsRequested uint32
|
||||
dataChannelsAccepted uint32
|
||||
@@ -65,10 +66,11 @@ type SCTPTransport struct {
|
||||
// meant to be used together with the basic WebRTC API.
|
||||
func (api *API) NewSCTPTransport(dtls *DTLSTransport) *SCTPTransport {
|
||||
res := &SCTPTransport{
|
||||
dtlsTransport: dtls,
|
||||
state: SCTPTransportStateConnecting,
|
||||
api: api,
|
||||
log: api.settingEngine.LoggerFactory.NewLogger("ortc"),
|
||||
dtlsTransport: dtls,
|
||||
state: SCTPTransportStateConnecting,
|
||||
api: api,
|
||||
log: api.settingEngine.LoggerFactory.NewLogger("ortc"),
|
||||
dataChannelIDsUsed: make(map[uint16]struct{}),
|
||||
}
|
||||
|
||||
res.updateMessageSize()
|
||||
@@ -287,6 +289,13 @@ func (r *SCTPTransport) onDataChannel(dc *DataChannel) (done chan struct{}) {
|
||||
r.lock.Lock()
|
||||
r.dataChannels = append(r.dataChannels, dc)
|
||||
r.dataChannelsAccepted++
|
||||
if dc.ID() != nil {
|
||||
r.dataChannelIDsUsed[*dc.ID()] = struct{}{}
|
||||
} else {
|
||||
// This cannot happen, the constructor for this datachannel in the caller
|
||||
// takes a pointer to the id.
|
||||
r.log.Errorf("accepted data channel with no ID")
|
||||
}
|
||||
handler := r.onDataChannelHandler
|
||||
r.lock.Unlock()
|
||||
|
||||
@@ -393,21 +402,12 @@ func (r *SCTPTransport) generateAndSetDataChannelID(dtlsRole DTLSRole, idOut **u
|
||||
r.lock.Lock()
|
||||
defer r.lock.Unlock()
|
||||
|
||||
// Create map of ids so we can compare without double-looping each time.
|
||||
idsMap := make(map[uint16]struct{}, len(r.dataChannels))
|
||||
for _, dc := range r.dataChannels {
|
||||
if dc.ID() == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
idsMap[*dc.ID()] = struct{}{}
|
||||
}
|
||||
|
||||
for ; id < max-1; id += 2 {
|
||||
if _, ok := idsMap[id]; ok {
|
||||
if _, ok := r.dataChannelIDsUsed[id]; ok {
|
||||
continue
|
||||
}
|
||||
*idOut = &id
|
||||
r.dataChannelIDsUsed[id] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@@ -10,11 +10,15 @@ import "testing"
|
||||
|
||||
func TestGenerateDataChannelID(t *testing.T) {
|
||||
sctpTransportWithChannels := func(ids []uint16) *SCTPTransport {
|
||||
ret := &SCTPTransport{dataChannels: []*DataChannel{}}
|
||||
ret := &SCTPTransport{
|
||||
dataChannels: []*DataChannel{},
|
||||
dataChannelIDsUsed: make(map[uint16]struct{}),
|
||||
}
|
||||
|
||||
for i := range ids {
|
||||
id := ids[i]
|
||||
ret.dataChannels = append(ret.dataChannels, &DataChannel{id: &id})
|
||||
ret.dataChannelIDsUsed[id] = struct{}{}
|
||||
}
|
||||
|
||||
return ret
|
||||
@@ -46,5 +50,8 @@ func TestGenerateDataChannelID(t *testing.T) {
|
||||
if *idPtr != testCase.result {
|
||||
t.Errorf("Wrong id: %d expected %d", *idPtr, testCase.result)
|
||||
}
|
||||
if _, ok := testCase.s.dataChannelIDsUsed[*idPtr]; !ok {
|
||||
t.Errorf("expected new id to be added to the map: %d", *idPtr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user