mirror of
https://github.com/pion/webrtc.git
synced 2025-12-24 11:51:03 +08:00
Provide SCTP Association OnClose callback
This commit is contained in:
@@ -45,6 +45,7 @@ type SCTPTransport struct {
|
||||
// OnStateChange func()
|
||||
|
||||
onErrorHandler func(error)
|
||||
onCloseHandler func(error)
|
||||
|
||||
sctpAssociation *sctp.Association
|
||||
onDataChannelHandler func(*DataChannel)
|
||||
@@ -176,6 +177,7 @@ func (r *SCTPTransport) acceptDataChannels(a *sctp.Association) {
|
||||
dataChannels = append(dataChannels, dc.dataChannel)
|
||||
}
|
||||
r.lock.RUnlock()
|
||||
|
||||
ACCEPT:
|
||||
for {
|
||||
dc, err := datachannel.Accept(a, &datachannel.Config{
|
||||
@@ -185,6 +187,9 @@ ACCEPT:
|
||||
if !errors.Is(err, io.EOF) {
|
||||
r.log.Errorf("Failed to accept data channel: %v", err)
|
||||
r.onError(err)
|
||||
r.onClose(err)
|
||||
} else {
|
||||
r.onClose(nil)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -232,9 +237,14 @@ ACCEPT:
|
||||
MaxRetransmits: maxRetransmits,
|
||||
}, r, r.api.settingEngine.LoggerFactory.NewLogger("ortc"))
|
||||
if err != nil {
|
||||
// This data channel is invalid. Close it and log an error.
|
||||
if err1 := dc.Close(); err1 != nil {
|
||||
r.log.Errorf("Failed to close invalid data channel: %v", err1)
|
||||
}
|
||||
r.log.Errorf("Failed to accept data channel: %v", err)
|
||||
r.onError(err)
|
||||
return
|
||||
// We've received a datachannel with invalid configuration. We can still receive other datachannels.
|
||||
continue ACCEPT
|
||||
}
|
||||
|
||||
<-r.onDataChannel(rtcDC)
|
||||
@@ -251,8 +261,7 @@ ACCEPT:
|
||||
}
|
||||
}
|
||||
|
||||
// OnError sets an event handler which is invoked when
|
||||
// the SCTP connection error occurs.
|
||||
// OnError sets an event handler which is invoked when the SCTP Association errors.
|
||||
func (r *SCTPTransport) OnError(f func(err error)) {
|
||||
r.lock.Lock()
|
||||
defer r.lock.Unlock()
|
||||
@@ -269,6 +278,23 @@ func (r *SCTPTransport) onError(err error) {
|
||||
}
|
||||
}
|
||||
|
||||
// OnClose sets an event handler which is invoked when the SCTP Association closes.
|
||||
func (r *SCTPTransport) OnClose(f func(err error)) {
|
||||
r.lock.Lock()
|
||||
defer r.lock.Unlock()
|
||||
r.onCloseHandler = f
|
||||
}
|
||||
|
||||
func (r *SCTPTransport) onClose(err error) {
|
||||
r.lock.RLock()
|
||||
handler := r.onCloseHandler
|
||||
r.lock.RUnlock()
|
||||
|
||||
if handler != nil {
|
||||
go handler(err)
|
||||
}
|
||||
}
|
||||
|
||||
// OnDataChannel sets an event handler which is invoked when a data
|
||||
// channel message arrives from a remote peer.
|
||||
func (r *SCTPTransport) OnDataChannel(f func(*DataChannel)) {
|
||||
|
||||
@@ -6,7 +6,13 @@
|
||||
|
||||
package webrtc
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGenerateDataChannelID(t *testing.T) {
|
||||
sctpTransportWithChannels := func(ids []uint16) *SCTPTransport {
|
||||
@@ -55,3 +61,66 @@ func TestGenerateDataChannelID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCTPTransportOnClose(t *testing.T) {
|
||||
offerPC, answerPC, err := newPair()
|
||||
require.NoError(t, err)
|
||||
|
||||
answerPC.OnDataChannel(func(dc *DataChannel) {
|
||||
dc.OnMessage(func(_ DataChannelMessage) {
|
||||
if err1 := dc.Send([]byte("hello")); err1 != nil {
|
||||
t.Error("failed to send message")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
recvMsg := make(chan struct{}, 1)
|
||||
offerPC.OnConnectionStateChange(func(state PeerConnectionState) {
|
||||
if state == PeerConnectionStateConnected {
|
||||
defer func() {
|
||||
offerPC.OnConnectionStateChange(nil)
|
||||
}()
|
||||
|
||||
dc, createErr := offerPC.CreateDataChannel(expectedLabel, nil)
|
||||
if createErr != nil {
|
||||
t.Errorf("Failed to create a PC pair for testing")
|
||||
return
|
||||
}
|
||||
dc.OnMessage(func(msg DataChannelMessage) {
|
||||
if !bytes.Equal(msg.Data, []byte("hello")) {
|
||||
t.Error("invalid msg received")
|
||||
}
|
||||
recvMsg <- struct{}{}
|
||||
})
|
||||
dc.OnOpen(func() {
|
||||
if err1 := dc.Send([]byte("hello")); err1 != nil {
|
||||
t.Error("failed to send initial msg", err1)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
err = signalPair(offerPC, answerPC)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-recvMsg:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timed out")
|
||||
}
|
||||
|
||||
// setup SCTP OnClose callback
|
||||
ch := make(chan error, 1)
|
||||
answerPC.SCTP().OnClose(func(err error) {
|
||||
ch <- err
|
||||
})
|
||||
|
||||
err = offerPC.Close() // This will trigger sctp onclose callback on remote
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timed out")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user