mirror of
https://github.com/pion/webrtc.git
synced 2025-09-27 03:25:58 +08:00
Close PeerConnection on DTLS CloseNotify
Resolves #1767 Resolves pion/dtls#151
This commit is contained in:
@@ -44,7 +44,8 @@ type DTLSTransport struct {
|
|||||||
state DTLSTransportState
|
state DTLSTransportState
|
||||||
srtpProtectionProfile srtp.ProtectionProfile
|
srtpProtectionProfile srtp.ProtectionProfile
|
||||||
|
|
||||||
onStateChangeHandler func(DTLSTransportState)
|
onStateChangeHandler func(DTLSTransportState)
|
||||||
|
internalOnCloseHandler func()
|
||||||
|
|
||||||
conn *dtls.Conn
|
conn *dtls.Conn
|
||||||
|
|
||||||
@@ -322,6 +323,7 @@ func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error {
|
|||||||
|
|
||||||
var dtlsConn *dtls.Conn
|
var dtlsConn *dtls.Conn
|
||||||
dtlsEndpoint := t.iceTransport.newEndpoint(mux.MatchDTLS)
|
dtlsEndpoint := t.iceTransport.newEndpoint(mux.MatchDTLS)
|
||||||
|
dtlsEndpoint.SetOnClose(t.internalOnCloseHandler)
|
||||||
role, dtlsConfig, err := prepareTransport()
|
role, dtlsConfig, err := prepareTransport()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@@ -17,7 +17,7 @@ import (
|
|||||||
|
|
||||||
// An invalid fingerprint MUST cause PeerConnectionState to go to PeerConnectionStateFailed
|
// An invalid fingerprint MUST cause PeerConnectionState to go to PeerConnectionStateFailed
|
||||||
func TestInvalidFingerprintCausesFailed(t *testing.T) {
|
func TestInvalidFingerprintCausesFailed(t *testing.T) {
|
||||||
lim := test.TimeOut(time.Second * 40)
|
lim := test.TimeOut(time.Second * 5)
|
||||||
defer lim.Stop()
|
defer lim.Stop()
|
||||||
|
|
||||||
report := test.CheckRoutines(t)
|
report := test.CheckRoutines(t)
|
||||||
@@ -46,8 +46,8 @@ func TestInvalidFingerprintCausesFailed(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
offerConnectionHasFailed := untilConnectionState(PeerConnectionStateFailed, pcOffer)
|
offerConnectionHasClosed := untilConnectionState(PeerConnectionStateClosed, pcOffer)
|
||||||
answerConnectionHasFailed := untilConnectionState(PeerConnectionStateFailed, pcAnswer)
|
answerConnectionHasClosed := untilConnectionState(PeerConnectionStateClosed, pcAnswer)
|
||||||
|
|
||||||
if _, err = pcOffer.CreateDataChannel("unusedDataChannel", nil); err != nil {
|
if _, err = pcOffer.CreateDataChannel("unusedDataChannel", nil); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -89,13 +89,17 @@ func TestInvalidFingerprintCausesFailed(t *testing.T) {
|
|||||||
t.Fatal("timed out waiting to receive offer")
|
t.Fatal("timed out waiting to receive offer")
|
||||||
}
|
}
|
||||||
|
|
||||||
offerConnectionHasFailed.Wait()
|
offerConnectionHasClosed.Wait()
|
||||||
answerConnectionHasFailed.Wait()
|
answerConnectionHasClosed.Wait()
|
||||||
|
|
||||||
assert.Equal(t, pcOffer.SCTP().Transport().State(), DTLSTransportStateFailed)
|
if pcOffer.SCTP().Transport().State() != DTLSTransportStateClosed && pcOffer.SCTP().Transport().State() != DTLSTransportStateFailed {
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
assert.Nil(t, pcOffer.SCTP().Transport().conn)
|
assert.Nil(t, pcOffer.SCTP().Transport().conn)
|
||||||
|
|
||||||
assert.Equal(t, pcAnswer.SCTP().Transport().State(), DTLSTransportStateFailed)
|
if pcAnswer.SCTP().Transport().State() != DTLSTransportStateClosed && pcAnswer.SCTP().Transport().State() != DTLSTransportStateFailed {
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
assert.Nil(t, pcAnswer.SCTP().Transport().conn)
|
assert.Nil(t, pcAnswer.SCTP().Transport().conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -15,14 +15,18 @@ import (
|
|||||||
|
|
||||||
// Endpoint implements net.Conn. It is used to read muxed packets.
|
// Endpoint implements net.Conn. It is used to read muxed packets.
|
||||||
type Endpoint struct {
|
type Endpoint struct {
|
||||||
mux *Mux
|
mux *Mux
|
||||||
buffer *packetio.Buffer
|
buffer *packetio.Buffer
|
||||||
|
onClose func()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close unregisters the endpoint from the Mux
|
// Close unregisters the endpoint from the Mux
|
||||||
func (e *Endpoint) Close() (err error) {
|
func (e *Endpoint) Close() (err error) {
|
||||||
err = e.close()
|
if e.onClose != nil {
|
||||||
if err != nil {
|
e.onClose()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = e.close(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,3 +80,9 @@ func (e *Endpoint) SetReadDeadline(time.Time) error {
|
|||||||
func (e *Endpoint) SetWriteDeadline(time.Time) error {
|
func (e *Endpoint) SetWriteDeadline(time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetOnClose is a user set callback that
|
||||||
|
// will be executed when `Close` is called
|
||||||
|
func (e *Endpoint) SetOnClose(onClose func()) {
|
||||||
|
e.onClose = onClose
|
||||||
|
}
|
||||||
|
@@ -2261,6 +2261,16 @@ func (pc *PeerConnection) startTransports(iceRole ICERole, dtlsRole DTLSRole, re
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pc.dtlsTransport.internalOnCloseHandler = func() {
|
||||||
|
pc.log.Info("Closing PeerConnection from DTLS CloseNotify")
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if pcClosErr := pc.Close(); pcClosErr != nil {
|
||||||
|
pc.log.Warnf("Failed to close PeerConnection from DTLS CloseNotify: %s", pcClosErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
// Start the dtls transport
|
// Start the dtls transport
|
||||||
err = pc.dtlsTransport.Start(DTLSParameters{
|
err = pc.dtlsTransport.Start(DTLSParameters{
|
||||||
Role: dtlsRole,
|
Role: dtlsRole,
|
||||||
|
@@ -24,6 +24,7 @@ import (
|
|||||||
"github.com/pion/rtp"
|
"github.com/pion/rtp"
|
||||||
"github.com/pion/sdp/v3"
|
"github.com/pion/sdp/v3"
|
||||||
"github.com/pion/transport/v3/test"
|
"github.com/pion/transport/v3/test"
|
||||||
|
"github.com/pion/transport/v3/vnet"
|
||||||
"github.com/pion/webrtc/v3/pkg/media"
|
"github.com/pion/webrtc/v3/pkg/media"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -329,10 +330,15 @@ func TestPeerConnection_Media_Disconnected(t *testing.T) {
|
|||||||
m := &MediaEngine{}
|
m := &MediaEngine{}
|
||||||
assert.NoError(t, m.RegisterDefaultCodecs())
|
assert.NoError(t, m.RegisterDefaultCodecs())
|
||||||
|
|
||||||
pcOffer, pcAnswer, err := NewAPI(WithSettingEngine(s), WithMediaEngine(m)).newPair(Configuration{})
|
pcOffer, pcAnswer, wan := createVNetPair(t)
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
keepPackets := &atomicBool{}
|
||||||
}
|
keepPackets.set(true)
|
||||||
|
|
||||||
|
// Add a filter that monitors the traffic on the router
|
||||||
|
wan.AddChunkFilter(func(c vnet.Chunk) bool {
|
||||||
|
return keepPackets.get()
|
||||||
|
})
|
||||||
|
|
||||||
vp8Track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2")
|
vp8Track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -360,14 +366,11 @@ func TestPeerConnection_Media_Disconnected(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
if pcCloseErr := pcAnswer.Close(); pcCloseErr != nil {
|
keepPackets.set(false)
|
||||||
haveDisconnected <- pcCloseErr
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
err = signalPair(pcOffer, pcAnswer)
|
if err = signalPair(pcOffer, pcAnswer); err != nil {
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -383,7 +386,8 @@ func TestPeerConnection_Media_Disconnected(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.NoError(t, pcOffer.Close())
|
assert.NoError(t, wan.Stop())
|
||||||
|
closePairNow(t, pcOffer, pcAnswer)
|
||||||
}
|
}
|
||||||
|
|
||||||
type undeclaredSsrcLogger struct{ unhandledSimulcastError chan struct{} }
|
type undeclaredSsrcLogger struct{ unhandledSimulcastError chan struct{} }
|
||||||
|
@@ -754,3 +754,41 @@ func TestTransportChain(t *testing.T) {
|
|||||||
|
|
||||||
closePairNow(t, offer, answer)
|
closePairNow(t, offer, answer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Assert that the PeerConnection closes via DTLS (and not ICE)
|
||||||
|
func TestDTLSClose(t *testing.T) {
|
||||||
|
lim := test.TimeOut(time.Second * 10)
|
||||||
|
defer lim.Stop()
|
||||||
|
|
||||||
|
report := test.CheckRoutines(t)
|
||||||
|
defer report()
|
||||||
|
|
||||||
|
pcOffer, pcAnswer, err := newPair()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = pcOffer.AddTransceiverFromKind(RTPCodecTypeVideo)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
peerConnectionsConnected := untilConnectionState(PeerConnectionStateConnected, pcOffer, pcAnswer)
|
||||||
|
|
||||||
|
offer, err := pcOffer.CreateOffer(nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
offerGatheringComplete := GatheringCompletePromise(pcOffer)
|
||||||
|
assert.NoError(t, pcOffer.SetLocalDescription(offer))
|
||||||
|
<-offerGatheringComplete
|
||||||
|
|
||||||
|
assert.NoError(t, pcAnswer.SetRemoteDescription(*pcOffer.LocalDescription()))
|
||||||
|
|
||||||
|
answer, err := pcAnswer.CreateAnswer(nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
answerGatheringComplete := GatheringCompletePromise(pcAnswer)
|
||||||
|
assert.NoError(t, pcAnswer.SetLocalDescription(answer))
|
||||||
|
<-answerGatheringComplete
|
||||||
|
|
||||||
|
assert.NoError(t, pcOffer.SetRemoteDescription(*pcAnswer.LocalDescription()))
|
||||||
|
|
||||||
|
peerConnectionsConnected.Wait()
|
||||||
|
assert.NoError(t, pcOffer.Close())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user