diff --git a/dtlstransport.go b/dtlstransport.go index ddaac7d9..bc0c7345 100644 --- a/dtlstransport.go +++ b/dtlstransport.go @@ -44,7 +44,8 @@ type DTLSTransport struct { state DTLSTransportState srtpProtectionProfile srtp.ProtectionProfile - onStateChangeHandler func(DTLSTransportState) + onStateChangeHandler func(DTLSTransportState) + internalOnCloseHandler func() conn *dtls.Conn @@ -322,6 +323,7 @@ func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error { var dtlsConn *dtls.Conn dtlsEndpoint := t.iceTransport.newEndpoint(mux.MatchDTLS) + dtlsEndpoint.SetOnClose(t.internalOnCloseHandler) role, dtlsConfig, err := prepareTransport() if err != nil { return err diff --git a/dtlstransport_test.go b/dtlstransport_test.go index 93d21095..485d1534 100644 --- a/dtlstransport_test.go +++ b/dtlstransport_test.go @@ -17,7 +17,7 @@ import ( // An invalid fingerprint MUST cause PeerConnectionState to go to PeerConnectionStateFailed func TestInvalidFingerprintCausesFailed(t *testing.T) { - lim := test.TimeOut(time.Second * 40) + lim := test.TimeOut(time.Second * 5) defer lim.Stop() report := test.CheckRoutines(t) @@ -46,8 +46,8 @@ func TestInvalidFingerprintCausesFailed(t *testing.T) { } }) - offerConnectionHasFailed := untilConnectionState(PeerConnectionStateFailed, pcOffer) - answerConnectionHasFailed := untilConnectionState(PeerConnectionStateFailed, pcAnswer) + offerConnectionHasClosed := untilConnectionState(PeerConnectionStateClosed, pcOffer) + answerConnectionHasClosed := untilConnectionState(PeerConnectionStateClosed, pcAnswer) if _, err = pcOffer.CreateDataChannel("unusedDataChannel", nil); err != nil { t.Fatal(err) @@ -89,13 +89,17 @@ func TestInvalidFingerprintCausesFailed(t *testing.T) { t.Fatal("timed out waiting to receive offer") } - offerConnectionHasFailed.Wait() - answerConnectionHasFailed.Wait() + offerConnectionHasClosed.Wait() + answerConnectionHasClosed.Wait() - assert.Equal(t, pcOffer.SCTP().Transport().State(), DTLSTransportStateFailed) + if pcOffer.SCTP().Transport().State() != DTLSTransportStateClosed && pcOffer.SCTP().Transport().State() != DTLSTransportStateFailed { + t.Fail() + } assert.Nil(t, pcOffer.SCTP().Transport().conn) - assert.Equal(t, pcAnswer.SCTP().Transport().State(), DTLSTransportStateFailed) + if pcAnswer.SCTP().Transport().State() != DTLSTransportStateClosed && pcAnswer.SCTP().Transport().State() != DTLSTransportStateFailed { + t.Fail() + } assert.Nil(t, pcAnswer.SCTP().Transport().conn) } diff --git a/internal/mux/endpoint.go b/internal/mux/endpoint.go index 227c2bc2..41dd61d5 100644 --- a/internal/mux/endpoint.go +++ b/internal/mux/endpoint.go @@ -15,14 +15,18 @@ import ( // Endpoint implements net.Conn. It is used to read muxed packets. type Endpoint struct { - mux *Mux - buffer *packetio.Buffer + mux *Mux + buffer *packetio.Buffer + onClose func() } // Close unregisters the endpoint from the Mux func (e *Endpoint) Close() (err error) { - err = e.close() - if err != nil { + if e.onClose != nil { + e.onClose() + } + + if err = e.close(); err != nil { return err } @@ -76,3 +80,9 @@ func (e *Endpoint) SetReadDeadline(time.Time) error { func (e *Endpoint) SetWriteDeadline(time.Time) error { return nil } + +// SetOnClose is a user set callback that +// will be executed when `Close` is called +func (e *Endpoint) SetOnClose(onClose func()) { + e.onClose = onClose +} diff --git a/peerconnection.go b/peerconnection.go index 6e33bff0..ff7b4fff 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -2261,6 +2261,16 @@ func (pc *PeerConnection) startTransports(iceRole ICERole, dtlsRole DTLSRole, re return } + pc.dtlsTransport.internalOnCloseHandler = func() { + pc.log.Info("Closing PeerConnection from DTLS CloseNotify") + + go func() { + if pcClosErr := pc.Close(); pcClosErr != nil { + pc.log.Warnf("Failed to close PeerConnection from DTLS CloseNotify: %s", pcClosErr) + } + }() + } + // Start the dtls transport err = pc.dtlsTransport.Start(DTLSParameters{ Role: dtlsRole, diff --git a/peerconnection_media_test.go b/peerconnection_media_test.go index b000f538..89d104d1 100644 --- a/peerconnection_media_test.go +++ b/peerconnection_media_test.go @@ -24,6 +24,7 @@ import ( "github.com/pion/rtp" "github.com/pion/sdp/v3" "github.com/pion/transport/v3/test" + "github.com/pion/transport/v3/vnet" "github.com/pion/webrtc/v3/pkg/media" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -329,10 +330,15 @@ func TestPeerConnection_Media_Disconnected(t *testing.T) { m := &MediaEngine{} assert.NoError(t, m.RegisterDefaultCodecs()) - pcOffer, pcAnswer, err := NewAPI(WithSettingEngine(s), WithMediaEngine(m)).newPair(Configuration{}) - if err != nil { - t.Fatal(err) - } + pcOffer, pcAnswer, wan := createVNetPair(t) + + keepPackets := &atomicBool{} + keepPackets.set(true) + + // Add a filter that monitors the traffic on the router + wan.AddChunkFilter(func(c vnet.Chunk) bool { + return keepPackets.get() + }) vp8Track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2") if err != nil { @@ -360,14 +366,11 @@ func TestPeerConnection_Media_Disconnected(t *testing.T) { time.Sleep(time.Second) } - if pcCloseErr := pcAnswer.Close(); pcCloseErr != nil { - haveDisconnected <- pcCloseErr - } + keepPackets.set(false) } }) - err = signalPair(pcOffer, pcAnswer) - if err != nil { + if err = signalPair(pcOffer, pcAnswer); err != nil { t.Fatal(err) } @@ -383,7 +386,8 @@ func TestPeerConnection_Media_Disconnected(t *testing.T) { } } - assert.NoError(t, pcOffer.Close()) + assert.NoError(t, wan.Stop()) + closePairNow(t, pcOffer, pcAnswer) } type undeclaredSsrcLogger struct{ unhandledSimulcastError chan struct{} } diff --git a/peerconnection_test.go b/peerconnection_test.go index c821b69a..bd5c84fa 100644 --- a/peerconnection_test.go +++ b/peerconnection_test.go @@ -754,3 +754,41 @@ func TestTransportChain(t *testing.T) { closePairNow(t, offer, answer) } + +// Assert that the PeerConnection closes via DTLS (and not ICE) +func TestDTLSClose(t *testing.T) { + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + pcOffer, pcAnswer, err := newPair() + assert.NoError(t, err) + + _, err = pcOffer.AddTransceiverFromKind(RTPCodecTypeVideo) + assert.NoError(t, err) + + peerConnectionsConnected := untilConnectionState(PeerConnectionStateConnected, pcOffer, pcAnswer) + + offer, err := pcOffer.CreateOffer(nil) + assert.NoError(t, err) + + offerGatheringComplete := GatheringCompletePromise(pcOffer) + assert.NoError(t, pcOffer.SetLocalDescription(offer)) + <-offerGatheringComplete + + assert.NoError(t, pcAnswer.SetRemoteDescription(*pcOffer.LocalDescription())) + + answer, err := pcAnswer.CreateAnswer(nil) + assert.NoError(t, err) + + answerGatheringComplete := GatheringCompletePromise(pcAnswer) + assert.NoError(t, pcAnswer.SetLocalDescription(answer)) + <-answerGatheringComplete + + assert.NoError(t, pcOffer.SetRemoteDescription(*pcAnswer.LocalDescription())) + + peerConnectionsConnected.Wait() + assert.NoError(t, pcOffer.Close()) +}