From b0947c133e5edbb6e0a5fb13d05f3d0c4efcce15 Mon Sep 17 00:00:00 2001 From: Alessandro Ros Date: Wed, 17 May 2023 21:14:00 +0200 Subject: [PATCH] move each goroutine in a dedicated struct (#285) --- client_play_test.go | 5 +- client_udpl.go => client_udp_listener.go | 19 +- pkg/formats/rtpav1/encoder_test.go | 25 +- pkg/formats/rtph264/encoder_test.go | 25 +- pkg/formats/rtph265/encoder_test.go | 25 +- pkg/formats/rtplpcm/encoder_test.go | 31 +-- pkg/formats/rtpmjpeg/encoder_test.go | 23 +- pkg/formats/rtpmpeg2audio/encoder_test.go | 25 +- pkg/formats/rtpmpeg4audio/encoder_test.go | 33 ++- pkg/formats/rtpmpeg4video/encoder_test.go | 27 +- pkg/formats/rtpsimpleaudio/encoder_test.go | 27 +- pkg/formats/rtpvp8/encoder_test.go | 25 +- pkg/formats/rtpvp9/encoder_test.go | 30 +- pkg/headers/rtpinfo_test.go | 122 +++------ pkg/headers/session_test.go | 10 +- pkg/headers/transport_test.go | 30 +- server.go | 304 ++++++++++++--------- server_conn.go | 212 ++++---------- server_conn_reader.go | 135 +++++++++ server_multicast_writer.go | 10 +- server_play_test.go | 48 ++-- server_session.go | 78 +++--- server_tcp_listener.go | 47 ++++ server_udpl.go => server_udp_listener.go | 0 24 files changed, 675 insertions(+), 641 deletions(-) rename client_udpl.go => client_udp_listener.go (95%) create mode 100644 server_conn_reader.go create mode 100644 server_tcp_listener.go rename server_udpl.go => server_udp_listener.go (100%) diff --git a/client_play_test.go b/client_play_test.go index 83b96fa0..7a923931 100644 --- a/client_play_test.go +++ b/client_play_test.go @@ -2693,10 +2693,7 @@ func TestClientPlayKeepaliveFromSession(t *testing.T) { }.Marshal(), "Session": headers.Session{ Session: "ABCDE", - Timeout: func() *uint { - v := uint(1) - return &v - }(), + Timeout: uintPtr(1), }.Marshal(), }, }) diff --git a/client_udpl.go b/client_udp_listener.go similarity index 95% rename from client_udpl.go rename to client_udp_listener.go index 53bd96d7..98d48b1b 100644 --- a/client_udpl.go +++ b/client_udp_listener.go @@ -11,6 +11,10 @@ import ( "golang.org/x/net/ipv4" ) +func int64Ptr(v int64) *int64 { + return &v +} + func randInRange(max int) int { b := big.NewInt(int64(max + 1)) n, _ := rand.Int(rand.Reader, b) @@ -129,15 +133,12 @@ func newClientUDPListener( } return &clientUDPListener{ - anyPortEnable: anyPortEnable, - writeTimeout: writeTimeout, - pc: pc, - cm: cm, - isRTP: isRTP, - lastPacketTime: func() *int64 { - v := int64(0) - return &v - }(), + anyPortEnable: anyPortEnable, + writeTimeout: writeTimeout, + pc: pc, + cm: cm, + isRTP: isRTP, + lastPacketTime: int64Ptr(0), }, nil } diff --git a/pkg/formats/rtpav1/encoder_test.go b/pkg/formats/rtpav1/encoder_test.go index 54f1cc42..4fa4db57 100644 --- a/pkg/formats/rtpav1/encoder_test.go +++ b/pkg/formats/rtpav1/encoder_test.go @@ -7,6 +7,14 @@ import ( "github.com/stretchr/testify/require" ) +func uint16Ptr(v uint16) *uint16 { + return &v +} + +func uint32Ptr(v uint32) *uint32 { + return &v +} + var shortOBU = []byte{ 0x0a, 0x0e, 0x00, 0x00, 0x00, 0x4a, 0xab, 0xbf, 0xc3, 0x77, 0x6b, 0xe4, 0x40, 0x40, 0x40, 0x41, @@ -977,19 +985,10 @@ func TestEncode(t *testing.T) { for _, ca := range cases { t.Run(ca.name, func(t *testing.T) { e := &Encoder{ - PayloadType: 96, - SSRC: func() *uint32 { - v := uint32(0x9dbb7812) - return &v - }(), - InitialSequenceNumber: func() *uint16 { - v := uint16(0x44ed) - return &v - }(), - InitialTimestamp: func() *uint32 { - v := uint32(0x88776655) - return &v - }(), + PayloadType: 96, + SSRC: uint32Ptr(0x9dbb7812), + InitialSequenceNumber: uint16Ptr(0x44ed), + InitialTimestamp: uint32Ptr(0x88776655), } e.Init() diff --git a/pkg/formats/rtph264/encoder_test.go b/pkg/formats/rtph264/encoder_test.go index 3f3ce3b2..6b0770dd 100644 --- a/pkg/formats/rtph264/encoder_test.go +++ b/pkg/formats/rtph264/encoder_test.go @@ -8,6 +8,14 @@ import ( "github.com/stretchr/testify/require" ) +func uint16Ptr(v uint16) *uint16 { + return &v +} + +func uint32Ptr(v uint32) *uint32 { + return &v +} + func mergeBytes(vals ...[]byte) []byte { size := 0 for _, v := range vals { @@ -274,19 +282,10 @@ func TestEncode(t *testing.T) { for _, ca := range cases { t.Run(ca.name, func(t *testing.T) { e := &Encoder{ - PayloadType: 96, - SSRC: func() *uint32 { - v := uint32(0x9dbb7812) - return &v - }(), - InitialSequenceNumber: func() *uint16 { - v := uint16(0x44ed) - return &v - }(), - InitialTimestamp: func() *uint32 { - v := uint32(0x88776655) - return &v - }(), + PayloadType: 96, + SSRC: uint32Ptr(0x9dbb7812), + InitialSequenceNumber: uint16Ptr(0x44ed), + InitialTimestamp: uint32Ptr(0x88776655), } e.Init() diff --git a/pkg/formats/rtph265/encoder_test.go b/pkg/formats/rtph265/encoder_test.go index cf9561be..266c58e8 100644 --- a/pkg/formats/rtph265/encoder_test.go +++ b/pkg/formats/rtph265/encoder_test.go @@ -8,6 +8,14 @@ import ( "github.com/stretchr/testify/require" ) +func uint16Ptr(v uint16) *uint16 { + return &v +} + +func uint32Ptr(v uint32) *uint32 { + return &v +} + func mergeBytes(vals ...[]byte) []byte { size := 0 for _, v := range vals { @@ -127,19 +135,10 @@ func TestEncode(t *testing.T) { for _, ca := range cases { t.Run(ca.name, func(t *testing.T) { e := &Encoder{ - PayloadType: 96, - SSRC: func() *uint32 { - v := uint32(0x9dbb7812) - return &v - }(), - InitialSequenceNumber: func() *uint16 { - v := uint16(0x44ed) - return &v - }(), - InitialTimestamp: func() *uint32 { - v := uint32(0x88776655) - return &v - }(), + PayloadType: 96, + SSRC: uint32Ptr(0x9dbb7812), + InitialSequenceNumber: uint16Ptr(0x44ed), + InitialTimestamp: uint32Ptr(0x88776655), } e.Init() diff --git a/pkg/formats/rtplpcm/encoder_test.go b/pkg/formats/rtplpcm/encoder_test.go index a2d6a83b..4afe7d5c 100644 --- a/pkg/formats/rtplpcm/encoder_test.go +++ b/pkg/formats/rtplpcm/encoder_test.go @@ -8,6 +8,14 @@ import ( "github.com/stretchr/testify/require" ) +func uint16Ptr(v uint16) *uint16 { + return &v +} + +func uint32Ptr(v uint32) *uint32 { + return &v +} + var cases = []struct { name string samples []byte @@ -64,22 +72,13 @@ func TestEncode(t *testing.T) { for _, ca := range cases { t.Run(ca.name, func(t *testing.T) { e := &Encoder{ - PayloadType: 96, - SSRC: func() *uint32 { - v := uint32(0x9dbb7812) - return &v - }(), - InitialSequenceNumber: func() *uint16 { - v := uint16(0x44ed) - return &v - }(), - InitialTimestamp: func() *uint32 { - v := uint32(0x88776655) - return &v - }(), - BitDepth: 24, - SampleRate: 48000, - ChannelCount: 2, + PayloadType: 96, + SSRC: uint32Ptr(0x9dbb7812), + InitialSequenceNumber: uint16Ptr(0x44ed), + InitialTimestamp: uint32Ptr(0x88776655), + BitDepth: 24, + SampleRate: 48000, + ChannelCount: 2, } e.Init() diff --git a/pkg/formats/rtpmjpeg/encoder_test.go b/pkg/formats/rtpmjpeg/encoder_test.go index 2af524db..c1a2e2ee 100644 --- a/pkg/formats/rtpmjpeg/encoder_test.go +++ b/pkg/formats/rtpmjpeg/encoder_test.go @@ -7,6 +7,14 @@ import ( "github.com/stretchr/testify/require" ) +func uint16Ptr(v uint16) *uint16 { + return &v +} + +func uint32Ptr(v uint32) *uint32 { + return &v +} + var cases = []struct { name string image []byte @@ -509,18 +517,9 @@ func TestEncode(t *testing.T) { for _, ca := range cases { t.Run(ca.name, func(t *testing.T) { e := &Encoder{ - SSRC: func() *uint32 { - v := uint32(0x9dbb7812) - return &v - }(), - InitialSequenceNumber: func() *uint16 { - v := uint16(0x44ed) - return &v - }(), - InitialTimestamp: func() *uint32 { - v := uint32(2289528607) - return &v - }(), + SSRC: uint32Ptr(0x9dbb7812), + InitialSequenceNumber: uint16Ptr(0x44ed), + InitialTimestamp: uint32Ptr(2289528607), } e.Init() diff --git a/pkg/formats/rtpmpeg2audio/encoder_test.go b/pkg/formats/rtpmpeg2audio/encoder_test.go index 1fbde0fb..e8e1c6f5 100644 --- a/pkg/formats/rtpmpeg2audio/encoder_test.go +++ b/pkg/formats/rtpmpeg2audio/encoder_test.go @@ -7,6 +7,14 @@ import ( "github.com/stretchr/testify/require" ) +func uint16Ptr(v uint16) *uint16 { + return &v +} + +func uint32Ptr(v uint32) *uint32 { + return &v +} + var cases = []struct { name string frames [][]byte @@ -464,19 +472,10 @@ func TestEncode(t *testing.T) { for _, ca := range cases { t.Run(ca.name, func(t *testing.T) { e := &Encoder{ - SSRC: func() *uint32 { - v := uint32(0x9dbb7812) - return &v - }(), - InitialSequenceNumber: func() *uint16 { - v := uint16(0x44ed) - return &v - }(), - InitialTimestamp: func() *uint32 { - v := uint32(0x88776655) - return &v - }(), - PayloadMaxSize: 400, + SSRC: uint32Ptr(0x9dbb7812), + InitialSequenceNumber: uint16Ptr(0x44ed), + InitialTimestamp: uint32Ptr(0x88776655), + PayloadMaxSize: 400, } e.Init() diff --git a/pkg/formats/rtpmpeg4audio/encoder_test.go b/pkg/formats/rtpmpeg4audio/encoder_test.go index f89ff2ce..3867a279 100644 --- a/pkg/formats/rtpmpeg4audio/encoder_test.go +++ b/pkg/formats/rtpmpeg4audio/encoder_test.go @@ -8,6 +8,14 @@ import ( "github.com/stretchr/testify/require" ) +func uint16Ptr(v uint16) *uint16 { + return &v +} + +func uint32Ptr(v uint32) *uint32 { + return &v +} + func mergeBytes(vals ...[]byte) []byte { size := 0 for _, v := range vals { @@ -474,23 +482,14 @@ func TestEncode(t *testing.T) { for _, ca := range cases { t.Run(ca.name, func(t *testing.T) { e := &Encoder{ - PayloadType: 96, - SampleRate: 48000, - SSRC: func() *uint32 { - v := uint32(0x9dbb7812) - return &v - }(), - InitialSequenceNumber: func() *uint16 { - v := uint16(0x44ed) - return &v - }(), - InitialTimestamp: func() *uint32 { - v := uint32(0x88776655) - return &v - }(), - SizeLength: ca.sizeLength, - IndexLength: ca.indexLength, - IndexDeltaLength: ca.indexDeltaLength, + PayloadType: 96, + SampleRate: 48000, + SSRC: uint32Ptr(0x9dbb7812), + InitialSequenceNumber: uint16Ptr(0x44ed), + InitialTimestamp: uint32Ptr(0x88776655), + SizeLength: ca.sizeLength, + IndexLength: ca.indexLength, + IndexDeltaLength: ca.indexDeltaLength, } e.Init() diff --git a/pkg/formats/rtpmpeg4video/encoder_test.go b/pkg/formats/rtpmpeg4video/encoder_test.go index e86554de..f66519cb 100644 --- a/pkg/formats/rtpmpeg4video/encoder_test.go +++ b/pkg/formats/rtpmpeg4video/encoder_test.go @@ -8,6 +8,14 @@ import ( "github.com/stretchr/testify/require" ) +func uint16Ptr(v uint16) *uint16 { + return &v +} + +func uint32Ptr(v uint32) *uint32 { + return &v +} + var cases = []struct { name string frame []byte @@ -66,20 +74,11 @@ func TestEncode(t *testing.T) { for _, ca := range cases { t.Run(ca.name, func(t *testing.T) { e := &Encoder{ - PayloadType: 96, - SSRC: func() *uint32 { - v := uint32(0x9dbb7812) - return &v - }(), - InitialSequenceNumber: func() *uint16 { - v := uint16(0x44ed) - return &v - }(), - InitialTimestamp: func() *uint32 { - v := uint32(0x88776655) - return &v - }(), - PayloadMaxSize: 100, + PayloadType: 96, + SSRC: uint32Ptr(0x9dbb7812), + InitialSequenceNumber: uint16Ptr(0x44ed), + InitialTimestamp: uint32Ptr(0x88776655), + PayloadMaxSize: 100, } e.Init() diff --git a/pkg/formats/rtpsimpleaudio/encoder_test.go b/pkg/formats/rtpsimpleaudio/encoder_test.go index 6b19843d..09756394 100644 --- a/pkg/formats/rtpsimpleaudio/encoder_test.go +++ b/pkg/formats/rtpsimpleaudio/encoder_test.go @@ -7,6 +7,14 @@ import ( "github.com/stretchr/testify/require" ) +func uint16Ptr(v uint16) *uint16 { + return &v +} + +func uint32Ptr(v uint32) *uint32 { + return &v +} + var cases = []struct { name string frame []byte @@ -33,20 +41,11 @@ func TestEncode(t *testing.T) { for _, ca := range cases { t.Run(ca.name, func(t *testing.T) { e := &Encoder{ - PayloadType: 0, - SampleRate: 8000, - SSRC: func() *uint32 { - v := uint32(0x9dbb7812) - return &v - }(), - InitialSequenceNumber: func() *uint16 { - v := uint16(0x44ed) - return &v - }(), - InitialTimestamp: func() *uint32 { - v := uint32(0x88776655) - return &v - }(), + PayloadType: 0, + SampleRate: 8000, + SSRC: uint32Ptr(0x9dbb7812), + InitialSequenceNumber: uint16Ptr(0x44ed), + InitialTimestamp: uint32Ptr(0x88776655), } e.Init() diff --git a/pkg/formats/rtpvp8/encoder_test.go b/pkg/formats/rtpvp8/encoder_test.go index f99d0071..d5674e35 100644 --- a/pkg/formats/rtpvp8/encoder_test.go +++ b/pkg/formats/rtpvp8/encoder_test.go @@ -8,6 +8,14 @@ import ( "github.com/stretchr/testify/require" ) +func uint16Ptr(v uint16) *uint16 { + return &v +} + +func uint32Ptr(v uint32) *uint32 { + return &v +} + func mergeBytes(vals ...[]byte) []byte { size := 0 for _, v := range vals { @@ -91,19 +99,10 @@ func TestEncode(t *testing.T) { for _, ca := range cases { t.Run(ca.name, func(t *testing.T) { e := &Encoder{ - PayloadType: 96, - SSRC: func() *uint32 { - v := uint32(0x9dbb7812) - return &v - }(), - InitialSequenceNumber: func() *uint16 { - v := uint16(0x44ed) - return &v - }(), - InitialTimestamp: func() *uint32 { - v := uint32(0x88776655) - return &v - }(), + PayloadType: 96, + SSRC: uint32Ptr(0x9dbb7812), + InitialSequenceNumber: uint16Ptr(0x44ed), + InitialTimestamp: uint32Ptr(0x88776655), } e.Init() diff --git a/pkg/formats/rtpvp9/encoder_test.go b/pkg/formats/rtpvp9/encoder_test.go index 473ad2cb..749ddab2 100644 --- a/pkg/formats/rtpvp9/encoder_test.go +++ b/pkg/formats/rtpvp9/encoder_test.go @@ -8,6 +8,14 @@ import ( "github.com/stretchr/testify/require" ) +func uint16Ptr(v uint16) *uint16 { + return &v +} + +func uint32Ptr(v uint32) *uint32 { + return &v +} + func mergeBytes(vals ...[]byte) []byte { size := 0 for _, v := range vals { @@ -92,23 +100,11 @@ func TestEncode(t *testing.T) { for _, ca := range cases { t.Run(ca.name, func(t *testing.T) { e := &Encoder{ - PayloadType: 96, - SSRC: func() *uint32 { - v := uint32(0x9dbb7812) - return &v - }(), - InitialSequenceNumber: func() *uint16 { - v := uint16(0x44ed) - return &v - }(), - InitialTimestamp: func() *uint32 { - v := uint32(0x88776655) - return &v - }(), - InitialPictureID: func() *uint16 { - v := uint16(0x35af) - return &v - }(), + PayloadType: 96, + SSRC: uint32Ptr(0x9dbb7812), + InitialSequenceNumber: uint16Ptr(0x44ed), + InitialTimestamp: uint32Ptr(0x88776655), + InitialPictureID: uint16Ptr(0x35af), } e.Init() diff --git a/pkg/headers/rtpinfo_test.go b/pkg/headers/rtpinfo_test.go index db8125db..fddf759d 100644 --- a/pkg/headers/rtpinfo_test.go +++ b/pkg/headers/rtpinfo_test.go @@ -8,6 +8,18 @@ import ( "github.com/bluenviron/gortsplib/v3/pkg/base" ) +func uint16Ptr(v uint16) *uint16 { + return &v +} + +func uint32Ptr(v uint32) *uint32 { + return &v +} + +func uintPtr(v uint) *uint { + return &v +} + var casesRTPInfo = []struct { name string vin base.HeaderValue @@ -20,15 +32,9 @@ var casesRTPInfo = []struct { base.HeaderValue{`url=rtsp://127.0.0.1/test.mkv/track1;seq=35243;rtptime=717574556`}, RTPInfo{ { - URL: "rtsp://127.0.0.1/test.mkv/track1", - SequenceNumber: func() *uint16 { - v := uint16(35243) - return &v - }(), - Timestamp: func() *uint32 { - v := uint32(717574556) - return &v - }(), + URL: "rtsp://127.0.0.1/test.mkv/track1", + SequenceNumber: uint16Ptr(35243), + Timestamp: uint32Ptr(717574556), }, }, }, @@ -40,26 +46,14 @@ var casesRTPInfo = []struct { `url=rtsp://127.0.0.1/test.mkv/track2;seq=13655;rtptime=2848846950`}, RTPInfo{ { - URL: "rtsp://127.0.0.1/test.mkv/track1", - SequenceNumber: func() *uint16 { - v := uint16(35243) - return &v - }(), - Timestamp: func() *uint32 { - v := uint32(717574556) - return &v - }(), + URL: "rtsp://127.0.0.1/test.mkv/track1", + SequenceNumber: uint16Ptr(35243), + Timestamp: uint32Ptr(717574556), }, { - URL: "rtsp://127.0.0.1/test.mkv/track2", - SequenceNumber: func() *uint16 { - v := uint16(13655) - return &v - }(), - Timestamp: func() *uint32 { - v := uint32(2848846950) - return &v - }(), + URL: "rtsp://127.0.0.1/test.mkv/track2", + SequenceNumber: uint16Ptr(13655), + Timestamp: uint32Ptr(2848846950), }, }, }, @@ -69,11 +63,8 @@ var casesRTPInfo = []struct { base.HeaderValue{`url=rtsp://127.0.0.1/test.mkv/track1;seq=35243`}, RTPInfo{ { - URL: "rtsp://127.0.0.1/test.mkv/track1", - SequenceNumber: func() *uint16 { - v := uint16(35243) - return &v - }(), + URL: "rtsp://127.0.0.1/test.mkv/track1", + SequenceNumber: uint16Ptr(35243), }, }, }, @@ -83,11 +74,8 @@ var casesRTPInfo = []struct { base.HeaderValue{`url=rtsp://127.0.0.1/test.mkv/track1;rtptime=717574556`}, RTPInfo{ { - URL: "rtsp://127.0.0.1/test.mkv/track1", - Timestamp: func() *uint32 { - v := uint32(717574556) - return &v - }(), + URL: "rtsp://127.0.0.1/test.mkv/track1", + Timestamp: uint32Ptr(717574556), }, }, }, @@ -97,15 +85,9 @@ var casesRTPInfo = []struct { base.HeaderValue{`url=trackID=0;seq=12447;rtptime=12447`}, RTPInfo{ { - URL: "trackID=0", - SequenceNumber: func() *uint16 { - v := uint16(12447) - return &v - }(), - Timestamp: func() *uint32 { - v := uint32(12447) - return &v - }(), + URL: "trackID=0", + SequenceNumber: uint16Ptr(12447), + Timestamp: uint32Ptr(12447), }, }, }, @@ -117,26 +99,14 @@ var casesRTPInfo = []struct { `seq=58477;rtptime=1020884293,url=rtsp://10.13.146.53/axis-media/media.amp/trackID=2;seq=15727;rtptime=1171661503`}, RTPInfo{ { - URL: "rtsp://10.13.146.53/axis-media/media.amp/trackID=1", - SequenceNumber: func() *uint16 { - v := uint16(58477) - return &v - }(), - Timestamp: func() *uint32 { - v := uint32(1020884293) - return &v - }(), + URL: "rtsp://10.13.146.53/axis-media/media.amp/trackID=1", + SequenceNumber: uint16Ptr(58477), + Timestamp: uint32Ptr(1020884293), }, { - URL: "rtsp://10.13.146.53/axis-media/media.amp/trackID=2", - SequenceNumber: func() *uint16 { - v := uint16(15727) - return &v - }(), - Timestamp: func() *uint32 { - v := uint32(1171661503) - return &v - }(), + URL: "rtsp://10.13.146.53/axis-media/media.amp/trackID=2", + SequenceNumber: uint16Ptr(15727), + Timestamp: uint32Ptr(1171661503), }, }, }, @@ -148,26 +118,14 @@ var casesRTPInfo = []struct { `url=trackID=2;seq=43807;rtptime=1702259566`}, RTPInfo{ { - URL: "trackID=1", - SequenceNumber: func() *uint16 { - v := uint16(55664) - return &v - }(), - Timestamp: func() *uint32 { - v := uint32(254718369) - return &v - }(), + URL: "trackID=1", + SequenceNumber: uint16Ptr(55664), + Timestamp: uint32Ptr(254718369), }, { - URL: "trackID=2", - SequenceNumber: func() *uint16 { - v := uint16(43807) - return &v - }(), - Timestamp: func() *uint32 { - v := uint32(1702259566) - return &v - }(), + URL: "trackID=2", + SequenceNumber: uint16Ptr(43807), + Timestamp: uint32Ptr(1702259566), }, }, }, diff --git a/pkg/headers/session_test.go b/pkg/headers/session_test.go index b670d938..0c4f2643 100644 --- a/pkg/headers/session_test.go +++ b/pkg/headers/session_test.go @@ -28,10 +28,7 @@ var casesSession = []struct { base.HeaderValue{`A3eqwsafq3rFASqew;timeout=47`}, Session{ Session: "A3eqwsafq3rFASqew", - Timeout: func() *uint { - v := uint(47) - return &v - }(), + Timeout: uintPtr(47), }, }, { @@ -40,10 +37,7 @@ var casesSession = []struct { base.HeaderValue{`A3eqwsafq3rFASqew;timeout=47`}, Session{ Session: "A3eqwsafq3rFASqew", - Timeout: func() *uint { - v := uint(47) - return &v - }(), + Timeout: uintPtr(47), }, }, } diff --git a/pkg/headers/transport_test.go b/pkg/headers/transport_test.go index b9e78d11..42f44c3d 100644 --- a/pkg/headers/transport_test.go +++ b/pkg/headers/transport_test.go @@ -60,10 +60,7 @@ var casesTransport = []struct { v := net.ParseIP("225.219.201.15") return &v }(), - TTL: func() *uint { - v := uint(127) - return &v - }(), + TTL: uintPtr(127), Ports: &[2]int{7000, 7001}, }, }, @@ -92,10 +89,7 @@ var casesTransport = []struct { }(), ClientPorts: &[2]int{14186, 14187}, ServerPorts: &[2]int{8052, 8053}, - SSRC: func() *uint32 { - v := uint32(0x0B6020AD) - return &v - }(), + SSRC: uint32Ptr(0x0B6020AD), }, }, { @@ -153,10 +147,7 @@ var casesTransport = []struct { }(), ClientPorts: &[2]int{14186, 14187}, ServerPorts: &[2]int{8052, 8053}, - SSRC: func() *uint32 { - v := uint32(0x04317f) - return &v - }(), + SSRC: uint32Ptr(0x04317f), }, }, { @@ -175,10 +166,7 @@ var casesTransport = []struct { }(), ClientPorts: &[2]int{14186, 14187}, ServerPorts: &[2]int{8052, 8053}, - SSRC: func() *uint32 { - v := uint32(0x04317f) - return &v - }(), + SSRC: uint32Ptr(0x04317f), }, }, { @@ -192,10 +180,7 @@ var casesTransport = []struct { return &v }(), InterleavedIDs: &[2]int{0, 1}, - SSRC: func() *uint32 { - v := uint32(0xD93FF) - return &v - }(), + SSRC: uint32Ptr(0xD93FF), }, }, { @@ -208,10 +193,7 @@ var casesTransport = []struct { v := TransportDeliveryUnicast return &v }(), - SSRC: func() *uint32 { - v := uint32(0x45dcb578) - return &v - }(), + SSRC: uint32Ptr(0x45dcb578), ClientPorts: &[2]int{32560, 32561}, ServerPorts: &[2]int{3046, 3047}, }, diff --git a/server.go b/server.go index 0dd14516..f2c185e3 100644 --- a/server.go +++ b/server.go @@ -41,7 +41,7 @@ type sessionRequestReq struct { res chan sessionRequestRes } -type streamMulticastIPReq struct { +type chGetMulticastIPReq struct { res chan net.IP } @@ -124,7 +124,7 @@ type Server struct { wg sync.WaitGroup multicastNet *net.IPNet multicastNextIP net.IP - tcpListener net.Listener + tcpListener *serverTCPListener udpRTPListener *serverUDPListener udpRTCPListener *serverUDPListener sessions map[string]*ServerSession @@ -132,10 +132,12 @@ type Server struct { closeError error // in - connClose chan *ServerConn - sessionRequest chan sessionRequestReq - sessionClose chan *ServerSession - streamMulticastIP chan streamMulticastIPReq + chNewConn chan net.Conn + chAcceptErr chan error + chCloseConn chan *ServerConn + chHandleRequest chan sessionRequestReq + chCloseSession chan *ServerSession + chGetMulticastIP chan chGetMulticastIPReq } // Start starts the server. @@ -287,8 +289,19 @@ func (s *Server) Start() error { s.multicastNextIP = s.multicastNet.IP } + s.ctx, s.ctxCancel = context.WithCancel(context.Background()) + + s.sessions = make(map[string]*ServerSession) + s.conns = make(map[*ServerConn]struct{}) + s.chNewConn = make(chan net.Conn) + s.chAcceptErr = make(chan error) + s.chCloseConn = make(chan *ServerConn) + s.chHandleRequest = make(chan sessionRequestReq) + s.chCloseSession = make(chan *ServerSession) + s.chGetMulticastIP = make(chan chGetMulticastIPReq) + var err error - s.tcpListener, err = s.Listen(restrictNetwork("tcp", s.RTSPAddress)) + s.tcpListener, err = newServerTCPListener(s) if err != nil { if s.udpRTPListener != nil { s.udpRTPListener.close() @@ -296,11 +309,10 @@ func (s *Server) Start() error { if s.udpRTCPListener != nil { s.udpRTCPListener.close() } + s.ctxCancel() return err } - s.ctx, s.ctxCancel = context.WithCancel(context.Background()) - s.wg.Add(1) go s.run() @@ -324,131 +336,7 @@ func (s *Server) Wait() error { func (s *Server) run() { defer s.wg.Done() - s.sessions = make(map[string]*ServerSession) - s.conns = make(map[*ServerConn]struct{}) - s.connClose = make(chan *ServerConn) - s.sessionRequest = make(chan sessionRequestReq) - s.sessionClose = make(chan *ServerSession) - s.streamMulticastIP = make(chan streamMulticastIPReq) - - s.wg.Add(1) - connNew := make(chan net.Conn) - acceptErr := make(chan error) - go func() { - defer s.wg.Done() - err := func() error { - for { - nconn, err := s.tcpListener.Accept() - if err != nil { - return err - } - - select { - case connNew <- nconn: - case <-s.ctx.Done(): - nconn.Close() - } - } - }() - - select { - case acceptErr <- err: - case <-s.ctx.Done(): - } - }() - - s.closeError = func() error { - for { - select { - case err := <-acceptErr: - return err - - case nconn := <-connNew: - sc := newServerConn(s, nconn) - s.conns[sc] = struct{}{} - - case sc := <-s.connClose: - if _, ok := s.conns[sc]; !ok { - continue - } - delete(s.conns, sc) - sc.Close() - - case req := <-s.sessionRequest: - if ss, ok := s.sessions[req.id]; ok { - if !req.sc.ip().Equal(ss.author.ip()) || - req.sc.zone() != ss.author.zone() { - req.res <- sessionRequestRes{ - res: &base.Response{ - StatusCode: base.StatusBadRequest, - }, - err: liberrors.ErrServerCannotUseSessionCreatedByOtherIP{}, - } - continue - } - - select { - case ss.request <- req: - case <-ss.ctx.Done(): - req.res <- sessionRequestRes{ - res: &base.Response{ - StatusCode: base.StatusBadRequest, - }, - err: liberrors.ErrServerTerminated{}, - } - } - } else { - if !req.create { - req.res <- sessionRequestRes{ - res: &base.Response{ - StatusCode: base.StatusSessionNotFound, - }, - err: liberrors.ErrServerSessionNotFound{}, - } - continue - } - - ss := newServerSession(s, req.sc) - s.sessions[ss.secretID] = ss - - select { - case ss.request <- req: - case <-ss.ctx.Done(): - req.res <- sessionRequestRes{ - res: &base.Response{ - StatusCode: base.StatusBadRequest, - }, - err: liberrors.ErrServerTerminated{}, - } - } - } - - case ss := <-s.sessionClose: - if sss, ok := s.sessions[ss.secretID]; !ok || sss != ss { - continue - } - delete(s.sessions, ss.secretID) - ss.Close() - - case req := <-s.streamMulticastIP: - ip32 := uint32(s.multicastNextIP[0])<<24 | uint32(s.multicastNextIP[1])<<16 | - uint32(s.multicastNextIP[2])<<8 | uint32(s.multicastNextIP[3]) - mask := uint32(s.multicastNet.Mask[0])<<24 | uint32(s.multicastNet.Mask[1])<<16 | - uint32(s.multicastNet.Mask[2])<<8 | uint32(s.multicastNet.Mask[3]) - ip32 = (ip32 & mask) | ((ip32 + 1) & ^mask) - ip := make(net.IP, 4) - ip[0] = byte(ip32 >> 24) - ip[1] = byte(ip32 >> 16) - ip[2] = byte(ip32 >> 8) - ip[3] = byte(ip32) - s.multicastNextIP = ip - req.res <- ip - - case <-s.ctx.Done(): - return liberrors.ErrServerTerminated{} - } - } - }() + s.closeError = s.runInner() s.ctxCancel() @@ -460,7 +348,100 @@ func (s *Server) run() { s.udpRTPListener.close() } - s.tcpListener.Close() + s.tcpListener.close() +} + +func (s *Server) runInner() error { + for { + select { + case err := <-s.chAcceptErr: + return err + + case nconn := <-s.chNewConn: + sc := newServerConn(s, nconn) + s.conns[sc] = struct{}{} + + case sc := <-s.chCloseConn: + if _, ok := s.conns[sc]; !ok { + continue + } + delete(s.conns, sc) + sc.Close() + + case req := <-s.chHandleRequest: + if ss, ok := s.sessions[req.id]; ok { + if !req.sc.ip().Equal(ss.author.ip()) || + req.sc.zone() != ss.author.zone() { + req.res <- sessionRequestRes{ + res: &base.Response{ + StatusCode: base.StatusBadRequest, + }, + err: liberrors.ErrServerCannotUseSessionCreatedByOtherIP{}, + } + continue + } + + select { + case ss.chHandleRequest <- req: + case <-ss.ctx.Done(): + req.res <- sessionRequestRes{ + res: &base.Response{ + StatusCode: base.StatusBadRequest, + }, + err: liberrors.ErrServerTerminated{}, + } + } + } else { + if !req.create { + req.res <- sessionRequestRes{ + res: &base.Response{ + StatusCode: base.StatusSessionNotFound, + }, + err: liberrors.ErrServerSessionNotFound{}, + } + continue + } + + ss := newServerSession(s, req.sc) + s.sessions[ss.secretID] = ss + + select { + case ss.chHandleRequest <- req: + case <-ss.ctx.Done(): + req.res <- sessionRequestRes{ + res: &base.Response{ + StatusCode: base.StatusBadRequest, + }, + err: liberrors.ErrServerTerminated{}, + } + } + } + + case ss := <-s.chCloseSession: + if sss, ok := s.sessions[ss.secretID]; !ok || sss != ss { + continue + } + delete(s.sessions, ss.secretID) + ss.Close() + + case req := <-s.chGetMulticastIP: + ip32 := uint32(s.multicastNextIP[0])<<24 | uint32(s.multicastNextIP[1])<<16 | + uint32(s.multicastNextIP[2])<<8 | uint32(s.multicastNextIP[3]) + mask := uint32(s.multicastNet.Mask[0])<<24 | uint32(s.multicastNet.Mask[1])<<16 | + uint32(s.multicastNet.Mask[2])<<8 | uint32(s.multicastNet.Mask[3]) + ip32 = (ip32 & mask) | ((ip32 + 1) & ^mask) + ip := make(net.IP, 4) + ip[0] = byte(ip32 >> 24) + ip[1] = byte(ip32 >> 16) + ip[2] = byte(ip32 >> 8) + ip[3] = byte(ip32) + s.multicastNextIP = ip + req.res <- ip + + case <-s.ctx.Done(): + return liberrors.ErrServerTerminated{} + } + } } // StartAndWait starts the server and waits until a fatal error. @@ -472,3 +453,56 @@ func (s *Server) StartAndWait() error { return s.Wait() } + +func (s *Server) getMulticastIP() (net.IP, error) { + res := make(chan net.IP) + select { + case s.chGetMulticastIP <- chGetMulticastIPReq{res: res}: + return <-res, nil + + case <-s.ctx.Done(): + return nil, fmt.Errorf("terminated") + } +} + +func (s *Server) newConn(nconn net.Conn) { + select { + case s.chNewConn <- nconn: + case <-s.ctx.Done(): + nconn.Close() + } +} + +func (s *Server) acceptErr(err error) { + select { + case s.chAcceptErr <- err: + case <-s.ctx.Done(): + } +} + +func (s *Server) closeConn(sc *ServerConn) { + select { + case s.chCloseConn <- sc: + case <-s.ctx.Done(): + } +} + +func (s *Server) closeSession(ss *ServerSession) { + select { + case s.chCloseSession <- ss: + case <-s.ctx.Done(): + } +} + +func (s *Server) handleRequest(req sessionRequestReq) (*base.Response, *ServerSession, error) { + select { + case s.chHandleRequest <- req: + res := <-req.res + return res.res, res.ss, res.err + + case <-s.ctx.Done(): + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, req.sc.session, liberrors.ErrServerTerminated{} + } +} diff --git a/server_conn.go b/server_conn.go index 5cd7dc0f..356ec22f 100644 --- a/server_conn.go +++ b/server_conn.go @@ -3,12 +3,10 @@ package gortsplib import ( "context" "crypto/tls" - "errors" "net" gourl "net/url" "strconv" "strings" - "sync/atomic" "time" "github.com/bluenviron/gortsplib/v3/pkg/base" @@ -71,10 +69,11 @@ type ServerConn struct { bc *bytecounter.ByteCounter conn *conn.Conn session *ServerSession - readFunc func(readRequest chan readReq) error // in - sessionRemove chan *ServerSession + chHandleRequest chan readReq + chReadErr chan error + chRemoveSession chan *ServerSession // out done chan struct{} @@ -91,18 +90,18 @@ func newServerConn( } sc := &ServerConn{ - s: s, - nconn: nconn, - bc: bytecounter.New(nconn, nil, nil), - ctx: ctx, - ctxCancel: ctxCancel, - remoteAddr: nconn.RemoteAddr().(*net.TCPAddr), - sessionRemove: make(chan *ServerSession), - done: make(chan struct{}), + s: s, + nconn: nconn, + bc: bytecounter.New(nconn, nil, nil), + ctx: ctx, + ctxCancel: ctxCancel, + remoteAddr: nconn.RemoteAddr().(*net.TCPAddr), + chHandleRequest: make(chan readReq), + chReadErr: make(chan error), + chRemoveSession: make(chan *ServerSession), + done: make(chan struct{}), } - sc.readFunc = sc.readFuncStandard - s.wg.Add(1) go sc.run() @@ -159,30 +158,21 @@ func (sc *ServerConn) run() { } sc.conn = conn.NewConn(sc.bc) + cr := newServerConnReader(sc) - readRequest := make(chan readReq) - readErr := make(chan error) - readDone := make(chan struct{}) - go sc.runReader(readRequest, readErr, readDone) - - err := sc.runInner(readRequest, readErr) + err := sc.runInner() sc.ctxCancel() sc.nconn.Close() - <-readDone + + cr.wait() if sc.session != nil { - select { - case sc.session.connRemove <- sc: - case <-sc.session.ctx.Done(): - } + sc.session.removeConn(sc) } - select { - case sc.s.connClose <- sc: - case <-sc.s.ctx.Done(): - } + sc.s.closeConn(sc) if h, ok := sc.s.Handler.(ServerHandlerOnConnClose); ok { h.OnConnClose(&ServerHandlerOnConnCloseCtx{ @@ -192,16 +182,16 @@ func (sc *ServerConn) run() { } } -func (sc *ServerConn) runInner(readRequest chan readReq, readErr chan error) error { +func (sc *ServerConn) runInner() error { for { select { - case req := <-readRequest: + case req := <-sc.chHandleRequest: req.res <- sc.handleRequestOuter(req.req) - case err := <-readErr: + case err := <-sc.chReadErr: return err - case ss := <-sc.sessionRemove: + case ss := <-sc.chRemoveSession: if sc.session == ss { sc.session = nil } @@ -212,111 +202,7 @@ func (sc *ServerConn) runInner(readRequest chan readReq, readErr chan error) err } } -var errSwitchReadFunc = errors.New("switch read function") - -func (sc *ServerConn) runReader(readRequest chan readReq, readErr chan error, readDone chan struct{}) { - defer close(readDone) - - for { - err := sc.readFunc(readRequest) - - if err == errSwitchReadFunc { - continue - } - - select { - case readErr <- err: - case <-sc.ctx.Done(): - } - break - } -} - -func (sc *ServerConn) readFuncStandard(readRequest chan readReq) error { - // reset deadline - sc.nconn.SetReadDeadline(time.Time{}) - - for { - any, err := sc.conn.ReadInterleavedFrameOrRequest() - if err != nil { - return err - } - - switch what := any.(type) { - case *base.Request: - cres := make(chan error) - select { - case readRequest <- readReq{req: what, res: cres}: - err = <-cres - if err != nil { - return err - } - - case <-sc.ctx.Done(): - return liberrors.ErrServerTerminated{} - } - - default: - return liberrors.ErrServerUnexpectedFrame{} - } - } -} - -func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error { - // reset deadline - sc.nconn.SetReadDeadline(time.Time{}) - - select { - case sc.session.startWriter <- struct{}{}: - case <-sc.session.ctx.Done(): - } - - for { - if sc.session.state == ServerSessionStateRecord { - sc.nconn.SetReadDeadline(time.Now().Add(sc.s.ReadTimeout)) - } - - what, err := sc.conn.ReadInterleavedFrameOrRequest() - if err != nil { - return err - } - - switch twhat := what.(type) { - case *base.InterleavedFrame: - channel := twhat.Channel - isRTP := true - if (channel % 2) != 0 { - channel-- - isRTP = false - } - - atomic.AddUint64(sc.session.bytesReceived, uint64(len(twhat.Payload))) - - if sm, ok := sc.session.tcpMediasByChannel[channel]; ok { - if isRTP { - sm.readRTP(twhat.Payload) - } else { - sm.readRTCP(twhat.Payload) - } - } - - case *base.Request: - cres := make(chan error) - select { - case readRequest <- readReq{req: twhat, res: cres}: - err := <-cres - if err != nil { - return err - } - - case <-sc.ctx.Done(): - return liberrors.ErrServerTerminated{} - } - } - } -} - -func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { +func (sc *ServerConn) handleRequestInner(req *base.Request) (*base.Response, error) { if cseq, ok := req.Header["CSeq"]; !ok || len(cseq) != 1 { return &base.Response{ StatusCode: base.StatusBadRequest, @@ -491,7 +377,7 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error { h.OnRequest(sc, req) } - res, err := sc.handleRequest(req) + res, err := sc.handleRequestInner(req) if res.Header == nil { res.Header = make(base.Header) @@ -544,17 +430,9 @@ func (sc *ServerConn) handleRequestInSession( res: cres, } - select { - case sc.session.request <- sreq: - res := <-cres - sc.session = res.ss - return res.res, res.err - - case <-sc.session.ctx.Done(): - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerTerminated{} - } + res, session, err := sc.session.handleRequest(sreq) + sc.session = session + return res, err } // otherwise, pass through Server @@ -567,15 +445,31 @@ func (sc *ServerConn) handleRequestInSession( res: cres, } - select { - case sc.s.sessionRequest <- sreq: - res := <-cres - sc.session = res.ss - return res.res, res.err + res, session, err := sc.s.handleRequest(sreq) + sc.session = session + return res, err +} - case <-sc.s.ctx.Done(): - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerTerminated{} +func (sc *ServerConn) removeSession(ss *ServerSession) { + select { + case sc.chRemoveSession <- ss: + case <-sc.ctx.Done(): + } +} + +func (sc *ServerConn) handleRequest(req readReq) error { + select { + case sc.chHandleRequest <- req: + return <-req.res + + case <-sc.ctx.Done(): + return liberrors.ErrServerTerminated{} + } +} + +func (sc *ServerConn) readErr(err error) { + select { + case sc.chReadErr <- err: + case <-sc.ctx.Done(): } } diff --git a/server_conn_reader.go b/server_conn_reader.go new file mode 100644 index 00000000..5566e315 --- /dev/null +++ b/server_conn_reader.go @@ -0,0 +1,135 @@ +package gortsplib + +import ( + "sync/atomic" + "time" + + "github.com/bluenviron/gortsplib/v3/pkg/base" + "github.com/bluenviron/gortsplib/v3/pkg/liberrors" +) + +type errSwitchReadFunc struct { + tcp bool +} + +func (errSwitchReadFunc) Error() string { + return "switching read function" +} + +func isErrSwitchReadFunc(err error) bool { + _, ok := err.(errSwitchReadFunc) + return ok +} + +type serverConnReader struct { + sc *ServerConn + + chReadDone chan struct{} +} + +func newServerConnReader(sc *ServerConn) *serverConnReader { + cr := &serverConnReader{ + sc: sc, + chReadDone: make(chan struct{}), + } + + go cr.run() + + return cr +} + +func (cr *serverConnReader) wait() { + <-cr.chReadDone +} + +func (cr *serverConnReader) run() { + defer close(cr.chReadDone) + + readFunc := cr.readFuncStandard + + for { + err := readFunc() + if err, ok := err.(errSwitchReadFunc); ok { + if err.tcp { + readFunc = cr.readFuncTCP + } else { + readFunc = cr.readFuncStandard + } + continue + } + + cr.sc.readErr(err) + break + } +} + +func (cr *serverConnReader) readFuncStandard() error { + // reset deadline + cr.sc.nconn.SetReadDeadline(time.Time{}) + + for { + any, err := cr.sc.conn.ReadInterleavedFrameOrRequest() + if err != nil { + return err + } + + switch what := any.(type) { + case *base.Request: + cres := make(chan error) + req := readReq{req: what, res: cres} + err := cr.sc.handleRequest(req) + if err != nil { + return err + } + + default: + return liberrors.ErrServerUnexpectedFrame{} + } + } +} + +func (cr *serverConnReader) readFuncTCP() error { + // reset deadline + cr.sc.nconn.SetReadDeadline(time.Time{}) + + cr.sc.session.startWriter() + + for { + if cr.sc.session.state == ServerSessionStateRecord { + cr.sc.nconn.SetReadDeadline(time.Now().Add(cr.sc.s.ReadTimeout)) + } + + what, err := cr.sc.conn.ReadInterleavedFrameOrRequest() + if err != nil { + return err + } + + switch twhat := what.(type) { + case *base.InterleavedFrame: + channel := twhat.Channel + isRTP := true + if (channel % 2) != 0 { + channel-- + isRTP = false + } + + atomic.AddUint64(cr.sc.session.bytesReceived, uint64(len(twhat.Payload))) + + if sm, ok := cr.sc.session.tcpMediasByChannel[channel]; ok { + if isRTP { + sm.readRTP(twhat.Payload) + } else { + sm.readRTCP(twhat.Payload) + } + } + + case *base.Request: + cres := make(chan error) + req := readReq{req: twhat, res: cres} + err := cr.sc.handleRequest(req) + if err != nil { + return err + } + } + } +} diff --git a/server_multicast_writer.go b/server_multicast_writer.go index 721896e2..acc1860a 100644 --- a/server_multicast_writer.go +++ b/server_multicast_writer.go @@ -1,7 +1,6 @@ package gortsplib import ( - "fmt" "net" "github.com/bluenviron/gortsplib/v3/pkg/ringbuffer" @@ -21,13 +20,10 @@ type serverMulticastWriter struct { } func newServerMulticastWriter(s *Server) (*serverMulticastWriter, error) { - res := make(chan net.IP) - select { - case s.streamMulticastIP <- streamMulticastIPReq{res: res}: - case <-s.ctx.Done(): - return nil, fmt.Errorf("terminated") + ip, err := s.getMulticastIP() + if err != nil { + return nil, err } - ip := <-res rtpl, rtcpl, err := newServerUDPListenerMulticastPair( s.ListenPacket, diff --git a/server_play_test.go b/server_play_test.go index ead9d82e..a3b03fad 100644 --- a/server_play_test.go +++ b/server_play_test.go @@ -25,6 +25,18 @@ import ( "github.com/bluenviron/gortsplib/v3/pkg/url" ) +func uintPtr(v uint) *uint { + return &v +} + +func uint16Ptr(v uint16) *uint16 { + return &v +} + +func uint32Ptr(v uint32) *uint32 { + return &v +} + func multicastCapableIP(t *testing.T) string { intfs, err := net.Interfaces() require.NoError(t, err) @@ -1895,18 +1907,12 @@ func TestServerPlayAdditionalInfos(t *testing.T) { Host: "localhost:8554", Path: mustParseURL((*rtpInfo)[0].URL).Path, }).String(), - SequenceNumber: func() *uint16 { - v := uint16(557) - return &v - }(), - Timestamp: (*rtpInfo)[0].Timestamp, + SequenceNumber: uint16Ptr(557), + Timestamp: (*rtpInfo)[0].Timestamp, }, }, rtpInfo) require.Equal(t, []*uint32{ - func() *uint32 { - v := uint32(96342362) - return &v - }(), + uint32Ptr(96342362), nil, }, ssrcs) @@ -1930,11 +1936,8 @@ func TestServerPlayAdditionalInfos(t *testing.T) { Host: "localhost:8554", Path: mustParseURL((*rtpInfo)[0].URL).Path, }).String(), - SequenceNumber: func() *uint16 { - v := uint16(557) - return &v - }(), - Timestamp: (*rtpInfo)[0].Timestamp, + SequenceNumber: uint16Ptr(557), + Timestamp: (*rtpInfo)[0].Timestamp, }, &headers.RTPInfoEntry{ URL: (&url.URL{ @@ -1942,22 +1945,13 @@ func TestServerPlayAdditionalInfos(t *testing.T) { Host: "localhost:8554", Path: mustParseURL((*rtpInfo)[1].URL).Path, }).String(), - SequenceNumber: func() *uint16 { - v := uint16(88) - return &v - }(), - Timestamp: (*rtpInfo)[1].Timestamp, + SequenceNumber: uint16Ptr(88), + Timestamp: (*rtpInfo)[1].Timestamp, }, }, rtpInfo) require.Equal(t, []*uint32{ - func() *uint32 { - v := uint32(96342362) - return &v - }(), - func() *uint32 { - v := uint32(536474323) - return &v - }(), + uint32Ptr(96342362), + uint32Ptr(536474323), }, ssrcs) } diff --git a/server_session.go b/server_session.go index afd7cbde..424e9ad2 100644 --- a/server_session.go +++ b/server_session.go @@ -192,9 +192,9 @@ type ServerSession struct { writer writer // in - request chan sessionRequestReq - connRemove chan *ServerConn - startWriter chan struct{} + chHandleRequest chan sessionRequestReq + chRemoveConn chan *ServerConn + chStartWriter chan struct{} } func newServerSession( @@ -217,9 +217,9 @@ func newServerSession( conns: make(map[*ServerConn]struct{}), lastRequestTime: time.Now(), udpCheckStreamTimer: emptyTimer(), - request: make(chan sessionRequestReq), - connRemove: make(chan *ServerConn), - startWriter: make(chan struct{}), + chHandleRequest: make(chan sessionRequestReq), + chRemoveConn: make(chan *ServerConn), + chStartWriter: make(chan struct{}), } s.wg.Add(1) @@ -354,16 +354,10 @@ func (ss *ServerSession) run() { // make sure that OnFrame() is never called after OnSessionClose() <-sc.done - select { - case sc.sessionRemove <- ss: - case <-sc.ctx.Done(): - } + sc.removeSession(ss) } - select { - case ss.s.sessionClose <- ss: - case <-ss.s.ctx.Done(): - } + ss.s.closeSession(ss) if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok { h.OnSessionClose(&ServerHandlerOnSessionCloseCtx{ @@ -376,18 +370,18 @@ func (ss *ServerSession) run() { func (ss *ServerSession) runInner() error { for { select { - case req := <-ss.request: + case req := <-ss.chHandleRequest: ss.lastRequestTime = time.Now() if _, ok := ss.conns[req.sc]; !ok { ss.conns[req.sc] = struct{}{} } - res, err := ss.handleRequest(req.sc, req.req) + res, err := ss.handleRequestInner(req.sc, req.req) returnedSession := ss - if err == nil || err == errSwitchReadFunc { + if err == nil || isErrSwitchReadFunc(err) { // ANNOUNCE responses don't contain the session header. if req.req.Method != base.Announce && req.req.Method != base.Teardown { @@ -428,11 +422,11 @@ func (ss *ServerSession) runInner() error { ss: returnedSession, } - if (err == nil || err == errSwitchReadFunc) && savedMethod == base.Teardown { + if (err == nil || isErrSwitchReadFunc(err)) && savedMethod == base.Teardown { return liberrors.ErrServerSessionTornDown{Author: req.sc.NetConn().RemoteAddr()} } - case sc := <-ss.connRemove: + case sc := <-ss.chRemoveConn: delete(ss.conns, sc) // if session is not in state RECORD or PLAY, or transport is TCP, @@ -445,7 +439,7 @@ func (ss *ServerSession) runInner() error { return liberrors.ErrServerSessionNotInUse{} } - case <-ss.startWriter: + case <-ss.chStartWriter: if (ss.state == ServerSessionStateRecord || ss.state == ServerSessionStatePlay) && *ss.setuppedTransport == TransportTCP { @@ -477,7 +471,7 @@ func (ss *ServerSession) runInner() error { } } -func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base.Response, error) { +func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (*base.Response, error) { if ss.tcpConn != nil && sc != ss.tcpConn { return &base.Response{ StatusCode: base.StatusBadRequest, @@ -926,8 +920,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base default: // TCP ss.tcpConn = sc - ss.tcpConn.readFunc = ss.tcpConn.readFuncTCP - err = errSwitchReadFunc + err = errSwitchReadFunc{true} // writer.start() is called by ServerConn after the response has been sent } @@ -1014,8 +1007,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base default: // TCP ss.tcpConn = sc - ss.tcpConn.readFunc = ss.tcpConn.readFuncTCP - err = errSwitchReadFunc + err = errSwitchReadFunc{true} // runWriter() is called by conn after sending the response } @@ -1068,8 +1060,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.udpCheckStreamTimer = emptyTimer() default: // TCP - ss.tcpConn.readFunc = ss.tcpConn.readFuncStandard - err = errSwitchReadFunc + err = errSwitchReadFunc{false} ss.tcpConn = nil } @@ -1079,8 +1070,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.udpCheckStreamTimer = emptyTimer() default: // TCP - ss.tcpConn.readFunc = ss.tcpConn.readFuncStandard - err = errSwitchReadFunc + err = errSwitchReadFunc{false} ss.tcpConn = nil } @@ -1093,8 +1083,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base var err error if (ss.state == ServerSessionStatePlay || ss.state == ServerSessionStateRecord) && *ss.setuppedTransport == TransportTCP { - ss.tcpConn.readFunc = ss.tcpConn.readFuncStandard - err = errSwitchReadFunc + err = errSwitchReadFunc{false} } return &base.Response{ @@ -1203,3 +1192,30 @@ func (ss *ServerSession) WritePacketRTCP(medi *media.Media, pkt rtcp.Packet) { ss.writePacketRTCP(medi, byts) } + +func (ss *ServerSession) handleRequest(req sessionRequestReq) (*base.Response, *ServerSession, error) { + select { + case ss.chHandleRequest <- req: + res := <-req.res + return res.res, res.ss, res.err + + case <-ss.ctx.Done(): + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, req.sc.session, liberrors.ErrServerTerminated{} + } +} + +func (ss *ServerSession) removeConn(sc *ServerConn) { + select { + case ss.chRemoveConn <- sc: + case <-ss.ctx.Done(): + } +} + +func (ss *ServerSession) startWriter() { + select { + case ss.chStartWriter <- struct{}{}: + case <-ss.ctx.Done(): + } +} diff --git a/server_tcp_listener.go b/server_tcp_listener.go new file mode 100644 index 00000000..7462020e --- /dev/null +++ b/server_tcp_listener.go @@ -0,0 +1,47 @@ +package gortsplib + +import ( + "net" +) + +type serverTCPListener struct { + s *Server + ln net.Listener +} + +func newServerTCPListener( + s *Server, +) (*serverTCPListener, error) { + ln, err := s.Listen(restrictNetwork("tcp", s.RTSPAddress)) + if err != nil { + return nil, err + } + + sl := &serverTCPListener{ + s: s, + ln: ln, + } + + s.wg.Add(1) + go sl.run() + + return sl, nil +} + +func (sl *serverTCPListener) close() { + sl.ln.Close() +} + +func (sl *serverTCPListener) run() { + defer sl.s.wg.Done() + + for { + nconn, err := sl.ln.Accept() + if err != nil { + sl.s.acceptErr(err) + return + } + + sl.s.newConn(nconn) + } +} diff --git a/server_udpl.go b/server_udp_listener.go similarity index 100% rename from server_udpl.go rename to server_udp_listener.go