Close PeerConnection on DTLS CloseNotify

Resolves #1767
Resolves pion/dtls#151
This commit is contained in:
Sean DuBois
2023-09-04 14:37:19 -04:00
committed by Sean DuBois
parent ea23dec2b9
commit 60eea430ac
6 changed files with 90 additions and 22 deletions

View File

@@ -44,7 +44,8 @@ type DTLSTransport struct {
state DTLSTransportState state DTLSTransportState
srtpProtectionProfile srtp.ProtectionProfile srtpProtectionProfile srtp.ProtectionProfile
onStateChangeHandler func(DTLSTransportState) onStateChangeHandler func(DTLSTransportState)
internalOnCloseHandler func()
conn *dtls.Conn conn *dtls.Conn
@@ -322,6 +323,7 @@ func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error {
var dtlsConn *dtls.Conn var dtlsConn *dtls.Conn
dtlsEndpoint := t.iceTransport.newEndpoint(mux.MatchDTLS) dtlsEndpoint := t.iceTransport.newEndpoint(mux.MatchDTLS)
dtlsEndpoint.SetOnClose(t.internalOnCloseHandler)
role, dtlsConfig, err := prepareTransport() role, dtlsConfig, err := prepareTransport()
if err != nil { if err != nil {
return err return err

View File

@@ -17,7 +17,7 @@ import (
// An invalid fingerprint MUST cause PeerConnectionState to go to PeerConnectionStateFailed // An invalid fingerprint MUST cause PeerConnectionState to go to PeerConnectionStateFailed
func TestInvalidFingerprintCausesFailed(t *testing.T) { func TestInvalidFingerprintCausesFailed(t *testing.T) {
lim := test.TimeOut(time.Second * 40) lim := test.TimeOut(time.Second * 5)
defer lim.Stop() defer lim.Stop()
report := test.CheckRoutines(t) report := test.CheckRoutines(t)
@@ -46,8 +46,8 @@ func TestInvalidFingerprintCausesFailed(t *testing.T) {
} }
}) })
offerConnectionHasFailed := untilConnectionState(PeerConnectionStateFailed, pcOffer) offerConnectionHasClosed := untilConnectionState(PeerConnectionStateClosed, pcOffer)
answerConnectionHasFailed := untilConnectionState(PeerConnectionStateFailed, pcAnswer) answerConnectionHasClosed := untilConnectionState(PeerConnectionStateClosed, pcAnswer)
if _, err = pcOffer.CreateDataChannel("unusedDataChannel", nil); err != nil { if _, err = pcOffer.CreateDataChannel("unusedDataChannel", nil); err != nil {
t.Fatal(err) t.Fatal(err)
@@ -89,13 +89,17 @@ func TestInvalidFingerprintCausesFailed(t *testing.T) {
t.Fatal("timed out waiting to receive offer") t.Fatal("timed out waiting to receive offer")
} }
offerConnectionHasFailed.Wait() offerConnectionHasClosed.Wait()
answerConnectionHasFailed.Wait() answerConnectionHasClosed.Wait()
assert.Equal(t, pcOffer.SCTP().Transport().State(), DTLSTransportStateFailed) if pcOffer.SCTP().Transport().State() != DTLSTransportStateClosed && pcOffer.SCTP().Transport().State() != DTLSTransportStateFailed {
t.Fail()
}
assert.Nil(t, pcOffer.SCTP().Transport().conn) assert.Nil(t, pcOffer.SCTP().Transport().conn)
assert.Equal(t, pcAnswer.SCTP().Transport().State(), DTLSTransportStateFailed) if pcAnswer.SCTP().Transport().State() != DTLSTransportStateClosed && pcAnswer.SCTP().Transport().State() != DTLSTransportStateFailed {
t.Fail()
}
assert.Nil(t, pcAnswer.SCTP().Transport().conn) assert.Nil(t, pcAnswer.SCTP().Transport().conn)
} }

View File

@@ -15,14 +15,18 @@ import (
// Endpoint implements net.Conn. It is used to read muxed packets. // Endpoint implements net.Conn. It is used to read muxed packets.
type Endpoint struct { type Endpoint struct {
mux *Mux mux *Mux
buffer *packetio.Buffer buffer *packetio.Buffer
onClose func()
} }
// Close unregisters the endpoint from the Mux // Close unregisters the endpoint from the Mux
func (e *Endpoint) Close() (err error) { func (e *Endpoint) Close() (err error) {
err = e.close() if e.onClose != nil {
if err != nil { e.onClose()
}
if err = e.close(); err != nil {
return err return err
} }
@@ -76,3 +80,9 @@ func (e *Endpoint) SetReadDeadline(time.Time) error {
func (e *Endpoint) SetWriteDeadline(time.Time) error { func (e *Endpoint) SetWriteDeadline(time.Time) error {
return nil return nil
} }
// SetOnClose is a user set callback that
// will be executed when `Close` is called
func (e *Endpoint) SetOnClose(onClose func()) {
e.onClose = onClose
}

View File

@@ -2261,6 +2261,16 @@ func (pc *PeerConnection) startTransports(iceRole ICERole, dtlsRole DTLSRole, re
return return
} }
pc.dtlsTransport.internalOnCloseHandler = func() {
pc.log.Info("Closing PeerConnection from DTLS CloseNotify")
go func() {
if pcClosErr := pc.Close(); pcClosErr != nil {
pc.log.Warnf("Failed to close PeerConnection from DTLS CloseNotify: %s", pcClosErr)
}
}()
}
// Start the dtls transport // Start the dtls transport
err = pc.dtlsTransport.Start(DTLSParameters{ err = pc.dtlsTransport.Start(DTLSParameters{
Role: dtlsRole, Role: dtlsRole,

View File

@@ -24,6 +24,7 @@ import (
"github.com/pion/rtp" "github.com/pion/rtp"
"github.com/pion/sdp/v3" "github.com/pion/sdp/v3"
"github.com/pion/transport/v3/test" "github.com/pion/transport/v3/test"
"github.com/pion/transport/v3/vnet"
"github.com/pion/webrtc/v3/pkg/media" "github.com/pion/webrtc/v3/pkg/media"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -329,10 +330,15 @@ func TestPeerConnection_Media_Disconnected(t *testing.T) {
m := &MediaEngine{} m := &MediaEngine{}
assert.NoError(t, m.RegisterDefaultCodecs()) assert.NoError(t, m.RegisterDefaultCodecs())
pcOffer, pcAnswer, err := NewAPI(WithSettingEngine(s), WithMediaEngine(m)).newPair(Configuration{}) pcOffer, pcAnswer, wan := createVNetPair(t)
if err != nil {
t.Fatal(err) keepPackets := &atomicBool{}
} keepPackets.set(true)
// Add a filter that monitors the traffic on the router
wan.AddChunkFilter(func(c vnet.Chunk) bool {
return keepPackets.get()
})
vp8Track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2") vp8Track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2")
if err != nil { if err != nil {
@@ -360,14 +366,11 @@ func TestPeerConnection_Media_Disconnected(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
} }
if pcCloseErr := pcAnswer.Close(); pcCloseErr != nil { keepPackets.set(false)
haveDisconnected <- pcCloseErr
}
} }
}) })
err = signalPair(pcOffer, pcAnswer) if err = signalPair(pcOffer, pcAnswer); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -383,7 +386,8 @@ func TestPeerConnection_Media_Disconnected(t *testing.T) {
} }
} }
assert.NoError(t, pcOffer.Close()) assert.NoError(t, wan.Stop())
closePairNow(t, pcOffer, pcAnswer)
} }
type undeclaredSsrcLogger struct{ unhandledSimulcastError chan struct{} } type undeclaredSsrcLogger struct{ unhandledSimulcastError chan struct{} }

View File

@@ -754,3 +754,41 @@ func TestTransportChain(t *testing.T) {
closePairNow(t, offer, answer) closePairNow(t, offer, answer)
} }
// Assert that the PeerConnection closes via DTLS (and not ICE)
func TestDTLSClose(t *testing.T) {
lim := test.TimeOut(time.Second * 10)
defer lim.Stop()
report := test.CheckRoutines(t)
defer report()
pcOffer, pcAnswer, err := newPair()
assert.NoError(t, err)
_, err = pcOffer.AddTransceiverFromKind(RTPCodecTypeVideo)
assert.NoError(t, err)
peerConnectionsConnected := untilConnectionState(PeerConnectionStateConnected, pcOffer, pcAnswer)
offer, err := pcOffer.CreateOffer(nil)
assert.NoError(t, err)
offerGatheringComplete := GatheringCompletePromise(pcOffer)
assert.NoError(t, pcOffer.SetLocalDescription(offer))
<-offerGatheringComplete
assert.NoError(t, pcAnswer.SetRemoteDescription(*pcOffer.LocalDescription()))
answer, err := pcAnswer.CreateAnswer(nil)
assert.NoError(t, err)
answerGatheringComplete := GatheringCompletePromise(pcAnswer)
assert.NoError(t, pcAnswer.SetLocalDescription(answer))
<-answerGatheringComplete
assert.NoError(t, pcOffer.SetRemoteDescription(*pcAnswer.LocalDescription()))
peerConnectionsConnected.Wait()
assert.NoError(t, pcOffer.Close())
}