From e6e7f11ee1255ad5c346e77b55d6b154f4abd3c1 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Fri, 9 Dec 2022 12:31:18 +0100 Subject: [PATCH] improve coverage --- client_publish_test.go | 178 +++++++++++++++++++++++++++++++++++++++++ client_read_test.go | 74 ++++++++++------- server_publish_test.go | 56 +++++++------ server_read_test.go | 151 ++++++++++++++++++++++++++++++++++ 4 files changed, 405 insertions(+), 54 deletions(-) diff --git a/client_publish_test.go b/client_publish_test.go index 80faae6d..70941713 100644 --- a/client_publish_test.go +++ b/client_publish_test.go @@ -1,6 +1,7 @@ package gortsplib import ( + "bytes" "crypto/tls" "net" "strings" @@ -838,6 +839,183 @@ func TestClientPublishAutomaticProtocol(t *testing.T) { require.NoError(t, err) } +func TestClientPublishDecodeErrors(t *testing.T) { + for _, ca := range []struct { + proto string + name string + }{ + {"udp", "rtcp invalid"}, + {"udp", "rtcp too big"}, + {"tcp", "rtcp too big"}, + } { + t.Run(ca.proto+" "+ca.name, func(t *testing.T) { + errorRecv := make(chan struct{}) + + l, err := net.Listen("tcp", "localhost:8554") + require.NoError(t, err) + defer l.Close() + + serverDone := make(chan struct{}) + defer func() { <-serverDone }() + go func() { + defer close(serverDone) + + nconn, err := l.Accept() + require.NoError(t, err) + defer nconn.Close() + conn := conn.NewConn(nconn) + + req, err := conn.ReadRequest() + require.NoError(t, err) + require.Equal(t, base.Options, req.Method) + + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Public": base.HeaderValue{strings.Join([]string{ + string(base.Announce), + string(base.Setup), + string(base.Record), + }, ", ")}, + }, + }) + require.NoError(t, err) + + req, err = conn.ReadRequest() + require.NoError(t, err) + require.Equal(t, base.Announce, req.Method) + + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + }) + require.NoError(t, err) + + req, err = conn.ReadRequest() + require.NoError(t, err) + require.Equal(t, base.Setup, req.Method) + + var inTH headers.Transport + err = inTH.Unmarshal(req.Header["Transport"]) + require.NoError(t, err) + + th := headers.Transport{ + Delivery: func() *headers.TransportDelivery { + v := headers.TransportDeliveryUnicast + return &v + }(), + } + + if ca.proto == "udp" { + th.Protocol = headers.TransportProtocolUDP + th.ClientPorts = inTH.ClientPorts + th.ServerPorts = &[2]int{34556, 34557} + } else { + th.Protocol = headers.TransportProtocolTCP + th.InterleavedIDs = inTH.InterleavedIDs + } + + var l1 net.PacketConn + var l2 net.PacketConn + + if ca.proto == "udp" { + l1, err = net.ListenPacket("udp", "127.0.0.1:34556") + require.NoError(t, err) + defer l1.Close() + + l2, err = net.ListenPacket("udp", "127.0.0.1:34557") + require.NoError(t, err) + defer l2.Close() + } + + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Transport": th.Marshal(), + }, + }) + require.NoError(t, err) + + req, err = conn.ReadRequest() + require.NoError(t, err) + require.Equal(t, base.Record, req.Method) + + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + }) + require.NoError(t, err) + + switch { + case ca.proto == "udp" && ca.name == "rtcp invalid": + l2.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: th.ClientPorts[1], + }) + + case ca.proto == "udp" && ca.name == "rtcp too big": + l2.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: th.ClientPorts[1], + }) + + case ca.proto == "tcp" && ca.name == "rtcp too big": + err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ + Channel: 1, + Payload: bytes.Repeat([]byte{0x01, 0x02}, 2000/2), + }, make([]byte, 2048)) + require.NoError(t, err) + } + + req, err = conn.ReadRequest() + require.NoError(t, err) + require.Equal(t, base.Teardown, req.Method) + + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + }) + require.NoError(t, err) + }() + + c := Client{ + Transport: func() *Transport { + if ca.proto == "udp" { + v := TransportUDP + return &v + } + v := TransportTCP + return &v + }(), + OnDecodeError: func(err error) { + switch { + case ca.proto == "udp" && ca.name == "rtcp invalid": + require.EqualError(t, err, "rtcp: packet too short") + + case ca.proto == "udp" && ca.name == "rtcp too big": + require.EqualError(t, err, "RTCP packet is too big to be read with UDP") + + case ca.proto == "tcp" && ca.name == "rtcp too big": + require.EqualError(t, err, "RTCP packet size (2000) is greater than maximum allowed (1472)") + } + close(errorRecv) + }, + } + + track := &TrackH264{ + PayloadType: 96, + SPS: []byte{0x01, 0x02, 0x03, 0x04}, + PPS: []byte{0x01, 0x02, 0x03, 0x04}, + PacketizationMode: 1, + } + + err = c.StartPublishing("rtsp://localhost:8554/stream", + Tracks{track}) + require.NoError(t, err) + defer c.Close() + + <-errorRecv + }) + } +} + func TestClientPublishRTCPReport(t *testing.T) { for _, ca := range []string{"udp", "tcp"} { t.Run(ca, func(t *testing.T) { diff --git a/client_read_test.go b/client_read_test.go index a994d728..b2cfa747 100644 --- a/client_read_test.go +++ b/client_read_test.go @@ -2730,15 +2730,19 @@ func TestClientReadDifferentSource(t *testing.T) { } func TestClientReadDecodeErrors(t *testing.T) { - for _, ca := range []string{ - "rtp invalid", - "rtcp invalid", - "rtp packets lost", - "rtp too big", - "rtcp too big", - "rtcp too big tcp", + for _, ca := range []struct { + proto string + name string + }{ + {"udp", "rtp invalid"}, + {"udp", "rtcp invalid"}, + {"udp", "rtp packets lost"}, + {"udp", "rtp too big"}, + {"udp", "rtcp too big"}, + {"tcp", "rtcp invalid"}, + {"tcp", "rtcp too big"}, } { - t.Run(ca, func(t *testing.T) { + t.Run(ca.proto+" "+ca.name, func(t *testing.T) { errorRecv := make(chan struct{}) l, err := net.Listen("tcp", "localhost:8554") @@ -2758,7 +2762,6 @@ func TestClientReadDecodeErrors(t *testing.T) { req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) - require.Equal(t, mustParseURL("rtsp://localhost:8554/stream"), req.URL) err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, @@ -2775,7 +2778,6 @@ func TestClientReadDecodeErrors(t *testing.T) { req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) - require.Equal(t, mustParseURL("rtsp://localhost:8554/stream"), req.URL) tracks := Tracks{&TrackGeneric{ Media: "application", @@ -2799,7 +2801,6 @@ func TestClientReadDecodeErrors(t *testing.T) { req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) - require.Equal(t, mustParseURL("rtsp://localhost:8554/stream/trackID=0"), req.URL) var inTH headers.Transport err = inTH.Unmarshal(req.Header["Transport"]) @@ -2812,7 +2813,7 @@ func TestClientReadDecodeErrors(t *testing.T) { }(), } - if ca != "rtcp too big tcp" { + if ca.proto == "udp" { th.Protocol = headers.TransportProtocolUDP th.ClientPorts = inTH.ClientPorts th.ServerPorts = &[2]int{34556, 34557} @@ -2824,7 +2825,7 @@ func TestClientReadDecodeErrors(t *testing.T) { var l1 net.PacketConn var l2 net.PacketConn - if ca != "rtcp too big tcp" { + if ca.proto == "udp" { l1, err = net.ListenPacket("udp", "127.0.0.1:34556") require.NoError(t, err) defer l1.Close() @@ -2845,27 +2846,26 @@ func TestClientReadDecodeErrors(t *testing.T) { req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Play, req.Method) - require.Equal(t, mustParseURL("rtsp://localhost:8554/stream/"), req.URL) err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, }) require.NoError(t, err) - switch ca { //nolint:dupl - case "rtp invalid": + switch { //nolint:dupl + case ca.proto == "udp" && ca.name == "rtp invalid": l1.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ClientPorts[0], }) - case "rtcp invalid": + case ca.proto == "udp" && ca.name == "rtcp invalid": l2.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ClientPorts[1], }) - case "rtp packets lost": + case ca.proto == "udp" && ca.name == "rtp packets lost": byts, _ := rtp.Packet{ Header: rtp.Header{ SequenceNumber: 30, @@ -2886,19 +2886,26 @@ func TestClientReadDecodeErrors(t *testing.T) { Port: th.ClientPorts[0], }) - case "rtp too big": + case ca.proto == "udp" && ca.name == "rtp too big": l1.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ClientPorts[0], }) - case "rtcp too big": + case ca.proto == "udp" && ca.name == "rtcp too big": l2.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ClientPorts[1], }) - case "rtcp too big tcp": + case ca.proto == "tcp" && ca.name == "rtcp invalid": + err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ + Channel: 1, + Payload: []byte{0x01, 0x02}, + }, make([]byte, 2048)) + require.NoError(t, err) + + case ca.proto == "tcp" && ca.name == "rtcp too big": err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ Channel: 1, Payload: bytes.Repeat([]byte{0x01, 0x02}, 2000/2), @@ -2909,7 +2916,6 @@ func TestClientReadDecodeErrors(t *testing.T) { req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) - require.Equal(t, mustParseURL("rtsp://localhost:8554/stream/"), req.URL) err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, @@ -2919,7 +2925,7 @@ func TestClientReadDecodeErrors(t *testing.T) { c := Client{ Transport: func() *Transport { - if ca != "rtcp too big tcp" { + if ca.proto == "udp" { v := TransportUDP return &v } @@ -2927,18 +2933,26 @@ func TestClientReadDecodeErrors(t *testing.T) { return &v }(), OnDecodeError: func(err error) { - switch ca { - case "rtp invalid": + switch { + case ca.proto == "udp" && ca.name == "rtp invalid": require.EqualError(t, err, "RTP header size insufficient: 2 < 4") - case "rtcp invalid": + + case ca.proto == "udp" && ca.name == "rtcp invalid": require.EqualError(t, err, "rtcp: packet too short") - case "rtp packets lost": + + case ca.proto == "udp" && ca.name == "rtp packets lost": require.EqualError(t, err, "69 RTP packet(s) lost") - case "rtp too big": + + case ca.proto == "udp" && ca.name == "rtp too big": require.EqualError(t, err, "RTP packet is too big to be read with UDP") - case "rtcp too big": + + case ca.proto == "udp" && ca.name == "rtcp too big": require.EqualError(t, err, "RTCP packet is too big to be read with UDP") - case "rtcp too big tcp": + + case ca.proto == "tcp" && ca.name == "rtcp invalid": + require.EqualError(t, err, "rtcp: packet too short") + + case ca.proto == "tcp" && ca.name == "rtcp too big": require.EqualError(t, err, "RTCP packet size (2000) is greater than maximum allowed (1472)") } close(errorRecv) diff --git a/server_publish_test.go b/server_publish_test.go index dac348f1..8da1ddb1 100644 --- a/server_publish_test.go +++ b/server_publish_test.go @@ -1485,15 +1485,18 @@ func TestServerPublishUDPChangeConn(t *testing.T) { } func TestServerPublishDecodeErrors(t *testing.T) { - for _, ca := range []string{ - "rtp invalid", - "rtcp invalid", - "rtp packets lost", - "rtp too big", - "rtcp too big", - "rtcp too big tcp", + for _, ca := range []struct { + proto string + name string + }{ + {"udp", "rtp invalid"}, + {"udp", "rtcp invalid"}, + {"udp", "rtp packets lost"}, + {"udp", "rtp too big"}, + {"udp", "rtcp too big"}, + {"tcp", "rtcp too big"}, } { - t.Run(ca, func(t *testing.T) { + t.Run(ca.proto+" "+ca.name, func(t *testing.T) { errorRecv := make(chan struct{}) s := &Server{ @@ -1514,18 +1517,23 @@ func TestServerPublishDecodeErrors(t *testing.T) { }, nil }, onDecodeError: func(ctx *ServerHandlerOnDecodeErrorCtx) { - switch ca { - case "rtp invalid": + switch { + case ca.proto == "udp" && ca.name == "rtp invalid": require.EqualError(t, ctx.Error, "RTP header size insufficient: 2 < 4") - case "rtcp invalid": + + case ca.proto == "udp" && ca.name == "rtcp invalid": require.EqualError(t, ctx.Error, "rtcp: packet too short") - case "rtp packets lost": + + case ca.proto == "udp" && ca.name == "rtp packets lost": require.EqualError(t, ctx.Error, "69 RTP packet(s) lost") - case "rtp too big": + + case ca.proto == "udp" && ca.name == "rtp too big": require.EqualError(t, ctx.Error, "RTP packet is too big to be read with UDP") - case "rtcp too big": + + case ca.proto == "udp" && ca.name == "rtcp too big": require.EqualError(t, ctx.Error, "RTCP packet is too big to be read with UDP") - case "rtcp too big tcp": + + case ca.proto == "tcp" && ca.name == "rtcp too big": require.EqualError(t, ctx.Error, "RTCP packet size (2000) is greater than maximum allowed (1472)") } close(errorRecv) @@ -1577,7 +1585,7 @@ func TestServerPublishDecodeErrors(t *testing.T) { }(), } - if ca != "rtcp too big tcp" { + if ca.proto == "udp" { inTH.Protocol = headers.TransportProtocolUDP inTH.ClientPorts = &[2]int{35466, 35467} } else { @@ -1588,7 +1596,7 @@ func TestServerPublishDecodeErrors(t *testing.T) { var l1 net.PacketConn var l2 net.PacketConn - if ca != "rtcp too big tcp" { + if ca.proto == "udp" { l1, err = net.ListenPacket("udp", "127.0.0.1:35466") require.NoError(t, err) defer l1.Close() @@ -1628,20 +1636,20 @@ func TestServerPublishDecodeErrors(t *testing.T) { require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) - switch ca { //nolint:dupl - case "rtp invalid": + switch { //nolint:dupl + case ca.proto == "udp" && ca.name == "rtp invalid": l1.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: resTH.ServerPorts[0], }) - case "rtcp invalid": + case ca.proto == "udp" && ca.name == "rtcp invalid": l2.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: resTH.ServerPorts[1], }) - case "rtp packets lost": + case ca.proto == "udp" && ca.name == "rtp packets lost": byts, _ := rtp.Packet{ Header: rtp.Header{ SequenceNumber: 30, @@ -1662,19 +1670,19 @@ func TestServerPublishDecodeErrors(t *testing.T) { Port: resTH.ServerPorts[0], }) - case "rtp too big": + case ca.proto == "udp" && ca.name == "rtp too big": l1.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: resTH.ServerPorts[0], }) - case "rtcp too big": + case ca.proto == "udp" && ca.name == "rtcp too big": l2.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: resTH.ServerPorts[1], }) - case "rtcp too big tcp": + case ca.proto == "tcp" && ca.name == "rtcp too big": err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ Channel: 1, Payload: bytes.Repeat([]byte{0x01, 0x02}, 2000/2), diff --git a/server_read_test.go b/server_read_test.go index dfce96bd..624dee2c 100644 --- a/server_read_test.go +++ b/server_read_test.go @@ -1,6 +1,7 @@ package gortsplib import ( + "bytes" "crypto/tls" "net" "strconv" @@ -716,6 +717,156 @@ func TestServerRead(t *testing.T) { } } +func TestServerReadDecodeErrors(t *testing.T) { + for _, ca := range []struct { + proto string + name string + }{ + {"udp", "rtcp invalid"}, + {"udp", "rtcp too big"}, + {"tcp", "rtcp too big"}, + } { + t.Run(ca.proto+" "+ca.name, func(t *testing.T) { + errorRecv := make(chan struct{}) + + track := &TrackH264{ + PayloadType: 96, + SPS: []byte{0x01, 0x02, 0x03, 0x04}, + PPS: []byte{0x01, 0x02, 0x03, 0x04}, + PacketizationMode: 1, + } + + stream := NewServerStream(Tracks{track}) + defer stream.Close() + + s := &Server{ + Handler: &testServerHandler{ + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, stream, nil + }, + onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onDecodeError: func(ctx *ServerHandlerOnDecodeErrorCtx) { + switch { + case ca.proto == "udp" && ca.name == "rtcp invalid": + require.EqualError(t, ctx.Error, "rtcp: packet too short") + + case ca.proto == "udp" && ca.name == "rtcp too big": + require.EqualError(t, ctx.Error, "RTCP packet is too big to be read with UDP") + + case ca.proto == "tcp" && ca.name == "rtcp too big": + require.EqualError(t, ctx.Error, "RTCP packet size (2000) is greater than maximum allowed (1472)") + } + close(errorRecv) + }, + }, + UDPRTPAddress: "127.0.0.1:8000", + UDPRTCPAddress: "127.0.0.1:8001", + RTSPAddress: "localhost:8554", + } + + err := s.Start() + require.NoError(t, err) + defer s.Close() + + nconn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer nconn.Close() + conn := conn.NewConn(nconn) + + inTH := &headers.Transport{ + Mode: func() *headers.TransportMode { + v := headers.TransportModePlay + return &v + }(), + Delivery: func() *headers.TransportDelivery { + v := headers.TransportDeliveryUnicast + return &v + }(), + } + + if ca.proto == "udp" { + inTH.Protocol = headers.TransportProtocolUDP + inTH.ClientPorts = &[2]int{35466, 35467} + } else { + inTH.Protocol = headers.TransportProtocolTCP + inTH.InterleavedIDs = &[2]int{0, 1} + } + + res, err := writeReqReadRes(conn, base.Request{ + Method: base.Setup, + URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Transport": inTH.Marshal(), + }, + }) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + var resTH headers.Transport + err = resTH.Unmarshal(res.Header["Transport"]) + require.NoError(t, err) + + var l1 net.PacketConn + var l2 net.PacketConn + + if ca.proto == "udp" { + l1, err = net.ListenPacket("udp", "127.0.0.1:35466") + require.NoError(t, err) + defer l1.Close() + + l2, err = net.ListenPacket("udp", "127.0.0.1:35467") + require.NoError(t, err) + defer l2.Close() + } + + var sx headers.Session + err = sx.Unmarshal(res.Header["Session"]) + require.NoError(t, err) + + res, err = writeReqReadRes(conn, base.Request{ + Method: base.Play, + URL: mustParseURL("rtsp://localhost:8554/teststream"), + Header: base.Header{ + "CSeq": base.HeaderValue{"2"}, + "Session": base.HeaderValue{sx.Session}, + }, + }) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + switch { //nolint:dupl + case ca.proto == "udp" && ca.name == "rtcp invalid": + l2.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: resTH.ServerPorts[1], + }) + + case ca.proto == "udp" && ca.name == "rtcp too big": + l2.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: resTH.ServerPorts[1], + }) + + case ca.proto == "tcp" && ca.name == "rtcp too big": + err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ + Channel: 1, + Payload: bytes.Repeat([]byte{0x01, 0x02}, 2000/2), + }, make([]byte, 2048)) + require.NoError(t, err) + } + + <-errorRecv + }) + } +} + func TestServerReadRTCPReport(t *testing.T) { for _, ca := range []string{"udp", "tcp"} { t.Run(ca, func(t *testing.T) {