client: support server-sent requests (#93) (#378)

This commit is contained in:
Alessandro Ros
2023-08-24 16:07:09 +02:00
committed by GitHub
parent bff0b57fbe
commit ed4bbe3986
16 changed files with 659 additions and 492 deletions

373
client.go
View File

@@ -315,18 +315,19 @@ type Client struct {
closeError error closeError error
writer asyncProcessor writer asyncProcessor
reader *clientReader reader *clientReader
connCloser *clientConnCloser
timeDecoder *rtptime.GlobalDecoder timeDecoder *rtptime.GlobalDecoder
// in // in
options chan optionsReq chOptions chan optionsReq
describe chan describeReq chDescribe chan describeReq
announce chan announceReq chAnnounce chan announceReq
setup chan setupReq chSetup chan setupReq
play chan playReq chPlay chan playReq
record chan recordReq chRecord chan recordReq
pause chan pauseReq chPause chan pauseReq
readError chan error chReadError chan error
chReadResponse chan *base.Response
chReadRequest chan *base.Request
// out // out
done chan struct{} done chan struct{}
@@ -425,14 +426,16 @@ func (c *Client) Start(scheme string, host string) error {
c.ctxCancel = ctxCancel c.ctxCancel = ctxCancel
c.checkTimeoutTimer = emptyTimer() c.checkTimeoutTimer = emptyTimer()
c.keepaliveTimer = emptyTimer() c.keepaliveTimer = emptyTimer()
c.options = make(chan optionsReq) c.chOptions = make(chan optionsReq)
c.describe = make(chan describeReq) c.chDescribe = make(chan describeReq)
c.announce = make(chan announceReq) c.chAnnounce = make(chan announceReq)
c.setup = make(chan setupReq) c.chSetup = make(chan setupReq)
c.play = make(chan playReq) c.chPlay = make(chan playReq)
c.record = make(chan recordReq) c.chRecord = make(chan recordReq)
c.pause = make(chan pauseReq) c.chPause = make(chan pauseReq)
c.readError = make(chan error) c.chReadError = make(chan error)
c.chReadResponse = make(chan *base.Response)
c.chReadRequest = make(chan *base.Request)
c.done = make(chan struct{}) c.done = make(chan struct{})
go c.run() go c.run()
@@ -499,76 +502,133 @@ func (c *Client) run() {
func (c *Client) runInner() error { func (c *Client) runInner() error {
for { for {
select { select {
case req := <-c.options: case req := <-c.chOptions:
res, err := c.doOptions(req.url) res, err := c.doOptions(req.url)
req.res <- clientRes{res: res, err: err} req.res <- clientRes{res: res, err: err}
case req := <-c.describe: case req := <-c.chDescribe:
sd, res, err := c.doDescribe(req.url) sd, res, err := c.doDescribe(req.url)
req.res <- clientRes{sd: sd, res: res, err: err} req.res <- clientRes{sd: sd, res: res, err: err}
case req := <-c.announce: case req := <-c.chAnnounce:
res, err := c.doAnnounce(req.url, req.desc) res, err := c.doAnnounce(req.url, req.desc)
req.res <- clientRes{res: res, err: err} req.res <- clientRes{res: res, err: err}
case req := <-c.setup: case req := <-c.chSetup:
res, err := c.doSetup(req.baseURL, req.media, req.rtpPort, req.rtcpPort) res, err := c.doSetup(req.baseURL, req.media, req.rtpPort, req.rtcpPort)
req.res <- clientRes{res: res, err: err} req.res <- clientRes{res: res, err: err}
case req := <-c.play: case req := <-c.chPlay:
res, err := c.doPlay(req.ra) res, err := c.doPlay(req.ra)
req.res <- clientRes{res: res, err: err} req.res <- clientRes{res: res, err: err}
case req := <-c.record: case req := <-c.chRecord:
res, err := c.doRecord() res, err := c.doRecord()
req.res <- clientRes{res: res, err: err} req.res <- clientRes{res: res, err: err}
case req := <-c.pause: case req := <-c.chPause:
res, err := c.doPause() res, err := c.doPause()
req.res <- clientRes{res: res, err: err} req.res <- clientRes{res: res, err: err}
case <-c.checkTimeoutTimer.C: case <-c.checkTimeoutTimer.C:
err := c.checkTimeout() err := c.doCheckTimeout()
if err != nil { if err != nil {
return err return err
} }
c.checkTimeoutTimer = time.NewTimer(c.checkTimeoutPeriod) c.checkTimeoutTimer = time.NewTimer(c.checkTimeoutPeriod)
case <-c.keepaliveTimer.C: case <-c.keepaliveTimer.C:
err := c.doKeepalive() err := c.doKeepAlive()
if err != nil { if err != nil {
return err return err
} }
c.keepaliveTimer = time.NewTimer(c.keepalivePeriod) c.keepaliveTimer = time.NewTimer(c.keepalivePeriod)
case err := <-c.readError: case err := <-c.chReadError:
c.reader = nil c.reader = nil
return err return err
case <-c.chReadResponse:
return liberrors.ErrClientUnexpectedResponse{}
case req := <-c.chReadRequest:
err := c.handleServerRequest(req)
if err != nil {
return err
}
case <-c.ctx.Done(): case <-c.ctx.Done():
return liberrors.ErrClientTerminated{} return liberrors.ErrClientTerminated{}
} }
} }
} }
func (c *Client) waitResponse() (*base.Response, error) {
for {
select {
case <-time.After(c.ReadTimeout):
return nil, liberrors.ErrClientRequestTimedOut{}
case err := <-c.chReadError:
c.reader = nil
return nil, err
case res := <-c.chReadResponse:
return res, nil
case req := <-c.chReadRequest:
err := c.handleServerRequest(req)
if err != nil {
return nil, err
}
case <-c.ctx.Done():
return nil, liberrors.ErrClientTerminated{}
}
}
}
func (c *Client) handleServerRequest(req *base.Request) error {
if req.Method != base.Options {
return liberrors.ErrClientUnhandledMethod{Method: req.Method}
}
if cseq, ok := req.Header["CSeq"]; !ok || len(cseq) != 1 {
return liberrors.ErrClientMissingCSeq{}
}
res := &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"User-Agent": base.HeaderValue{c.UserAgent},
"CSeq": req.Header["CSeq"],
},
}
c.nconn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
return c.conn.WriteResponse(res)
}
func (c *Client) doClose() { func (c *Client) doClose() {
if c.connCloser != nil {
c.connCloser.close()
c.connCloser = nil
}
if c.state == clientStatePlay || c.state == clientStateRecord { if c.state == clientStatePlay || c.state == clientStateRecord {
c.playRecordStop(true) c.stopWriter()
c.stopReadRoutines()
} }
if c.baseURL != nil { if c.nconn != nil && c.baseURL != nil {
c.do(&base.Request{ //nolint:errcheck c.do(&base.Request{ //nolint:errcheck
Method: base.Teardown, Method: base.Teardown,
URL: c.baseURL, URL: c.baseURL,
}, true, false) }, true)
} }
if c.nconn != nil { if c.reader != nil {
c.nconn.Close()
c.reader.wait()
c.reader = nil
c.nconn = nil
c.conn = nil
} else if c.nconn != nil {
c.nconn.Close() c.nconn.Close()
c.nconn = nil c.nconn = nil
c.conn = nil c.conn = nil
@@ -668,12 +728,8 @@ func (c *Client) trySwitchingProtocol2(medi *description.Media, baseURL *url.URL
return c.doSetup(baseURL, medi, 0, 0) return c.doSetup(baseURL, medi, 0, 0)
} }
func (c *Client) playRecordStart() { func (c *Client) startReadRoutines() {
c.connCloser.close() // allocate writer here because it's needed by RTCP receiver / sender
c.connCloser = nil
c.timeDecoder = rtptime.NewGlobalDecoder()
if c.state == clientStatePlay { if c.state == clientStatePlay {
// when reading, buffer is only used to send RTCP receiver reports, // when reading, buffer is only used to send RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval. // that are much smaller than RTP packets and are sent at a fixed interval.
@@ -683,7 +739,7 @@ func (c *Client) playRecordStart() {
c.writer.allocateBuffer(c.WriteBufferCount) c.writer.allocateBuffer(c.WriteBufferCount)
} }
c.writer.start() c.timeDecoder = rtptime.NewGlobalDecoder()
for _, cm := range c.medias { for _, cm := range c.medias {
cm.start() cm.start()
@@ -707,14 +763,14 @@ func (c *Client) playRecordStart() {
} }
} }
c.reader = newClientReader(c) if *c.effectiveTransport == TransportTCP {
c.reader.setAllowInterleavedFrames(true)
}
} }
func (c *Client) playRecordStop(isClosing bool) { func (c *Client) stopReadRoutines() {
if c.reader != nil { if c.reader != nil {
c.reader.close() c.reader.setAllowInterleavedFrames(false)
<-c.readError
c.reader = nil
} }
c.checkTimeoutTimer = emptyTimer() c.checkTimeoutTimer = emptyTimer()
@@ -724,22 +780,28 @@ func (c *Client) playRecordStop(isClosing bool) {
cm.stop() cm.stop()
} }
c.writer.stop()
c.timeDecoder = nil c.timeDecoder = nil
if !isClosing {
c.connCloser = newClientConnCloser(c.ctx, c.nconn)
} }
func (c *Client) startWriter() {
c.writer.start()
}
func (c *Client) stopWriter() {
c.writer.stop()
} }
func (c *Client) connOpen() error { func (c *Client) connOpen() error {
if c.nconn != nil {
return nil
}
if c.connURL.Scheme != "rtsp" && c.connURL.Scheme != "rtsps" { if c.connURL.Scheme != "rtsp" && c.connURL.Scheme != "rtsps" {
return fmt.Errorf("unsupported scheme '%s'", c.connURL.Scheme) return liberrors.ErrClientUnsupportedScheme{Scheme: c.connURL.Scheme}
} }
if c.connURL.Scheme == "rtsps" && c.Transport != nil && *c.Transport != TransportTCP { if c.connURL.Scheme == "rtsps" && c.Transport != nil && *c.Transport != TransportTCP {
return fmt.Errorf("RTSPS can be used only with TCP") return liberrors.ErrClientRTSPSTCP{}
} }
dialCtx, dialCtxCancel := context.WithTimeout(c.ctx, c.ReadTimeout) dialCtx, dialCtxCancel := context.WithTimeout(c.ctx, c.ReadTimeout)
@@ -752,11 +814,9 @@ func (c *Client) connOpen() error {
if c.connURL.Scheme == "rtsps" { if c.connURL.Scheme == "rtsps" {
tlsConfig := c.TLSConfig tlsConfig := c.TLSConfig
if tlsConfig == nil { if tlsConfig == nil {
tlsConfig = &tls.Config{} tlsConfig = &tls.Config{}
} }
tlsConfig.ServerName = c.connURL.Hostname() tlsConfig.ServerName = c.connURL.Hostname()
nconn = tls.Client(nconn, tlsConfig) nconn = tls.Client(nconn, tlsConfig)
@@ -765,19 +825,12 @@ func (c *Client) connOpen() error {
c.nconn = nconn c.nconn = nconn
bc := bytecounter.New(c.nconn, c.BytesReceived, c.BytesSent) bc := bytecounter.New(c.nconn, c.BytesReceived, c.BytesSent)
c.conn = conn.NewConn(bc) c.conn = conn.NewConn(bc)
c.connCloser = newClientConnCloser(c.ctx, c.nconn) c.reader = newClientReader(c)
return nil return nil
} }
func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*base.Response, error) { func (c *Client) do(req *base.Request, skipResponse bool) (*base.Response, error) {
if c.nconn == nil {
err := c.connOpen()
if err != nil {
return nil, err
}
}
if !c.optionsSent && req.Method != base.Options { if !c.optionsSent && req.Method != base.Options {
_, err := c.doOptions(req.URL) _, err := c.doOptions(req.URL)
if err != nil { if err != nil {
@@ -814,18 +867,9 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba
return nil, nil return nil, nil
} }
c.nconn.SetReadDeadline(time.Now().Add(c.ReadTimeout)) res, err := c.waitResponse()
var res *base.Response
if allowFrames {
// read the response and ignore interleaved frames in between;
// interleaved frames are sent in two cases:
// * when the server is v4lrtspserver, before the PLAY response
// * when the stream is already playing
res, err = c.conn.ReadResponseIgnoreFrames()
} else {
res, err = c.conn.ReadResponse()
}
if err != nil { if err != nil {
c.ctxCancel()
return nil, err return nil, err
} }
@@ -856,7 +900,7 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba
} }
c.sender = sender c.sender = sender
return c.do(req, skipResponse, allowFrames) return c.do(req, skipResponse)
} }
return res, nil return res, nil
@@ -899,7 +943,7 @@ func (c *Client) isInTCPTimeout() bool {
return now.Sub(lft) >= c.ReadTimeout return now.Sub(lft) >= c.ReadTimeout
} }
func (c *Client) checkTimeout() error { func (c *Client) doCheckTimeout() error {
if *c.effectiveTransport == TransportUDP || if *c.effectiveTransport == TransportUDP ||
*c.effectiveTransport == TransportUDPMulticast { *c.effectiveTransport == TransportUDPMulticast {
if c.checkTimeoutInitial { if c.checkTimeoutInitial {
@@ -921,7 +965,7 @@ func (c *Client) checkTimeout() error {
return nil return nil
} }
func (c *Client) doKeepalive() error { func (c *Client) doKeepAlive() error {
_, err := c.do(&base.Request{ _, err := c.do(&base.Request{
Method: func() base.Method { Method: func() base.Method {
// the VLC integrated rtsp server requires GET_PARAMETER // the VLC integrated rtsp server requires GET_PARAMETER
@@ -932,7 +976,7 @@ func (c *Client) doKeepalive() error {
}(), }(),
// use the stream base URL, otherwise some cameras do not reply // use the stream base URL, otherwise some cameras do not reply
URL: c.baseURL, URL: c.baseURL,
}, true, false) }, false)
return err return err
} }
@@ -946,10 +990,15 @@ func (c *Client) doOptions(u *url.URL) (*base.Response, error) {
return nil, err return nil, err
} }
err = c.connOpen()
if err != nil {
return nil, err
}
res, err := c.do(&base.Request{ res, err := c.do(&base.Request{
Method: base.Options, Method: base.Options,
URL: u, URL: u,
}, false, false) }, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -973,12 +1022,12 @@ func (c *Client) doOptions(u *url.URL) (*base.Response, error) {
func (c *Client) Options(u *url.URL) (*base.Response, error) { func (c *Client) Options(u *url.URL) (*base.Response, error) {
cres := make(chan clientRes) cres := make(chan clientRes)
select { select {
case c.options <- optionsReq{url: u, res: cres}: case c.chOptions <- optionsReq{url: u, res: cres}:
res := <-cres res := <-cres
return res.res, res.err return res.res, res.err
case <-c.ctx.Done(): case <-c.done:
return nil, liberrors.ErrClientTerminated{} return nil, c.closeError
} }
} }
@@ -992,13 +1041,18 @@ func (c *Client) doDescribe(u *url.URL) (*description.Session, *base.Response, e
return nil, nil, err return nil, nil, err
} }
err = c.connOpen()
if err != nil {
return nil, nil, err
}
res, err := c.do(&base.Request{ res, err := c.do(&base.Request{
Method: base.Describe, Method: base.Describe,
URL: u, URL: u,
Header: base.Header{ Header: base.Header{
"Accept": base.HeaderValue{"application/sdp"}, "Accept": base.HeaderValue{"application/sdp"},
}, },
}, false, false) }, false)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -1069,12 +1123,12 @@ func (c *Client) doDescribe(u *url.URL) (*description.Session, *base.Response, e
func (c *Client) Describe(u *url.URL) (*description.Session, *base.Response, error) { func (c *Client) Describe(u *url.URL) (*description.Session, *base.Response, error) {
cres := make(chan clientRes) cres := make(chan clientRes)
select { select {
case c.describe <- describeReq{url: u, res: cres}: case c.chDescribe <- describeReq{url: u, res: cres}:
res := <-cres res := <-cres
return res.sd, res.res, res.err return res.sd, res.res, res.err
case <-c.ctx.Done(): case <-c.done:
return nil, nil, liberrors.ErrClientTerminated{} return nil, nil, c.closeError
} }
} }
@@ -1086,6 +1140,11 @@ func (c *Client) doAnnounce(u *url.URL, desc *description.Session) (*base.Respon
return nil, err return nil, err
} }
err = c.connOpen()
if err != nil {
return nil, err
}
prepareForAnnounce(desc) prepareForAnnounce(desc)
byts, err := desc.Marshal(false) byts, err := desc.Marshal(false)
@@ -1100,7 +1159,7 @@ func (c *Client) doAnnounce(u *url.URL, desc *description.Session) (*base.Respon
"Content-Type": base.HeaderValue{"application/sdp"}, "Content-Type": base.HeaderValue{"application/sdp"},
}, },
Body: byts, Body: byts,
}, false, false) }, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1121,12 +1180,12 @@ func (c *Client) doAnnounce(u *url.URL, desc *description.Session) (*base.Respon
func (c *Client) Announce(u *url.URL, desc *description.Session) (*base.Response, error) { func (c *Client) Announce(u *url.URL, desc *description.Session) (*base.Response, error) {
cres := make(chan clientRes) cres := make(chan clientRes)
select { select {
case c.announce <- announceReq{url: u, desc: desc, res: cres}: case c.chAnnounce <- announceReq{url: u, desc: desc, res: cres}:
res := <-cres res := <-cres
return res.res, res.err return res.res, res.err
case <-c.ctx.Done(): case <-c.done:
return nil, liberrors.ErrClientTerminated{} return nil, c.closeError
} }
} }
@@ -1145,6 +1204,11 @@ func (c *Client) doSetup(
return nil, err return nil, err
} }
err = c.connOpen()
if err != nil {
return nil, err
}
if c.baseURL != nil && *baseURL != *c.baseURL { if c.baseURL != nil && *baseURL != *c.baseURL {
return nil, liberrors.ErrClientCannotSetupMediasDifferentURLs{} return nil, liberrors.ErrClientCannotSetupMediasDifferentURLs{}
} }
@@ -1229,7 +1293,7 @@ func (c *Client) doSetup(
Header: base.Header{ Header: base.Header{
"Transport": th.Marshal(), "Transport": th.Marshal(),
}, },
}, false, false) }, false)
if err != nil { if err != nil {
cm.close() cm.close()
return nil, err return nil, err
@@ -1428,7 +1492,7 @@ func (c *Client) Setup(
) (*base.Response, error) { ) (*base.Response, error) {
cres := make(chan clientRes) cres := make(chan clientRes)
select { select {
case c.setup <- setupReq{ case c.chSetup <- setupReq{
baseURL: baseURL, baseURL: baseURL,
media: media, media: media,
rtpPort: rtpPort, rtpPort: rtpPort,
@@ -1438,8 +1502,8 @@ func (c *Client) Setup(
res := <-cres res := <-cres
return res.res, res.err return res.res, res.err
case <-c.ctx.Done(): case <-c.done:
return nil, liberrors.ErrClientTerminated{} return nil, c.closeError
} }
} }
@@ -1462,19 +1526,8 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
return nil, err return nil, err
} }
// open the firewall by sending empty packets to the counterpart. c.state = clientStatePlay
// do this before sending the request. c.startReadRoutines()
// don't do this with multicast, otherwise the RTP packet is going to be broadcasted
// to all listeners, including us, messing up the stream.
if *c.effectiveTransport == TransportUDP {
for _, ct := range c.medias {
byts, _ := (&rtp.Packet{Header: rtp.Header{Version: 2}}).Marshal()
ct.udpRTPListener.write(byts) //nolint:errcheck
byts, _ = (&rtcp.ReceiverReport{}).Marshal()
ct.udpRTCPListener.write(byts) //nolint:errcheck
}
}
// Range is mandatory in Parrot Streaming Server // Range is mandatory in Parrot Streaming Server
if ra == nil { if ra == nil {
@@ -1491,20 +1544,37 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
Header: base.Header{ Header: base.Header{
"Range": ra.Marshal(), "Range": ra.Marshal(),
}, },
}, false, *c.effectiveTransport == TransportTCP) }, false)
if err != nil { if err != nil {
c.stopReadRoutines()
c.state = clientStatePrePlay
return nil, err return nil, err
} }
if res.StatusCode != base.StatusOK { if res.StatusCode != base.StatusOK {
c.stopReadRoutines()
c.state = clientStatePrePlay
return nil, liberrors.ErrClientBadStatusCode{ return nil, liberrors.ErrClientBadStatusCode{
Code: res.StatusCode, Message: res.StatusMessage, Code: res.StatusCode, Message: res.StatusMessage,
} }
} }
// open the firewall by sending empty packets to the counterpart.
// do this before sending the request.
// don't do this with multicast, otherwise the RTP packet is going to be broadcasted
// to all listeners, including us, messing up the stream.
if *c.effectiveTransport == TransportUDP {
for _, cm := range c.medias {
byts, _ := (&rtp.Packet{Header: rtp.Header{Version: 2}}).Marshal()
cm.udpRTPListener.write(byts) //nolint:errcheck
byts, _ = (&rtcp.ReceiverReport{}).Marshal()
cm.udpRTCPListener.write(byts) //nolint:errcheck
}
}
c.startWriter()
c.lastRange = ra c.lastRange = ra
c.state = clientStatePlay
c.playRecordStart()
return res, nil return res, nil
} }
@@ -1514,12 +1584,12 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
func (c *Client) Play(ra *headers.Range) (*base.Response, error) { func (c *Client) Play(ra *headers.Range) (*base.Response, error) {
cres := make(chan clientRes) cres := make(chan clientRes)
select { select {
case c.play <- playReq{ra: ra, res: cres}: case c.chPlay <- playReq{ra: ra, res: cres}:
res := <-cres res := <-cres
return res.res, res.err return res.res, res.err
case <-c.ctx.Done(): case <-c.done:
return nil, liberrors.ErrClientTerminated{} return nil, c.closeError
} }
} }
@@ -1531,22 +1601,28 @@ func (c *Client) doRecord() (*base.Response, error) {
return nil, err return nil, err
} }
c.state = clientStateRecord
c.startReadRoutines()
res, err := c.do(&base.Request{ res, err := c.do(&base.Request{
Method: base.Record, Method: base.Record,
URL: c.baseURL, URL: c.baseURL,
}, false, false) }, false)
if err != nil { if err != nil {
c.stopReadRoutines()
c.state = clientStatePreRecord
return nil, err return nil, err
} }
if res.StatusCode != base.StatusOK { if res.StatusCode != base.StatusOK {
c.stopReadRoutines()
c.state = clientStatePreRecord
return nil, liberrors.ErrClientBadStatusCode{ return nil, liberrors.ErrClientBadStatusCode{
Code: res.StatusCode, Message: res.StatusMessage, Code: res.StatusCode, Message: res.StatusMessage,
} }
} }
c.state = clientStateRecord c.startWriter()
c.playRecordStart()
return nil, nil return nil, nil
} }
@@ -1556,12 +1632,12 @@ func (c *Client) doRecord() (*base.Response, error) {
func (c *Client) Record() (*base.Response, error) { func (c *Client) Record() (*base.Response, error) {
cres := make(chan clientRes) cres := make(chan clientRes)
select { select {
case c.record <- recordReq{res: cres}: case c.chRecord <- recordReq{res: cres}:
res := <-cres res := <-cres
return res.res, res.err return res.res, res.err
case <-c.ctx.Done(): case <-c.done:
return nil, liberrors.ErrClientTerminated{} return nil, c.closeError
} }
} }
@@ -1574,9 +1650,26 @@ func (c *Client) doPause() (*base.Response, error) {
return nil, err return nil, err
} }
c.playRecordStop(false) c.stopWriter()
res, err := c.do(&base.Request{
Method: base.Pause,
URL: c.baseURL,
}, false)
if err != nil {
c.startWriter()
return nil, err
}
if res.StatusCode != base.StatusOK {
c.startWriter()
return nil, liberrors.ErrClientBadStatusCode{
Code: res.StatusCode, Message: res.StatusMessage,
}
}
c.stopReadRoutines()
// change state regardless of the response
switch c.state { switch c.state {
case clientStatePlay: case clientStatePlay:
c.state = clientStatePrePlay c.state = clientStatePrePlay
@@ -1584,20 +1677,6 @@ func (c *Client) doPause() (*base.Response, error) {
c.state = clientStatePreRecord c.state = clientStatePreRecord
} }
res, err := c.do(&base.Request{
Method: base.Pause,
URL: c.baseURL,
}, false, *c.effectiveTransport == TransportTCP)
if err != nil {
return nil, err
}
if res.StatusCode != base.StatusOK {
return nil, liberrors.ErrClientBadStatusCode{
Code: res.StatusCode, Message: res.StatusMessage,
}
}
return res, nil return res, nil
} }
@@ -1606,12 +1685,12 @@ func (c *Client) doPause() (*base.Response, error) {
func (c *Client) Pause() (*base.Response, error) { func (c *Client) Pause() (*base.Response, error) {
cres := make(chan clientRes) cres := make(chan clientRes)
select { select {
case c.pause <- pauseReq{res: cres}: case c.chPause <- pauseReq{res: cres}:
res := <-cres res := <-cres
return res.res, res.err return res.res, res.err
case <-c.ctx.Done(): case <-c.done:
return nil, liberrors.ErrClientTerminated{} return nil, c.closeError
} }
} }
@@ -1720,3 +1799,15 @@ func (c *Client) PacketNTP(medi *description.Media, pkt *rtp.Packet) (time.Time,
ct := cm.formats[pkt.PayloadType] ct := cm.formats[pkt.PayloadType]
return ct.rtcpReceiver.PacketNTP(pkt.Timestamp) return ct.rtcpReceiver.PacketNTP(pkt.Timestamp)
} }
func (c *Client) readResponse(res *base.Response) {
c.chReadResponse <- res
}
func (c *Client) readRequest(req *base.Request) {
c.chReadRequest <- req
}
func (c *Client) readError(err error) {
c.chReadError <- err
}

View File

@@ -1,43 +0,0 @@
package gortsplib
import (
"context"
"net"
)
type clientConnCloser struct {
ctx context.Context
nconn net.Conn
terminate chan struct{}
done chan struct{}
}
func newClientConnCloser(ctx context.Context, nconn net.Conn) *clientConnCloser {
cc := &clientConnCloser{
ctx: ctx,
nconn: nconn,
terminate: make(chan struct{}),
done: make(chan struct{}),
}
go cc.run()
return cc
}
func (cc *clientConnCloser) close() {
close(cc.terminate)
<-cc.done
}
func (cc *clientConnCloser) run() {
defer close(cc.done)
select {
case <-cc.ctx.Done():
cc.nconn.Close()
case <-cc.terminate:
}
}

View File

@@ -1,66 +1,70 @@
package gortsplib package gortsplib
import ( import (
"time" "sync/atomic"
"github.com/bluenviron/gortsplib/v4/pkg/base" "github.com/bluenviron/gortsplib/v4/pkg/base"
"github.com/bluenviron/gortsplib/v4/pkg/liberrors"
) )
type clientReader struct { type clientReader struct {
c *Client c *Client
closeErr chan error allowInterleavedFrames atomic.Bool
} }
func newClientReader(c *Client) *clientReader { func newClientReader(c *Client) *clientReader {
r := &clientReader{ r := &clientReader{
c: c, c: c,
closeErr: make(chan error),
} }
// for some reason, SetReadDeadline() must always be called in the same
// goroutine, otherwise Read() freezes.
// therefore, we disable the deadline and perform a check with a ticker.
r.c.nconn.SetReadDeadline(time.Time{})
go r.run() go r.run()
return r return r
} }
func (r *clientReader) close() { func (r *clientReader) setAllowInterleavedFrames(v bool) {
r.c.nconn.SetReadDeadline(time.Now()) r.allowInterleavedFrames.Store(v)
}
func (r *clientReader) wait() {
for {
select {
case <-r.c.chReadError:
return
case <-r.c.chReadResponse:
case <-r.c.chReadRequest:
}
}
} }
func (r *clientReader) run() { func (r *clientReader) run() {
r.c.readError <- r.runInner() err := r.runInner()
r.c.readError(err)
} }
func (r *clientReader) runInner() error { func (r *clientReader) runInner() error {
if *r.c.effectiveTransport == TransportUDP || *r.c.effectiveTransport == TransportUDPMulticast {
for { for {
res, err := r.c.conn.ReadResponse() what, err := r.c.conn.Read()
if err != nil {
return err
}
r.c.OnResponse(res)
}
} else {
for {
what, err := r.c.conn.ReadInterleavedFrameOrResponse()
if err != nil { if err != nil {
return err return err
} }
switch what := what.(type) { switch what := what.(type) {
case *base.Response: case *base.Response:
r.c.OnResponse(what) r.c.readResponse(what)
case *base.Request:
r.c.readRequest(what)
case *base.InterleavedFrame: case *base.InterleavedFrame:
if !r.allowInterleavedFrames.Load() {
return liberrors.ErrClientUnexpectedFrame{}
}
if cb, ok := r.c.tcpCallbackByChannel[what.Channel]; ok { if cb, ok := r.c.tcpCallbackByChannel[what.Channel]; ok {
cb(what.Payload) cb(what.Payload)
} }
} }
} }
} }
}

View File

@@ -3,6 +3,7 @@ package gortsplib
import ( import (
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"fmt"
"net" "net"
"strings" "strings"
"sync" "sync"
@@ -39,7 +40,7 @@ var testRTPPacket = rtp.Packet{
CSRC: []uint32{}, CSRC: []uint32{},
SSRC: 0x38F27A2F, SSRC: 0x38F27A2F,
}, },
Payload: []byte{0x01, 0x02, 0x03, 0x04}, Payload: []byte{1, 2, 3, 4},
} }
var testRTPPacketMarshaled = mustMarshalPacketRTP(&testRTPPacket) var testRTPPacketMarshaled = mustMarshalPacketRTP(&testRTPPacket)
@@ -101,6 +102,23 @@ func record(c *Client, ur string, medias []*description.Media, cb func(*descript
return nil return nil
} }
func readRequestIgnoreFrames(c *conn.Conn) (*base.Request, error) {
for {
what, err := c.Read()
if err != nil {
return nil, err
}
switch what := what.(type) {
case *base.InterleavedFrame:
case *base.Request:
return what, nil
case *base.Response:
return nil, fmt.Errorf("unexpected response")
}
}
}
func TestClientRecordSerial(t *testing.T) { func TestClientRecordSerial(t *testing.T) {
for _, transport := range []string{ for _, transport := range []string{
"udp", "udp",
@@ -412,7 +430,7 @@ func TestClientRecordParallel(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
req, err = conn.ReadRequestIgnoreFrames() req, err = readRequestIgnoreFrames(conn)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Teardown, req.Method) require.Equal(t, base.Teardown, req.Method)
@@ -552,7 +570,7 @@ func TestClientRecordPauseSerial(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
req, err = conn.ReadRequestIgnoreFrames() req, err = readRequestIgnoreFrames(conn)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Pause, req.Method) require.Equal(t, base.Pause, req.Method)
@@ -570,7 +588,7 @@ func TestClientRecordPauseSerial(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
req, err = conn.ReadRequestIgnoreFrames() req, err = readRequestIgnoreFrames(conn)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Teardown, req.Method) require.Equal(t, base.Teardown, req.Method)
@@ -700,7 +718,7 @@ func TestClientRecordPauseParallel(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
req, err = conn.ReadRequestIgnoreFrames() req, err = readRequestIgnoreFrames(conn)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Pause, req.Method) require.Equal(t, base.Pause, req.Method)

View File

@@ -375,3 +375,92 @@ func TestClientCloseDuringRequest(t *testing.T) {
<-optionsDone <-optionsDone
close(releaseConn) close(releaseConn)
} }
func TestClientReplyToServerRequest(t *testing.T) {
for _, ca := range []string{"after response", "before response"} {
t.Run(ca, func(t *testing.T) {
l, err := net.Listen("tcp", "localhost:8554")
require.NoError(t, err)
defer l.Close()
serverDone := make(chan struct{})
go func() {
defer close(serverDone)
nconn, err := l.Accept()
require.NoError(t, err)
conn := conn.NewConn(nconn)
defer nconn.Close()
req, err := conn.ReadRequest()
require.NoError(t, err)
require.Equal(t, base.Options, req.Method)
if ca == "after response" {
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
string(base.Describe),
}, ", ")},
},
})
require.NoError(t, err)
err = conn.WriteRequest(&base.Request{
Method: base.Options,
URL: nil,
Header: base.Header{
"CSeq": base.HeaderValue{"4"},
},
})
require.NoError(t, err)
res, err := conn.ReadResponse()
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
require.Equal(t, "4", res.Header["CSeq"][0])
} else {
err = conn.WriteRequest(&base.Request{
Method: base.Options,
URL: nil,
Header: base.Header{
"CSeq": base.HeaderValue{"4"},
},
})
require.NoError(t, err)
res, err := conn.ReadResponse()
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
require.Equal(t, "4", res.Header["CSeq"][0])
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
string(base.Describe),
}, ", ")},
},
})
require.NoError(t, err)
}
}()
u, err := url.Parse("rtsp://localhost:8554/stream")
require.NoError(t, err)
c := Client{}
err = c.Start(u.Scheme, u.Host)
require.NoError(t, err)
defer c.Close()
_, err = c.Options(u)
require.NoError(t, err)
<-serverDone
})
}
}

View File

@@ -66,11 +66,15 @@ func (req *Request) Unmarshal(br *bufio.Reader) error {
} }
rawURL := string(byts[:len(byts)-1]) rawURL := string(byts[:len(byts)-1])
if rawURL != "*" {
ur, err := url.Parse(rawURL) ur, err := url.Parse(rawURL)
if err != nil { if err != nil {
return fmt.Errorf("invalid URL (%v)", rawURL) return fmt.Errorf("invalid URL (%v)", rawURL)
} }
req.URL = ur req.URL = ur
} else {
req.URL = nil
}
byts, err = readBytesLimited(br, '\r', requestMaxProtocolLength) byts, err = readBytesLimited(br, '\r', requestMaxProtocolLength)
if err != nil { if err != nil {
@@ -102,10 +106,15 @@ func (req *Request) Unmarshal(br *bufio.Reader) error {
// MarshalSize returns the size of a Request. // MarshalSize returns the size of a Request.
func (req Request) MarshalSize() int { func (req Request) MarshalSize() int {
n := 0 n := len(req.Method) + 1
urStr := req.URL.CloneWithoutCredentials().String() if req.URL != nil {
n += len([]byte(string(req.Method) + " " + urStr + " " + rtspProtocol10 + "\r\n")) n += len(req.URL.CloneWithoutCredentials().String())
} else {
n++
}
n += 1 + len(rtspProtocol10) + 2
if len(req.Body) != 0 { if len(req.Body) != 0 {
req.Header["Content-Length"] = HeaderValue{strconv.FormatInt(int64(len(req.Body)), 10)} req.Header["Content-Length"] = HeaderValue{strconv.FormatInt(int64(len(req.Body)), 10)}
@@ -122,8 +131,23 @@ func (req Request) MarshalSize() int {
func (req Request) MarshalTo(buf []byte) (int, error) { func (req Request) MarshalTo(buf []byte) (int, error) {
pos := 0 pos := 0
urStr := req.URL.CloneWithoutCredentials().String() pos += copy(buf[pos:], []byte(req.Method))
pos += copy(buf[pos:], []byte(string(req.Method)+" "+urStr+" "+rtspProtocol10+"\r\n")) buf[pos] = ' '
pos++
if req.URL != nil {
pos += copy(buf[pos:], []byte(req.URL.CloneWithoutCredentials().String()))
} else {
pos += copy(buf[pos:], []byte("*"))
}
buf[pos] = ' '
pos++
pos += copy(buf[pos:], rtspProtocol10)
buf[pos] = '\r'
pos++
buf[pos] = '\n'
pos++
if len(req.Body) != 0 { if len(req.Body) != 0 {
req.Header["Content-Length"] = HeaderValue{strconv.FormatInt(int64(len(req.Body)), 10)} req.Header["Content-Length"] = HeaderValue{strconv.FormatInt(int64(len(req.Body)), 10)}

View File

@@ -138,6 +138,21 @@ var casesRequest = []struct {
), ),
}, },
}, },
{
"server-side announce",
[]byte("OPTIONS * RTSP/1.0\r\n" +
"CSeq: 1\r\n" +
"User-Agent: RDIPCamera\r\n" +
"\r\n"),
Request{
Method: "OPTIONS",
URL: nil,
Header: Header{
"CSeq": HeaderValue{"1"},
"User-Agent": HeaderValue{"RDIPCamera"},
},
},
},
} }
func TestRequestUnmarshal(t *testing.T) { func TestRequestUnmarshal(t *testing.T) {

View File

@@ -194,9 +194,7 @@ func (res Response) MarshalSize() int {
} }
} }
n += len([]byte(rtspProtocol10 + " " + n += len(rtspProtocol10) + 1 + len(strconv.FormatInt(int64(res.StatusCode), 10)) + 1 + len(res.StatusMessage) + 2
strconv.FormatInt(int64(res.StatusCode), 10) + " " +
res.StatusMessage + "\r\n"))
if len(res.Body) != 0 { if len(res.Body) != 0 {
res.Header["Content-Length"] = HeaderValue{strconv.FormatInt(int64(len(res.Body)), 10)} res.Header["Content-Length"] = HeaderValue{strconv.FormatInt(int64(len(res.Body)), 10)}
@@ -219,9 +217,17 @@ func (res Response) MarshalTo(buf []byte) (int, error) {
pos := 0 pos := 0
pos += copy(buf[pos:], []byte(rtspProtocol10+" "+ pos += copy(buf[pos:], []byte(rtspProtocol10))
strconv.FormatInt(int64(res.StatusCode), 10)+" "+ buf[pos] = ' '
res.StatusMessage+"\r\n")) pos++
pos += copy(buf[pos:], []byte(strconv.FormatInt(int64(res.StatusCode), 10)))
buf[pos] = ' '
pos++
pos += copy(buf[pos:], []byte(res.StatusMessage))
buf[pos] = '\r'
pos++
buf[pos] = '\n'
pos++
if len(res.Body) != 0 { if len(res.Body) != 0 {
res.Header["Content-Length"] = HeaderValue{strconv.FormatInt(int64(len(res.Body)), 10)} res.Header["Content-Length"] = HeaderValue{strconv.FormatInt(int64(len(res.Body)), 10)}

View File

@@ -16,8 +16,8 @@ const (
type Conn struct { type Conn struct {
w io.Writer w io.Writer
br *bufio.Reader br *bufio.Reader
req base.Request
res base.Response // reuse interleaved frames. they should never be passed to secondary routines
fr base.InterleavedFrame fr base.InterleavedFrame
} }
@@ -29,16 +29,36 @@ func NewConn(rw io.ReadWriter) *Conn {
} }
} }
// Read reads a Request, a Response or an Interleaved frame.
func (c *Conn) Read() (interface{}, error) {
byts, err := c.br.Peek(2)
if err != nil {
return nil, err
}
if byts[0] == base.InterleavedFrameMagicByte {
return c.ReadInterleavedFrame()
}
if byts[0] == 'R' && byts[1] == 'T' {
return c.ReadResponse()
}
return c.ReadRequest()
}
// ReadRequest reads a Request. // ReadRequest reads a Request.
func (c *Conn) ReadRequest() (*base.Request, error) { func (c *Conn) ReadRequest() (*base.Request, error) {
err := c.req.Unmarshal(c.br) var req base.Request
return &c.req, err err := req.Unmarshal(c.br)
return &req, err
} }
// ReadResponse reads a Response. // ReadResponse reads a Response.
func (c *Conn) ReadResponse() (*base.Response, error) { func (c *Conn) ReadResponse() (*base.Response, error) {
err := c.res.Unmarshal(c.br) var res base.Response
return &c.res, err err := res.Unmarshal(c.br)
return &res, err
} }
// ReadInterleavedFrame reads a InterleavedFrame. // ReadInterleavedFrame reads a InterleavedFrame.
@@ -47,64 +67,6 @@ func (c *Conn) ReadInterleavedFrame() (*base.InterleavedFrame, error) {
return &c.fr, err return &c.fr, err
} }
// ReadInterleavedFrameOrRequest reads an InterleavedFrame or a Request.
func (c *Conn) ReadInterleavedFrameOrRequest() (interface{}, error) {
b, err := c.br.ReadByte()
if err != nil {
return nil, err
}
c.br.UnreadByte() //nolint:errcheck
if b == base.InterleavedFrameMagicByte {
return c.ReadInterleavedFrame()
}
return c.ReadRequest()
}
// ReadInterleavedFrameOrResponse reads an InterleavedFrame or a Response.
func (c *Conn) ReadInterleavedFrameOrResponse() (interface{}, error) {
b, err := c.br.ReadByte()
if err != nil {
return nil, err
}
c.br.UnreadByte() //nolint:errcheck
if b == base.InterleavedFrameMagicByte {
return c.ReadInterleavedFrame()
}
return c.ReadResponse()
}
// ReadRequestIgnoreFrames reads a Request and ignores frames in between.
func (c *Conn) ReadRequestIgnoreFrames() (*base.Request, error) {
for {
recv, err := c.ReadInterleavedFrameOrRequest()
if err != nil {
return nil, err
}
if req, ok := recv.(*base.Request); ok {
return req, nil
}
}
}
// ReadResponseIgnoreFrames reads a Response and ignores frames in between.
func (c *Conn) ReadResponseIgnoreFrames() (*base.Response, error) {
for {
recv, err := c.ReadInterleavedFrameOrResponse()
if err != nil {
return nil, err
}
if res, ok := recv.(*base.Response); ok {
return res, nil
}
}
}
// WriteRequest writes a request. // WriteRequest writes a request.
func (c *Conn) WriteRequest(req *base.Request) error { func (c *Conn) WriteRequest(req *base.Request) error {
buf, _ := req.Marshal() buf, _ := req.Marshal()

View File

@@ -18,18 +18,19 @@ func mustParseURL(s string) *url.URL {
return u return u
} }
func TestReadInterleavedFrameOrRequest(t *testing.T) { func TestRead(t *testing.T) {
byts := []byte("DESCRIBE rtsp://example.com/media.mp4 RTSP/1.0\r\n" + for _, ca := range []struct {
name string
enc []byte
dec interface{}
}{
{
"request",
[]byte("DESCRIBE rtsp://example.com/media.mp4 RTSP/1.0\r\n" +
"Accept: application/sdp\r\n" + "Accept: application/sdp\r\n" +
"CSeq: 2\r\n" + "CSeq: 2\r\n" +
"\r\n") "\r\n"),
byts = append(byts, []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}...) &base.Request{
conn := NewConn(bytes.NewBuffer(byts))
out, err := conn.ReadInterleavedFrameOrRequest()
require.NoError(t, err)
require.Equal(t, &base.Request{
Method: base.Describe, Method: base.Describe,
URL: &url.URL{ URL: &url.URL{
Scheme: "rtsp", Scheme: "rtsp",
@@ -40,143 +41,47 @@ func TestReadInterleavedFrameOrRequest(t *testing.T) {
"Accept": base.HeaderValue{"application/sdp"}, "Accept": base.HeaderValue{"application/sdp"},
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"2"},
}, },
}, out) },
out, err = conn.ReadInterleavedFrameOrRequest()
require.NoError(t, err)
require.Equal(t, &base.InterleavedFrame{
Channel: 6,
Payload: []byte{0x01, 0x02, 0x03, 0x04},
}, out)
}
func TestReadInterleavedFrameOrRequestErrors(t *testing.T) {
for _, ca := range []struct {
name string
byts []byte
err string
}{
{
"empty",
[]byte{},
"EOF",
}, },
{ {
"invalid frame", "response",
[]byte{0x24, 0x00}, []byte("RTSP/1.0 200 OK\r\n" +
"unexpected EOF",
},
{
"invalid request",
[]byte("DESCRIBE"),
"EOF",
},
} {
t.Run(ca.name, func(t *testing.T) {
conn := NewConn(bytes.NewBuffer(ca.byts))
_, err := conn.ReadInterleavedFrameOrRequest()
require.EqualError(t, err, ca.err)
})
}
}
func TestReadInterleavedFrameOrResponse(t *testing.T) {
byts := []byte("RTSP/1.0 200 OK\r\n" +
"CSeq: 1\r\n" + "CSeq: 1\r\n" +
"Public: DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE\r\n" + "Public: DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE\r\n" +
"\r\n") "\r\n"),
byts = append(byts, []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}...) &base.Response{
conn := NewConn(bytes.NewBuffer(byts))
out, err := conn.ReadInterleavedFrameOrResponse()
require.NoError(t, err)
require.Equal(t, &base.Response{
StatusCode: 200, StatusCode: 200,
StatusMessage: "OK", StatusMessage: "OK",
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"1"}, "CSeq": base.HeaderValue{"1"},
"Public": base.HeaderValue{"DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE"}, "Public": base.HeaderValue{"DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE"},
}, },
}, out) },
},
out, err = conn.ReadInterleavedFrameOrResponse() {
require.NoError(t, err) "frame",
require.Equal(t, &base.InterleavedFrame{ []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4},
&base.InterleavedFrame{
Channel: 6, Channel: 6,
Payload: []byte{0x01, 0x02, 0x03, 0x04}, Payload: []byte{0x01, 0x02, 0x03, 0x04},
}, out)
}
func TestReadInterleavedFrameOrResponseErrors(t *testing.T) {
for _, ca := range []struct {
name string
byts []byte
err string
}{
{
"empty",
[]byte{},
"EOF",
}, },
{
"invalid frame",
[]byte{0x24, 0x00},
"unexpected EOF",
},
{
"invalid response",
[]byte("RTSP/1.0"),
"EOF",
}, },
} { } {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
conn := NewConn(bytes.NewBuffer(ca.byts)) buf := bytes.NewBuffer(ca.enc)
_, err := conn.ReadInterleavedFrameOrResponse() conn := NewConn(buf)
require.EqualError(t, err, ca.err) dec, err := conn.Read()
require.NoError(t, err)
require.Equal(t, ca.dec, dec)
}) })
} }
} }
func TestReadRequestIgnoreFrames(t *testing.T) { func TestReadError(t *testing.T) {
byts := []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4} var buf bytes.Buffer
byts = append(byts, []byte("OPTIONS rtsp://example.com/media.mp4 RTSP/1.0\r\n"+ conn := NewConn(&buf)
"CSeq: 1\r\n"+ _, err := conn.Read()
"Proxy-Require: gzipped-messages\r\n"+ require.Error(t, err)
"Require: implicit-play\r\n"+
"\r\n")...)
conn := NewConn(bytes.NewBuffer(byts))
_, err := conn.ReadRequestIgnoreFrames()
require.NoError(t, err)
}
func TestReadRequestIgnoreFramesErrors(t *testing.T) {
byts := []byte{0x25}
conn := NewConn(bytes.NewBuffer(byts))
_, err := conn.ReadRequestIgnoreFrames()
require.EqualError(t, err, "EOF")
}
func TestReadResponseIgnoreFrames(t *testing.T) {
byts := []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}
byts = append(byts, []byte("RTSP/1.0 200 OK\r\n"+
"CSeq: 1\r\n"+
"Public: DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE\r\n"+
"\r\n")...)
conn := NewConn(bytes.NewBuffer(byts))
_, err := conn.ReadResponseIgnoreFrames()
require.NoError(t, err)
}
func TestReadResponseIgnoreFramesErrors(t *testing.T) {
byts := []byte{0x25}
conn := NewConn(bytes.NewBuffer(byts))
_, err := conn.ReadResponseIgnoreFrames()
require.EqualError(t, err, "EOF")
} }
func TestWriteRequest(t *testing.T) { func TestWriteRequest(t *testing.T) {

View File

@@ -196,3 +196,63 @@ type ErrClientRTPInfoInvalid struct {
func (e ErrClientRTPInfoInvalid) Error() string { func (e ErrClientRTPInfoInvalid) Error() string {
return fmt.Sprintf("invalid RTP-Info: %v", e.Err) return fmt.Sprintf("invalid RTP-Info: %v", e.Err)
} }
// ErrClientUnexpectedFrame is an error that can be returned by a client.
type ErrClientUnexpectedFrame struct{}
// Error implements the error interface.
func (e ErrClientUnexpectedFrame) Error() string {
return "received unexpected interleaved frame"
}
// ErrClientRequestTimedOut is an error that can be returned by a client.
type ErrClientRequestTimedOut struct{}
// Error implements the error interface.
func (e ErrClientRequestTimedOut) Error() string {
return "request timed out"
}
// ErrClientUnsupportedScheme is an error that can be returned by a client.
type ErrClientUnsupportedScheme struct {
Scheme string
}
// Error implements the error interface.
func (e ErrClientUnsupportedScheme) Error() string {
return fmt.Sprintf("unsupported scheme: %v", e.Scheme)
}
// ErrClientRTSPSTCP is an error that can be returned by a client.
type ErrClientRTSPSTCP struct{}
// Error implements the error interface.
func (e ErrClientRTSPSTCP) Error() string {
return "RTSPS can be used only with TCP"
}
// ErrClientUnexpectedResponse is an error that can be returned by a client.
type ErrClientUnexpectedResponse struct{}
// Error implements the error interface.
func (e ErrClientUnexpectedResponse) Error() string {
return "received unexpected response"
}
// ErrClientMissingCSeq is an error that can be returned by a client.
type ErrClientMissingCSeq struct{}
// Error implements the error interface.
func (e ErrClientMissingCSeq) Error() string {
return "CSeq is missing"
}
// ErrClientUnhandledMethod is an error that can be returned by a client.
type ErrClientUnhandledMethod struct {
Method base.Method
}
// Error implements the error interface.
func (e ErrClientUnhandledMethod) Error() string {
return fmt.Sprintf("unhandled method: %v", e.Method)
}

View File

@@ -251,3 +251,11 @@ type ErrServerUnexpectedFrame struct{}
func (e ErrServerUnexpectedFrame) Error() string { func (e ErrServerUnexpectedFrame) Error() string {
return "received unexpected interleaved frame" return "received unexpected interleaved frame"
} }
// ErrServerUnexpectedResponse is an error that can be returned by a client.
type ErrServerUnexpectedResponse struct{}
// Error implements the error interface.
func (e ErrServerUnexpectedResponse) Error() string {
return "received unexpected response"
}

View File

@@ -74,8 +74,8 @@ type ServerConn struct {
session *ServerSession session *ServerSession
// in // in
chHandleRequest chan readReq chReadRequest chan readReq
chReadErr chan error chReadError chan error
chRemoveSession chan *ServerSession chRemoveSession chan *ServerSession
// out // out
@@ -99,8 +99,8 @@ func newServerConn(
ctx: ctx, ctx: ctx,
ctxCancel: ctxCancel, ctxCancel: ctxCancel,
remoteAddr: nconn.RemoteAddr().(*net.TCPAddr), remoteAddr: nconn.RemoteAddr().(*net.TCPAddr),
chHandleRequest: make(chan readReq), chReadRequest: make(chan readReq),
chReadErr: make(chan error), chReadError: make(chan error),
chRemoveSession: make(chan *ServerSession), chRemoveSession: make(chan *ServerSession),
done: make(chan struct{}), done: make(chan struct{}),
} }
@@ -187,10 +187,10 @@ func (sc *ServerConn) run() {
func (sc *ServerConn) runInner() error { func (sc *ServerConn) runInner() error {
for { for {
select { select {
case req := <-sc.chHandleRequest: case req := <-sc.chReadRequest:
req.res <- sc.handleRequestOuter(req.req) req.res <- sc.handleRequestOuter(req.req)
case err := <-sc.chReadErr: case err := <-sc.chReadError:
return err return err
case ss := <-sc.chRemoveSession: case ss := <-sc.chRemoveSession:
@@ -462,9 +462,9 @@ func (sc *ServerConn) removeSession(ss *ServerSession) {
} }
} }
func (sc *ServerConn) handleRequest(req readReq) error { func (sc *ServerConn) readRequest(req readReq) error {
select { select {
case sc.chHandleRequest <- req: case sc.chReadRequest <- req:
return <-req.res return <-req.res
case <-sc.ctx.Done(): case <-sc.ctx.Done():
@@ -472,9 +472,9 @@ func (sc *ServerConn) handleRequest(req readReq) error {
} }
} }
func (sc *ServerConn) readErr(err error) { func (sc *ServerConn) readError(err error) {
select { select {
case sc.chReadErr <- err: case sc.chReadError <- err:
case <-sc.ctx.Done(): case <-sc.ctx.Done():
} }
} }

View File

@@ -58,7 +58,7 @@ func (cr *serverConnReader) run() {
continue continue
} }
cr.sc.readErr(err) cr.sc.readError(err)
break break
} }
} }
@@ -68,7 +68,7 @@ func (cr *serverConnReader) readFuncStandard() error {
cr.sc.nconn.SetReadDeadline(time.Time{}) cr.sc.nconn.SetReadDeadline(time.Time{})
for { for {
what, err := cr.sc.conn.ReadInterleavedFrameOrRequest() what, err := cr.sc.conn.Read()
if err != nil { if err != nil {
return err return err
} }
@@ -77,12 +77,15 @@ func (cr *serverConnReader) readFuncStandard() error {
case *base.Request: case *base.Request:
cres := make(chan error) cres := make(chan error)
req := readReq{req: what, res: cres} req := readReq{req: what, res: cres}
err := cr.sc.handleRequest(req) err := cr.sc.readRequest(req)
if err != nil { if err != nil {
return err return err
} }
default: case *base.Response:
return liberrors.ErrServerUnexpectedResponse{}
case *base.InterleavedFrame:
return liberrors.ErrServerUnexpectedFrame{} return liberrors.ErrServerUnexpectedFrame{}
} }
} }
@@ -99,26 +102,29 @@ func (cr *serverConnReader) readFuncTCP() error {
cr.sc.nconn.SetReadDeadline(time.Now().Add(cr.sc.s.ReadTimeout)) cr.sc.nconn.SetReadDeadline(time.Now().Add(cr.sc.s.ReadTimeout))
} }
what, err := cr.sc.conn.ReadInterleavedFrameOrRequest() what, err := cr.sc.conn.Read()
if err != nil { if err != nil {
return err return err
} }
switch twhat := what.(type) { switch what := what.(type) {
case *base.InterleavedFrame:
atomic.AddUint64(cr.sc.session.bytesReceived, uint64(len(twhat.Payload)))
if cb, ok := cr.sc.session.tcpCallbackByChannel[twhat.Channel]; ok {
cb(twhat.Payload)
}
case *base.Request: case *base.Request:
cres := make(chan error) cres := make(chan error)
req := readReq{req: twhat, res: cres} req := readReq{req: what, res: cres}
err := cr.sc.handleRequest(req) err := cr.sc.readRequest(req)
if err != nil { if err != nil {
return err return err
} }
case *base.Response:
return liberrors.ErrServerUnexpectedResponse{}
case *base.InterleavedFrame:
atomic.AddUint64(cr.sc.session.bytesReceived, uint64(len(what.Payload)))
if cb, ok := cr.sc.session.tcpCallbackByChannel[what.Channel]; ok {
cb(what.Payload)
}
} }
} }
} }

View File

@@ -119,7 +119,7 @@ func doSetup(t *testing.T, conn *conn.Conn, u string,
return res, &th return res, &th
} }
func doPlay(t *testing.T, conn *conn.Conn, u string, session string) { func doPlay(t *testing.T, conn *conn.Conn, u string, session string) *base.Response {
res, err := writeReqReadRes(conn, base.Request{ res, err := writeReqReadRes(conn, base.Request{
Method: base.Play, Method: base.Play,
URL: mustParseURL(u), URL: mustParseURL(u),
@@ -130,6 +130,7 @@ func doPlay(t *testing.T, conn *conn.Conn, u string, session string) {
}) })
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
return res
} }
func doPause(t *testing.T, conn *conn.Conn, u string, session string) { func doPause(t *testing.T, conn *conn.Conn, u string, session string) {
@@ -1887,7 +1888,7 @@ func TestServerPlayAdditionalInfos(t *testing.T) {
ssrcs[1] = th.SSRC ssrcs[1] = th.SSRC
doPlay(t, conn, "rtsp://localhost:8554/teststream", session) res = doPlay(t, conn, "rtsp://localhost:8554/teststream", session)
var ri headers.RTPInfo var ri headers.RTPInfo
err = ri.Unmarshal(res.Header["RTP-Info"]) err = ri.Unmarshal(res.Header["RTP-Info"])

View File

@@ -102,6 +102,35 @@ func findFirstSupportedTransportHeader(s *Server, tsh headers.Transports) *heade
return nil return nil
} }
func generateRTPInfo(
now time.Time,
setuppedMediasOrdered []*serverSessionMedia,
setuppedStream *ServerStream,
setuppedPath string,
u *url.URL,
) (headers.RTPInfo, bool) {
var ri headers.RTPInfo
for _, sm := range setuppedMediasOrdered {
entry := setuppedStream.rtpInfoEntry(sm.media, now)
if entry != nil {
entry.URL = (&url.URL{
Scheme: u.Scheme,
Host: u.Host,
Path: setuppedPath + "/trackID=" +
strconv.FormatInt(int64(setuppedStream.streamMedias[sm.media].trackID), 10),
}).String()
ri = append(ri, entry)
}
}
if len(ri) == 0 {
return nil, false
}
return ri, true
}
// ServerSessionState is a state of a ServerSession. // ServerSessionState is a state of a ServerSession.
type ServerSessionState int type ServerSessionState int
@@ -900,26 +929,18 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
// writer.start() is called by ServerConn after the response has been sent // writer.start() is called by ServerConn after the response has been sent
} }
var ri headers.RTPInfo rtpInfo, ok := generateRTPInfo(
now := ss.s.timeNow() ss.s.timeNow(),
ss.setuppedMediasOrdered,
ss.setuppedStream,
*ss.setuppedPath,
req.URL)
for _, sm := range ss.setuppedMediasOrdered { if ok {
entry := ss.setuppedStream.rtpInfoEntry(sm.media, now)
if entry != nil {
entry.URL = (&url.URL{
Scheme: req.URL.Scheme,
Host: req.URL.Host,
Path: *ss.setuppedPath + "/trackID=" +
strconv.FormatInt(int64(ss.setuppedStream.streamMedias[sm.media].trackID), 10),
}).String()
ri = append(ri, entry)
}
}
if len(ri) > 0 {
if res.Header == nil { if res.Header == nil {
res.Header = make(base.Header) res.Header = make(base.Header)
} }
res.Header["RTP-Info"] = ri.Marshal() res.Header["RTP-Info"] = rtpInfo.Marshal()
} }
return res, err return res, err