server: allow setupping one path at a time

This commit is contained in:
aler9
2021-03-15 21:00:19 +01:00
parent d2cd127695
commit f4baab63e8
3 changed files with 214 additions and 39 deletions

View File

@@ -43,7 +43,7 @@ func stringsReverseIndex(s, substr string) int {
func extractTrackIDAndPath(url *base.URL,
thMode *headers.TransportMode,
publishTracks []ServerConnAnnouncedTrack,
publishPath string) (int, string, error) {
setupPath *string) (int, string, error) {
pathAndQuery, ok := url.RTSPPathAndQuery()
if !ok {
@@ -73,13 +73,17 @@ func extractTrackIDAndPath(url *base.URL,
path, _ := base.PathSplitQuery(pathAndQuery)
if setupPath != nil && path != *setupPath {
return 0, "", fmt.Errorf("can't setup tracks with different paths")
}
return trackID, path, nil
}
for trackID, track := range publishTracks {
u, _ := track.track.URL()
if u.String() == url.String() {
return trackID, publishPath, nil
return trackID, *setupPath, nil
}
}
@@ -184,7 +188,8 @@ type ServerConn struct {
bw *bufio.Writer
state ServerConnState
tracks map[int]ServerConnTrack
streamProtocol *StreamProtocol
setupProtocol *StreamProtocol
setupPath *string
// frame mode only
doEnableFrames bool
@@ -197,7 +202,6 @@ type ServerConn struct {
readHandlers ServerConnReadHandlers
// publish only
publishPath string
publishTracks []ServerConnAnnouncedTrack
backgroundRecordTerminate chan struct{}
backgroundRecordDone chan struct{}
@@ -245,7 +249,7 @@ func (sc *ServerConn) State() ServerConnState {
// StreamProtocol returns the setupped tracks protocol.
func (sc *ServerConn) StreamProtocol() *StreamProtocol {
return sc.streamProtocol
return sc.setupProtocol
}
// HasTrack checks whether a track has been setup.
@@ -312,7 +316,7 @@ func (sc *ServerConn) zone() string {
func (sc *ServerConn) frameModeEnable() {
switch sc.state {
case ServerConnStatePlay:
if *sc.streamProtocol == StreamProtocolTCP {
if *sc.setupProtocol == StreamProtocolTCP {
sc.doEnableFrames = true
} else {
// readers can send RTCP frames, they cannot sent RTP frames
@@ -322,7 +326,7 @@ func (sc *ServerConn) frameModeEnable() {
}
case ServerConnStateRecord:
if *sc.streamProtocol == StreamProtocolTCP {
if *sc.setupProtocol == StreamProtocolTCP {
sc.doEnableFrames = true
sc.readTimeoutEnabled = true
@@ -348,7 +352,7 @@ func (sc *ServerConn) frameModeEnable() {
func (sc *ServerConn) frameModeDisable() {
switch sc.state {
case ServerConnStatePlay:
if *sc.streamProtocol == StreamProtocolTCP {
if *sc.setupProtocol == StreamProtocolTCP {
sc.framesEnabled = false
sc.frameRingBuffer.Close()
<-sc.backgroundWriteDone
@@ -363,7 +367,7 @@ func (sc *ServerConn) frameModeDisable() {
close(sc.backgroundRecordTerminate)
<-sc.backgroundRecordDone
if *sc.streamProtocol == StreamProtocolTCP {
if *sc.setupProtocol == StreamProtocolTCP {
sc.readTimeoutEnabled = false
sc.nconn.SetReadDeadline(time.Time{})
@@ -515,7 +519,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
if res.StatusCode == 200 {
sc.state = ServerConnStatePreRecord
sc.publishPath = reqPath
sc.setupPath = &reqPath
sc.publishTracks = make([]ServerConnAnnouncedTrack, len(tracks))
for trackID, track := range tracks {
@@ -553,6 +557,26 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, fmt.Errorf("transport header: %s", err)
}
if th.Delivery != nil && *th.Delivery == base.StreamDeliveryMulticast {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
trackID, path, err := extractTrackIDAndPath(req.URL, th.Mode,
sc.publishTracks, sc.setupPath)
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
if _, ok := sc.tracks[trackID]; ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("track %d has already been setup", trackID)
}
switch sc.state {
case ServerConnStateInitial, ServerConnStatePrePlay: // play
if th.Mode != nil && *th.Mode != headers.TransportModePlay {
@@ -569,32 +593,6 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}
}
if th.Delivery != nil && *th.Delivery == base.StreamDeliveryMulticast {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
}, nil
}
trackID, path, err := extractTrackIDAndPath(req.URL, th.Mode,
sc.publishTracks, sc.publishPath)
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
if _, ok := sc.tracks[trackID]; ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("track %d has already been setup", trackID)
}
if sc.streamProtocol != nil && *sc.streamProtocol != th.Protocol {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("can't setup tracks with different protocols")
}
if th.Protocol == StreamProtocolUDP {
if sc.udpRTPListener == nil {
return &base.Response{
@@ -624,10 +622,16 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}
}
if sc.setupProtocol != nil && *sc.setupProtocol != th.Protocol {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("can't setup tracks with different protocols")
}
res, err := sc.readHandlers.OnSetup(req, th, path, trackID)
if res.StatusCode == 200 {
sc.streamProtocol = &th.Protocol
sc.setupProtocol = &th.Protocol
if sc.tracks == nil {
sc.tracks = make(map[int]ServerConnTrack)
@@ -668,6 +672,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
switch sc.state {
case ServerConnStateInitial:
sc.state = ServerConnStatePrePlay
sc.setupPath = &path
}
// workaround to prevent a bug in rtspclientsink
@@ -949,7 +954,7 @@ func (sc *ServerConn) Read(readHandlers ServerConnReadHandlers) chan error {
// WriteFrame writes a frame.
func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []byte) {
if *sc.streamProtocol == StreamProtocolUDP {
if *sc.setupProtocol == StreamProtocolUDP {
track := sc.tracks[trackID]
if streamType == StreamTypeRTP {
@@ -990,7 +995,7 @@ func (sc *ServerConn) backgroundRecord() {
for {
select {
case <-checkStreamTicker.C:
if *sc.streamProtocol != StreamProtocolUDP {
if *sc.setupProtocol != StreamProtocolUDP {
continue
}

View File

@@ -187,6 +187,101 @@ func TestServerConnPublishSetupPath(t *testing.T) {
}
}
func TestServerConnPublishSetupDifferentPaths(t *testing.T) {
s, err := Serve("127.0.0.1:8554")
require.NoError(t, err)
defer s.Close()
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
conn, err := s.Accept()
require.NoError(t, err)
defer conn.Close()
onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
}
<-conn.Read(ServerConnReadHandlers{
OnSetup: onSetup,
})
}()
conn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
track, err := NewTrackH264(96, []byte("123456"), []byte("123456"))
require.NoError(t, err)
track.Media.Attributes = append(track.Media.Attributes, psdp.Attribute{
Key: "control",
Value: "trackID=0",
})
sout := &psdp.SessionDescription{
SessionName: psdp.SessionName("Stream"),
Origin: psdp.Origin{
Username: "-",
NetworkType: "IN",
AddressType: "IP4",
UnicastAddress: "127.0.0.1",
},
TimeDescriptions: []psdp.TimeDescription{
{Timing: psdp.Timing{0, 0}}, //nolint:govet
},
MediaDescriptions: []*psdp.MediaDescription{
track.Media,
},
}
byts, _ := sout.Marshal()
err = base.Request{
Method: base.Announce,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
"Content-Type": base.HeaderValue{"application/sdp"},
},
Body: byts,
}.Write(bconn.Writer)
require.NoError(t, err)
th := &headers.Transport{
Protocol: StreamProtocolTCP,
Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast
return &v
}(),
Mode: func() *headers.TransportMode {
v := headers.TransportModePlay
return &v
}(),
InterleavedIds: &[2]int{0, 1},
}
err = base.Request{
Method: base.Setup,
URL: base.MustParseURL("rtsp://localhost:8554/test2stream/trackID=0"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Transport": th.Write(),
},
}.Write(bconn.Writer)
require.NoError(t, err)
var res base.Response
err = res.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.StatusBadRequest, res.StatusCode)
}
func TestServerConnPublishReceivePackets(t *testing.T) {
for _, proto := range []string{
"udp",

View File

@@ -131,6 +131,81 @@ func TestServerConnReadSetupPath(t *testing.T) {
}
}
func TestServerConnReadSetupDifferentPaths(t *testing.T) {
s, err := Serve("127.0.0.1:8554")
require.NoError(t, err)
defer s.Close()
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
conn, err := s.Accept()
require.NoError(t, err)
defer conn.Close()
onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
}
<-conn.Read(ServerConnReadHandlers{
OnSetup: onSetup,
})
}()
conn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
th := &headers.Transport{
Protocol: StreamProtocolTCP,
Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast
return &v
}(),
Mode: func() *headers.TransportMode {
v := headers.TransportModePlay
return &v
}(),
InterleavedIds: &[2]int{0, 1},
}
err = base.Request{
Method: base.Setup,
URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
"Transport": th.Write(),
},
}.Write(bconn.Writer)
require.NoError(t, err)
var res base.Response
err = res.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
th.InterleavedIds = &[2]int{2, 3}
err = base.Request{
Method: base.Setup,
URL: base.MustParseURL("rtsp://localhost:8554/test12stream/trackID=1"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Transport": th.Write(),
},
}.Write(bconn.Writer)
require.NoError(t, err)
err = res.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.StatusBadRequest, res.StatusCode)
}
func TestServerConnReadReceivePackets(t *testing.T) {
for _, proto := range []string{
"udp",