diff --git a/client.go b/client.go index a131bcaf..1c4eecff 100644 --- a/client.go +++ b/client.go @@ -315,18 +315,19 @@ type Client struct { closeError error writer asyncProcessor reader *clientReader - connCloser *clientConnCloser timeDecoder *rtptime.GlobalDecoder // in - options chan optionsReq - describe chan describeReq - announce chan announceReq - setup chan setupReq - play chan playReq - record chan recordReq - pause chan pauseReq - readError chan error + chOptions chan optionsReq + chDescribe chan describeReq + chAnnounce chan announceReq + chSetup chan setupReq + chPlay chan playReq + chRecord chan recordReq + chPause chan pauseReq + chReadError chan error + chReadResponse chan *base.Response + chReadRequest chan *base.Request // out done chan struct{} @@ -425,14 +426,16 @@ func (c *Client) Start(scheme string, host string) error { c.ctxCancel = ctxCancel c.checkTimeoutTimer = emptyTimer() c.keepaliveTimer = emptyTimer() - c.options = make(chan optionsReq) - c.describe = make(chan describeReq) - c.announce = make(chan announceReq) - c.setup = make(chan setupReq) - c.play = make(chan playReq) - c.record = make(chan recordReq) - c.pause = make(chan pauseReq) - c.readError = make(chan error) + c.chOptions = make(chan optionsReq) + c.chDescribe = make(chan describeReq) + c.chAnnounce = make(chan announceReq) + c.chSetup = make(chan setupReq) + c.chPlay = make(chan playReq) + c.chRecord = make(chan recordReq) + c.chPause = make(chan pauseReq) + c.chReadError = make(chan error) + c.chReadResponse = make(chan *base.Response) + c.chReadRequest = make(chan *base.Request) c.done = make(chan struct{}) go c.run() @@ -499,76 +502,133 @@ func (c *Client) run() { func (c *Client) runInner() error { for { select { - case req := <-c.options: + case req := <-c.chOptions: res, err := c.doOptions(req.url) req.res <- clientRes{res: res, err: err} - case req := <-c.describe: + case req := <-c.chDescribe: sd, res, err := c.doDescribe(req.url) req.res <- clientRes{sd: sd, res: res, err: err} - case req := <-c.announce: + case req := <-c.chAnnounce: res, err := c.doAnnounce(req.url, req.desc) req.res <- clientRes{res: res, err: err} - case req := <-c.setup: + case req := <-c.chSetup: res, err := c.doSetup(req.baseURL, req.media, req.rtpPort, req.rtcpPort) req.res <- clientRes{res: res, err: err} - case req := <-c.play: + case req := <-c.chPlay: res, err := c.doPlay(req.ra) req.res <- clientRes{res: res, err: err} - case req := <-c.record: + case req := <-c.chRecord: res, err := c.doRecord() req.res <- clientRes{res: res, err: err} - case req := <-c.pause: + case req := <-c.chPause: res, err := c.doPause() req.res <- clientRes{res: res, err: err} case <-c.checkTimeoutTimer.C: - err := c.checkTimeout() + err := c.doCheckTimeout() if err != nil { return err } c.checkTimeoutTimer = time.NewTimer(c.checkTimeoutPeriod) case <-c.keepaliveTimer.C: - err := c.doKeepalive() + err := c.doKeepAlive() if err != nil { return err } c.keepaliveTimer = time.NewTimer(c.keepalivePeriod) - case err := <-c.readError: + case err := <-c.chReadError: c.reader = nil return err + case <-c.chReadResponse: + return liberrors.ErrClientUnexpectedResponse{} + + case req := <-c.chReadRequest: + err := c.handleServerRequest(req) + if err != nil { + return err + } + case <-c.ctx.Done(): return liberrors.ErrClientTerminated{} } } } +func (c *Client) waitResponse() (*base.Response, error) { + for { + select { + case <-time.After(c.ReadTimeout): + return nil, liberrors.ErrClientRequestTimedOut{} + + case err := <-c.chReadError: + c.reader = nil + return nil, err + + case res := <-c.chReadResponse: + return res, nil + + case req := <-c.chReadRequest: + err := c.handleServerRequest(req) + if err != nil { + return nil, err + } + + case <-c.ctx.Done(): + return nil, liberrors.ErrClientTerminated{} + } + } +} + +func (c *Client) handleServerRequest(req *base.Request) error { + if req.Method != base.Options { + return liberrors.ErrClientUnhandledMethod{Method: req.Method} + } + + if cseq, ok := req.Header["CSeq"]; !ok || len(cseq) != 1 { + return liberrors.ErrClientMissingCSeq{} + } + + res := &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "User-Agent": base.HeaderValue{c.UserAgent}, + "CSeq": req.Header["CSeq"], + }, + } + + c.nconn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) + return c.conn.WriteResponse(res) +} + func (c *Client) doClose() { - if c.connCloser != nil { - c.connCloser.close() - c.connCloser = nil - } - if c.state == clientStatePlay || c.state == clientStateRecord { - c.playRecordStop(true) + c.stopWriter() + c.stopReadRoutines() } - if c.baseURL != nil { + if c.nconn != nil && c.baseURL != nil { c.do(&base.Request{ //nolint:errcheck Method: base.Teardown, URL: c.baseURL, - }, true, false) + }, true) } - if c.nconn != nil { + if c.reader != nil { + c.nconn.Close() + c.reader.wait() + c.reader = nil + c.nconn = nil + c.conn = nil + } else if c.nconn != nil { c.nconn.Close() c.nconn = nil c.conn = nil @@ -668,12 +728,8 @@ func (c *Client) trySwitchingProtocol2(medi *description.Media, baseURL *url.URL return c.doSetup(baseURL, medi, 0, 0) } -func (c *Client) playRecordStart() { - c.connCloser.close() - c.connCloser = nil - - c.timeDecoder = rtptime.NewGlobalDecoder() - +func (c *Client) startReadRoutines() { + // allocate writer here because it's needed by RTCP receiver / sender if c.state == clientStatePlay { // when reading, buffer is only used to send RTCP receiver reports, // that are much smaller than RTP packets and are sent at a fixed interval. @@ -683,7 +739,7 @@ func (c *Client) playRecordStart() { c.writer.allocateBuffer(c.WriteBufferCount) } - c.writer.start() + c.timeDecoder = rtptime.NewGlobalDecoder() for _, cm := range c.medias { cm.start() @@ -707,14 +763,14 @@ func (c *Client) playRecordStart() { } } - c.reader = newClientReader(c) + if *c.effectiveTransport == TransportTCP { + c.reader.setAllowInterleavedFrames(true) + } } -func (c *Client) playRecordStop(isClosing bool) { +func (c *Client) stopReadRoutines() { if c.reader != nil { - c.reader.close() - <-c.readError - c.reader = nil + c.reader.setAllowInterleavedFrames(false) } c.checkTimeoutTimer = emptyTimer() @@ -724,22 +780,28 @@ func (c *Client) playRecordStop(isClosing bool) { cm.stop() } - c.writer.stop() - c.timeDecoder = nil +} - if !isClosing { - c.connCloser = newClientConnCloser(c.ctx, c.nconn) - } +func (c *Client) startWriter() { + c.writer.start() +} + +func (c *Client) stopWriter() { + c.writer.stop() } func (c *Client) connOpen() error { + if c.nconn != nil { + return nil + } + if c.connURL.Scheme != "rtsp" && c.connURL.Scheme != "rtsps" { - return fmt.Errorf("unsupported scheme '%s'", c.connURL.Scheme) + return liberrors.ErrClientUnsupportedScheme{Scheme: c.connURL.Scheme} } if c.connURL.Scheme == "rtsps" && c.Transport != nil && *c.Transport != TransportTCP { - return fmt.Errorf("RTSPS can be used only with TCP") + return liberrors.ErrClientRTSPSTCP{} } dialCtx, dialCtxCancel := context.WithTimeout(c.ctx, c.ReadTimeout) @@ -752,11 +814,9 @@ func (c *Client) connOpen() error { if c.connURL.Scheme == "rtsps" { tlsConfig := c.TLSConfig - if tlsConfig == nil { tlsConfig = &tls.Config{} } - tlsConfig.ServerName = c.connURL.Hostname() nconn = tls.Client(nconn, tlsConfig) @@ -765,19 +825,12 @@ func (c *Client) connOpen() error { c.nconn = nconn bc := bytecounter.New(c.nconn, c.BytesReceived, c.BytesSent) c.conn = conn.NewConn(bc) - c.connCloser = newClientConnCloser(c.ctx, c.nconn) + c.reader = newClientReader(c) return nil } -func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*base.Response, error) { - if c.nconn == nil { - err := c.connOpen() - if err != nil { - return nil, err - } - } - +func (c *Client) do(req *base.Request, skipResponse bool) (*base.Response, error) { if !c.optionsSent && req.Method != base.Options { _, err := c.doOptions(req.URL) if err != nil { @@ -814,18 +867,9 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba return nil, nil } - c.nconn.SetReadDeadline(time.Now().Add(c.ReadTimeout)) - var res *base.Response - if allowFrames { - // read the response and ignore interleaved frames in between; - // interleaved frames are sent in two cases: - // * when the server is v4lrtspserver, before the PLAY response - // * when the stream is already playing - res, err = c.conn.ReadResponseIgnoreFrames() - } else { - res, err = c.conn.ReadResponse() - } + res, err := c.waitResponse() if err != nil { + c.ctxCancel() return nil, err } @@ -856,7 +900,7 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba } c.sender = sender - return c.do(req, skipResponse, allowFrames) + return c.do(req, skipResponse) } return res, nil @@ -899,7 +943,7 @@ func (c *Client) isInTCPTimeout() bool { return now.Sub(lft) >= c.ReadTimeout } -func (c *Client) checkTimeout() error { +func (c *Client) doCheckTimeout() error { if *c.effectiveTransport == TransportUDP || *c.effectiveTransport == TransportUDPMulticast { if c.checkTimeoutInitial { @@ -921,7 +965,7 @@ func (c *Client) checkTimeout() error { return nil } -func (c *Client) doKeepalive() error { +func (c *Client) doKeepAlive() error { _, err := c.do(&base.Request{ Method: func() base.Method { // the VLC integrated rtsp server requires GET_PARAMETER @@ -932,7 +976,7 @@ func (c *Client) doKeepalive() error { }(), // use the stream base URL, otherwise some cameras do not reply URL: c.baseURL, - }, true, false) + }, false) return err } @@ -946,10 +990,15 @@ func (c *Client) doOptions(u *url.URL) (*base.Response, error) { return nil, err } + err = c.connOpen() + if err != nil { + return nil, err + } + res, err := c.do(&base.Request{ Method: base.Options, URL: u, - }, false, false) + }, false) if err != nil { return nil, err } @@ -973,12 +1022,12 @@ func (c *Client) doOptions(u *url.URL) (*base.Response, error) { func (c *Client) Options(u *url.URL) (*base.Response, error) { cres := make(chan clientRes) select { - case c.options <- optionsReq{url: u, res: cres}: + case c.chOptions <- optionsReq{url: u, res: cres}: res := <-cres return res.res, res.err - case <-c.ctx.Done(): - return nil, liberrors.ErrClientTerminated{} + case <-c.done: + return nil, c.closeError } } @@ -992,13 +1041,18 @@ func (c *Client) doDescribe(u *url.URL) (*description.Session, *base.Response, e return nil, nil, err } + err = c.connOpen() + if err != nil { + return nil, nil, err + } + res, err := c.do(&base.Request{ Method: base.Describe, URL: u, Header: base.Header{ "Accept": base.HeaderValue{"application/sdp"}, }, - }, false, false) + }, false) if err != nil { return nil, nil, err } @@ -1069,12 +1123,12 @@ func (c *Client) doDescribe(u *url.URL) (*description.Session, *base.Response, e func (c *Client) Describe(u *url.URL) (*description.Session, *base.Response, error) { cres := make(chan clientRes) select { - case c.describe <- describeReq{url: u, res: cres}: + case c.chDescribe <- describeReq{url: u, res: cres}: res := <-cres return res.sd, res.res, res.err - case <-c.ctx.Done(): - return nil, nil, liberrors.ErrClientTerminated{} + case <-c.done: + return nil, nil, c.closeError } } @@ -1086,6 +1140,11 @@ func (c *Client) doAnnounce(u *url.URL, desc *description.Session) (*base.Respon return nil, err } + err = c.connOpen() + if err != nil { + return nil, err + } + prepareForAnnounce(desc) byts, err := desc.Marshal(false) @@ -1100,7 +1159,7 @@ func (c *Client) doAnnounce(u *url.URL, desc *description.Session) (*base.Respon "Content-Type": base.HeaderValue{"application/sdp"}, }, Body: byts, - }, false, false) + }, false) if err != nil { return nil, err } @@ -1121,12 +1180,12 @@ func (c *Client) doAnnounce(u *url.URL, desc *description.Session) (*base.Respon func (c *Client) Announce(u *url.URL, desc *description.Session) (*base.Response, error) { cres := make(chan clientRes) select { - case c.announce <- announceReq{url: u, desc: desc, res: cres}: + case c.chAnnounce <- announceReq{url: u, desc: desc, res: cres}: res := <-cres return res.res, res.err - case <-c.ctx.Done(): - return nil, liberrors.ErrClientTerminated{} + case <-c.done: + return nil, c.closeError } } @@ -1145,6 +1204,11 @@ func (c *Client) doSetup( return nil, err } + err = c.connOpen() + if err != nil { + return nil, err + } + if c.baseURL != nil && *baseURL != *c.baseURL { return nil, liberrors.ErrClientCannotSetupMediasDifferentURLs{} } @@ -1229,7 +1293,7 @@ func (c *Client) doSetup( Header: base.Header{ "Transport": th.Marshal(), }, - }, false, false) + }, false) if err != nil { cm.close() return nil, err @@ -1428,7 +1492,7 @@ func (c *Client) Setup( ) (*base.Response, error) { cres := make(chan clientRes) select { - case c.setup <- setupReq{ + case c.chSetup <- setupReq{ baseURL: baseURL, media: media, rtpPort: rtpPort, @@ -1438,8 +1502,8 @@ func (c *Client) Setup( res := <-cres return res.res, res.err - case <-c.ctx.Done(): - return nil, liberrors.ErrClientTerminated{} + case <-c.done: + return nil, c.closeError } } @@ -1462,19 +1526,8 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) { return nil, err } - // open the firewall by sending empty packets to the counterpart. - // do this before sending the request. - // don't do this with multicast, otherwise the RTP packet is going to be broadcasted - // to all listeners, including us, messing up the stream. - if *c.effectiveTransport == TransportUDP { - for _, ct := range c.medias { - byts, _ := (&rtp.Packet{Header: rtp.Header{Version: 2}}).Marshal() - ct.udpRTPListener.write(byts) //nolint:errcheck - - byts, _ = (&rtcp.ReceiverReport{}).Marshal() - ct.udpRTCPListener.write(byts) //nolint:errcheck - } - } + c.state = clientStatePlay + c.startReadRoutines() // Range is mandatory in Parrot Streaming Server if ra == nil { @@ -1491,20 +1544,37 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) { Header: base.Header{ "Range": ra.Marshal(), }, - }, false, *c.effectiveTransport == TransportTCP) + }, false) if err != nil { + c.stopReadRoutines() + c.state = clientStatePrePlay return nil, err } if res.StatusCode != base.StatusOK { + c.stopReadRoutines() + c.state = clientStatePrePlay return nil, liberrors.ErrClientBadStatusCode{ Code: res.StatusCode, Message: res.StatusMessage, } } + // open the firewall by sending empty packets to the counterpart. + // do this before sending the request. + // don't do this with multicast, otherwise the RTP packet is going to be broadcasted + // to all listeners, including us, messing up the stream. + if *c.effectiveTransport == TransportUDP { + for _, cm := range c.medias { + byts, _ := (&rtp.Packet{Header: rtp.Header{Version: 2}}).Marshal() + cm.udpRTPListener.write(byts) //nolint:errcheck + + byts, _ = (&rtcp.ReceiverReport{}).Marshal() + cm.udpRTCPListener.write(byts) //nolint:errcheck + } + } + + c.startWriter() c.lastRange = ra - c.state = clientStatePlay - c.playRecordStart() return res, nil } @@ -1514,12 +1584,12 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) { func (c *Client) Play(ra *headers.Range) (*base.Response, error) { cres := make(chan clientRes) select { - case c.play <- playReq{ra: ra, res: cres}: + case c.chPlay <- playReq{ra: ra, res: cres}: res := <-cres return res.res, res.err - case <-c.ctx.Done(): - return nil, liberrors.ErrClientTerminated{} + case <-c.done: + return nil, c.closeError } } @@ -1531,22 +1601,28 @@ func (c *Client) doRecord() (*base.Response, error) { return nil, err } + c.state = clientStateRecord + c.startReadRoutines() + res, err := c.do(&base.Request{ Method: base.Record, URL: c.baseURL, - }, false, false) + }, false) if err != nil { + c.stopReadRoutines() + c.state = clientStatePreRecord return nil, err } if res.StatusCode != base.StatusOK { + c.stopReadRoutines() + c.state = clientStatePreRecord return nil, liberrors.ErrClientBadStatusCode{ Code: res.StatusCode, Message: res.StatusMessage, } } - c.state = clientStateRecord - c.playRecordStart() + c.startWriter() return nil, nil } @@ -1556,12 +1632,12 @@ func (c *Client) doRecord() (*base.Response, error) { func (c *Client) Record() (*base.Response, error) { cres := make(chan clientRes) select { - case c.record <- recordReq{res: cres}: + case c.chRecord <- recordReq{res: cres}: res := <-cres return res.res, res.err - case <-c.ctx.Done(): - return nil, liberrors.ErrClientTerminated{} + case <-c.done: + return nil, c.closeError } } @@ -1574,9 +1650,26 @@ func (c *Client) doPause() (*base.Response, error) { return nil, err } - c.playRecordStop(false) + c.stopWriter() + + res, err := c.do(&base.Request{ + Method: base.Pause, + URL: c.baseURL, + }, false) + if err != nil { + c.startWriter() + return nil, err + } + + if res.StatusCode != base.StatusOK { + c.startWriter() + return nil, liberrors.ErrClientBadStatusCode{ + Code: res.StatusCode, Message: res.StatusMessage, + } + } + + c.stopReadRoutines() - // change state regardless of the response switch c.state { case clientStatePlay: c.state = clientStatePrePlay @@ -1584,20 +1677,6 @@ func (c *Client) doPause() (*base.Response, error) { c.state = clientStatePreRecord } - res, err := c.do(&base.Request{ - Method: base.Pause, - URL: c.baseURL, - }, false, *c.effectiveTransport == TransportTCP) - if err != nil { - return nil, err - } - - if res.StatusCode != base.StatusOK { - return nil, liberrors.ErrClientBadStatusCode{ - Code: res.StatusCode, Message: res.StatusMessage, - } - } - return res, nil } @@ -1606,12 +1685,12 @@ func (c *Client) doPause() (*base.Response, error) { func (c *Client) Pause() (*base.Response, error) { cres := make(chan clientRes) select { - case c.pause <- pauseReq{res: cres}: + case c.chPause <- pauseReq{res: cres}: res := <-cres return res.res, res.err - case <-c.ctx.Done(): - return nil, liberrors.ErrClientTerminated{} + case <-c.done: + return nil, c.closeError } } @@ -1720,3 +1799,15 @@ func (c *Client) PacketNTP(medi *description.Media, pkt *rtp.Packet) (time.Time, ct := cm.formats[pkt.PayloadType] return ct.rtcpReceiver.PacketNTP(pkt.Timestamp) } + +func (c *Client) readResponse(res *base.Response) { + c.chReadResponse <- res +} + +func (c *Client) readRequest(req *base.Request) { + c.chReadRequest <- req +} + +func (c *Client) readError(err error) { + c.chReadError <- err +} diff --git a/client_conn_closer.go b/client_conn_closer.go deleted file mode 100644 index 3623109b..00000000 --- a/client_conn_closer.go +++ /dev/null @@ -1,43 +0,0 @@ -package gortsplib - -import ( - "context" - "net" -) - -type clientConnCloser struct { - ctx context.Context - nconn net.Conn - - terminate chan struct{} - done chan struct{} -} - -func newClientConnCloser(ctx context.Context, nconn net.Conn) *clientConnCloser { - cc := &clientConnCloser{ - ctx: ctx, - nconn: nconn, - terminate: make(chan struct{}), - done: make(chan struct{}), - } - - go cc.run() - - return cc -} - -func (cc *clientConnCloser) close() { - close(cc.terminate) - <-cc.done -} - -func (cc *clientConnCloser) run() { - defer close(cc.done) - - select { - case <-cc.ctx.Done(): - cc.nconn.Close() - - case <-cc.terminate: - } -} diff --git a/client_reader.go b/client_reader.go index e11965d1..ee1aec6f 100644 --- a/client_reader.go +++ b/client_reader.go @@ -1,65 +1,69 @@ package gortsplib import ( - "time" + "sync/atomic" "github.com/bluenviron/gortsplib/v4/pkg/base" + "github.com/bluenviron/gortsplib/v4/pkg/liberrors" ) type clientReader struct { - c *Client - closeErr chan error + c *Client + allowInterleavedFrames atomic.Bool } func newClientReader(c *Client) *clientReader { r := &clientReader{ - c: c, - closeErr: make(chan error), + c: c, } - // for some reason, SetReadDeadline() must always be called in the same - // goroutine, otherwise Read() freezes. - // therefore, we disable the deadline and perform a check with a ticker. - r.c.nconn.SetReadDeadline(time.Time{}) - go r.run() return r } -func (r *clientReader) close() { - r.c.nconn.SetReadDeadline(time.Now()) +func (r *clientReader) setAllowInterleavedFrames(v bool) { + r.allowInterleavedFrames.Store(v) +} + +func (r *clientReader) wait() { + for { + select { + case <-r.c.chReadError: + return + + case <-r.c.chReadResponse: + case <-r.c.chReadRequest: + } + } } func (r *clientReader) run() { - r.c.readError <- r.runInner() + err := r.runInner() + r.c.readError(err) } func (r *clientReader) runInner() error { - if *r.c.effectiveTransport == TransportUDP || *r.c.effectiveTransport == TransportUDPMulticast { - for { - res, err := r.c.conn.ReadResponse() - if err != nil { - return err - } - - r.c.OnResponse(res) + for { + what, err := r.c.conn.Read() + if err != nil { + return err } - } else { - for { - what, err := r.c.conn.ReadInterleavedFrameOrResponse() - if err != nil { - return err + + switch what := what.(type) { + case *base.Response: + r.c.readResponse(what) + + case *base.Request: + r.c.readRequest(what) + + case *base.InterleavedFrame: + if !r.allowInterleavedFrames.Load() { + return liberrors.ErrClientUnexpectedFrame{} } - switch what := what.(type) { - case *base.Response: - r.c.OnResponse(what) - - case *base.InterleavedFrame: - if cb, ok := r.c.tcpCallbackByChannel[what.Channel]; ok { - cb(what.Payload) - } + if cb, ok := r.c.tcpCallbackByChannel[what.Channel]; ok { + cb(what.Payload) } } } diff --git a/client_record_test.go b/client_record_test.go index 34015d8e..e3a04e85 100644 --- a/client_record_test.go +++ b/client_record_test.go @@ -3,6 +3,7 @@ package gortsplib import ( "bytes" "crypto/tls" + "fmt" "net" "strings" "sync" @@ -39,7 +40,7 @@ var testRTPPacket = rtp.Packet{ CSRC: []uint32{}, SSRC: 0x38F27A2F, }, - Payload: []byte{0x01, 0x02, 0x03, 0x04}, + Payload: []byte{1, 2, 3, 4}, } var testRTPPacketMarshaled = mustMarshalPacketRTP(&testRTPPacket) @@ -101,6 +102,23 @@ func record(c *Client, ur string, medias []*description.Media, cb func(*descript return nil } +func readRequestIgnoreFrames(c *conn.Conn) (*base.Request, error) { + for { + what, err := c.Read() + if err != nil { + return nil, err + } + + switch what := what.(type) { + case *base.InterleavedFrame: + case *base.Request: + return what, nil + case *base.Response: + return nil, fmt.Errorf("unexpected response") + } + } +} + func TestClientRecordSerial(t *testing.T) { for _, transport := range []string{ "udp", @@ -412,7 +430,7 @@ func TestClientRecordParallel(t *testing.T) { }) require.NoError(t, err) - req, err = conn.ReadRequestIgnoreFrames() + req, err = readRequestIgnoreFrames(conn) require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) @@ -552,7 +570,7 @@ func TestClientRecordPauseSerial(t *testing.T) { }) require.NoError(t, err) - req, err = conn.ReadRequestIgnoreFrames() + req, err = readRequestIgnoreFrames(conn) require.NoError(t, err) require.Equal(t, base.Pause, req.Method) @@ -570,7 +588,7 @@ func TestClientRecordPauseSerial(t *testing.T) { }) require.NoError(t, err) - req, err = conn.ReadRequestIgnoreFrames() + req, err = readRequestIgnoreFrames(conn) require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) @@ -700,7 +718,7 @@ func TestClientRecordPauseParallel(t *testing.T) { }) require.NoError(t, err) - req, err = conn.ReadRequestIgnoreFrames() + req, err = readRequestIgnoreFrames(conn) require.NoError(t, err) require.Equal(t, base.Pause, req.Method) diff --git a/client_test.go b/client_test.go index 24c98bea..e1733d71 100644 --- a/client_test.go +++ b/client_test.go @@ -375,3 +375,92 @@ func TestClientCloseDuringRequest(t *testing.T) { <-optionsDone close(releaseConn) } + +func TestClientReplyToServerRequest(t *testing.T) { + for _, ca := range []string{"after response", "before response"} { + t.Run(ca, func(t *testing.T) { + l, err := net.Listen("tcp", "localhost:8554") + require.NoError(t, err) + defer l.Close() + + serverDone := make(chan struct{}) + + go func() { + defer close(serverDone) + + nconn, err := l.Accept() + require.NoError(t, err) + conn := conn.NewConn(nconn) + defer nconn.Close() + + req, err := conn.ReadRequest() + require.NoError(t, err) + require.Equal(t, base.Options, req.Method) + + if ca == "after response" { + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Public": base.HeaderValue{strings.Join([]string{ + string(base.Describe), + }, ", ")}, + }, + }) + require.NoError(t, err) + + err = conn.WriteRequest(&base.Request{ + Method: base.Options, + URL: nil, + Header: base.Header{ + "CSeq": base.HeaderValue{"4"}, + }, + }) + require.NoError(t, err) + + res, err := conn.ReadResponse() + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + require.Equal(t, "4", res.Header["CSeq"][0]) + } else { + err = conn.WriteRequest(&base.Request{ + Method: base.Options, + URL: nil, + Header: base.Header{ + "CSeq": base.HeaderValue{"4"}, + }, + }) + require.NoError(t, err) + + res, err := conn.ReadResponse() + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + require.Equal(t, "4", res.Header["CSeq"][0]) + + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Public": base.HeaderValue{strings.Join([]string{ + string(base.Describe), + }, ", ")}, + }, + }) + require.NoError(t, err) + } + }() + + u, err := url.Parse("rtsp://localhost:8554/stream") + require.NoError(t, err) + + c := Client{} + + err = c.Start(u.Scheme, u.Host) + require.NoError(t, err) + defer c.Close() + + _, err = c.Options(u) + require.NoError(t, err) + + <-serverDone + }) + } +} diff --git a/pkg/base/request.go b/pkg/base/request.go index d3104e7e..d859f19b 100644 --- a/pkg/base/request.go +++ b/pkg/base/request.go @@ -66,11 +66,15 @@ func (req *Request) Unmarshal(br *bufio.Reader) error { } rawURL := string(byts[:len(byts)-1]) - ur, err := url.Parse(rawURL) - if err != nil { - return fmt.Errorf("invalid URL (%v)", rawURL) + if rawURL != "*" { + ur, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid URL (%v)", rawURL) + } + req.URL = ur + } else { + req.URL = nil } - req.URL = ur byts, err = readBytesLimited(br, '\r', requestMaxProtocolLength) if err != nil { @@ -102,10 +106,15 @@ func (req *Request) Unmarshal(br *bufio.Reader) error { // MarshalSize returns the size of a Request. func (req Request) MarshalSize() int { - n := 0 + n := len(req.Method) + 1 - urStr := req.URL.CloneWithoutCredentials().String() - n += len([]byte(string(req.Method) + " " + urStr + " " + rtspProtocol10 + "\r\n")) + if req.URL != nil { + n += len(req.URL.CloneWithoutCredentials().String()) + } else { + n++ + } + + n += 1 + len(rtspProtocol10) + 2 if len(req.Body) != 0 { req.Header["Content-Length"] = HeaderValue{strconv.FormatInt(int64(len(req.Body)), 10)} @@ -122,8 +131,23 @@ func (req Request) MarshalSize() int { func (req Request) MarshalTo(buf []byte) (int, error) { pos := 0 - urStr := req.URL.CloneWithoutCredentials().String() - pos += copy(buf[pos:], []byte(string(req.Method)+" "+urStr+" "+rtspProtocol10+"\r\n")) + pos += copy(buf[pos:], []byte(req.Method)) + buf[pos] = ' ' + pos++ + + if req.URL != nil { + pos += copy(buf[pos:], []byte(req.URL.CloneWithoutCredentials().String())) + } else { + pos += copy(buf[pos:], []byte("*")) + } + + buf[pos] = ' ' + pos++ + pos += copy(buf[pos:], rtspProtocol10) + buf[pos] = '\r' + pos++ + buf[pos] = '\n' + pos++ if len(req.Body) != 0 { req.Header["Content-Length"] = HeaderValue{strconv.FormatInt(int64(len(req.Body)), 10)} diff --git a/pkg/base/request_test.go b/pkg/base/request_test.go index c09661c1..d39cdf11 100644 --- a/pkg/base/request_test.go +++ b/pkg/base/request_test.go @@ -138,6 +138,21 @@ var casesRequest = []struct { ), }, }, + { + "server-side announce", + []byte("OPTIONS * RTSP/1.0\r\n" + + "CSeq: 1\r\n" + + "User-Agent: RDIPCamera\r\n" + + "\r\n"), + Request{ + Method: "OPTIONS", + URL: nil, + Header: Header{ + "CSeq": HeaderValue{"1"}, + "User-Agent": HeaderValue{"RDIPCamera"}, + }, + }, + }, } func TestRequestUnmarshal(t *testing.T) { diff --git a/pkg/base/response.go b/pkg/base/response.go index 20336495..8f6b4ced 100644 --- a/pkg/base/response.go +++ b/pkg/base/response.go @@ -194,9 +194,7 @@ func (res Response) MarshalSize() int { } } - n += len([]byte(rtspProtocol10 + " " + - strconv.FormatInt(int64(res.StatusCode), 10) + " " + - res.StatusMessage + "\r\n")) + n += len(rtspProtocol10) + 1 + len(strconv.FormatInt(int64(res.StatusCode), 10)) + 1 + len(res.StatusMessage) + 2 if len(res.Body) != 0 { res.Header["Content-Length"] = HeaderValue{strconv.FormatInt(int64(len(res.Body)), 10)} @@ -219,9 +217,17 @@ func (res Response) MarshalTo(buf []byte) (int, error) { pos := 0 - pos += copy(buf[pos:], []byte(rtspProtocol10+" "+ - strconv.FormatInt(int64(res.StatusCode), 10)+" "+ - res.StatusMessage+"\r\n")) + pos += copy(buf[pos:], []byte(rtspProtocol10)) + buf[pos] = ' ' + pos++ + pos += copy(buf[pos:], []byte(strconv.FormatInt(int64(res.StatusCode), 10))) + buf[pos] = ' ' + pos++ + pos += copy(buf[pos:], []byte(res.StatusMessage)) + buf[pos] = '\r' + pos++ + buf[pos] = '\n' + pos++ if len(res.Body) != 0 { res.Header["Content-Length"] = HeaderValue{strconv.FormatInt(int64(len(res.Body)), 10)} diff --git a/pkg/conn/conn.go b/pkg/conn/conn.go index 66a63206..59d63978 100644 --- a/pkg/conn/conn.go +++ b/pkg/conn/conn.go @@ -14,11 +14,11 @@ const ( // Conn is a RTSP connection. type Conn struct { - w io.Writer - br *bufio.Reader - req base.Request - res base.Response - fr base.InterleavedFrame + w io.Writer + br *bufio.Reader + + // reuse interleaved frames. they should never be passed to secondary routines + fr base.InterleavedFrame } // NewConn allocates a Conn. @@ -29,16 +29,36 @@ func NewConn(rw io.ReadWriter) *Conn { } } +// Read reads a Request, a Response or an Interleaved frame. +func (c *Conn) Read() (interface{}, error) { + byts, err := c.br.Peek(2) + if err != nil { + return nil, err + } + + if byts[0] == base.InterleavedFrameMagicByte { + return c.ReadInterleavedFrame() + } + + if byts[0] == 'R' && byts[1] == 'T' { + return c.ReadResponse() + } + + return c.ReadRequest() +} + // ReadRequest reads a Request. func (c *Conn) ReadRequest() (*base.Request, error) { - err := c.req.Unmarshal(c.br) - return &c.req, err + var req base.Request + err := req.Unmarshal(c.br) + return &req, err } // ReadResponse reads a Response. func (c *Conn) ReadResponse() (*base.Response, error) { - err := c.res.Unmarshal(c.br) - return &c.res, err + var res base.Response + err := res.Unmarshal(c.br) + return &res, err } // ReadInterleavedFrame reads a InterleavedFrame. @@ -47,64 +67,6 @@ func (c *Conn) ReadInterleavedFrame() (*base.InterleavedFrame, error) { return &c.fr, err } -// ReadInterleavedFrameOrRequest reads an InterleavedFrame or a Request. -func (c *Conn) ReadInterleavedFrameOrRequest() (interface{}, error) { - b, err := c.br.ReadByte() - if err != nil { - return nil, err - } - c.br.UnreadByte() //nolint:errcheck - - if b == base.InterleavedFrameMagicByte { - return c.ReadInterleavedFrame() - } - - return c.ReadRequest() -} - -// ReadInterleavedFrameOrResponse reads an InterleavedFrame or a Response. -func (c *Conn) ReadInterleavedFrameOrResponse() (interface{}, error) { - b, err := c.br.ReadByte() - if err != nil { - return nil, err - } - c.br.UnreadByte() //nolint:errcheck - - if b == base.InterleavedFrameMagicByte { - return c.ReadInterleavedFrame() - } - - return c.ReadResponse() -} - -// ReadRequestIgnoreFrames reads a Request and ignores frames in between. -func (c *Conn) ReadRequestIgnoreFrames() (*base.Request, error) { - for { - recv, err := c.ReadInterleavedFrameOrRequest() - if err != nil { - return nil, err - } - - if req, ok := recv.(*base.Request); ok { - return req, nil - } - } -} - -// ReadResponseIgnoreFrames reads a Response and ignores frames in between. -func (c *Conn) ReadResponseIgnoreFrames() (*base.Response, error) { - for { - recv, err := c.ReadInterleavedFrameOrResponse() - if err != nil { - return nil, err - } - - if res, ok := recv.(*base.Response); ok { - return res, nil - } - } -} - // WriteRequest writes a request. func (c *Conn) WriteRequest(req *base.Request) error { buf, _ := req.Marshal() diff --git a/pkg/conn/conn_test.go b/pkg/conn/conn_test.go index 37be5f9e..c9b499cb 100644 --- a/pkg/conn/conn_test.go +++ b/pkg/conn/conn_test.go @@ -18,165 +18,70 @@ func mustParseURL(s string) *url.URL { return u } -func TestReadInterleavedFrameOrRequest(t *testing.T) { - byts := []byte("DESCRIBE rtsp://example.com/media.mp4 RTSP/1.0\r\n" + - "Accept: application/sdp\r\n" + - "CSeq: 2\r\n" + - "\r\n") - byts = append(byts, []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}...) - - conn := NewConn(bytes.NewBuffer(byts)) - - out, err := conn.ReadInterleavedFrameOrRequest() - require.NoError(t, err) - require.Equal(t, &base.Request{ - Method: base.Describe, - URL: &url.URL{ - Scheme: "rtsp", - Host: "example.com", - Path: "/media.mp4", - }, - Header: base.Header{ - "Accept": base.HeaderValue{"application/sdp"}, - "CSeq": base.HeaderValue{"2"}, - }, - }, out) - - out, err = conn.ReadInterleavedFrameOrRequest() - require.NoError(t, err) - require.Equal(t, &base.InterleavedFrame{ - Channel: 6, - Payload: []byte{0x01, 0x02, 0x03, 0x04}, - }, out) -} - -func TestReadInterleavedFrameOrRequestErrors(t *testing.T) { +func TestRead(t *testing.T) { for _, ca := range []struct { name string - byts []byte - err string + enc []byte + dec interface{} }{ { - "empty", - []byte{}, - "EOF", + "request", + []byte("DESCRIBE rtsp://example.com/media.mp4 RTSP/1.0\r\n" + + "Accept: application/sdp\r\n" + + "CSeq: 2\r\n" + + "\r\n"), + &base.Request{ + Method: base.Describe, + URL: &url.URL{ + Scheme: "rtsp", + Host: "example.com", + Path: "/media.mp4", + }, + Header: base.Header{ + "Accept": base.HeaderValue{"application/sdp"}, + "CSeq": base.HeaderValue{"2"}, + }, + }, }, { - "invalid frame", - []byte{0x24, 0x00}, - "unexpected EOF", + "response", + []byte("RTSP/1.0 200 OK\r\n" + + "CSeq: 1\r\n" + + "Public: DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE\r\n" + + "\r\n"), + &base.Response{ + StatusCode: 200, + StatusMessage: "OK", + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Public": base.HeaderValue{"DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE"}, + }, + }, }, { - "invalid request", - []byte("DESCRIBE"), - "EOF", + "frame", + []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}, + &base.InterleavedFrame{ + Channel: 6, + Payload: []byte{0x01, 0x02, 0x03, 0x04}, + }, }, } { t.Run(ca.name, func(t *testing.T) { - conn := NewConn(bytes.NewBuffer(ca.byts)) - _, err := conn.ReadInterleavedFrameOrRequest() - require.EqualError(t, err, ca.err) + buf := bytes.NewBuffer(ca.enc) + conn := NewConn(buf) + dec, err := conn.Read() + require.NoError(t, err) + require.Equal(t, ca.dec, dec) }) } } -func TestReadInterleavedFrameOrResponse(t *testing.T) { - byts := []byte("RTSP/1.0 200 OK\r\n" + - "CSeq: 1\r\n" + - "Public: DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE\r\n" + - "\r\n") - byts = append(byts, []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}...) - - conn := NewConn(bytes.NewBuffer(byts)) - - out, err := conn.ReadInterleavedFrameOrResponse() - require.NoError(t, err) - require.Equal(t, &base.Response{ - StatusCode: 200, - StatusMessage: "OK", - Header: base.Header{ - "CSeq": base.HeaderValue{"1"}, - "Public": base.HeaderValue{"DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE"}, - }, - }, out) - - out, err = conn.ReadInterleavedFrameOrResponse() - require.NoError(t, err) - require.Equal(t, &base.InterleavedFrame{ - Channel: 6, - Payload: []byte{0x01, 0x02, 0x03, 0x04}, - }, out) -} - -func TestReadInterleavedFrameOrResponseErrors(t *testing.T) { - for _, ca := range []struct { - name string - byts []byte - err string - }{ - { - "empty", - []byte{}, - "EOF", - }, - { - "invalid frame", - []byte{0x24, 0x00}, - "unexpected EOF", - }, - { - "invalid response", - []byte("RTSP/1.0"), - "EOF", - }, - } { - t.Run(ca.name, func(t *testing.T) { - conn := NewConn(bytes.NewBuffer(ca.byts)) - _, err := conn.ReadInterleavedFrameOrResponse() - require.EqualError(t, err, ca.err) - }) - } -} - -func TestReadRequestIgnoreFrames(t *testing.T) { - byts := []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4} - byts = append(byts, []byte("OPTIONS rtsp://example.com/media.mp4 RTSP/1.0\r\n"+ - "CSeq: 1\r\n"+ - "Proxy-Require: gzipped-messages\r\n"+ - "Require: implicit-play\r\n"+ - "\r\n")...) - - conn := NewConn(bytes.NewBuffer(byts)) - _, err := conn.ReadRequestIgnoreFrames() - require.NoError(t, err) -} - -func TestReadRequestIgnoreFramesErrors(t *testing.T) { - byts := []byte{0x25} - - conn := NewConn(bytes.NewBuffer(byts)) - _, err := conn.ReadRequestIgnoreFrames() - require.EqualError(t, err, "EOF") -} - -func TestReadResponseIgnoreFrames(t *testing.T) { - byts := []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4} - byts = append(byts, []byte("RTSP/1.0 200 OK\r\n"+ - "CSeq: 1\r\n"+ - "Public: DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE\r\n"+ - "\r\n")...) - - conn := NewConn(bytes.NewBuffer(byts)) - _, err := conn.ReadResponseIgnoreFrames() - require.NoError(t, err) -} - -func TestReadResponseIgnoreFramesErrors(t *testing.T) { - byts := []byte{0x25} - - conn := NewConn(bytes.NewBuffer(byts)) - _, err := conn.ReadResponseIgnoreFrames() - require.EqualError(t, err, "EOF") +func TestReadError(t *testing.T) { + var buf bytes.Buffer + conn := NewConn(&buf) + _, err := conn.Read() + require.Error(t, err) } func TestWriteRequest(t *testing.T) { diff --git a/pkg/liberrors/client.go b/pkg/liberrors/client.go index 39aa689c..90d94dbd 100644 --- a/pkg/liberrors/client.go +++ b/pkg/liberrors/client.go @@ -196,3 +196,63 @@ type ErrClientRTPInfoInvalid struct { func (e ErrClientRTPInfoInvalid) Error() string { return fmt.Sprintf("invalid RTP-Info: %v", e.Err) } + +// ErrClientUnexpectedFrame is an error that can be returned by a client. +type ErrClientUnexpectedFrame struct{} + +// Error implements the error interface. +func (e ErrClientUnexpectedFrame) Error() string { + return "received unexpected interleaved frame" +} + +// ErrClientRequestTimedOut is an error that can be returned by a client. +type ErrClientRequestTimedOut struct{} + +// Error implements the error interface. +func (e ErrClientRequestTimedOut) Error() string { + return "request timed out" +} + +// ErrClientUnsupportedScheme is an error that can be returned by a client. +type ErrClientUnsupportedScheme struct { + Scheme string +} + +// Error implements the error interface. +func (e ErrClientUnsupportedScheme) Error() string { + return fmt.Sprintf("unsupported scheme: %v", e.Scheme) +} + +// ErrClientRTSPSTCP is an error that can be returned by a client. +type ErrClientRTSPSTCP struct{} + +// Error implements the error interface. +func (e ErrClientRTSPSTCP) Error() string { + return "RTSPS can be used only with TCP" +} + +// ErrClientUnexpectedResponse is an error that can be returned by a client. +type ErrClientUnexpectedResponse struct{} + +// Error implements the error interface. +func (e ErrClientUnexpectedResponse) Error() string { + return "received unexpected response" +} + +// ErrClientMissingCSeq is an error that can be returned by a client. +type ErrClientMissingCSeq struct{} + +// Error implements the error interface. +func (e ErrClientMissingCSeq) Error() string { + return "CSeq is missing" +} + +// ErrClientUnhandledMethod is an error that can be returned by a client. +type ErrClientUnhandledMethod struct { + Method base.Method +} + +// Error implements the error interface. +func (e ErrClientUnhandledMethod) Error() string { + return fmt.Sprintf("unhandled method: %v", e.Method) +} diff --git a/pkg/liberrors/server.go b/pkg/liberrors/server.go index f56b3fb6..274cc371 100644 --- a/pkg/liberrors/server.go +++ b/pkg/liberrors/server.go @@ -251,3 +251,11 @@ type ErrServerUnexpectedFrame struct{} func (e ErrServerUnexpectedFrame) Error() string { return "received unexpected interleaved frame" } + +// ErrServerUnexpectedResponse is an error that can be returned by a client. +type ErrServerUnexpectedResponse struct{} + +// Error implements the error interface. +func (e ErrServerUnexpectedResponse) Error() string { + return "received unexpected response" +} diff --git a/server_conn.go b/server_conn.go index 883bf25b..eac113bf 100644 --- a/server_conn.go +++ b/server_conn.go @@ -74,8 +74,8 @@ type ServerConn struct { session *ServerSession // in - chHandleRequest chan readReq - chReadErr chan error + chReadRequest chan readReq + chReadError chan error chRemoveSession chan *ServerSession // out @@ -99,8 +99,8 @@ func newServerConn( ctx: ctx, ctxCancel: ctxCancel, remoteAddr: nconn.RemoteAddr().(*net.TCPAddr), - chHandleRequest: make(chan readReq), - chReadErr: make(chan error), + chReadRequest: make(chan readReq), + chReadError: make(chan error), chRemoveSession: make(chan *ServerSession), done: make(chan struct{}), } @@ -187,10 +187,10 @@ func (sc *ServerConn) run() { func (sc *ServerConn) runInner() error { for { select { - case req := <-sc.chHandleRequest: + case req := <-sc.chReadRequest: req.res <- sc.handleRequestOuter(req.req) - case err := <-sc.chReadErr: + case err := <-sc.chReadError: return err case ss := <-sc.chRemoveSession: @@ -462,9 +462,9 @@ func (sc *ServerConn) removeSession(ss *ServerSession) { } } -func (sc *ServerConn) handleRequest(req readReq) error { +func (sc *ServerConn) readRequest(req readReq) error { select { - case sc.chHandleRequest <- req: + case sc.chReadRequest <- req: return <-req.res case <-sc.ctx.Done(): @@ -472,9 +472,9 @@ func (sc *ServerConn) handleRequest(req readReq) error { } } -func (sc *ServerConn) readErr(err error) { +func (sc *ServerConn) readError(err error) { select { - case sc.chReadErr <- err: + case sc.chReadError <- err: case <-sc.ctx.Done(): } } diff --git a/server_conn_reader.go b/server_conn_reader.go index aedf5a60..43d6660b 100644 --- a/server_conn_reader.go +++ b/server_conn_reader.go @@ -58,7 +58,7 @@ func (cr *serverConnReader) run() { continue } - cr.sc.readErr(err) + cr.sc.readError(err) break } } @@ -68,7 +68,7 @@ func (cr *serverConnReader) readFuncStandard() error { cr.sc.nconn.SetReadDeadline(time.Time{}) for { - what, err := cr.sc.conn.ReadInterleavedFrameOrRequest() + what, err := cr.sc.conn.Read() if err != nil { return err } @@ -77,12 +77,15 @@ func (cr *serverConnReader) readFuncStandard() error { case *base.Request: cres := make(chan error) req := readReq{req: what, res: cres} - err := cr.sc.handleRequest(req) + err := cr.sc.readRequest(req) if err != nil { return err } - default: + case *base.Response: + return liberrors.ErrServerUnexpectedResponse{} + + case *base.InterleavedFrame: return liberrors.ErrServerUnexpectedFrame{} } } @@ -99,26 +102,29 @@ func (cr *serverConnReader) readFuncTCP() error { cr.sc.nconn.SetReadDeadline(time.Now().Add(cr.sc.s.ReadTimeout)) } - what, err := cr.sc.conn.ReadInterleavedFrameOrRequest() + what, err := cr.sc.conn.Read() if err != nil { return err } - switch twhat := what.(type) { - case *base.InterleavedFrame: - atomic.AddUint64(cr.sc.session.bytesReceived, uint64(len(twhat.Payload))) - - if cb, ok := cr.sc.session.tcpCallbackByChannel[twhat.Channel]; ok { - cb(twhat.Payload) - } - + switch what := what.(type) { case *base.Request: cres := make(chan error) - req := readReq{req: twhat, res: cres} - err := cr.sc.handleRequest(req) + req := readReq{req: what, res: cres} + err := cr.sc.readRequest(req) if err != nil { return err } + + case *base.Response: + return liberrors.ErrServerUnexpectedResponse{} + + case *base.InterleavedFrame: + atomic.AddUint64(cr.sc.session.bytesReceived, uint64(len(what.Payload))) + + if cb, ok := cr.sc.session.tcpCallbackByChannel[what.Channel]; ok { + cb(what.Payload) + } } } } diff --git a/server_play_test.go b/server_play_test.go index fdcc0589..7d41d707 100644 --- a/server_play_test.go +++ b/server_play_test.go @@ -119,7 +119,7 @@ func doSetup(t *testing.T, conn *conn.Conn, u string, return res, &th } -func doPlay(t *testing.T, conn *conn.Conn, u string, session string) { +func doPlay(t *testing.T, conn *conn.Conn, u string, session string) *base.Response { res, err := writeReqReadRes(conn, base.Request{ Method: base.Play, URL: mustParseURL(u), @@ -130,6 +130,7 @@ func doPlay(t *testing.T, conn *conn.Conn, u string, session string) { }) require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) + return res } func doPause(t *testing.T, conn *conn.Conn, u string, session string) { @@ -1887,7 +1888,7 @@ func TestServerPlayAdditionalInfos(t *testing.T) { ssrcs[1] = th.SSRC - doPlay(t, conn, "rtsp://localhost:8554/teststream", session) + res = doPlay(t, conn, "rtsp://localhost:8554/teststream", session) var ri headers.RTPInfo err = ri.Unmarshal(res.Header["RTP-Info"]) diff --git a/server_session.go b/server_session.go index a2efc8e2..f39f5f80 100644 --- a/server_session.go +++ b/server_session.go @@ -102,6 +102,35 @@ func findFirstSupportedTransportHeader(s *Server, tsh headers.Transports) *heade return nil } +func generateRTPInfo( + now time.Time, + setuppedMediasOrdered []*serverSessionMedia, + setuppedStream *ServerStream, + setuppedPath string, + u *url.URL, +) (headers.RTPInfo, bool) { + var ri headers.RTPInfo + + for _, sm := range setuppedMediasOrdered { + entry := setuppedStream.rtpInfoEntry(sm.media, now) + if entry != nil { + entry.URL = (&url.URL{ + Scheme: u.Scheme, + Host: u.Host, + Path: setuppedPath + "/trackID=" + + strconv.FormatInt(int64(setuppedStream.streamMedias[sm.media].trackID), 10), + }).String() + ri = append(ri, entry) + } + } + + if len(ri) == 0 { + return nil, false + } + + return ri, true +} + // ServerSessionState is a state of a ServerSession. type ServerSessionState int @@ -900,26 +929,18 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( // writer.start() is called by ServerConn after the response has been sent } - var ri headers.RTPInfo - now := ss.s.timeNow() + rtpInfo, ok := generateRTPInfo( + ss.s.timeNow(), + ss.setuppedMediasOrdered, + ss.setuppedStream, + *ss.setuppedPath, + req.URL) - for _, sm := range ss.setuppedMediasOrdered { - entry := ss.setuppedStream.rtpInfoEntry(sm.media, now) - if entry != nil { - entry.URL = (&url.URL{ - Scheme: req.URL.Scheme, - Host: req.URL.Host, - Path: *ss.setuppedPath + "/trackID=" + - strconv.FormatInt(int64(ss.setuppedStream.streamMedias[sm.media].trackID), 10), - }).String() - ri = append(ri, entry) - } - } - if len(ri) > 0 { + if ok { if res.Header == nil { res.Header = make(base.Header) } - res.Header["RTP-Info"] = ri.Marshal() + res.Header["RTP-Info"] = rtpInfo.Marshal() } return res, err