mirror of
https://github.com/aler9/gortsplib
synced 2025-10-05 15:16:51 +08:00
server: fix bug that allowed two readers to use the same UDP ports
UDP ports of a reader that performed a SETUP request, but not a PLAY request, were not taken into account when checking port availability
This commit is contained in:
@@ -280,6 +280,75 @@ func TestServerReadSetupErrors(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerReadSetupErrorSameUDPPorts(t *testing.T) {
|
||||
track := &TrackH264{
|
||||
PayloadType: 96,
|
||||
SPS: []byte{0x01, 0x02, 0x03, 0x04},
|
||||
PPS: []byte{0x01, 0x02, 0x03, 0x04},
|
||||
}
|
||||
|
||||
stream := NewServerStream(Tracks{track})
|
||||
defer stream.Close()
|
||||
|
||||
s := &Server{
|
||||
Handler: &testServerHandler{
|
||||
onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
}, stream, nil
|
||||
},
|
||||
onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
UDPRTPAddress: "127.0.0.1:8000",
|
||||
UDPRTCPAddress: "127.0.0.1:8001",
|
||||
RTSPAddress: "localhost:8554",
|
||||
}
|
||||
|
||||
err := s.Start()
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
nconn, err := net.Dial("tcp", "localhost:8554")
|
||||
require.NoError(t, err)
|
||||
defer nconn.Close()
|
||||
conn := conn.NewConn(nconn)
|
||||
|
||||
inTH := &headers.Transport{
|
||||
Delivery: func() *headers.TransportDelivery {
|
||||
v := headers.TransportDeliveryUnicast
|
||||
return &v
|
||||
}(),
|
||||
Mode: func() *headers.TransportMode {
|
||||
v := headers.TransportModePlay
|
||||
return &v
|
||||
}(),
|
||||
Protocol: headers.TransportProtocolUDP,
|
||||
ClientPorts: &[2]int{35466, 35467},
|
||||
}
|
||||
|
||||
res, err := writeReqReadRes(conn, base.Request{
|
||||
Method: base.Setup,
|
||||
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
|
||||
Header: base.Header{
|
||||
"CSeq": base.HeaderValue{"1"},
|
||||
"Transport": inTH.Marshal(),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
if i == 0 {
|
||||
require.Equal(t, base.StatusOK, res.StatusCode)
|
||||
} else {
|
||||
require.Equal(t, base.StatusBadRequest, res.StatusCode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerRead(t *testing.T) {
|
||||
for _, transport := range []string{
|
||||
"udp",
|
||||
@@ -1871,113 +1940,3 @@ func TestServerReadAdditionalInfos(t *testing.T) {
|
||||
}(),
|
||||
}, ssrcs)
|
||||
}
|
||||
|
||||
func TestServerReadErrorUDPSamePorts(t *testing.T) {
|
||||
track := &TrackH264{
|
||||
PayloadType: 96,
|
||||
SPS: []byte{0x01, 0x02, 0x03, 0x04},
|
||||
PPS: []byte{0x01, 0x02, 0x03, 0x04},
|
||||
}
|
||||
|
||||
stream := NewServerStream(Tracks{track})
|
||||
defer stream.Close()
|
||||
|
||||
s := &Server{
|
||||
Handler: &testServerHandler{
|
||||
onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
}, stream, nil
|
||||
},
|
||||
onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
UDPRTPAddress: "127.0.0.1:8000",
|
||||
UDPRTCPAddress: "127.0.0.1:8001",
|
||||
RTSPAddress: "localhost:8554",
|
||||
}
|
||||
|
||||
err := s.Start()
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
func() {
|
||||
nconn, err := net.Dial("tcp", "localhost:8554")
|
||||
require.NoError(t, err)
|
||||
defer nconn.Close()
|
||||
conn := conn.NewConn(nconn)
|
||||
|
||||
inTH := &headers.Transport{
|
||||
Delivery: func() *headers.TransportDelivery {
|
||||
v := headers.TransportDeliveryUnicast
|
||||
return &v
|
||||
}(),
|
||||
Mode: func() *headers.TransportMode {
|
||||
v := headers.TransportModePlay
|
||||
return &v
|
||||
}(),
|
||||
Protocol: headers.TransportProtocolUDP,
|
||||
ClientPorts: &[2]int{35466, 35467},
|
||||
}
|
||||
|
||||
res, err := writeReqReadRes(conn, base.Request{
|
||||
Method: base.Setup,
|
||||
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
|
||||
Header: base.Header{
|
||||
"CSeq": base.HeaderValue{"1"},
|
||||
"Transport": inTH.Marshal(),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, base.StatusOK, res.StatusCode)
|
||||
|
||||
var sx headers.Session
|
||||
err = sx.Unmarshal(res.Header["Session"])
|
||||
require.NoError(t, err)
|
||||
|
||||
res, err = writeReqReadRes(conn, base.Request{
|
||||
Method: base.Play,
|
||||
URL: mustParseURL("rtsp://localhost:8554/teststream"),
|
||||
Header: base.Header{
|
||||
"CSeq": base.HeaderValue{"2"},
|
||||
"Session": base.HeaderValue{sx.Session},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, base.StatusOK, res.StatusCode)
|
||||
}()
|
||||
|
||||
func() {
|
||||
nconn, err := net.Dial("tcp", "localhost:8554")
|
||||
require.NoError(t, err)
|
||||
defer nconn.Close()
|
||||
conn := conn.NewConn(nconn)
|
||||
|
||||
inTH := &headers.Transport{
|
||||
Delivery: func() *headers.TransportDelivery {
|
||||
v := headers.TransportDeliveryUnicast
|
||||
return &v
|
||||
}(),
|
||||
Mode: func() *headers.TransportMode {
|
||||
v := headers.TransportModePlay
|
||||
return &v
|
||||
}(),
|
||||
Protocol: headers.TransportProtocolUDP,
|
||||
ClientPorts: &[2]int{35466, 35467},
|
||||
}
|
||||
|
||||
res, err := writeReqReadRes(conn, base.Request{
|
||||
Method: base.Setup,
|
||||
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
|
||||
Header: base.Header{
|
||||
"CSeq": base.HeaderValue{"1"},
|
||||
"Transport": inTH.Marshal(),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, base.StatusBadRequest, res.StatusCode)
|
||||
}()
|
||||
}
|
||||
|
@@ -21,7 +21,7 @@ type serverStreamTrack struct {
|
||||
udpRTCPSender *rtcpsender.RTCPSender
|
||||
}
|
||||
|
||||
// ServerStream represents a single stream.
|
||||
// ServerStream represents a data stream.
|
||||
// This is in charge of
|
||||
// - distributing the stream to each reader
|
||||
// - allocating multicast listeners
|
||||
@@ -31,10 +31,10 @@ type ServerStream struct {
|
||||
|
||||
mutex sync.RWMutex
|
||||
s *Server
|
||||
readersUnicast map[*ServerSession]struct{}
|
||||
activeUnicastReaders map[*ServerSession]struct{}
|
||||
readers map[*ServerSession]struct{}
|
||||
serverMulticastHandlers []*serverMulticastHandler
|
||||
stTracks []*serverStreamTrack
|
||||
ssTracks []*serverStreamTrack
|
||||
}
|
||||
|
||||
// NewServerStream allocates a ServerStream.
|
||||
@@ -43,14 +43,14 @@ func NewServerStream(tracks Tracks) *ServerStream {
|
||||
tracks.setControls()
|
||||
|
||||
st := &ServerStream{
|
||||
tracks: tracks,
|
||||
readersUnicast: make(map[*ServerSession]struct{}),
|
||||
readers: make(map[*ServerSession]struct{}),
|
||||
tracks: tracks,
|
||||
activeUnicastReaders: make(map[*ServerSession]struct{}),
|
||||
readers: make(map[*ServerSession]struct{}),
|
||||
}
|
||||
|
||||
st.stTracks = make([]*serverStreamTrack, len(tracks))
|
||||
for i := range st.stTracks {
|
||||
st.stTracks[i] = &serverStreamTrack{}
|
||||
st.ssTracks = make([]*serverStreamTrack, len(tracks))
|
||||
for i := range st.ssTracks {
|
||||
st.ssTracks[i] = &serverStreamTrack{}
|
||||
}
|
||||
|
||||
return st
|
||||
@@ -73,7 +73,7 @@ func (st *ServerStream) Close() error {
|
||||
}
|
||||
|
||||
st.readers = nil
|
||||
st.readersUnicast = nil
|
||||
st.activeUnicastReaders = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -86,14 +86,14 @@ func (st *ServerStream) Tracks() Tracks {
|
||||
func (st *ServerStream) ssrc(trackID int) uint32 {
|
||||
st.mutex.Lock()
|
||||
defer st.mutex.Unlock()
|
||||
return st.stTracks[trackID].lastSSRC
|
||||
return st.ssTracks[trackID].lastSSRC
|
||||
}
|
||||
|
||||
func (st *ServerStream) rtpInfo(trackID int, now time.Time) (uint16, uint32, bool) {
|
||||
st.mutex.Lock()
|
||||
defer st.mutex.Unlock()
|
||||
|
||||
track := st.stTracks[trackID]
|
||||
track := st.ssTracks[trackID]
|
||||
|
||||
if !track.firstPacketSent {
|
||||
return 0, 0, false
|
||||
@@ -125,31 +125,10 @@ func (st *ServerStream) readerAdd(
|
||||
return fmt.Errorf("stream is closed")
|
||||
}
|
||||
|
||||
if st.s == nil {
|
||||
st.s = ss.s
|
||||
|
||||
for trackID, track := range st.stTracks {
|
||||
cTrackID := trackID
|
||||
|
||||
// always generate RTCP sender reports.
|
||||
// they're mandatory needed when transport protocol is UDP or UDP-multicast.
|
||||
// they're also needed when transport protocol is TCP and client is Nvidia Deepstream
|
||||
// since they're used to compute NTP timestamp of frames:
|
||||
// https://docs.nvidia.com/metropolis/deepstream/dev-guide/text/DS_NTP_Timestamp.html
|
||||
track.udpRTCPSender = rtcpsender.New(
|
||||
st.s.udpSenderReportPeriod,
|
||||
st.tracks[trackID].ClockRate(),
|
||||
func(pkt rtcp.Packet) {
|
||||
st.WritePacketRTCP(cTrackID, pkt)
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
switch transport {
|
||||
case TransportUDP:
|
||||
// check if client ports are already in use by another reader.
|
||||
for r := range st.readersUnicast {
|
||||
// check if client ports are already in use by another reader
|
||||
for r := range st.readers {
|
||||
if *r.setuppedTransport == TransportUDP &&
|
||||
r.author.ip().Equal(ss.author.ip()) &&
|
||||
r.author.zone() == ss.author.zone() {
|
||||
@@ -183,6 +162,27 @@ func (st *ServerStream) readerAdd(
|
||||
}
|
||||
}
|
||||
|
||||
if st.s == nil {
|
||||
st.s = ss.s
|
||||
|
||||
for trackID, track := range st.ssTracks {
|
||||
cTrackID := trackID
|
||||
|
||||
// always generate RTCP sender reports.
|
||||
// they're mandatory when transport protocol is UDP or UDP-multicast.
|
||||
// they're also needed when transport protocol is TCP and client is Nvidia Deepstream
|
||||
// since they're used to compute NTP timestamp of frames:
|
||||
// https://docs.nvidia.com/metropolis/deepstream/dev-guide/text/DS_NTP_Timestamp.html
|
||||
track.udpRTCPSender = rtcpsender.New(
|
||||
st.s.udpSenderReportPeriod,
|
||||
st.tracks[trackID].ClockRate(),
|
||||
func(pkt rtcp.Packet) {
|
||||
st.WritePacketRTCP(cTrackID, pkt)
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
st.readers[ss] = struct{}{}
|
||||
|
||||
return nil
|
||||
@@ -209,7 +209,7 @@ func (st *ServerStream) readerSetActive(ss *ServerSession) {
|
||||
|
||||
switch *ss.setuppedTransport {
|
||||
case TransportUDP, TransportTCP:
|
||||
st.readersUnicast[ss] = struct{}{}
|
||||
st.activeUnicastReaders[ss] = struct{}{}
|
||||
|
||||
default: // UDPMulticast
|
||||
for trackID, track := range ss.setuppedTracks {
|
||||
@@ -225,7 +225,7 @@ func (st *ServerStream) readerSetInactive(ss *ServerSession) {
|
||||
|
||||
switch *ss.setuppedTransport {
|
||||
case TransportUDP, TransportTCP:
|
||||
delete(st.readersUnicast, ss)
|
||||
delete(st.activeUnicastReaders, ss)
|
||||
|
||||
default: // UDPMulticast
|
||||
if st.serverMulticastHandlers != nil {
|
||||
@@ -248,7 +248,7 @@ func (st *ServerStream) WritePacketRTP(trackID int, pkt *rtp.Packet, ptsEqualsDT
|
||||
st.mutex.RLock()
|
||||
defer st.mutex.RUnlock()
|
||||
|
||||
track := st.stTracks[trackID]
|
||||
track := st.ssTracks[trackID]
|
||||
now := time.Now()
|
||||
|
||||
if !track.firstPacketSent || ptsEqualsDTS {
|
||||
@@ -265,7 +265,7 @@ func (st *ServerStream) WritePacketRTP(trackID int, pkt *rtp.Packet, ptsEqualsDT
|
||||
}
|
||||
|
||||
// send unicast
|
||||
for r := range st.readersUnicast {
|
||||
for r := range st.activeUnicastReaders {
|
||||
r.writePacketRTP(trackID, byts)
|
||||
}
|
||||
|
||||
@@ -286,7 +286,7 @@ func (st *ServerStream) WritePacketRTCP(trackID int, pkt rtcp.Packet) {
|
||||
defer st.mutex.RUnlock()
|
||||
|
||||
// send unicast
|
||||
for r := range st.readersUnicast {
|
||||
for r := range st.activeUnicastReaders {
|
||||
r.writePacketRTCP(trackID, byts)
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user