server: provide path to OnSetup()

This commit is contained in:
aler9
2021-03-14 16:41:17 +01:00
parent 378c5639bb
commit d902b7da93
9 changed files with 154 additions and 204 deletions

View File

@@ -72,7 +72,7 @@ func handleConn(conn *gortsplib.ServerConn) {
}
// called after receiving a SETUP request.
onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) {
onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{

View File

@@ -71,7 +71,7 @@ func handleConn(conn *gortsplib.ServerConn) {
}
// called after receiving a SETUP request.
onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) {
onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{

View File

@@ -71,7 +71,7 @@ func handleConn(conn *gortsplib.ServerConn) {
}
// called after receiving a SETUP request.
onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) {
onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{

View File

@@ -3,37 +3,9 @@ package base
import (
"fmt"
"net/url"
"strconv"
"strings"
)
func stringsReverseIndex(s, substr string) int {
for i := len(s) - 1 - len(substr); i >= 0; i-- {
if s[i:i+len(substr)] == substr {
return i
}
}
return -1
}
// PathSplitControlAttribute splits a path and query from a control attribute.
func PathSplitControlAttribute(pathAndQuery string) (int, string, bool) {
i := stringsReverseIndex(pathAndQuery, "/trackID=")
// URL doesn't contain trackID - we assume it's track 0
if i < 0 {
return 0, pathAndQuery, true
}
tmp, err := strconv.ParseInt(pathAndQuery[i+len("/trackID="):], 10, 64)
if err != nil || tmp < 0 {
return 0, "", false
}
trackID := int(tmp)
return trackID, pathAndQuery[:i], true
}
// PathSplitQuery splits a path from a query.
func PathSplitQuery(pathAndQuery string) (string, string) {
i := strings.Index(pathAndQuery, "?")

View File

@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"net"
"strconv"
"strings"
"sync/atomic"
"time"
@@ -30,6 +31,61 @@ var (
errServerCSeqMissing = errors.New("CSeq is missing")
)
func stringsReverseIndex(s, substr string) int {
for i := len(s) - 1 - len(substr); i >= 0; i-- {
if s[i:i+len(substr)] == substr {
return i
}
}
return -1
}
func extractTrackIDAndPath(url *base.URL,
thMode *headers.TransportMode,
publishTracks []ServerConnAnnouncedTrack,
publishPath string) (int, string, error) {
pathAndQuery, ok := url.RTSPPathAndQuery()
if !ok {
return 0, "", fmt.Errorf("invalid URL (%s)", url)
}
if thMode == nil || *thMode == headers.TransportModePlay {
i := stringsReverseIndex(pathAndQuery, "/trackID=")
// URL doesn't contain trackID - it's track zero
if i < 0 {
if !strings.HasSuffix(pathAndQuery, "/") {
return 0, "", fmt.Errorf("path must end with a slash (%v)", pathAndQuery)
}
pathAndQuery = pathAndQuery[:len(pathAndQuery)-1]
// we assume it's track 0
return 0, pathAndQuery, nil
}
tmp, err := strconv.ParseInt(pathAndQuery[i+len("/trackID="):], 10, 64)
if err != nil || tmp < 0 {
return 0, "", fmt.Errorf("unable to parse track ID (%v)", pathAndQuery)
}
trackID := int(tmp)
pathAndQuery = pathAndQuery[:i]
path, _ := base.PathSplitQuery(pathAndQuery)
return trackID, path, nil
}
for trackID, track := range publishTracks {
u, _ := track.track.URL()
if u.String() == url.String() {
return trackID, publishPath, nil
}
}
return 0, "", fmt.Errorf("invalid track path (%s)", pathAndQuery)
}
// ServerConnState is the state of the connection.
type ServerConnState int
@@ -92,7 +148,7 @@ type ServerConnReadHandlers struct {
OnAnnounce func(req *base.Request, tracks Tracks) (*base.Response, error)
// called after receiving a SETUP request.
OnSetup func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error)
OnSetup func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error)
// called after receiving a PLAY request.
OnPlay func(req *base.Request) (*base.Response, error)
@@ -127,20 +183,22 @@ type ServerConn struct {
br *bufio.Reader
bw *bufio.Writer
state ServerConnState
readHandlers ServerConnReadHandlers
tracks map[int]ServerConnTrack
streamProtocol *StreamProtocol
announcedTracks []ServerConnAnnouncedTrack
// frame mode only
doEnableFrames bool
framesEnabled bool
readTimeoutEnabled bool
// writer
frameRingBuffer *ringbuffer.RingBuffer
backgroundWriteDone chan struct{}
// background record
// read only
readHandlers ServerConnReadHandlers
// publish only
publishPath string
publishTracks []ServerConnAnnouncedTrack
backgroundRecordTerminate chan struct{}
backgroundRecordDone chan struct{}
udpTimeout int32
@@ -457,14 +515,14 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
if res.StatusCode == 200 {
sc.state = ServerConnStatePreRecord
sc.publishPath = reqPath
sc.announcedTracks = make([]ServerConnAnnouncedTrack, len(tracks))
sc.publishTracks = make([]ServerConnAnnouncedTrack, len(tracks))
for trackID, track := range tracks {
clockRate, _ := track.ClockRate()
v := time.Now().Unix()
sc.announcedTracks[trackID] = ServerConnAnnouncedTrack{
sc.publishTracks[trackID] = ServerConnAnnouncedTrack{
track: track,
rtcpReceiver: rtcpreceiver.New(nil, clockRate),
udpLastFrameTime: &v,
@@ -488,13 +546,6 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, err
}
pathAndQuery, ok := req.URL.RTSPPathAndQuery()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid path (%s)", req.URL)
}
th, err := headers.ReadTransport(req.Header["Transport"])
if err != nil {
return &base.Response{
@@ -524,25 +575,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, nil
}
trackID, err := func() (int, error) {
if th.Mode == nil || *th.Mode == headers.TransportModePlay {
trackID, _, ok := base.PathSplitControlAttribute(pathAndQuery)
if !ok {
return 0, fmt.Errorf("invalid track path (%s)", pathAndQuery)
}
return trackID, nil
}
for trackID, track := range sc.announcedTracks {
u, _ := track.track.URL()
if u.String() == req.URL.String() {
return trackID, nil
}
}
return 0, fmt.Errorf("invalid track path (%s)", pathAndQuery)
}()
trackID, path, err := extractTrackIDAndPath(req.URL, th.Mode,
sc.publishTracks, sc.publishPath)
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
@@ -590,7 +624,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}
}
res, err := sc.readHandlers.OnSetup(req, th, trackID)
res, err := sc.readHandlers.OnSetup(req, th, path, trackID)
if res.StatusCode == 200 {
sc.streamProtocol = &th.Protocol
@@ -697,7 +731,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, fmt.Errorf("no tracks have been setup")
}
if len(sc.tracks) != len(sc.announcedTracks) {
if len(sc.tracks) != len(sc.publishTracks) {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("not all tracks have been setup")
@@ -860,7 +894,7 @@ outer:
// forward frame only if it has been set up
if _, ok := sc.tracks[frame.TrackID]; ok {
if sc.state == ServerConnStateRecord {
sc.announcedTracks[frame.TrackID].rtcpReceiver.ProcessFrame(time.Now(),
sc.publishTracks[frame.TrackID].rtcpReceiver.ProcessFrame(time.Now(),
frame.StreamType, frame.Payload)
}
sc.readHandlers.OnFrame(frame.TrackID, frame.StreamType, frame.Payload)
@@ -961,7 +995,7 @@ func (sc *ServerConn) backgroundRecord() {
}
now := time.Now()
for _, track := range sc.announcedTracks {
for _, track := range sc.publishTracks {
last := time.Unix(atomic.LoadInt64(track.udpLastFrameTime), 0)
if now.Sub(last) >= sc.conf.ReadTimeout {
@@ -973,7 +1007,7 @@ func (sc *ServerConn) backgroundRecord() {
case <-receiverReportTicker.C:
now := time.Now()
for trackID, track := range sc.announcedTracks {
for trackID, track := range sc.publishTracks {
r := track.rtcpReceiver.Report(now)
sc.WriteFrame(trackID, StreamTypeRTP, r)
}

View File

@@ -145,32 +145,13 @@ func (ts *testServ) handleConn(conn *ServerConn) {
}, nil
}
onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) {
switch conn.State() {
case ServerConnStateInitial, ServerConnStatePrePlay:
pathAndQuery, ok := req.URL.RTSPPathAndQuery()
if !ok {
onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) {
if path != "teststream" {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid path (%s)", req.URL)
}
_, pathAndQuery, ok = base.PathSplitControlAttribute(pathAndQuery)
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid path (%s)", req.URL)
}
reqPath, _ := base.PathSplitQuery(pathAndQuery)
if reqPath != "teststream" {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid path (%s)", req.URL)
}
}
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{

View File

@@ -21,35 +21,58 @@ func TestServerConnPublishSetupPath(t *testing.T) {
name string
control string
url string
path string
trackID int
}{
{
"normal",
"trackID=0",
"rtsp://localhost:8554/teststream/trackID=0",
"teststream",
0,
},
{
"unordered id",
"trackID=2",
"rtsp://localhost:8554/teststream/trackID=2",
"teststream",
0,
},
{
"custom param name",
"testing=0",
"rtsp://localhost:8554/teststream/testing=0",
"teststream",
0,
},
{
"query",
"?testing=0",
"rtsp://localhost:8554/teststream?testing=0",
"teststream",
0,
},
{
"subpath",
"trackID=0",
"rtsp://localhost:8554/test/stream/trackID=0",
"test/stream",
0,
},
{
"subpath and query",
"?testing=0",
"rtsp://localhost:8554/test/stream?testing=0",
"test/stream",
0,
},
} {
t.Run(ca.name, func(t *testing.T) {
setupDone := make(chan int)
type pathTrackIDPair struct {
path string
trackID int
}
setupDone := make(chan pathTrackIDPair)
s, err := Serve("127.0.0.1:8554")
require.NoError(t, err)
@@ -70,8 +93,8 @@ func TestServerConnPublishSetupPath(t *testing.T) {
}, nil
}
onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) {
setupDone <- trackID
onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) {
setupDone <- pathTrackIDPair{path, trackID}
return &base.Response{
StatusCode: base.StatusOK,
}, nil
@@ -116,7 +139,7 @@ func TestServerConnPublishSetupPath(t *testing.T) {
err = base.Request{
Method: base.Announce,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
URL: base.MustParseURL("rtsp://localhost:8554/" + ca.path),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
"Content-Type": base.HeaderValue{"application/sdp"},
@@ -153,8 +176,9 @@ func TestServerConnPublishSetupPath(t *testing.T) {
}.Write(bconn.Writer)
require.NoError(t, err)
trackID := <-setupDone
require.Equal(t, ca.trackID, trackID)
pair := <-setupDone
require.Equal(t, ca.path, pair.path)
require.Equal(t, ca.trackID, pair.trackID)
err = res.Read(bconn.Reader)
require.NoError(t, err)
@@ -197,7 +221,7 @@ func TestServerConnPublishReceivePackets(t *testing.T) {
}, nil
}
onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) {
onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil

View File

@@ -17,21 +17,47 @@ func TestServerConnReadSetupPath(t *testing.T) {
for _, ca := range []struct {
name string
url string
path string
trackID int
}{
{
"normal",
"rtsp://localhost:8554/teststream/trackID=0",
"teststream",
0,
},
{
"unordered id",
"rtsp://localhost:8554/teststream/trackID=2",
"teststream",
2,
},
{
// this is needed to support reading mpegts with ffmpeg
"without track id",
"rtsp://localhost:8554/teststream/",
"teststream",
0,
},
{
"subpath",
"rtsp://localhost:8554/test/stream/trackID=0",
"test/stream",
0,
},
{
"subpath without track id",
"rtsp://localhost:8554/test/stream/",
"test/stream",
0,
},
} {
t.Run(ca.name, func(t *testing.T) {
setupDone := make(chan int)
type pathTrackIDPair struct {
path string
trackID int
}
setupDone := make(chan pathTrackIDPair)
s, err := Serve("127.0.0.1:8554")
require.NoError(t, err)
@@ -46,8 +72,8 @@ func TestServerConnReadSetupPath(t *testing.T) {
require.NoError(t, err)
defer conn.Close()
onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) {
setupDone <- trackID
onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) {
setupDone <- pathTrackIDPair{path, trackID}
return &base.Response{
StatusCode: base.StatusOK,
}, nil
@@ -87,8 +113,9 @@ func TestServerConnReadSetupPath(t *testing.T) {
}.Write(bconn.Writer)
require.NoError(t, err)
trackID := <-setupDone
require.Equal(t, ca.trackID, trackID)
pair := <-setupDone
require.Equal(t, ca.path, pair.path)
require.Equal(t, ca.trackID, pair.trackID)
var res base.Response
err = res.Read(bconn.Reader)
@@ -124,7 +151,7 @@ func TestServerConnReadReceivePackets(t *testing.T) {
require.NoError(t, err)
defer conn.Close()
onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) {
onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
@@ -231,94 +258,6 @@ func TestServerConnReadReceivePackets(t *testing.T) {
}
}
func TestServerConnReadWithoutSetupTrackID(t *testing.T) {
s, err := Serve("127.0.0.1:8554")
require.NoError(t, err)
defer s.Close()
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
conn, err := s.Accept()
require.NoError(t, err)
defer conn.Close()
onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
}
onPlay := func(req *base.Request) (*base.Response, error) {
go func() {
time.Sleep(100 * time.Millisecond)
conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00"))
}()
return &base.Response{
StatusCode: base.StatusOK,
}, nil
}
err = <-conn.Read(ServerConnReadHandlers{
OnSetup: onSetup,
OnPlay: onPlay,
})
require.Equal(t, io.EOF, err)
}()
conn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
err = base.Request{
Method: base.Setup,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
"Transport": headers.Transport{
Protocol: StreamProtocolTCP,
Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast
return &v
}(),
Mode: func() *headers.TransportMode {
v := headers.TransportModePlay
return &v
}(),
InterleavedIds: &[2]int{0, 1},
}.Write(),
},
}.Write(bconn.Writer)
require.NoError(t, err)
var res base.Response
err = res.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
err = base.Request{
Method: base.Play,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
},
}.Write(bconn.Writer)
require.NoError(t, err)
err = res.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
var fr base.InterleavedFrame
fr.Payload = make([]byte, 2048)
err = fr.Read(bconn.Reader)
require.NoError(t, err)
}
func TestServerConnReadTCPResponseBeforeFrames(t *testing.T) {
s, err := Serve("127.0.0.1:8554")
require.NoError(t, err)
@@ -338,7 +277,7 @@ func TestServerConnReadTCPResponseBeforeFrames(t *testing.T) {
writerTerminate := make(chan struct{})
defer close(writerTerminate)
onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) {
onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
@@ -446,7 +385,7 @@ func TestServerConnReadPlayMultiple(t *testing.T) {
writerTerminate := make(chan struct{})
defer close(writerTerminate)
onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) {
onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
@@ -561,7 +500,7 @@ func TestServerConnReadPauseMultiple(t *testing.T) {
writerTerminate := make(chan struct{})
defer close(writerTerminate)
onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) {
onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil

View File

@@ -121,8 +121,8 @@ func (s *serverUDPListener) run() {
if clientData.isPublishing {
now := time.Now()
atomic.StoreInt64(clientData.sc.announcedTracks[clientData.trackID].udpLastFrameTime, now.Unix())
clientData.sc.announcedTracks[clientData.trackID].rtcpReceiver.ProcessFrame(now, s.streamType, buf[:n])
atomic.StoreInt64(clientData.sc.publishTracks[clientData.trackID].udpLastFrameTime, now.Unix())
clientData.sc.publishTracks[clientData.trackID].rtcpReceiver.ProcessFrame(now, s.streamType, buf[:n])
}
clientData.sc.readHandlers.OnFrame(clientData.trackID, s.streamType, buf[:n])