diff --git a/sdp_test.go b/sdp_test.go index 433579bf..25783bee 100644 --- a/sdp_test.go +++ b/sdp_test.go @@ -711,7 +711,7 @@ func TestPopulateSDP(t *testing.T) { //nolint:cyclop,maintidx // Test contains rid map keys var ridFound int for _, desc := range offerSdp.MediaDescriptions { - if desc.MediaName.Media != "video" { + if desc.MediaName.Media != string(MediaKindVideo) { continue } ridsInSDP := getRids(desc) @@ -770,7 +770,7 @@ func TestPopulateSDP(t *testing.T) { //nolint:cyclop,maintidx // Test codecs foundVP8 := false for _, desc := range offerSdp.MediaDescriptions { - if desc.MediaName.Media != "video" { + if desc.MediaName.Media != string(MediaKindVideo) { continue } for _, a := range desc.Attributes { @@ -868,7 +868,7 @@ func TestPopulateSDP(t *testing.T) { //nolint:cyclop,maintidx // Test codecs foundRejectedTrack := false for _, desc := range offerSdp.MediaDescriptions { - if desc.MediaName.Media != "audio" { + if desc.MediaName.Media != string(MediaKindAudio) { continue } assert.True(t, desc.ConnectionInformation != nil, "connection information must be provided for rejected tracks") diff --git a/settingengine_test.go b/settingengine_test.go index 1724b65c..64fab1e2 100644 --- a/settingengine_test.go +++ b/settingengine_test.go @@ -7,18 +7,22 @@ package webrtc import ( + "bytes" "context" + "crypto/x509" "net" "testing" "time" "github.com/pion/datachannel" + "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/ice/v4" "github.com/pion/stun/v3" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/assert" + "golang.org/x/net/proxy" ) func TestSetEphemeralUDPPortRange(t *testing.T) { @@ -464,3 +468,161 @@ func TestEnableDataChannelBlockWrite(t *testing.T) { assert.ErrorIs(t, err, context.DeadlineExceeded) closePairNow(t, offer, answer) } + +func TestSettingEngine_getReceiveMTU_Custom(t *testing.T) { + var se SettingEngine + se.SetReceiveMTU(1234) + + got := se.getReceiveMTU() + assert.Equal(t, uint(1234), got) +} + +func TestSettingEngine_ICEAcceptanceAndSTUNSetters(t *testing.T) { + var se SettingEngine + + host := 10 * time.Millisecond + srflx := 20 * time.Millisecond + prflx := 30 * time.Millisecond + relay := 40 * time.Millisecond + stun := 50 * time.Millisecond + + se.SetHostAcceptanceMinWait(host) + se.SetSrflxAcceptanceMinWait(srflx) + se.SetPrflxAcceptanceMinWait(prflx) + se.SetRelayAcceptanceMinWait(relay) + se.SetSTUNGatherTimeout(stun) + + assert.NotNil(t, se.timeout.ICEHostAcceptanceMinWait) + assert.NotNil(t, se.timeout.ICESrflxAcceptanceMinWait) + assert.NotNil(t, se.timeout.ICEPrflxAcceptanceMinWait) + assert.NotNil(t, se.timeout.ICERelayAcceptanceMinWait) + assert.NotNil(t, se.timeout.ICESTUNGatherTimeout) + + assert.Equal(t, host, *se.timeout.ICEHostAcceptanceMinWait) + assert.Equal(t, srflx, *se.timeout.ICESrflxAcceptanceMinWait) + assert.Equal(t, prflx, *se.timeout.ICEPrflxAcceptanceMinWait) + assert.Equal(t, relay, *se.timeout.ICERelayAcceptanceMinWait) + assert.Equal(t, stun, *se.timeout.ICESTUNGatherTimeout) +} + +func TestSettingEngine_CandidateFiltersAndNetworkTypes(t *testing.T) { + var se SettingEngine + + nts := []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6} + se.SetNetworkTypes(nts) + assert.Equal(t, nts, se.candidates.ICENetworkTypes) + + ifFilter := func(name string) bool { return name == "eth0" } + ipFilter := func(ip net.IP) bool { return ip.IsLoopback() } + + se.SetInterfaceFilter(ifFilter) + se.SetIPFilter(ipFilter) + se.SetIncludeLoopbackCandidate(true) + + assert.NotNil(t, se.candidates.InterfaceFilter) + assert.NotNil(t, se.candidates.IPFilter) + assert.True(t, se.candidates.InterfaceFilter("eth0")) + assert.False(t, se.candidates.InterfaceFilter("wlan0")) + assert.True(t, se.candidates.IPFilter(net.IPv4(127, 0, 0, 1))) + assert.True(t, se.candidates.IncludeLoopbackCandidate) +} + +func TestSettingEngine_MDNSAndCredentialsAndFingerprint(t *testing.T) { + var se SettingEngine + + se.SetMulticastDNSHostName("host.local.") + se.SetICECredentials("ufrag123", "pwd456") + se.DisableCertificateFingerprintVerification(true) + + assert.Equal(t, "host.local.", se.candidates.MulticastDNSHostName) + assert.Equal(t, "ufrag123", se.candidates.UsernameFragment) + assert.Equal(t, "pwd456", se.candidates.Password) + assert.True(t, se.disableCertificateFingerprintVerification) +} + +func TestSettingEngine_UDPMuxProxyBindingAndTCPFlags(t *testing.T) { + var se SettingEngine + + var mux ice.UDPMux + se.SetICEUDPMux(mux) + assert.Equal(t, mux, se.iceUDPMux) + + se.SetICEProxyDialer(proxy.Direct) + assert.Equal(t, proxy.Direct, se.iceProxyDialer) + + var maxReq uint16 = 77 + se.SetICEMaxBindingRequests(maxReq) + assert.NotNil(t, se.iceMaxBindingRequests) + assert.Equal(t, maxReq, *se.iceMaxBindingRequests) + + se.DisableActiveTCP(true) + assert.True(t, se.iceDisableActiveTCP) +} + +func TestSettingEngine_MediaEngineAndMTUFlags(t *testing.T) { + var se SettingEngine + + se.DisableMediaEngineMultipleCodecs(true) + assert.True(t, se.disableMediaEngineMultipleCodecs) + + se.SetReceiveMTU(1337) + assert.Equal(t, uint(1337), se.receiveMTU) +} + +func TestSettingEngine_DTLSSetters(t *testing.T) { + var se SettingEngine + + se.SetDTLSInsecureSkipHelloVerify(true) + se.SetDTLSDisableInsecureSkipVerify(true) + se.SetDTLSExtendedMasterSecret(dtls.RequireExtendedMasterSecret) + + auth := dtls.RequireAnyClientCert + se.SetDTLSClientAuth(auth) + + clientCAs := x509.NewCertPool() + rootCAs := x509.NewCertPool() + var keyBuf bytes.Buffer + + se.SetDTLSClientCAs(clientCAs) + se.SetDTLSRootCAs(rootCAs) + se.SetDTLSKeyLogWriter(&keyBuf) + + called := false + se.SetDTLSCustomerCipherSuites(func() []dtls.CipherSuite { + called = true + + return nil + }) + + assert.True(t, se.dtls.insecureSkipHelloVerify) + assert.True(t, se.dtls.disableInsecureSkipVerify) + assert.Equal(t, dtls.RequireExtendedMasterSecret, se.dtls.extendedMasterSecret) + assert.NotNil(t, se.dtls.clientAuth) + assert.Equal(t, auth, *se.dtls.clientAuth) + assert.Equal(t, clientCAs, se.dtls.clientCAs) + assert.Equal(t, rootCAs, se.dtls.rootCAs) + _, _ = se.dtls.keyLogWriter.Write([]byte("test")) + assert.NotZero(t, keyBuf.Len()) + _ = se.dtls.customCipherSuites() + assert.True(t, called) +} + +func TestSettingEngine_SCTPSetters(t *testing.T) { + var se SettingEngine + + se.EnableSCTPZeroChecksum(true) + se.SetSCTPMinCwnd(11) + se.SetSCTPFastRtxWnd(22) + se.SetSCTPCwndCAStep(33) + + assert.True(t, se.sctp.enableZeroChecksum) + assert.Equal(t, uint32(11), se.sctp.minCwnd) + assert.Equal(t, uint32(22), se.sctp.fastRtxWnd) + assert.Equal(t, uint32(33), se.sctp.cwndCAStep) +} + +func TestSettingEngine_HandleUndeclaredSSRCWithoutAnswer(t *testing.T) { + var se SettingEngine + se.SetHandleUndeclaredSSRCWithoutAnswer(true) + assert.True(t, se.handleUndeclaredSSRCWithoutAnswer) +} diff --git a/signalingstate_test.go b/signalingstate_test.go index 1708d316..1ccc27fa 100644 --- a/signalingstate_test.go +++ b/signalingstate_test.go @@ -161,3 +161,12 @@ func TestSignalingState_Transitions(t *testing.T) { } } } + +func TestStateChangeOp_String_SetLocal(t *testing.T) { + assert.Equal(t, "SetLocal", stateChangeOpSetLocal.String()) +} + +func TestStateChangeOp_String_Default(t *testing.T) { + var unknown stateChangeOp = 999 + assert.Equal(t, "Unknown State Change Operation", unknown.String()) +} diff --git a/srtp_writer_future_test.go b/srtp_writer_future_test.go new file mode 100644 index 00000000..0a6f1625 --- /dev/null +++ b/srtp_writer_future_test.go @@ -0,0 +1,130 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js +// +build !js + +package webrtc + +import ( + "io" + "testing" + "time" + + "github.com/pion/rtp" + "github.com/pion/srtp/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newSWFStopClosed() *srtpWriterFuture { + stop := make(chan struct{}) + close(stop) + + tr := &DTLSTransport{ + srtpReady: make(chan struct{}), + } + sender := &RTPSender{ + stopCalled: stop, + transport: tr, + } + + return &srtpWriterFuture{ + ssrc: 1234, + rtpSender: sender, + } +} + +func newSWFReadyButNoSessions() *srtpWriterFuture { + tr := &DTLSTransport{ + srtpReady: make(chan struct{}), + } + close(tr.srtpReady) + + sender := &RTPSender{ + stopCalled: make(chan struct{}), + transport: tr, + } + + return &srtpWriterFuture{ + ssrc: 5678, + rtpSender: sender, + } +} + +func TestSRTPWriterFuture_Errors_WhenStopCalled(t *testing.T) { + swf := newSWFStopClosed() + + n, err := swf.WriteRTP(&rtp.Header{}, []byte("x")) + assert.Zero(t, n) + assert.ErrorIs(t, err, io.ErrClosedPipe) + + n, err = swf.Write([]byte("x")) + assert.Zero(t, n) + assert.ErrorIs(t, err, io.ErrClosedPipe) + + buf := make([]byte, 1) + n, err = swf.Read(buf) + assert.Zero(t, n) + assert.ErrorIs(t, err, io.ErrClosedPipe) + + err = swf.SetReadDeadline(time.Now()) + assert.ErrorIs(t, err, io.ErrClosedPipe) +} + +func TestSRTPWriterFuture_Errors_WhenClosedFlagSet(t *testing.T) { + tr := &DTLSTransport{srtpReady: make(chan struct{})} + close(tr.srtpReady) + + sender := &RTPSender{ + stopCalled: make(chan struct{}), + transport: tr, + } + + swf := &srtpWriterFuture{ + ssrc: 42, + rtpSender: sender, + closed: true, + } + + _, err := swf.WriteRTP(&rtp.Header{}, nil) + assert.ErrorIs(t, err, io.ErrClosedPipe) + + _, err = swf.Read(make([]byte, 1)) + assert.ErrorIs(t, err, io.ErrClosedPipe) + + err = swf.SetReadDeadline(time.Now()) + assert.ErrorIs(t, err, io.ErrClosedPipe) + + _, err = swf.Write(nil) + assert.ErrorIs(t, err, io.ErrClosedPipe) +} + +func TestSRTPWriterFuture_Errors_WhenSessionsUnavailable(t *testing.T) { + swf := newSWFReadyButNoSessions() + + n, err := swf.WriteRTP(&rtp.Header{}, nil) + assert.Zero(t, n) + require.Error(t, err) + + n, err = swf.Write([]byte("data")) + assert.Zero(t, n) + require.Error(t, err) + + n, err = swf.Read(make([]byte, 1)) + assert.Zero(t, n) + require.Error(t, err) + + err = swf.SetReadDeadline(time.Now()) + require.Error(t, err) +} + +func TestSRTPWriterFuture_Close_AlreadyClosed(t *testing.T) { + s := &srtpWriterFuture{ + closed: true, + } + s.rtcpReadStream.Store(&srtp.ReadStreamSRTCP{}) + + err := s.Close() + assert.NoError(t, err, "Close on an already-closed srtpWriterFuture should return nil") +} diff --git a/stats.go b/stats.go index 6302f1a1..5f85c146 100644 --- a/stats.go +++ b/stats.go @@ -1326,7 +1326,7 @@ type AudioPlayoutStats struct { // SynthesizedSamplesDuration is measured in seconds and is incremented each time an audio sample is synthesized by // this playout path. This metric can be used together with totalSamplesDuration to calculate the percentage of played // out media being synthesized. If the playout path is unable to produce audio samples on time for device playout, - // samples are synthesized to be playout out instead. Synthesization typically only happens if the pipeline is + // samples are synthesized to be played out instead. Synthesization typically only happens if the pipeline is // underperforming. Samples synthesized by the RTCInboundRtpStreamStats are not counted for here, but in // InboundRtpStreamStats.concealedSamples. SynthesizedSamplesDuration float64 `json:"synthesizedSamplesDuration"` @@ -1746,7 +1746,7 @@ type AudioReceiverStats struct { // 0 represents silence, and 0.5 represents approximately 6 dBSPL change in // the sound pressure level from 0 dBov. // - // If the track is sourced from an Receiver, does no audio processing, has a + // If the track is sourced from a Receiver, does no audio processing, has a // constant level, and has a volume setting of 1.0, the audio level is expected // to be the same as the audio level of the source SSRC, while if the volume setting // is 0.5, the AudioLevel is expected to be half that value. @@ -1871,11 +1871,11 @@ type VideoReceiverStats struct { FramesReceived uint32 `json:"framesReceived"` // KeyFramesReceived represents the total number of complete key frames received - // for this MediaStreamTrack, such as Infra-frames in VP8 [RFC6386] or I-frames + // for this MediaStreamTrack, such as Intra-frames in VP8 [RFC6386] or I-frames // in H.264 [RFC6184]. This is a subset of framesReceived. `framesReceived - keyFramesReceived` // gives you the number of delta frames received. This metric is incremented when // the complete key frame is received. It is not incremented if a partial key - // frames is received and sent for decoding, i.e., the frame could not be recovered + // frame is received and sent for decoding, i.e., the frame could not be recovered // via retransmission or FEC. KeyFramesReceived uint32 `json:"keyFramesReceived"` @@ -1982,7 +1982,7 @@ type TransportStats struct { // Present only if DTLS is negotiated. LocalCertificateID string `json:"localCertificateId"` - // LocalCertificateID is the ID of the CertificateStats for the remote certificate. + // RemoteCertificateID is the ID of the CertificateStats for the remote certificate. // Present only if DTLS is negotiated. RemoteCertificateID string `json:"remoteCertificateId"` @@ -2247,7 +2247,7 @@ type ICECandidatePairStats struct { // STUN binding response expired. ConsentExpiredTimestamp StatsTimestamp `json:"consentExpiredTimestamp"` - // PacketsDiscardedOnSend retpresents the total number of packets for this candidate pair + // PacketsDiscardedOnSend represents the total number of packets for this candidate pair // that have been discarded due to socket errors, i.e. a socket error occurred // when handing the packets to the socket. This might happen due to various reasons, // including full buffer or no available memory. @@ -2321,8 +2321,8 @@ type ICECandidateStats struct { // Priority is the "Priority" field of the ICECandidate. Priority int32 `json:"priority"` - // URL is the URL of the TURN or STUN server indicated in the that translated - // this IP address. It is the URL address surfaced in an PeerConnectionICEEvent. + // URL of the TURN or STUN server that produced this candidate + // It is the URL address surfaced in an PeerConnectionICEEvent. URL string `json:"url"` // RelayProtocol is the protocol used by the endpoint to communicate with the diff --git a/stats_go_test.go b/stats_go_test.go index 993a93ba..fd4d221b 100644 --- a/stats_go_test.go +++ b/stats_go_test.go @@ -8,6 +8,7 @@ package webrtc import ( "encoding/json" + "errors" "fmt" "sync" "testing" @@ -1503,3 +1504,701 @@ func TestPeerConnection_GetStats_Closed(t *testing.T) { pc.GetStats() } + +func TestUnmarshalStatsJSON_TypeFieldUnmarshalError(t *testing.T) { + input := []byte(`{"type":123}`) + + _, err := UnmarshalStatsJSON(input) + require.Error(t, err) + assert.Contains(t, err.Error(), "unmarshal json type:") +} + +func TestUnmarshalStatsJSON_SCTPTransport(t *testing.T) { + input := []byte(`{ + "timestamp": 1689668364374.479, + "type": "sctp-transport", + "id": "SCTP1", + "transportId": "T01", + "smoothedRoundTripTime": 0.123, + "congestionWindow": 512, + "receiverWindow": 2048, + "mtu": 1200, + "unackData": 7, + "bytesSent": 12345, + "bytesReceived": 67890 + }`) + + s, err := UnmarshalStatsJSON(input) + require.NoError(t, err) + + st, ok := s.(SCTPTransportStats) + require.True(t, ok, "expected SCTPTransportStats") + assert.Equal(t, StatsTypeSCTPTransport, st.Type) + assert.Equal(t, "SCTP1", st.ID) + assert.Equal(t, "T01", st.TransportID) + assert.InDelta(t, 0.123, st.SmoothedRoundTripTime, 1e-9) + assert.EqualValues(t, 512, st.CongestionWindow) + assert.EqualValues(t, 2048, st.ReceiverWindow) + assert.EqualValues(t, 1200, st.MTU) + assert.EqualValues(t, 7, st.UNACKData) + assert.EqualValues(t, 12345, st.BytesSent) + assert.EqualValues(t, 67890, st.BytesReceived) +} + +func TestUnmarshalStatsJSON_UnknownType(t *testing.T) { + input := []byte(`{"type":"def-not-a-real-type"}`) + + _, err := UnmarshalStatsJSON(input) + require.Error(t, err) + assert.ErrorIs(t, err, ErrUnknownType) +} + +func TestUnmarshalCodecStats_ErrorWrap(t *testing.T) { + bad := []byte(`{"payloadType":"not-a-number"}`) + + _, err := unmarshalCodecStats(bad) + require.Error(t, err) + + assert.ErrorContains(t, err, "unmarshal codec stats:") + + var ute *json.UnmarshalTypeError + assert.True(t, errors.As(err, &ute), "expected underlying error to be *json.UnmarshalTypeError") +} + +func TestUnmarshalInboundRTPStreamStats_ErrorWrap(t *testing.T) { + bad := []byte(`{"packetsReceived":"not-a-number"}`) + + _, err := unmarshalInboundRTPStreamStats(bad) + require.Error(t, err) + + assert.ErrorContains(t, err, "unmarshal inbound rtp stream stats:") + + var ute *json.UnmarshalTypeError + assert.True(t, errors.As(err, &ute), "expected underlying error to be *json.UnmarshalTypeError") +} + +func TestUnmarshalOutboundRTPStreamStats_ErrorWrap(t *testing.T) { + bad := []byte(`{"packetsSent":"oops"}`) + + _, err := unmarshalOutboundRTPStreamStats(bad) + require.Error(t, err) + + assert.ErrorContains(t, err, "unmarshal outbound rtp stream stats:") + + var ute *json.UnmarshalTypeError + assert.True(t, errors.As(err, &ute), "expected underlying error to be *json.UnmarshalTypeError") +} + +func TestUnmarshalRemoteInboundRTPStreamStats_ErrorWrap(t *testing.T) { + bad := []byte(`{"packetsReceived":"nope"}`) + + _, err := unmarshalRemoteInboundRTPStreamStats(bad) + require.Error(t, err) + + assert.ErrorContains(t, err, "unmarshal remote inbound rtp stream stats:") + + var ute *json.UnmarshalTypeError + assert.True(t, errors.As(err, &ute), "expected underlying error to be *json.UnmarshalTypeError") +} + +func TestUnmarshalRemoteOutboundRTPStreamStats_ErrorWrap(t *testing.T) { + bad := []byte(`{"packetsSent":"nope"}`) + + _, err := unmarshalRemoteOutboundRTPStreamStats(bad) + require.Error(t, err) + + assert.ErrorContains(t, err, "unmarshal remote outbound rtp stream stats:") + + var ute *json.UnmarshalTypeError + assert.True(t, errors.As(err, &ute), "expected underlying error to be *json.UnmarshalTypeError") +} + +func TestUnmarshalCSRCStats_ErrorWrap(t *testing.T) { + bad := []byte(`{"packetsContributedTo":"nope"}`) + + _, err := unmarshalCSRCStats(bad) + require.Error(t, err) + + assert.ErrorContains(t, err, "unmarshal csrc stats:") + + var ute *json.UnmarshalTypeError + assert.True(t, errors.As(err, &ute), "expected underlying error to be *json.UnmarshalTypeError") +} + +func TestUnmarshalMediaSourceStats_ErrorPaths(t *testing.T) { + t.Run("error unmarshalling kind holder", func(t *testing.T) { + bad := []byte(`{"kind":123}`) + _, err := unmarshalMediaSourceStats(bad) + require.Error(t, err) + assert.ErrorContains(t, err, "unmarshal json kind:") + + var ute *json.UnmarshalTypeError + assert.True(t, errors.As(err, &ute), "expected underlying *json.UnmarshalTypeError") + }) + + t.Run("error unmarshalling audio source stats", func(t *testing.T) { + bad := []byte(`{"type":"media-source","kind":"audio","audioLevel":"oops"}`) + _, err := unmarshalMediaSourceStats(bad) + require.Error(t, err) + assert.ErrorContains(t, err, "unmarshal audio source stats:") + + var ute *json.UnmarshalTypeError + assert.True(t, errors.As(err, &ute), "expected underlying *json.UnmarshalTypeError") + }) + + t.Run("error unmarshalling video source stats", func(t *testing.T) { + bad := []byte(`{"type":"media-source","kind":"video","width":"oops"}`) + _, err := unmarshalMediaSourceStats(bad) + require.Error(t, err) + assert.ErrorContains(t, err, "unmarshal video source stats:") + + var ute *json.UnmarshalTypeError + assert.True(t, errors.As(err, &ute), "expected underlying *json.UnmarshalTypeError") + }) + + t.Run("unknown kind default case", func(t *testing.T) { + bad := []byte(`{"type":"media-source","kind":"banana"}`) + _, err := unmarshalMediaSourceStats(bad) + require.Error(t, err) + assert.ErrorContains(t, err, "kind:") + assert.True(t, errors.Is(err, ErrUnknownType), "expected ErrUnknownType") + }) +} + +func TestUnmarshalMediaPlayoutStats_Error(t *testing.T) { + badJSON := []byte(`{ + "type": "media-playout", + "id": "AP", + "kind": "audio", + "timestamp": "not-a-number" + }`) + + s, err := unmarshalMediaPlayoutStats(badJSON) + require.Error(t, err) + assert.Nil(t, s) + assert.Contains(t, err.Error(), "unmarshal audio playout stats") +} + +func TestUnmarshalPeerConnectionStats_Error(t *testing.T) { + bad := []byte(`{ + "type": "peer-connection", + "id": "P", + "timestamp": "not-a-number" + }`) + + got, err := unmarshalPeerConnectionStats(bad) + require.Error(t, err) + assert.Equal(t, PeerConnectionStats{}, got, "should return zero value on error") + assert.Contains(t, err.Error(), "unmarshal pc stats") +} + +func TestUnmarshalDataChannelStats_Error(t *testing.T) { + bad := []byte(`{ + "type": "data-channel", + "id": "D1", + "timestamp": "not-a-number" + }`) + + got, err := unmarshalDataChannelStats(bad) + require.Error(t, err) + assert.Equal(t, DataChannelStats{}, got, "should return zero value on error") + assert.Contains(t, err.Error(), "unmarshal data channel stats") +} + +func TestUnmarshalStreamStats_Error(t *testing.T) { + bad := []byte(`{ + "type": "stream", + "id": "S1", + "timestamp": "invalid" + }`) + + got, err := unmarshalStreamStats(bad) + require.Error(t, err) + assert.Equal(t, MediaStreamStats{}, got, "expected zero value on error") + assert.Contains(t, err.Error(), "unmarshal stream stats") +} + +func TestUnmarshalSenderStats_SyntaxErrorOnKind(t *testing.T) { + s, err := unmarshalSenderStats([]byte(`{`)) + require.Error(t, err) + assert.Nil(t, s) + + var se *json.SyntaxError + assert.ErrorAs(t, err, &se) +} + +func TestUnmarshalSenderStats_Audio_UnmarshalTypeError(t *testing.T) { + payload := []byte(`{"kind":"audio","timestamp":"oops"}`) + s, err := unmarshalSenderStats(payload) + require.Error(t, err) + assert.Nil(t, s) + + var ute *json.UnmarshalTypeError + assert.ErrorAs(t, err, &ute) +} + +func TestUnmarshalSenderStats_Video_UnmarshalTypeError(t *testing.T) { + payload := []byte(`{"kind":"video","timestamp":"oops"}`) + s, err := unmarshalSenderStats(payload) + require.Error(t, err) + assert.Nil(t, s) + + var ute *json.UnmarshalTypeError + assert.ErrorAs(t, err, &ute) +} + +func TestUnmarshalSenderStats_UnknownKind(t *testing.T) { + s, err := unmarshalSenderStats([]byte(`{"kind":"def-not-a-real-kind"}`)) + require.Error(t, err) + assert.Nil(t, s) + assert.ErrorIs(t, err, ErrUnknownType) +} + +func TestUnmarshalTrackStats_SyntaxErrorOnKind(t *testing.T) { + s, err := unmarshalTrackStats([]byte(`{`)) // invalid JSON + require.Error(t, err) + assert.Nil(t, s) + + var se *json.SyntaxError + assert.ErrorAs(t, err, &se) +} + +func TestUnmarshalTrackStats_Audio_UnmarshalTypeError(t *testing.T) { + payload := []byte(`{"kind":"` + string(MediaKindAudio) + `","timestamp":"oops"}`) + s, err := unmarshalTrackStats(payload) + require.Error(t, err) + assert.Nil(t, s) + + var ute *json.UnmarshalTypeError + assert.ErrorAs(t, err, &ute) +} + +func TestUnmarshalTrackStats_Video_UnmarshalTypeError(t *testing.T) { + payload := []byte(`{"kind":"` + string(MediaKindVideo) + `","timestamp":"oops"}`) + s, err := unmarshalTrackStats(payload) + require.Error(t, err) + assert.Nil(t, s) + + var ute *json.UnmarshalTypeError + assert.ErrorAs(t, err, &ute) +} + +func TestUnmarshalTrackStats_UnknownKind(t *testing.T) { + s, err := unmarshalTrackStats([]byte(`{"kind":"definitely-not-real"}`)) + require.Error(t, err) + assert.Nil(t, s) + assert.ErrorIs(t, err, ErrUnknownType) +} + +func TestUnmarshalReceiverStats_SyntaxErrorOnKind(t *testing.T) { + s, err := unmarshalReceiverStats([]byte(`{`)) // invalid JSON + require.Error(t, err) + assert.Nil(t, s) + + var se *json.SyntaxError + assert.ErrorAs(t, err, &se) +} + +func TestUnmarshalReceiverStats_Audio_UnmarshalTypeError(t *testing.T) { + payload := []byte(`{"kind":"` + string(MediaKindAudio) + `","timestamp":"oops"}`) + s, err := unmarshalReceiverStats(payload) + require.Error(t, err) + assert.Nil(t, s) + + var ute *json.UnmarshalTypeError + assert.ErrorAs(t, err, &ute) +} + +func TestUnmarshalReceiverStats_Video_UnmarshalTypeError(t *testing.T) { + payload := []byte(`{"kind":"` + string(MediaKindVideo) + `","timestamp":"oops"}`) + s, err := unmarshalReceiverStats(payload) + require.Error(t, err) + assert.Nil(t, s) + + var ute *json.UnmarshalTypeError + assert.ErrorAs(t, err, &ute) +} + +func TestUnmarshalReceiverStats_UnknownKind(t *testing.T) { + s, err := unmarshalReceiverStats([]byte(`{"kind":"not-a-real-kind"}`)) + require.Error(t, err) + assert.Nil(t, s) + assert.ErrorIs(t, err, ErrUnknownType) +} + +func TestUnmarshalTransportStats_Error(t *testing.T) { + payload := []byte(`{"timestamp":"oops"}`) + + s, err := unmarshalTransportStats(payload) + require.Error(t, err) + assert.Equal(t, TransportStats{}, s) + assert.Contains(t, err.Error(), "unmarshal transport stats:") + + var ute *json.UnmarshalTypeError + assert.ErrorAs(t, err, &ute) +} + +func TestToICECandidatePairStats_InvalidState(t *testing.T) { + bogus := ice.CandidatePairState(255) + + in := ice.CandidatePairStats{ + State: bogus, + } + + out, err := toICECandidatePairStats(in) + require.Error(t, err) + assert.Equal(t, ICECandidatePairStats{}, out) + + assert.Contains(t, err.Error(), bogus.String()) +} + +func TestUnmarshalICECandidatePairStats_Error(t *testing.T) { + bad := []byte(`{"timestamp":"not-a-number"}`) + + got, err := unmarshalICECandidatePairStats(bad) + require.Error(t, err) + assert.Equal(t, ICECandidatePairStats{}, got) + + assert.Contains(t, err.Error(), "unmarshal ice candidate pair stats") + + var ute *json.UnmarshalTypeError + assert.ErrorAs(t, err, &ute) +} + +func TestUnmarshalICECandidateStats_Error(t *testing.T) { + bad := []byte(`{"timestamp":"not-a-number"}`) + + got, err := unmarshalICECandidateStats(bad) + require.Error(t, err) + assert.Equal(t, ICECandidateStats{}, got) + + assert.Contains(t, err.Error(), "unmarshal ice candidate stats") + + var ute *json.UnmarshalTypeError + assert.ErrorAs(t, err, &ute) +} + +func TestUnmarshalCertificateStats_Error(t *testing.T) { + bad := []byte(`{"timestamp":"not-a-number"}`) + + got, err := unmarshalCertificateStats(bad) + require.Error(t, err) + assert.Equal(t, CertificateStats{}, got) + + assert.Contains(t, err.Error(), "unmarshal certificate stats") + + var ute *json.UnmarshalTypeError + assert.ErrorAs(t, err, &ute) +} + +func TestUnmarshalSCTPTransportStats_Success(t *testing.T) { + good := []byte(`{ + "timestamp": 1234, + "type": "sctp-transport", + "id": "SCTP1", + "transportId": "T01", + "smoothedRoundTripTime": 0.123, + "congestionWindow": 512, + "receiverWindow": 1024, + "mtu": 1200, + "unackData": 3, + "bytesSent": 1000, + "bytesReceived": 2000 + }`) + + got, err := unmarshalSCTPTransportStats(good) + require.NoError(t, err) + + assert.Equal(t, StatsTimestamp(1234), got.Timestamp) + assert.Equal(t, StatsTypeSCTPTransport, got.Type) + assert.Equal(t, "SCTP1", got.ID) + assert.Equal(t, "T01", got.TransportID) + assert.InDelta(t, 0.123, got.SmoothedRoundTripTime, 1e-9) + assert.Equal(t, uint32(512), got.CongestionWindow) + assert.Equal(t, uint32(1024), got.ReceiverWindow) + assert.Equal(t, uint32(1200), got.MTU) + assert.Equal(t, uint32(3), got.UNACKData) + assert.Equal(t, uint64(1000), got.BytesSent) + assert.Equal(t, uint64(2000), got.BytesReceived) +} + +func TestUnmarshalSCTPTransportStats_Error(t *testing.T) { + bad := []byte(`{"bytesReceived":"oops"}`) + + got, err := unmarshalSCTPTransportStats(bad) + require.Error(t, err) + assert.Equal(t, SCTPTransportStats{}, got) + + assert.Contains(t, err.Error(), "unmarshal sctp transport stats") + + var ute *json.UnmarshalTypeError + assert.ErrorAs(t, err, &ute) +} + +func TestStatsReport_GetConnectionStats_MissingEntry(t *testing.T) { + conn := &PeerConnection{} + conn.getStatsID() + + r := StatsReport{} + got, ok := r.GetConnectionStats(conn) + + assert.False(t, ok) + assert.Equal(t, PeerConnectionStats{}, got) +} + +func TestStatsReport_GetConnectionStats_WrongType(t *testing.T) { + conn := &PeerConnection{} + id := conn.getStatsID() + + r := StatsReport{ + id: DataChannelStats{ID: "not-a-pc-stats"}, + } + + got, ok := r.GetConnectionStats(conn) + + assert.False(t, ok) + assert.Equal(t, PeerConnectionStats{}, got) +} + +func TestStatsReport_GetConnectionStats_Success(t *testing.T) { + conn := &PeerConnection{} + id := conn.getStatsID() + + want := PeerConnectionStats{ + ID: id, + Type: StatsTypePeerConnection, + Timestamp: 1234, + } + + r := StatsReport{ + id: want, + } + + got, ok := r.GetConnectionStats(conn) + + require.True(t, ok) + assert.Equal(t, want, got) +} + +func TestStatsReport_GetDataChannelStats_MissingEntry(t *testing.T) { + dc := &DataChannel{} + dc.getStatsID() + + r := StatsReport{} // empty -> triggers first `if !ok` + got, ok := r.GetDataChannelStats(dc) + + assert.False(t, ok) + assert.Equal(t, DataChannelStats{}, got) +} + +func TestStatsReport_GetDataChannelStats_WrongType(t *testing.T) { + dc := &DataChannel{} + id := dc.getStatsID() + + // Put a different Stats type under the correct key to fail type assertion + r := StatsReport{ + id: PeerConnectionStats{ID: "not-a-dc-stats"}, + } + + got, ok := r.GetDataChannelStats(dc) + + assert.False(t, ok) // triggers second `if !ok` (type assertion fails) + assert.Equal(t, DataChannelStats{}, got) // zero value on failure +} + +func TestStatsReport_GetDataChannelStats_Success(t *testing.T) { + dc := &DataChannel{} + id := dc.getStatsID() + + want := DataChannelStats{ + ID: id, + Type: StatsTypeDataChannel, + Timestamp: 1234, + Label: "chat", + Protocol: "json", + DataChannelIdentifier: 7, + TransportID: "T1", + State: DataChannelStateOpen, + MessagesSent: 10, + BytesSent: 100, + MessagesReceived: 12, + BytesReceived: 120, + } + + r := StatsReport{ + id: want, + } + + got, ok := r.GetDataChannelStats(dc) + + require.True(t, ok) + assert.Equal(t, want, got) +} + +func TestStatsReport_GetICECandidateStats_MissingEntry(t *testing.T) { + c := &ICECandidate{statsID: "C1"} + r := StatsReport{} + + got, ok := r.GetICECandidateStats(c) + + assert.False(t, ok) + assert.Equal(t, ICECandidateStats{}, got) +} + +func TestStatsReport_GetICECandidateStats_WrongType(t *testing.T) { + c := &ICECandidate{statsID: "C2"} + + r := StatsReport{ + "C2": PeerConnectionStats{ID: "not-candidate"}, + } + + got, ok := r.GetICECandidateStats(c) + + assert.False(t, ok) + assert.Equal(t, ICECandidateStats{}, got) +} + +func TestStatsReport_GetICECandidateStats_Success(t *testing.T) { + statsID := "C3" + c := &ICECandidate{statsID: statsID} + + want := ICECandidateStats{ + ID: statsID, + Type: StatsTypeLocalCandidate, + } + + r := StatsReport{ + statsID: want, + } + + got, ok := r.GetICECandidateStats(c) + + require.True(t, ok) + assert.Equal(t, want, got) +} + +func TestStatsReport_GetICECandidatePairStats_MissingEntry(t *testing.T) { + pair := &ICECandidatePair{statsID: "CP1"} + r := StatsReport{} + + got, ok := r.GetICECandidatePairStats(pair) + + assert.False(t, ok) + assert.Equal(t, ICECandidatePairStats{}, got) +} + +func TestStatsReport_GetICECandidatePairStats_WrongType(t *testing.T) { + pair := &ICECandidatePair{statsID: "CP2"} + + r := StatsReport{ + "CP2": PeerConnectionStats{ID: "not-candidate-pair"}, + } + + got, ok := r.GetICECandidatePairStats(pair) + + assert.False(t, ok) + assert.Equal(t, ICECandidatePairStats{}, got) +} + +func TestStatsReport_GetICECandidatePairStats_Success(t *testing.T) { + statsID := "CP3" + pair := &ICECandidatePair{statsID: statsID} + + want := ICECandidatePairStats{ + ID: statsID, + Type: StatsTypeCandidatePair, + } + + r := StatsReport{ + statsID: want, + } + + got, ok := r.GetICECandidatePairStats(pair) + + require.True(t, ok) + assert.Equal(t, want, got) +} + +func TestStatsReport_GetCertificateStats_MissingEntry(t *testing.T) { + cert := &Certificate{statsID: "CERT1"} + r := StatsReport{} + + got, ok := r.GetCertificateStats(cert) + + assert.False(t, ok) + assert.Equal(t, CertificateStats{}, got) +} + +func TestStatsReport_GetCertificateStats_WrongType(t *testing.T) { + cert := &Certificate{statsID: "CERT2"} + + r := StatsReport{ + "CERT2": PeerConnectionStats{ID: "not-certificate"}, + } + + got, ok := r.GetCertificateStats(cert) + + assert.False(t, ok) + assert.Equal(t, CertificateStats{}, got) +} + +func TestStatsReport_GetCertificateStats_Success(t *testing.T) { + statsID := "CERT3" + cert := &Certificate{statsID: statsID} + + want := CertificateStats{ + ID: statsID, + Type: StatsTypeCertificate, + } + + r := StatsReport{ + statsID: want, + } + + got, ok := r.GetCertificateStats(cert) + + require.True(t, ok) + assert.Equal(t, want, got) +} + +func TestStatsReport_GetCodecStats_MissingEntry(t *testing.T) { + codec := &RTPCodecParameters{statsID: "CODEC1"} + r := StatsReport{} + + got, ok := r.GetCodecStats(codec) + + assert.False(t, ok) + assert.Equal(t, CodecStats{}, got) +} + +func TestStatsReport_GetCodecStats_WrongType(t *testing.T) { + codec := &RTPCodecParameters{statsID: "CODEC2"} + + r := StatsReport{ + "CODEC2": PeerConnectionStats{ID: "not-codec"}, + } + + got, ok := r.GetCodecStats(codec) + + assert.False(t, ok) + assert.Equal(t, CodecStats{}, got) +} + +func TestStatsReport_GetCodecStats_Success(t *testing.T) { + statsID := "CODEC3" + codec := &RTPCodecParameters{statsID: statsID} + + want := CodecStats{ + ID: statsID, + Type: StatsTypeCodec, + } + + r := StatsReport{ + statsID: want, + } + + got, ok := r.GetCodecStats(codec) + + require.True(t, ok) + assert.Equal(t, want, got) +} diff --git a/track_local_static_test.go b/track_local_static_test.go index b5c70c51..093d1770 100644 --- a/track_local_static_test.go +++ b/track_local_static_test.go @@ -13,8 +13,10 @@ import ( "testing" "time" + "github.com/pion/interceptor" "github.com/pion/rtp" "github.com/pion/transport/v3/test" + "github.com/pion/webrtc/v4/pkg/media" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -504,3 +506,357 @@ func Test_TrackLocalStatic_Timestamp(t *testing.T) { <-onTrackFired.Done() closePairNow(t, pcOffer, pcAnswer) } + +type dummyWriter struct{} + +func (dummyWriter) WriteRTP(_ *rtp.Header, _ []byte) (int, error) { return 0, nil } +func (dummyWriter) Write(_ []byte) (int, error) { return 0, nil } + +type dummyTrackLocalContext struct { + id string +} + +func (d dummyTrackLocalContext) ID() string { return d.id } +func (d dummyTrackLocalContext) SSRC() SSRC { return 0 } +func (d dummyTrackLocalContext) SSRCRetransmission() SSRC { return 0 } +func (d dummyTrackLocalContext) SSRCForwardErrorCorrection() SSRC { return 0 } +func (d dummyTrackLocalContext) WriteStream() TrackLocalWriter { return dummyWriter{} } +func (d dummyTrackLocalContext) HeaderExtensions() []RTPHeaderExtensionParameter { return nil } +func (d dummyTrackLocalContext) RTCPReader() interceptor.RTCPReader { return nil } +func (d dummyTrackLocalContext) CodecParameters() []RTPCodecParameters { + return []RTPCodecParameters{{ + RTPCodecCapability: RTPCodecCapability{ + MimeType: MimeTypeVP8, + ClockRate: 90000, + }, + PayloadType: 96, + }} +} + +func Test_TrackLocalStaticRTP_Unbind_ErrUnbindFailed(t *testing.T) { + track, err := NewTrackLocalStaticRTP( + RTPCodecCapability{MimeType: MimeTypeVP8}, + "video", + "pion", + ) + require.NoError(t, err) + + ctx := dummyTrackLocalContext{id: "nonexistent-id"} + + err = track.Unbind(ctx) + require.ErrorIs(t, err, ErrUnbindFailed) +} + +func Test_TrackLocalStaticRTP_Kind_Default(t *testing.T) { + track, err := NewTrackLocalStaticRTP( + RTPCodecCapability{MimeType: "application/unknown"}, + "id", + "stream", + ) + require.NoError(t, err) + + require.Equal(t, RTPCodecType(0), track.Kind()) +} + +func Test_TrackLocalStaticRTP_Codec_ReturnsConfiguredCodec(t *testing.T) { + testCapability := RTPCodecCapability{ + MimeType: MimeTypeVP8, + ClockRate: 90000, + Channels: 0, + SDPFmtpLine: "profile-id=0", + RTCPFeedback: []RTCPFeedback{{Type: "nack"}, {Type: "ccm", Parameter: "fir"}}, + } + + track, err := NewTrackLocalStaticRTP(testCapability, "video", "pion") + require.NoError(t, err) + + got := track.Codec() + require.Equal(t, testCapability, got) +} + +var errWriteBoom = errors.New("fake write failure") + +type errWriter struct{} + +func (errWriter) WriteRTP(_ *rtp.Header, _ []byte) (int, error) { return 0, errWriteBoom } +func (errWriter) Write(_ []byte) (int, error) { return 0, nil } + +func Test_TrackLocalStaticRTP_writeRTP_ReturnsError(t *testing.T) { + track, err := NewTrackLocalStaticRTP( + RTPCodecCapability{MimeType: MimeTypeVP8}, + "id", + "stream", + ) + require.NoError(t, err) + + track.mu.Lock() + track.bindings = []trackBinding{{ + id: "b1", + ssrc: 0x1234, + payloadType: 96, + writeStream: errWriter{}, + }} + track.mu.Unlock() + + pkt := &rtp.Packet{Payload: []byte{0x01, 0x02, 0x03}} + + err = track.writeRTP(pkt) + require.Error(t, err) + require.Contains(t, err.Error(), errWriteBoom.Error()) +} + +func Test_TrackLocalStaticRTP_Write_UnmarshalError(t *testing.T) { + track, err := NewTrackLocalStaticRTP( + RTPCodecCapability{MimeType: MimeTypeVP8}, + "id", + "stream", + ) + require.NoError(t, err) + + n, werr := track.Write([]byte{0x80}) // < 12-byte RTP header + require.Error(t, werr) + require.Equal(t, 0, n) +} + +func Test_TrackLocalStaticSample_Codec_ReturnsConfiguredCodec(t *testing.T) { + testCapability := RTPCodecCapability{ + MimeType: MimeTypeVP8, + ClockRate: 90000, + Channels: 0, + SDPFmtpLine: "profile-id=0", + RTCPFeedback: []RTCPFeedback{{Type: "nack"}, {Type: "ccm", Parameter: "fir"}}, + } + + sample, err := NewTrackLocalStaticSample(testCapability, "video", "pion") + require.NoError(t, err) + + got := sample.Codec() + require.Equal(t, testCapability, got) +} + +var errPayloaderBoom = errors.New("payloader boom") + +func Test_TrackLocalStaticSample_Bind_PayloaderError(t *testing.T) { + sample, err := NewTrackLocalStaticSample( + RTPCodecCapability{MimeType: MimeTypeVP8, ClockRate: 90000}, + "video", + "pion", + ) + require.NoError(t, err) + + sample.rtpTrack.mu.Lock() + sample.rtpTrack.payloader = func(_ RTPCodecCapability) (rtp.Payloader, error) { + return nil, errPayloaderBoom + } + sample.rtpTrack.mu.Unlock() + + _, bindErr := sample.Bind(dummyTrackLocalContext{id: "ctx-1"}) + require.ErrorIs(t, bindErr, errPayloaderBoom) + + sample.rtpTrack.mu.RLock() + defer sample.rtpTrack.mu.RUnlock() + require.Nil(t, sample.packetizer) +} + +type fakePacketizer struct { + skipCalls int + lastSample uint32 + + packetizeCalls int +} + +func (f *fakePacketizer) SkipSamples(n uint32) { f.skipCalls++; f.lastSample = n } +func (f *fakePacketizer) GeneratePadding(samples uint32) []*rtp.Packet { + f.packetizeCalls++ + f.lastSample = samples + + return []*rtp.Packet{{}, {}} +} +func (f *fakePacketizer) EnableAbsSendTime(value int) {} +func (f *fakePacketizer) Packetize(_ []byte, _ uint32) []*rtp.Packet { + f.packetizeCalls++ + + return []*rtp.Packet{ + {Payload: []byte{0x01}}, + {Payload: []byte{0x02}}, + } +} + +func Test_TrackLocalStaticSample_WriteSample_AppendErrors(t *testing.T) { + testSample, err := NewTrackLocalStaticSample( + RTPCodecCapability{MimeType: MimeTypeVP8}, + "video", + "pion", + ) + require.NoError(t, err) + + testSample.rtpTrack.mu.Lock() + testSample.rtpTrack.bindings = []trackBinding{{ + id: "b1", + ssrc: 0x1234, + payloadType: 96, + writeStream: errWriter{}, + }} + testSample.rtpTrack.mu.Unlock() + + fp := &fakePacketizer{} + testSample.rtpTrack.mu.Lock() + testSample.packetizer = fp + testSample.sequencer = rtp.NewRandomSequencer() + testSample.clockRate = 48000 + testSample.rtpTrack.mu.Unlock() + + in := media.Sample{ + Data: []byte("hi"), + Duration: 20 * time.Millisecond, + PrevDroppedPackets: 3, + } + + err = testSample.WriteSample(in) + + require.Error(t, err) + require.Contains(t, err.Error(), errWriteBoom.Error()) + + require.Equal(t, 1, fp.skipCalls) + require.Equal(t, uint32(960*3), fp.lastSample) + + require.Equal(t, 1, fp.packetizeCalls) +} + +func Test_TrackLocalStaticSample_GeneratePadding_PacketizerNil_ReturnsNil(t *testing.T) { + s, err := NewTrackLocalStaticSample( + RTPCodecCapability{MimeType: MimeTypeVP8}, + "video", + "pion", + ) + require.NoError(t, err) + + err = s.GeneratePadding(10) + require.NoError(t, err) +} + +func Test_TrackLocalStaticSample_GeneratePadding_AppendsAndReturnsError(t *testing.T) { + testSample, err := NewTrackLocalStaticSample( + RTPCodecCapability{MimeType: MimeTypeVP8}, + "video", + "pion", + ) + require.NoError(t, err) + + testSample.rtpTrack.mu.Lock() + testSample.rtpTrack.bindings = []trackBinding{{ + id: "b1", + ssrc: 0x1234, + payloadType: 96, + writeStream: errWriter{}, + }} + + fp := &fakePacketizer{} + testSample.packetizer = fp + testSample.rtpTrack.mu.Unlock() + + err = testSample.GeneratePadding(7) + require.Error(t, err) + require.Contains(t, err.Error(), errWriteBoom.Error()) + + require.Equal(t, 1, fp.packetizeCalls) + require.Equal(t, uint32(7), fp.lastSample) +} + +func Test_TrackRemote_Msid(t *testing.T) { + t.Run("Populated", func(t *testing.T) { + tr := newTrackRemote(RTPCodecTypeVideo, 1234, 0, "", nil) + + tr.mu.Lock() + tr.id = "video" + tr.streamID = "desktop" + tr.mu.Unlock() + + require.Equal(t, "desktop video", tr.Msid()) + }) + + t.Run("Empty", func(t *testing.T) { + tr := newTrackRemote(RTPCodecTypeAudio, 0, 0, "", nil) + require.Equal(t, " ", tr.Msid()) + }) +} + +func Test_TrackRemote_checkAndUpdateTrack_ShortPacket(t *testing.T) { + tr := newTrackRemote(RTPCodecTypeVideo, 0, 0, "", &RTPReceiver{ + api: &API{mediaEngine: &MediaEngine{}}, + kind: RTPCodecTypeVideo, + }) + + err := tr.checkAndUpdateTrack([]byte{0x80}) + require.ErrorIs(t, err, errRTPTooShort) +} + +func Test_TrackRemote_checkAndUpdateTrack_CodecNotFound(t *testing.T) { + me := &MediaEngine{} // intentionally empty: no codecs registered. + api := &API{mediaEngine: me} + recv := &RTPReceiver{api: api, kind: RTPCodecTypeVideo} + tr := newTrackRemote(RTPCodecTypeVideo, 0, 0, "", recv) + + // minimal RTP header-sized buffer with a payload type byte. + b := []byte{0x80, 96} + + err := tr.checkAndUpdateTrack(b) + require.ErrorIs(t, err, ErrCodecNotFound) +} + +func Test_TrackRemote_ReadRTP_UnmarshalError(t *testing.T) { + me := &MediaEngine{} + require.NoError(t, me.RegisterCodec(RTPCodecParameters{ + RTPCodecCapability: RTPCodecCapability{ + MimeType: MimeTypeVP8, + ClockRate: 90000, + }, + PayloadType: 96, + }, RTPCodecTypeVideo)) + + api := &API{ + mediaEngine: me, + settingEngine: &SettingEngine{}, + } + + recv := &RTPReceiver{ + api: api, + kind: RTPCodecTypeVideo, + } + + tr := newTrackRemote(RTPCodecTypeVideo, 0, 0, "", recv) + + tr.mu.Lock() + tr.peeked = []byte{0x80, 96} + tr.peekedAttributes = nil + tr.mu.Unlock() + + pkt, attrs, err := tr.ReadRTP() + require.Error(t, err, "expected Unmarshal to fail on too-short RTP data") + require.Nil(t, pkt) + require.Nil(t, attrs) +} + +func TestBaseTrackLocalContext_HeaderExtensions_ReturnsParams(t *testing.T) { + hdrs := []RTPHeaderExtensionParameter{ + {URI: "urn:ietf:params:rtp-hdrext:sdes:mid", ID: 1}, + {URI: "urn:ietf:params:rtp-hdrext:sdes:rtp-stream-id", ID: 2}, + } + + ctx := baseTrackLocalContext{ + params: RTPParameters{ + HeaderExtensions: hdrs, + }, + } + + got := ctx.HeaderExtensions() + require.Equal(t, hdrs, got) + + got[0].URI = "changed" + assert.Equal(t, "changed", ctx.params.HeaderExtensions[0].URI) +} + +func TestBaseTrackLocalContext_HeaderExtensions_NilWhenUnset(t *testing.T) { + var ctx baseTrackLocalContext + assert.Nil(t, ctx.HeaderExtensions()) +}