mirror of
https://github.com/pion/webrtc.git
synced 2025-09-26 19:21:12 +08:00
324 lines
8.8 KiB
Go
324 lines
8.8 KiB
Go
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
//go:build !js
|
|
// +build !js
|
|
|
|
package webrtc
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestGenerateDataChannelID(t *testing.T) {
|
|
sctpTransportWithChannels := func(ids []uint16) *SCTPTransport {
|
|
ret := &SCTPTransport{
|
|
dataChannels: []*DataChannel{},
|
|
dataChannelIDsUsed: make(map[uint16]struct{}),
|
|
}
|
|
|
|
for i := range ids {
|
|
id := ids[i]
|
|
ret.dataChannels = append(ret.dataChannels, &DataChannel{id: &id})
|
|
ret.dataChannelIDsUsed[id] = struct{}{}
|
|
}
|
|
|
|
return ret
|
|
}
|
|
|
|
testCases := []struct {
|
|
role DTLSRole
|
|
s *SCTPTransport
|
|
result uint16
|
|
}{
|
|
{DTLSRoleClient, sctpTransportWithChannels([]uint16{}), 0},
|
|
{DTLSRoleClient, sctpTransportWithChannels([]uint16{1}), 0},
|
|
{DTLSRoleClient, sctpTransportWithChannels([]uint16{0}), 2},
|
|
{DTLSRoleClient, sctpTransportWithChannels([]uint16{0, 2}), 4},
|
|
{DTLSRoleClient, sctpTransportWithChannels([]uint16{0, 4}), 2},
|
|
{DTLSRoleServer, sctpTransportWithChannels([]uint16{}), 1},
|
|
{DTLSRoleServer, sctpTransportWithChannels([]uint16{0}), 1},
|
|
{DTLSRoleServer, sctpTransportWithChannels([]uint16{1}), 3},
|
|
{DTLSRoleServer, sctpTransportWithChannels([]uint16{1, 3}), 5},
|
|
{DTLSRoleServer, sctpTransportWithChannels([]uint16{1, 5}), 3},
|
|
}
|
|
for _, testCase := range testCases {
|
|
idPtr := new(uint16)
|
|
err := testCase.s.generateAndSetDataChannelID(testCase.role, &idPtr)
|
|
assert.NoError(t, err, "failed to generate data channel id")
|
|
assert.Equal(t, testCase.result, *idPtr)
|
|
assert.Contains(
|
|
t, testCase.s.dataChannelIDsUsed, *idPtr,
|
|
"expected new id to be added to the map",
|
|
)
|
|
}
|
|
}
|
|
|
|
func TestSCTPTransportOnClose(t *testing.T) {
|
|
offerPC, answerPC, err := newPair()
|
|
require.NoError(t, err)
|
|
|
|
defer closePairNow(t, offerPC, answerPC)
|
|
|
|
answerPC.OnDataChannel(func(dc *DataChannel) {
|
|
dc.OnMessage(func(_ DataChannelMessage) {
|
|
assert.NoError(t, dc.Send([]byte("hello")), "failed to send message")
|
|
})
|
|
})
|
|
|
|
recvMsg := make(chan struct{}, 1)
|
|
offerPC.OnConnectionStateChange(func(state PeerConnectionState) {
|
|
if state == PeerConnectionStateConnected {
|
|
defer func() {
|
|
offerPC.OnConnectionStateChange(nil)
|
|
}()
|
|
|
|
dc, createErr := offerPC.CreateDataChannel(expectedLabel, nil)
|
|
assert.NoError(t, createErr, "Failed to create a PC pair for testing")
|
|
dc.OnMessage(func(msg DataChannelMessage) {
|
|
assert.Equal(
|
|
t, []byte("hello"), msg.Data,
|
|
"invalid msg received",
|
|
)
|
|
recvMsg <- struct{}{}
|
|
})
|
|
dc.OnOpen(func() {
|
|
assert.NoError(t, dc.Send([]byte("hello")), "failed to send initial msg")
|
|
})
|
|
}
|
|
})
|
|
|
|
err = signalPair(offerPC, answerPC)
|
|
require.NoError(t, err)
|
|
|
|
select {
|
|
case <-recvMsg:
|
|
case <-time.After(5 * time.Second):
|
|
assert.Fail(t, "timed out")
|
|
}
|
|
|
|
// setup SCTP OnClose callback
|
|
ch := make(chan error, 1)
|
|
answerPC.SCTP().OnClose(func(err error) {
|
|
ch <- err
|
|
})
|
|
|
|
err = offerPC.Close() // This will trigger sctp onclose callback on remote
|
|
require.NoError(t, err)
|
|
|
|
select {
|
|
case <-ch:
|
|
case <-time.After(5 * time.Second):
|
|
assert.Fail(t, "timed out")
|
|
}
|
|
}
|
|
|
|
func TestSCTPTransportOutOfBandNegotiatedDataChannelDetach(t *testing.T) { //nolint:cyclop
|
|
// nolint:varnamelen
|
|
const N = 10
|
|
done := make(chan struct{}, N)
|
|
for i := 0; i < N; i++ {
|
|
go func() {
|
|
// Use Detach data channels mode
|
|
s := SettingEngine{}
|
|
s.DetachDataChannels()
|
|
api := NewAPI(WithSettingEngine(s))
|
|
|
|
// Set up two peer connections.
|
|
config := Configuration{}
|
|
offerPC, err := api.NewPeerConnection(config)
|
|
assert.NoError(t, err)
|
|
answerPC, err := api.NewPeerConnection(config)
|
|
assert.NoError(t, err)
|
|
|
|
defer closePairNow(t, offerPC, answerPC)
|
|
defer func() { done <- struct{}{} }()
|
|
|
|
negotiated := true
|
|
id := uint16(0)
|
|
readDetach := make(chan struct{})
|
|
dc1, err := offerPC.CreateDataChannel("", &DataChannelInit{
|
|
Negotiated: &negotiated,
|
|
ID: &id,
|
|
})
|
|
assert.NoError(t, err)
|
|
|
|
dc1.OnOpen(func() {
|
|
_, _ = dc1.Detach()
|
|
close(readDetach)
|
|
})
|
|
|
|
writeDetach := make(chan struct{})
|
|
dc2, err := answerPC.CreateDataChannel("", &DataChannelInit{
|
|
Negotiated: &negotiated,
|
|
ID: &id,
|
|
})
|
|
assert.NoError(t, err)
|
|
|
|
dc2.OnOpen(func() {
|
|
_, _ = dc2.Detach()
|
|
close(writeDetach)
|
|
})
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
go func() {
|
|
defer wg.Done()
|
|
connestd := make(chan struct{}, 1)
|
|
offerPC.OnConnectionStateChange(func(state PeerConnectionState) {
|
|
if state == PeerConnectionStateConnected {
|
|
connestd <- struct{}{}
|
|
}
|
|
})
|
|
select {
|
|
case <-connestd:
|
|
case <-time.After(10 * time.Second):
|
|
assert.Fail(t, "conn establishment timed out")
|
|
|
|
return
|
|
}
|
|
<-readDetach
|
|
err1 := dc1.dataChannel.SetReadDeadline(time.Now().Add(10 * time.Second))
|
|
assert.NoError(t, err1)
|
|
buf := make([]byte, 10)
|
|
n, err1 := dc1.dataChannel.Read(buf)
|
|
assert.NoError(t, err1)
|
|
assert.Equal(t, "hello", string(buf[:n]), "invalid read")
|
|
}()
|
|
go func() {
|
|
defer wg.Done()
|
|
connestd := make(chan struct{}, 1)
|
|
answerPC.OnConnectionStateChange(func(state PeerConnectionState) {
|
|
if state == PeerConnectionStateConnected {
|
|
connestd <- struct{}{}
|
|
}
|
|
})
|
|
select {
|
|
case <-connestd:
|
|
case <-time.After(10 * time.Second):
|
|
assert.Fail(t, "connection establishment timed out")
|
|
|
|
return
|
|
}
|
|
<-writeDetach
|
|
n, err1 := dc2.dataChannel.Write([]byte("hello"))
|
|
assert.NoError(t, err1)
|
|
assert.Equal(t, len("hello"), n)
|
|
}()
|
|
err = signalPair(offerPC, answerPC)
|
|
require.NoError(t, err)
|
|
wg.Wait()
|
|
}()
|
|
}
|
|
|
|
for i := 0; i < N; i++ {
|
|
select {
|
|
case <-done:
|
|
case <-time.After(20 * time.Second):
|
|
assert.Fail(t, "timed out")
|
|
}
|
|
}
|
|
}
|
|
|
|
// Assert that max-message-size is signaled properly
|
|
// and able to be configured via SettingEngine.
|
|
func TestMaxMessageSizeSignaling(t *testing.T) {
|
|
t.Run("Local Offer", func(t *testing.T) {
|
|
peerConnection, err := NewPeerConnection(Configuration{})
|
|
require.NoError(t, err)
|
|
|
|
_, err = peerConnection.CreateDataChannel("", nil)
|
|
require.NoError(t, err)
|
|
|
|
offer, err := peerConnection.CreateOffer(nil)
|
|
require.NoError(t, err)
|
|
|
|
require.Contains(t, offer.SDP, "a=max-message-size:1073741823\r\n")
|
|
require.NoError(t, peerConnection.Close())
|
|
})
|
|
|
|
t.Run("Local SettingEngine", func(t *testing.T) {
|
|
settingEngine := SettingEngine{}
|
|
settingEngine.SetSCTPMaxMessageSize(4321)
|
|
|
|
peerConnection, err := NewAPI(WithSettingEngine(settingEngine)).NewPeerConnection(Configuration{})
|
|
require.NoError(t, err)
|
|
|
|
_, err = peerConnection.CreateDataChannel("", nil)
|
|
require.NoError(t, err)
|
|
|
|
offer, err := peerConnection.CreateOffer(nil)
|
|
require.NoError(t, err)
|
|
|
|
require.Contains(t, offer.SDP, "a=max-message-size:4321\r\n")
|
|
require.NoError(t, peerConnection.Close())
|
|
})
|
|
|
|
t.Run("Remote", func(t *testing.T) {
|
|
settingEngine := SettingEngine{}
|
|
settingEngine.SetSCTPMaxMessageSize(4321)
|
|
|
|
offerPeerConnection, err := NewAPI(WithSettingEngine(settingEngine)).NewPeerConnection(Configuration{})
|
|
require.NoError(t, err)
|
|
|
|
answerPeerConnection, err := NewPeerConnection(Configuration{})
|
|
require.NoError(t, err)
|
|
|
|
onDataChannelOpen, onDataChannelOpenCancel := context.WithCancel(context.Background())
|
|
answerPeerConnection.OnDataChannel(func(d *DataChannel) {
|
|
d.OnOpen(func() {
|
|
onDataChannelOpenCancel()
|
|
})
|
|
})
|
|
|
|
require.NoError(t, signalPair(offerPeerConnection, answerPeerConnection))
|
|
|
|
<-onDataChannelOpen.Done()
|
|
require.Equal(t, uint32(defaultMaxSCTPMessageSize), offerPeerConnection.SCTP().GetCapabilities().MaxMessageSize)
|
|
require.Equal(t, uint32(4321), answerPeerConnection.SCTP().GetCapabilities().MaxMessageSize)
|
|
|
|
closePairNow(t, offerPeerConnection, answerPeerConnection)
|
|
})
|
|
|
|
t.Run("Remote Unset", func(t *testing.T) {
|
|
offerPeerConnection, answerPeerConnection, err := newPair()
|
|
require.NoError(t, err)
|
|
|
|
require.NoError(t, signalPairWithModification(offerPeerConnection, answerPeerConnection, func(sessionDescription string) (filtered string) { // nolint
|
|
scanner := bufio.NewScanner(strings.NewReader(sessionDescription))
|
|
for scanner.Scan() {
|
|
if strings.HasPrefix(scanner.Text(), "a=max-message-size") {
|
|
continue
|
|
}
|
|
|
|
filtered += scanner.Text() + "\r\n"
|
|
}
|
|
|
|
return
|
|
}))
|
|
|
|
onDataChannelOpen, onDataChannelOpenCancel := context.WithCancel(context.Background())
|
|
answerPeerConnection.OnDataChannel(func(d *DataChannel) {
|
|
d.OnOpen(func() {
|
|
onDataChannelOpenCancel()
|
|
})
|
|
})
|
|
|
|
require.NoError(t, signalPair(offerPeerConnection, answerPeerConnection))
|
|
|
|
<-onDataChannelOpen.Done()
|
|
require.Equal(t, uint32(defaultMaxSCTPMessageSize), offerPeerConnection.SCTP().GetCapabilities().MaxMessageSize)
|
|
require.Equal(t, uint32(sctpMaxMessageSizeUnsetValue), answerPeerConnection.SCTP().GetCapabilities().MaxMessageSize)
|
|
|
|
closePairNow(t, offerPeerConnection, answerPeerConnection)
|
|
})
|
|
}
|