fix various race conditions when writing packets to closed clients or server sessions (#684)

This commit is contained in:
Alessandro Ros
2025-01-19 12:07:59 +01:00
committed by GitHub
parent b2cfa93d13
commit ca6286321d
12 changed files with 438 additions and 219 deletions

View File

@@ -482,7 +482,7 @@ func TestClientRecordSocketError(t *testing.T) {
}
}
func TestClientRecordPauseSerial(t *testing.T) {
func TestClientRecordPauseRecordSerial(t *testing.T) {
for _, transport := range []string{
"udp",
"tcp",
@@ -618,6 +618,9 @@ func TestClientRecordPauseSerial(t *testing.T) {
_, err = c.Pause()
require.NoError(t, err)
err = c.WritePacketRTP(medi, &testRTPPacket)
require.NoError(t, err)
_, err = c.Record()
require.NoError(t, err)
@@ -627,6 +630,187 @@ func TestClientRecordPauseSerial(t *testing.T) {
}
}
func TestClientRecordPauseRecordParallel(t *testing.T) {
for _, transport := range []string{
"udp",
"tcp",
} {
t.Run(transport, func(t *testing.T) {
l, err := net.Listen("tcp", "localhost:8554")
require.NoError(t, err)
defer l.Close()
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
nconn, err2 := l.Accept()
require.NoError(t, err2)
defer nconn.Close()
conn := conn.NewConn(nconn)
req, err2 := conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Options, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
string(base.Announce),
string(base.Setup),
string(base.Record),
string(base.Pause),
}, ", ")},
},
})
require.NoError(t, err2)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Announce, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Setup, req.Method)
var inTH headers.Transport
err2 = inTH.Unmarshal(req.Header["Transport"])
require.NoError(t, err2)
th := headers.Transport{
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
}
if transport == "udp" {
th.Protocol = headers.TransportProtocolUDP
th.ServerPorts = &[2]int{34556, 34557}
th.ClientPorts = inTH.ClientPorts
} else {
th.Protocol = headers.TransportProtocolTCP
th.InterleavedIDs = inTH.InterleavedIDs
}
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": th.Marshal(),
},
})
require.NoError(t, err2)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Record, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
if transport == "tcp" {
_, err2 = conn.ReadInterleavedFrame()
require.NoError(t, err2)
}
req, err2 = readRequestIgnoreFrames(conn)
require.NoError(t, err2)
require.Equal(t, base.Pause, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Record, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
if transport == "tcp" {
_, err2 = conn.ReadInterleavedFrame()
require.NoError(t, err2)
}
req, err2 = readRequestIgnoreFrames(conn)
require.NoError(t, err2)
require.Equal(t, base.Teardown, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
}()
c := Client{
Transport: func() *Transport {
if transport == "udp" {
v := TransportUDP
return &v
}
v := TransportTCP
return &v
}(),
}
medi := testH264Media
medias := []*description.Media{medi}
err = record(&c, "rtsp://localhost:8554/teststream", medias, nil)
require.NoError(t, err)
defer c.Close()
writerTerminate := make(chan struct{})
writerDone := make(chan struct{})
defer func() {
close(writerTerminate)
<-writerDone
}()
go func() {
defer close(writerDone)
ti := time.NewTicker(50 * time.Millisecond)
defer ti.Stop()
for {
select {
case <-ti.C:
err2 := c.WritePacketRTP(medi, &testRTPPacket)
require.NoError(t, err2)
case <-writerTerminate:
return
}
}
}()
time.Sleep(500 * time.Millisecond)
_, err = c.Pause()
require.NoError(t, err)
time.Sleep(500 * time.Millisecond)
_, err = c.Record()
require.NoError(t, err)
time.Sleep(500 * time.Millisecond)
})
}
}
func TestClientRecordAutomaticProtocol(t *testing.T) {
l, err := net.Listen("tcp", "localhost:8554")
require.NoError(t, err)