server: fix bug that allowed two readers to use the same UDP ports

UDP ports of a reader that performed a SETUP request, but not a PLAY
request, were not taken into account when checking port availability
This commit is contained in:
aler9
2022-10-28 14:37:28 +02:00
parent 145b21ef3d
commit 0b75c240c7
2 changed files with 109 additions and 150 deletions

View File

@@ -280,6 +280,75 @@ func TestServerReadSetupErrors(t *testing.T) {
} }
} }
func TestServerReadSetupErrorSameUDPPorts(t *testing.T) {
track := &TrackH264{
PayloadType: 96,
SPS: []byte{0x01, 0x02, 0x03, 0x04},
PPS: []byte{0x01, 0x02, 0x03, 0x04},
}
stream := NewServerStream(Tracks{track})
defer stream.Close()
s := &Server{
Handler: &testServerHandler{
onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
},
UDPRTPAddress: "127.0.0.1:8000",
UDPRTCPAddress: "127.0.0.1:8001",
RTSPAddress: "localhost:8554",
}
err := s.Start()
require.NoError(t, err)
defer s.Close()
for i := 0; i < 2; i++ {
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer nconn.Close()
conn := conn.NewConn(nconn)
inTH := &headers.Transport{
Delivery: func() *headers.TransportDelivery {
v := headers.TransportDeliveryUnicast
return &v
}(),
Mode: func() *headers.TransportMode {
v := headers.TransportModePlay
return &v
}(),
Protocol: headers.TransportProtocolUDP,
ClientPorts: &[2]int{35466, 35467},
}
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
"Transport": inTH.Marshal(),
},
})
require.NoError(t, err)
if i == 0 {
require.Equal(t, base.StatusOK, res.StatusCode)
} else {
require.Equal(t, base.StatusBadRequest, res.StatusCode)
}
}
}
func TestServerRead(t *testing.T) { func TestServerRead(t *testing.T) {
for _, transport := range []string{ for _, transport := range []string{
"udp", "udp",
@@ -1871,113 +1940,3 @@ func TestServerReadAdditionalInfos(t *testing.T) {
}(), }(),
}, ssrcs) }, ssrcs)
} }
func TestServerReadErrorUDPSamePorts(t *testing.T) {
track := &TrackH264{
PayloadType: 96,
SPS: []byte{0x01, 0x02, 0x03, 0x04},
PPS: []byte{0x01, 0x02, 0x03, 0x04},
}
stream := NewServerStream(Tracks{track})
defer stream.Close()
s := &Server{
Handler: &testServerHandler{
onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
},
UDPRTPAddress: "127.0.0.1:8000",
UDPRTCPAddress: "127.0.0.1:8001",
RTSPAddress: "localhost:8554",
}
err := s.Start()
require.NoError(t, err)
defer s.Close()
func() {
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer nconn.Close()
conn := conn.NewConn(nconn)
inTH := &headers.Transport{
Delivery: func() *headers.TransportDelivery {
v := headers.TransportDeliveryUnicast
return &v
}(),
Mode: func() *headers.TransportMode {
v := headers.TransportModePlay
return &v
}(),
Protocol: headers.TransportProtocolUDP,
ClientPorts: &[2]int{35466, 35467},
}
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
"Transport": inTH.Marshal(),
},
})
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
var sx headers.Session
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, base.Request{
Method: base.Play,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Session": base.HeaderValue{sx.Session},
},
})
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
}()
func() {
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer nconn.Close()
conn := conn.NewConn(nconn)
inTH := &headers.Transport{
Delivery: func() *headers.TransportDelivery {
v := headers.TransportDeliveryUnicast
return &v
}(),
Mode: func() *headers.TransportMode {
v := headers.TransportModePlay
return &v
}(),
Protocol: headers.TransportProtocolUDP,
ClientPorts: &[2]int{35466, 35467},
}
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
"Transport": inTH.Marshal(),
},
})
require.NoError(t, err)
require.Equal(t, base.StatusBadRequest, res.StatusCode)
}()
}

View File

@@ -21,7 +21,7 @@ type serverStreamTrack struct {
udpRTCPSender *rtcpsender.RTCPSender udpRTCPSender *rtcpsender.RTCPSender
} }
// ServerStream represents a single stream. // ServerStream represents a data stream.
// This is in charge of // This is in charge of
// - distributing the stream to each reader // - distributing the stream to each reader
// - allocating multicast listeners // - allocating multicast listeners
@@ -31,10 +31,10 @@ type ServerStream struct {
mutex sync.RWMutex mutex sync.RWMutex
s *Server s *Server
readersUnicast map[*ServerSession]struct{} activeUnicastReaders map[*ServerSession]struct{}
readers map[*ServerSession]struct{} readers map[*ServerSession]struct{}
serverMulticastHandlers []*serverMulticastHandler serverMulticastHandlers []*serverMulticastHandler
stTracks []*serverStreamTrack ssTracks []*serverStreamTrack
} }
// NewServerStream allocates a ServerStream. // NewServerStream allocates a ServerStream.
@@ -43,14 +43,14 @@ func NewServerStream(tracks Tracks) *ServerStream {
tracks.setControls() tracks.setControls()
st := &ServerStream{ st := &ServerStream{
tracks: tracks, tracks: tracks,
readersUnicast: make(map[*ServerSession]struct{}), activeUnicastReaders: make(map[*ServerSession]struct{}),
readers: make(map[*ServerSession]struct{}), readers: make(map[*ServerSession]struct{}),
} }
st.stTracks = make([]*serverStreamTrack, len(tracks)) st.ssTracks = make([]*serverStreamTrack, len(tracks))
for i := range st.stTracks { for i := range st.ssTracks {
st.stTracks[i] = &serverStreamTrack{} st.ssTracks[i] = &serverStreamTrack{}
} }
return st return st
@@ -73,7 +73,7 @@ func (st *ServerStream) Close() error {
} }
st.readers = nil st.readers = nil
st.readersUnicast = nil st.activeUnicastReaders = nil
return nil return nil
} }
@@ -86,14 +86,14 @@ func (st *ServerStream) Tracks() Tracks {
func (st *ServerStream) ssrc(trackID int) uint32 { func (st *ServerStream) ssrc(trackID int) uint32 {
st.mutex.Lock() st.mutex.Lock()
defer st.mutex.Unlock() defer st.mutex.Unlock()
return st.stTracks[trackID].lastSSRC return st.ssTracks[trackID].lastSSRC
} }
func (st *ServerStream) rtpInfo(trackID int, now time.Time) (uint16, uint32, bool) { func (st *ServerStream) rtpInfo(trackID int, now time.Time) (uint16, uint32, bool) {
st.mutex.Lock() st.mutex.Lock()
defer st.mutex.Unlock() defer st.mutex.Unlock()
track := st.stTracks[trackID] track := st.ssTracks[trackID]
if !track.firstPacketSent { if !track.firstPacketSent {
return 0, 0, false return 0, 0, false
@@ -125,31 +125,10 @@ func (st *ServerStream) readerAdd(
return fmt.Errorf("stream is closed") return fmt.Errorf("stream is closed")
} }
if st.s == nil {
st.s = ss.s
for trackID, track := range st.stTracks {
cTrackID := trackID
// always generate RTCP sender reports.
// they're mandatory needed when transport protocol is UDP or UDP-multicast.
// they're also needed when transport protocol is TCP and client is Nvidia Deepstream
// since they're used to compute NTP timestamp of frames:
// https://docs.nvidia.com/metropolis/deepstream/dev-guide/text/DS_NTP_Timestamp.html
track.udpRTCPSender = rtcpsender.New(
st.s.udpSenderReportPeriod,
st.tracks[trackID].ClockRate(),
func(pkt rtcp.Packet) {
st.WritePacketRTCP(cTrackID, pkt)
},
)
}
}
switch transport { switch transport {
case TransportUDP: case TransportUDP:
// check if client ports are already in use by another reader. // check if client ports are already in use by another reader
for r := range st.readersUnicast { for r := range st.readers {
if *r.setuppedTransport == TransportUDP && if *r.setuppedTransport == TransportUDP &&
r.author.ip().Equal(ss.author.ip()) && r.author.ip().Equal(ss.author.ip()) &&
r.author.zone() == ss.author.zone() { r.author.zone() == ss.author.zone() {
@@ -183,6 +162,27 @@ func (st *ServerStream) readerAdd(
} }
} }
if st.s == nil {
st.s = ss.s
for trackID, track := range st.ssTracks {
cTrackID := trackID
// always generate RTCP sender reports.
// they're mandatory when transport protocol is UDP or UDP-multicast.
// they're also needed when transport protocol is TCP and client is Nvidia Deepstream
// since they're used to compute NTP timestamp of frames:
// https://docs.nvidia.com/metropolis/deepstream/dev-guide/text/DS_NTP_Timestamp.html
track.udpRTCPSender = rtcpsender.New(
st.s.udpSenderReportPeriod,
st.tracks[trackID].ClockRate(),
func(pkt rtcp.Packet) {
st.WritePacketRTCP(cTrackID, pkt)
},
)
}
}
st.readers[ss] = struct{}{} st.readers[ss] = struct{}{}
return nil return nil
@@ -209,7 +209,7 @@ func (st *ServerStream) readerSetActive(ss *ServerSession) {
switch *ss.setuppedTransport { switch *ss.setuppedTransport {
case TransportUDP, TransportTCP: case TransportUDP, TransportTCP:
st.readersUnicast[ss] = struct{}{} st.activeUnicastReaders[ss] = struct{}{}
default: // UDPMulticast default: // UDPMulticast
for trackID, track := range ss.setuppedTracks { for trackID, track := range ss.setuppedTracks {
@@ -225,7 +225,7 @@ func (st *ServerStream) readerSetInactive(ss *ServerSession) {
switch *ss.setuppedTransport { switch *ss.setuppedTransport {
case TransportUDP, TransportTCP: case TransportUDP, TransportTCP:
delete(st.readersUnicast, ss) delete(st.activeUnicastReaders, ss)
default: // UDPMulticast default: // UDPMulticast
if st.serverMulticastHandlers != nil { if st.serverMulticastHandlers != nil {
@@ -248,7 +248,7 @@ func (st *ServerStream) WritePacketRTP(trackID int, pkt *rtp.Packet, ptsEqualsDT
st.mutex.RLock() st.mutex.RLock()
defer st.mutex.RUnlock() defer st.mutex.RUnlock()
track := st.stTracks[trackID] track := st.ssTracks[trackID]
now := time.Now() now := time.Now()
if !track.firstPacketSent || ptsEqualsDTS { if !track.firstPacketSent || ptsEqualsDTS {
@@ -265,7 +265,7 @@ func (st *ServerStream) WritePacketRTP(trackID int, pkt *rtp.Packet, ptsEqualsDT
} }
// send unicast // send unicast
for r := range st.readersUnicast { for r := range st.activeUnicastReaders {
r.writePacketRTP(trackID, byts) r.writePacketRTP(trackID, byts)
} }
@@ -286,7 +286,7 @@ func (st *ServerStream) WritePacketRTCP(trackID int, pkt rtcp.Packet) {
defer st.mutex.RUnlock() defer st.mutex.RUnlock()
// send unicast // send unicast
for r := range st.readersUnicast { for r := range st.activeUnicastReaders {
r.writePacketRTCP(trackID, byts) r.writePacketRTCP(trackID, byts)
} }