Fixes DataChannel panic on sctptransport race

DataChannel.open would panic during which if PeerConnection is closed,
stopping underlying sctpTransport which sets association to nil;

And ensureSCTP() method doesn't guarantee sctpTransport's availability
out of it's own scope.
This commit is contained in:
Markus Tzoe
2021-04-06 11:29:16 +08:00
committed by Sean DuBois
parent 38d35c6ff0
commit 2f77a28dca
2 changed files with 16 additions and 35 deletions

View File

@@ -104,19 +104,22 @@ func (api *API) newDataChannel(params *DataChannelParameters, log logging.Levele
// open opens the datachannel over the sctp transport
func (d *DataChannel) open(sctpTransport *SCTPTransport) error {
if sctpTransport == nil {
return errSCTPNotEstablished
}
sctpTransport.lock.RLock()
association := sctpTransport.association
sctpTransport.lock.RUnlock()
if association == nil {
return errSCTPNotEstablished
}
d.mu.Lock()
if d.sctpTransport != nil {
// already open
if d.sctpTransport != nil { // already open
d.mu.Unlock()
return nil
}
d.sctpTransport = sctpTransport
if err := d.ensureSCTP(); err != nil {
d.mu.Unlock()
return err
}
var channelType datachannel.ChannelType
var reliabilityParameter uint32
@@ -160,8 +163,7 @@ func (d *DataChannel) open(sctpTransport *SCTPTransport) error {
return err
}
}
dc, err := datachannel.Dial(d.sctpTransport.association, *d.id, cfg)
dc, err := datachannel.Dial(association, *d.id, cfg)
if err != nil {
d.mu.Unlock()
return err
@@ -176,19 +178,6 @@ func (d *DataChannel) open(sctpTransport *SCTPTransport) error {
return nil
}
func (d *DataChannel) ensureSCTP() error {
if d.sctpTransport == nil {
return errSCTPNotEstablished
}
d.sctpTransport.lock.RLock()
defer d.sctpTransport.lock.RUnlock()
if d.sctpTransport.association == nil {
return errSCTPNotEstablished
}
return nil
}
// Transport returns the SCTPTransport instance the DataChannel is sending over.
func (d *DataChannel) Transport() *SCTPTransport {
d.mu.RLock()

View File

@@ -96,12 +96,13 @@ func (r *SCTPTransport) Start(remoteCaps SCTPCapabilities) error {
}
r.isStarted = true
if err := r.ensureDTLS(); err != nil {
return err
dtlsTransport := r.Transport()
if dtlsTransport == nil || dtlsTransport.conn == nil {
return errSCTPTransportDTLS
}
sctpAssociation, err := sctp.Client(sctp.Config{
NetConn: r.Transport().conn,
NetConn: dtlsTransport.conn,
LoggerFactory: r.api.settingEngine.LoggerFactory,
})
if err != nil {
@@ -137,15 +138,6 @@ func (r *SCTPTransport) Stop() error {
return nil
}
func (r *SCTPTransport) ensureDTLS() error {
dtlsTransport := r.Transport()
if dtlsTransport == nil || dtlsTransport.conn == nil {
return errSCTPTransportDTLS
}
return nil
}
func (r *SCTPTransport) acceptDataChannels(a *sctp.Association) {
for {
dc, err := datachannel.Accept(a, &datachannel.Config{