diff --git a/client.go b/client.go index 0cc8f50a..8f0339ee 100644 --- a/client.go +++ b/client.go @@ -132,10 +132,16 @@ type Client struct { // // callbacks // - // callback called before every request. + // called before every request. OnRequest func(*base.Request) - // callback called after every response. + // called after every response. OnResponse func(*base.Response) + // called before sending a PLAY request. + OnPlay func(*Client) + // called when a RTP packet arrives. + OnPacketRTP func(*Client, int, []byte) + // called when a RTCP packet arrives. + OnPacketRTCP func(*Client, int, []byte) // // RTSP parameters @@ -210,12 +216,10 @@ type Client struct { lastRange *headers.Range backgroundRunning bool backgroundErr error - tcpFrameBuffer *multibuffer.MultiBuffer // tcp - tcpWriteMutex sync.Mutex // tcp - readCBMutex sync.RWMutex // read - readCB func(int, StreamType, []byte) // read - writeMutex sync.RWMutex // write - writeFrameAllowed bool // write + tcpFrameBuffer *multibuffer.MultiBuffer // tcp + tcpWriteMutex sync.Mutex // tcp + writeMutex sync.RWMutex // write + writeFrameAllowed bool // write // in options chan optionsReq @@ -230,12 +234,21 @@ type Client struct { // out backgroundInnerDone chan error backgroundDone chan struct{} - readCBSet chan struct{} done chan struct{} } // Dial connects to a server. func (c *Client) Dial(scheme string, host string) error { + // callbacks + if c.OnPacketRTP == nil { + c.OnPacketRTP = func(c *Client, trackID int, payload []byte) { + } + } + if c.OnPacketRTCP == nil { + c.OnPacketRTCP = func(c *Client, trackID int, payload []byte) { + } + } + // RTSP parameters if c.ReadTimeout == 0 { c.ReadTimeout = 10 * time.Second @@ -414,14 +427,14 @@ func (c *Client) DialPublishContext(ctx context.Context, address string, tracks return nil } -// Close closes the connection and waits for all its resources to exit. +// Close closes all the client resources and waits for them to exit. func (c *Client) Close() error { c.ctxCancel() <-c.done return nil } -// Tracks returns all the tracks that the connection is reading or publishing. +// Tracks returns all the tracks that the client is reading or publishing. func (c *Client) Tracks() Tracks { ids := make([]int, len(c.tracks)) pos := 0 @@ -534,10 +547,6 @@ func (c *Client) reset(isSwitchingProtocol bool) { c.tracks = nil c.tracksByChannel = nil c.tcpFrameBuffer = nil - - if !isSwitchingProtocol { - c.readCB = nil - } } func (c *Client) checkState(allowed map[clientState]struct{}) error { @@ -590,12 +599,6 @@ func (c *Client) switchProtocolIfTimeout(err error) error { return nil } -func (c *Client) pullReadCB() func(int, StreamType, []byte) { - c.readCBMutex.RLock() - defer c.readCBMutex.RUnlock() - return c.readCB -} - func (c *Client) backgroundStart(isSwitchingProtocol bool) { c.writeMutex.Lock() c.writeFrameAllowed = true @@ -791,10 +794,10 @@ func (c *Client) runBackgroundPlayTCP() error { } channel := frame.Channel - streamType := StreamTypeRTP + isRTP := true if (channel % 2) != 0 { channel-- - streamType = StreamTypeRTCP + isRTP = false } trackID, ok := c.tracksByChannel[channel] @@ -805,13 +808,13 @@ func (c *Client) runBackgroundPlayTCP() error { now := time.Now() atomic.StoreInt64(&lastFrameTime, now.Unix()) - if streamType == StreamTypeRTP { + if isRTP { c.tracks[trackID].rtcpReceiver.ProcessPacketRTP(now, frame.Payload) + c.OnPacketRTP(c, trackID, frame.Payload) } else { c.tracks[trackID].rtcpReceiver.ProcessPacketRTCP(now, frame.Payload) + c.OnPacketRTCP(c, trackID, frame.Payload) } - - c.pullReadCB()(trackID, streamType, frame.Payload) } }() @@ -923,10 +926,10 @@ func (c *Client) runBackgroundRecordTCP() error { } channel := frame.Channel - streamType := StreamTypeRTP + isRTP := true if (channel % 2) != 0 { channel-- - streamType = StreamTypeRTCP + isRTP = false } trackID, ok := c.tracksByChannel[channel] @@ -934,7 +937,9 @@ func (c *Client) runBackgroundRecordTCP() error { continue } - c.pullReadCB()(trackID, streamType, frame.Payload) + if !isRTP { + c.OnPacketRTCP(c, trackID, frame.Payload) + } } }() @@ -1677,6 +1682,10 @@ func (c *Client) doPlay(ra *headers.Range, isSwitchingProtocol bool) (*base.Resp } } + if c.OnPlay != nil { + c.OnPlay(c) + } + header := make(base.Header) // Range is mandatory in Parrot Streaming Server @@ -1707,21 +1716,6 @@ func (c *Client) doPlay(ra *headers.Range, isSwitchingProtocol bool) (*base.Resp c.state = clientStatePlay c.lastRange = ra - if !isSwitchingProtocol { - // use a temporary callback that is replaces as soon as - // the user calls ReadFrames() - c.readCBSet = make(chan struct{}) - copy := c.readCBSet - c.readCB = func(trackID int, streamType StreamType, payload []byte) { - select { - case <-copy: - case <-c.ctx.Done(): - return - } - c.pullReadCB()(trackID, streamType, payload) - } - } - c.backgroundStart(isSwitchingProtocol) return res, nil @@ -1765,11 +1759,6 @@ func (c *Client) doRecord() (*base.Response, error) { c.state = clientStateRecord - // when publishing, calling ReadFrames() is not mandatory - // use an empty callback - c.readCB = func(trackID int, streamType StreamType, payload []byte) { - } - c.backgroundStart(false) return nil, nil @@ -1849,17 +1838,7 @@ func (c *Client) Seek(ra *headers.Range) (*base.Response, error) { } // ReadFrames starts reading frames. -func (c *Client) ReadFrames(onFrame func(int, StreamType, []byte)) error { - c.readCBMutex.Lock() - c.readCB = onFrame - c.readCBMutex.Unlock() - - // replace temporary callback with final callback - if c.readCBSet != nil { - close(c.readCBSet) - c.readCBSet = nil - } - +func (c *Client) ReadFrames() error { <-c.backgroundDone return c.backgroundErr } diff --git a/client_publish_test.go b/client_publish_test.go index 7a48de77..1ab7696a 100644 --- a/client_publish_test.go +++ b/client_publish_test.go @@ -158,6 +158,8 @@ func TestClientPublishSerial(t *testing.T) { require.NoError(t, err) }() + recvDone := make(chan struct{}) + c := &Client{ Transport: func() *Transport { if transport == "udp" { @@ -167,6 +169,11 @@ func TestClientPublishSerial(t *testing.T) { v := TransportTCP return &v }(), + OnPacketRTCP: func(c *Client, trackID int, payload []byte) { + require.Equal(t, 0, trackID) + require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, payload) + close(recvDone) + }, } track, err := NewTrackH264(96, &TrackConfigH264{[]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}}) @@ -176,16 +183,10 @@ func TestClientPublishSerial(t *testing.T) { Tracks{track}) require.NoError(t, err) - recvDone := make(chan struct{}) done := make(chan struct{}) go func() { defer close(done) - c.ReadFrames(func(trackID int, streamType StreamType, payload []byte) { - require.Equal(t, 0, trackID) - require.Equal(t, StreamTypeRTCP, streamType) - require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, payload) - close(recvDone) - }) + c.ReadFrames() }() err = c.WritePacketRTP(0, diff --git a/client_read_test.go b/client_read_test.go index a48a3e09..88880d29 100644 --- a/client_read_test.go +++ b/client_read_test.go @@ -394,6 +394,8 @@ func TestClientRead(t *testing.T) { require.NoError(t, err) }() + counter := uint64(0) + c := &Client{ Transport: func() *Transport { switch transport { @@ -410,17 +412,8 @@ func TestClientRead(t *testing.T) { return &v } }(), - } - - err = c.DialRead(scheme + "://" + listenIP + ":8554/test/stream?param=value") - require.NoError(t, err) - - done := make(chan struct{}) - counter := uint64(0) - go func() { - defer close(done) - c.ReadFrames(func(id int, streamType StreamType, payload []byte) { - // skip multicast loopback + OnPacketRTP: func(c *Client, trackID int, payload []byte) { + // ignore multicast loopback if transport == "multicast" { add := atomic.AddUint64(&counter, 1) if add >= 2 { @@ -428,21 +421,29 @@ func TestClientRead(t *testing.T) { } } - require.Equal(t, 0, id) - require.Equal(t, StreamTypeRTP, streamType) + require.Equal(t, 0, trackID) require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, payload) err = c.WritePacketRTCP(0, []byte{0x05, 0x06, 0x07, 0x08}) require.NoError(t, err) - }) + }, + } + + err = c.DialRead(scheme + "://" + listenIP + ":8554/test/stream?param=value") + require.NoError(t, err) + + done := make(chan struct{}) + + go func() { + defer close(done) + c.ReadFrames() }() <-frameRecv c.Close() <-done - c.ReadFrames(func(id int, typ StreamType, payload []byte) { - }) + c.ReadFrames() }) } } @@ -665,11 +666,18 @@ func TestClientReadPartial(t *testing.T) { require.NoError(t, err) }() + frameRecv := make(chan struct{}) + c := &Client{ Transport: func() *Transport { v := TransportTCP return &v }(), + OnPacketRTP: func(c *Client, trackID int, payload []byte) { + require.Equal(t, 0, trackID) + require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, payload) + close(frameRecv) + }, } u, err := base.ParseURL("rtsp://" + listenIP + ":8554/teststream") @@ -689,15 +697,9 @@ func TestClientReadPartial(t *testing.T) { require.NoError(t, err) done := make(chan struct{}) - frameRecv := make(chan struct{}) go func() { defer close(done) - c.ReadFrames(func(id int, streamType StreamType, payload []byte) { - require.Equal(t, 0, id) - require.Equal(t, StreamTypeRTP, streamType) - require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, payload) - close(frameRecv) - }) + c.ReadFrames() }() <-frameRecv @@ -920,20 +922,22 @@ func TestClientReadAnyPort(t *testing.T) { }) }() + frameRecv := make(chan struct{}) + c := &Client{ AnyPortEnable: true, + OnPacketRTP: func(c *Client, trackID int, payload []byte) { + close(frameRecv) + }, } err = c.DialRead("rtsp://localhost:8554/teststream") require.NoError(t, err) - frameRecv := make(chan struct{}) done := make(chan struct{}) go func() { defer close(done) - c.ReadFrames(func(id int, typ StreamType, payload []byte) { - close(frameRecv) - }) + c.ReadFrames() }() <-frameRecv @@ -1043,18 +1047,21 @@ func TestClientReadAutomaticProtocol(t *testing.T) { require.NoError(t, err) }() - c := Client{} + frameRecv := make(chan struct{}) + + c := Client{ + OnPacketRTP: func(c *Client, trackID int, payload []byte) { + close(frameRecv) + }, + } err = c.DialRead("rtsp://localhost:8554/teststream") require.NoError(t, err) - frameRecv := make(chan struct{}) done := make(chan struct{}) go func() { defer close(done) - c.ReadFrames(func(id int, typ StreamType, payload []byte) { - close(frameRecv) - }) + c.ReadFrames() }() <-frameRecv @@ -1248,20 +1255,22 @@ func TestClientReadAutomaticProtocol(t *testing.T) { conn.Close() }() + frameRecv := make(chan struct{}) + c := &Client{ ReadTimeout: 1 * time.Second, + OnPacketRTP: func(c *Client, trackID int, payload []byte) { + close(frameRecv) + }, } err = c.DialRead("rtsp://myuser:mypass@localhost:8554/teststream") require.NoError(t, err) - frameRecv := make(chan struct{}) done := make(chan struct{}) go func() { defer close(done) - c.ReadFrames(func(id int, typ StreamType, payload []byte) { - close(frameRecv) - }) + c.ReadFrames() }() <-frameRecv @@ -1374,24 +1383,26 @@ func TestClientReadDifferentInterleavedIDs(t *testing.T) { require.NoError(t, err) }() + frameRecv := make(chan struct{}) + c := &Client{ Transport: func() *Transport { v := TransportTCP return &v }(), + OnPacketRTP: func(c *Client, trackID int, payload []byte) { + require.Equal(t, 0, trackID) + close(frameRecv) + }, } err = c.DialRead("rtsp://localhost:8554/teststream") require.NoError(t, err) - frameRecv := make(chan struct{}) done := make(chan struct{}) go func() { defer close(done) - c.ReadFrames(func(id int, typ StreamType, payload []byte) { - require.Equal(t, 0, id) - close(frameRecv) - }) + c.ReadFrames() }() <-frameRecv @@ -1528,18 +1539,21 @@ func TestClientReadRedirect(t *testing.T) { }) }() - c := Client{} + frameRecv := make(chan struct{}) + + c := Client{ + OnPacketRTP: func(c *Client, trackID int, payload []byte) { + close(frameRecv) + }, + } err = c.DialRead("rtsp://localhost:8554/path1") require.NoError(t, err) - frameRecv := make(chan struct{}) done := make(chan struct{}) go func() { defer close(done) - c.ReadFrames(func(id int, typ StreamType, payload []byte) { - close(frameRecv) - }) + c.ReadFrames() }() <-frameRecv @@ -1723,6 +1737,9 @@ func TestClientReadPause(t *testing.T) { require.NoError(t, err) }() + firstFrame := int32(0) + frameRecv := make(chan struct{}) + c := &Client{ Transport: func() *Transport { if transport == "udp" { @@ -1732,21 +1749,20 @@ func TestClientReadPause(t *testing.T) { v := TransportTCP return &v }(), + OnPacketRTP: func(c *Client, trackID int, payload []byte) { + if atomic.SwapInt32(&firstFrame, 1) == 0 { + close(frameRecv) + } + }, } err = c.DialRead("rtsp://localhost:8554/teststream") require.NoError(t, err) - firstFrame := int32(0) - frameRecv := make(chan struct{}) done := make(chan struct{}) go func() { defer close(done) - c.ReadFrames(func(id int, typ StreamType, payload []byte) { - if atomic.SwapInt32(&firstFrame, 1) == 0 { - close(frameRecv) - } - }) + c.ReadFrames() }() <-frameRecv @@ -1754,22 +1770,18 @@ func TestClientReadPause(t *testing.T) { require.NoError(t, err) <-done - c.ReadFrames(func(id int, typ StreamType, payload []byte) { - }) + c.ReadFrames() + + firstFrame = int32(0) + frameRecv = make(chan struct{}) _, err = c.Play(nil) require.NoError(t, err) - firstFrame = int32(0) - frameRecv = make(chan struct{}) done = make(chan struct{}) go func() { defer close(done) - c.ReadFrames(func(id int, typ StreamType, payload []byte) { - if atomic.SwapInt32(&firstFrame, 1) == 0 { - close(frameRecv) - } - }) + c.ReadFrames() }() <-frameRecv @@ -1917,28 +1929,36 @@ func TestClientReadRTCPReport(t *testing.T) { require.NoError(t, err) }() + recv := 0 + recvDone := make(chan struct{}) + c := &Client{ Transport: func() *Transport { v := TransportTCP return &v }(), + OnPacketRTP: func(c *Client, trackID int, payload []byte) { + recv++ + if recv >= 3 { + close(recvDone) + } + }, + OnPacketRTCP: func(c *Client, trackID int, payload []byte) { + recv++ + if recv >= 3 { + close(recvDone) + } + }, receiverReportPeriod: 1 * time.Second, } err = c.DialRead("rtsp://localhost:8554/teststream") require.NoError(t, err) - recv := 0 - recvDone := make(chan struct{}) done := make(chan struct{}) go func() { defer close(done) - c.ReadFrames(func(id int, typ StreamType, payload []byte) { - recv++ - if recv >= 3 { - close(recvDone) - } - }) + c.ReadFrames() }() time.Sleep(1300 * time.Millisecond) @@ -2096,8 +2116,7 @@ func TestClientReadErrorTimeout(t *testing.T) { require.NoError(t, err) defer c.Close() - err = c.ReadFrames(func(trackID int, streamType StreamType, payload []byte) { - }) + err = c.ReadFrames() switch transport { case "udp", "auto": @@ -2216,23 +2235,25 @@ func TestClientReadIgnoreTCPInvalidTrack(t *testing.T) { require.NoError(t, err) }() + recv := make(chan struct{}) + c := &Client{ Transport: func() *Transport { v := TransportTCP return &v }(), + OnPacketRTP: func(c *Client, trackID int, payload []byte) { + close(recv) + }, } err = c.DialRead("rtsp://localhost:8554/teststream") require.NoError(t, err) - recv := make(chan struct{}) done := make(chan struct{}) go func() { defer close(done) - c.ReadFrames(func(trackID int, streamType StreamType, payload []byte) { - close(recv) - }) + c.ReadFrames() }() <-recv diff --git a/clientudpl.go b/clientudpl.go index 81e49822..bff4f2c0 100644 --- a/clientudpl.go +++ b/clientudpl.go @@ -169,11 +169,11 @@ func (l *clientUDPListener) run() { if l.streamType == StreamTypeRTP { l.c.tracks[l.trackID].rtcpReceiver.ProcessPacketRTP(now, buf[:n]) + l.c.OnPacketRTP(l.c, l.trackID, buf[:n]) } else { l.c.tracks[l.trackID].rtcpReceiver.ProcessPacketRTCP(now, buf[:n]) + l.c.OnPacketRTCP(l.c, l.trackID, buf[:n]) } - - l.c.pullReadCB()(l.trackID, l.streamType, buf[:n]) } } else { // record for { @@ -191,7 +191,7 @@ func (l *clientUDPListener) run() { now := time.Now() atomic.StoreInt64(l.lastFrameTime, now.Unix()) - l.c.pullReadCB()(l.trackID, l.streamType, buf[:n]) + l.c.OnPacketRTCP(l.c, l.trackID, buf[:n]) } } } diff --git a/examples/client-read-h264-save-to-disk/main.go b/examples/client-read-h264-save-to-disk/main.go index 8c281cf7..6b737a61 100644 --- a/examples/client-read-h264-save-to-disk/main.go +++ b/examples/client-read-h264-save-to-disk/main.go @@ -25,31 +25,6 @@ const ( ) func main() { - c := gortsplib.Client{} - - // connect to the server and start reading all tracks - err := c.DialRead(inputStream) - if err != nil { - panic(err) - } - defer c.Close() - - // find the H264 track - var h264TrackID int = -1 - var h264Conf *gortsplib.TrackConfigH264 - for i, track := range c.Tracks() { - if track.IsH264() { - h264TrackID = i - h264Conf, err = track.ExtractConfigH264() - if err != nil { - panic(err) - } - } - } - if h264TrackID < 0 { - panic(fmt.Errorf("H264 track not found")) - } - // open output file f, err := os.Create(outputFile) if err != nil { @@ -73,100 +48,128 @@ func main() { }) mux.SetPCRPID(256) - // read packets - err = c.ReadFrames(func(trackID int, streamType gortsplib.StreamType, payload []byte) { - if trackID != h264TrackID { - return - } + var h264TrackID int = -1 + var h264Conf *gortsplib.TrackConfigH264 - if streamType != gortsplib.StreamTypeRTP { - return - } - - // parse RTP packets - var pkt rtp.Packet - err := pkt.Unmarshal(payload) - if err != nil { - return - } - - // decode H264 NALUs from RTP packets - nalus, pts, err := dec.DecodeUntilMarker(&pkt) - if err != nil { - return - } - - if !firstPacketWritten { - firstPacketWritten = true - startPTS = pts - } - - // check whether there's an IDR - idrPresent := func() bool { - for _, nalu := range nalus { - typ := h264.NALUType(nalu[0] & 0x1F) - if typ == h264.NALUTypeIDR { - return true + c := gortsplib.Client{ + // called before sending a PLAY request + OnPlay: func(c *gortsplib.Client) { + // find the H264 track + for i, track := range c.Tracks() { + if track.IsH264() { + h264TrackID = i + var err error + h264Conf, err = track.ExtractConfigH264() + if err != nil { + panic(err) + } } } - return false - }() - - // prepend an AUD. This is required by some players - filteredNALUs := [][]byte{ - {byte(h264.NALUTypeAccessUnitDelimiter), 240}, - } - - for _, nalu := range nalus { - // remove existing SPS, PPS, AUD - typ := h264.NALUType(nalu[0] & 0x1F) - switch typ { - case h264.NALUTypeSPS, h264.NALUTypePPS, h264.NALUTypeAccessUnitDelimiter: - continue + if h264TrackID < 0 { + panic(fmt.Errorf("H264 track not found")) + } + }, + // called when a RTP packet arrives + OnPacketRTP: func(c *gortsplib.Client, trackID int, payload []byte) { + if trackID != h264TrackID { + return } - // add SPS and PPS before every IDR - if typ == h264.NALUTypeIDR { - filteredNALUs = append(filteredNALUs, h264Conf.SPS) - filteredNALUs = append(filteredNALUs, h264Conf.PPS) + // parse RTP packets + var pkt rtp.Packet + err := pkt.Unmarshal(payload) + if err != nil { + return } - filteredNALUs = append(filteredNALUs, nalu) - } + // decode H264 NALUs from RTP packets + nalus, pts, err := dec.DecodeUntilMarker(&pkt) + if err != nil { + return + } - // encode into Annex-B - enc, err := h264.EncodeAnnexB(filteredNALUs) - if err != nil { - panic(err) - } + if !firstPacketWritten { + firstPacketWritten = true + startPTS = pts + } - dts := dtsEst.Feed(pts - startPTS) - pts = pts - startPTS + // check whether there's an IDR + idrPresent := func() bool { + for _, nalu := range nalus { + typ := h264.NALUType(nalu[0] & 0x1F) + if typ == h264.NALUTypeIDR { + return true + } + } + return false + }() - // write TS packet - _, err = mux.WriteData(&astits.MuxerData{ - PID: 256, - AdaptationField: &astits.PacketAdaptationField{ - RandomAccessIndicator: idrPresent, - }, - PES: &astits.PESData{ - Header: &astits.PESHeader{ - OptionalHeader: &astits.PESOptionalHeader{ - MarkerBits: 2, - PTSDTSIndicator: astits.PTSDTSIndicatorBothPresent, - DTS: &astits.ClockReference{Base: int64(dts.Seconds() * 90000)}, - PTS: &astits.ClockReference{Base: int64(pts.Seconds() * 90000)}, - }, - StreamID: 224, // video + // prepend an AUD. This is required by some players + filteredNALUs := [][]byte{ + {byte(h264.NALUTypeAccessUnitDelimiter), 240}, + } + + for _, nalu := range nalus { + // remove existing SPS, PPS, AUD + typ := h264.NALUType(nalu[0] & 0x1F) + switch typ { + case h264.NALUTypeSPS, h264.NALUTypePPS, h264.NALUTypeAccessUnitDelimiter: + continue + } + + // add SPS and PPS before every IDR + if typ == h264.NALUTypeIDR { + filteredNALUs = append(filteredNALUs, h264Conf.SPS) + filteredNALUs = append(filteredNALUs, h264Conf.PPS) + } + + filteredNALUs = append(filteredNALUs, nalu) + } + + // encode into Annex-B + enc, err := h264.EncodeAnnexB(filteredNALUs) + if err != nil { + panic(err) + } + + dts := dtsEst.Feed(pts - startPTS) + pts = pts - startPTS + + // write TS packet + _, err = mux.WriteData(&astits.MuxerData{ + PID: 256, + AdaptationField: &astits.PacketAdaptationField{ + RandomAccessIndicator: idrPresent, }, - Data: enc, - }, - }) - if err != nil { - panic(err) - } + PES: &astits.PESData{ + Header: &astits.PESHeader{ + OptionalHeader: &astits.PESOptionalHeader{ + MarkerBits: 2, + PTSDTSIndicator: astits.PTSDTSIndicatorBothPresent, + DTS: &astits.ClockReference{Base: int64(dts.Seconds() * 90000)}, + PTS: &astits.ClockReference{Base: int64(pts.Seconds() * 90000)}, + }, + StreamID: 224, // video + }, + Data: enc, + }, + }) + if err != nil { + panic(err) + } - fmt.Println("wrote ts packet") - }) + fmt.Println("wrote ts packet") + }, + } + + // connect to the server and start reading all tracks + err = c.DialRead(inputStream) + if err != nil { + panic(err) + } + defer c.Close() + + // read packets + err = c.ReadFrames() panic(err) } diff --git a/examples/client-read-h264/main.go b/examples/client-read-h264/main.go index 3a48b847..9143bc65 100644 --- a/examples/client-read-h264/main.go +++ b/examples/client-read-h264/main.go @@ -14,7 +14,54 @@ import ( // 3. get H264 NALUs of that track func main() { - c := gortsplib.Client{} + var h264Track int + var dec *rtph264.Decoder + + c := gortsplib.Client{ + // called before sending a PLAY request + OnPlay: func(c *gortsplib.Client) { + // find the H264 track + h264Track = func() int { + for i, track := range c.Tracks() { + if track.IsH264() { + return i + } + } + return -1 + }() + if h264Track < 0 { + panic(fmt.Errorf("H264 track not found")) + } + fmt.Printf("H264 track is number %d\n", h264Track+1) + + // instantiate a RTP/H264 decoder + dec = rtph264.NewDecoder() + }, + // called when a RTP packet arrives + OnPacketRTP: func(c *gortsplib.Client, trackID int, payload []byte) { + if trackID != h264Track { + return + } + + // parse RTP packets + var pkt rtp.Packet + err := pkt.Unmarshal(payload) + if err != nil { + return + } + + // decode H264 NALUs from RTP packets + nalus, _, err := dec.Decode(&pkt) + if err != nil { + return + } + + // print NALUs + for _, nalu := range nalus { + fmt.Printf("received H264 NALU of size %d\n", len(nalu)) + } + }, + } // connect to the server and start reading all tracks err := c.DialRead("rtsp://localhost:8554/mystream") @@ -23,50 +70,7 @@ func main() { } defer c.Close() - // find the H264 track - h264Track := func() int { - for i, track := range c.Tracks() { - if track.IsH264() { - return i - } - } - return -1 - }() - if h264Track < 0 { - panic(fmt.Errorf("H264 track not found")) - } - fmt.Printf("H264 track is number %d\n", h264Track+1) - - // instantiate a RTP/H264 decoder - dec := rtph264.NewDecoder() - // read packets - err = c.ReadFrames(func(trackID int, streamType gortsplib.StreamType, payload []byte) { - if streamType != gortsplib.StreamTypeRTP { - return - } - - if trackID != h264Track { - return - } - - // parse RTP packets - var pkt rtp.Packet - err := pkt.Unmarshal(payload) - if err != nil { - return - } - - // decode H264 NALUs from RTP packets - nalus, _, err := dec.Decode(&pkt) - if err != nil { - return - } - - // print NALUs - for _, nalu := range nalus { - fmt.Printf("received H264 NALU of size %d\n", len(nalu)) - } - }) + err = c.ReadFrames() panic(err) } diff --git a/examples/client-read-options/main.go b/examples/client-read-options/main.go index a094d555..b71627aa 100644 --- a/examples/client-read-options/main.go +++ b/examples/client-read-options/main.go @@ -20,6 +20,14 @@ func main() { ReadTimeout: 10 * time.Second, // timeout of write operations WriteTimeout: 10 * time.Second, + // called when a RTP packet arrives + OnPacketRTP: func(c *gortsplib.Client, trackID int, payload []byte) { + fmt.Printf("RTP packet from track %d, size %d\n", trackID, len(payload)) + }, + // called when a RTCP packet arrives + OnPacketRTCP: func(c *gortsplib.Client, trackID int, payload []byte) { + fmt.Printf("RTCP packet from track %d, size %d\n", trackID, len(payload)) + }, } // connect to the server and start reading all tracks @@ -30,8 +38,6 @@ func main() { defer c.Close() // read packets - err = c.ReadFrames(func(trackID int, streamType gortsplib.StreamType, payload []byte) { - fmt.Printf("packet from track %d, type %v, size %d\n", trackID, streamType, len(payload)) - }) + err = c.ReadFrames() panic(err) } diff --git a/examples/client-read-partial/main.go b/examples/client-read-partial/main.go index a8cf7a95..7ce65338 100644 --- a/examples/client-read-partial/main.go +++ b/examples/client-read-partial/main.go @@ -19,7 +19,16 @@ func main() { panic(err) } - c := gortsplib.Client{} + c := gortsplib.Client{ + // called when a RTP packet arrives + OnPacketRTP: func(c *gortsplib.Client, trackID int, payload []byte) { + fmt.Printf("RTP packet from track %d, size %d\n", trackID, len(payload)) + }, + // called when a RTCP packet arrives + OnPacketRTCP: func(c *gortsplib.Client, trackID int, payload []byte) { + fmt.Printf("RTCP packet from track %d, size %d\n", trackID, len(payload)) + }, + } err = c.Dial(u.Scheme, u.Host) if err != nil { @@ -54,8 +63,6 @@ func main() { } // read packets - err = c.ReadFrames(func(trackID int, streamType gortsplib.StreamType, payload []byte) { - fmt.Printf("packet from track %d, type %v, size %d\n", trackID, streamType, len(payload)) - }) + err = c.ReadFrames() panic(err) } diff --git a/examples/client-read-pause/main.go b/examples/client-read-pause/main.go index c3e7f093..4895e3d9 100644 --- a/examples/client-read-pause/main.go +++ b/examples/client-read-pause/main.go @@ -14,7 +14,16 @@ import ( // 4. repeat func main() { - c := gortsplib.Client{} + c := gortsplib.Client{ + // called when a RTP packet arrives + OnPacketRTP: func(c *gortsplib.Client, trackID int, payload []byte) { + fmt.Printf("RTP packet from track %d, size %d\n", trackID, len(payload)) + }, + // called when a RTCP packet arrives + OnPacketRTCP: func(c *gortsplib.Client, trackID int, payload []byte) { + fmt.Printf("RTCP packet from track %d, size %d\n", trackID, len(payload)) + }, + } // connect to the server and start reading all tracks err := c.DialRead("rtsp://localhost:8554/mystream") @@ -28,9 +37,7 @@ func main() { done := make(chan struct{}) go func() { defer close(done) - c.ReadFrames(func(trackID int, streamType gortsplib.StreamType, payload []byte) { - fmt.Printf("packet from track %d, type %v, size %d\n", trackID, streamType, len(payload)) - }) + c.ReadFrames() }() // wait diff --git a/examples/client-read/main.go b/examples/client-read/main.go index 58dd16a1..4b43a6a8 100644 --- a/examples/client-read/main.go +++ b/examples/client-read/main.go @@ -10,7 +10,16 @@ import ( // 1. connect to a RTSP server and read all tracks on a path func main() { - c := gortsplib.Client{} + c := gortsplib.Client{ + // called when a RTP packet arrives + OnPacketRTP: func(c *gortsplib.Client, trackID int, payload []byte) { + fmt.Printf("RTP packet from track %d, size %d\n", trackID, len(payload)) + }, + // called when a RTCP packet arrives + OnPacketRTCP: func(c *gortsplib.Client, trackID int, payload []byte) { + fmt.Printf("RTCP packet from track %d, size %d\n", trackID, len(payload)) + }, + } // connect to the server and start reading all tracks err := c.DialRead("rtsp://localhost:8554/mystream") @@ -20,8 +29,6 @@ func main() { defer c.Close() // read packets - err = c.ReadFrames(func(trackID int, streamType gortsplib.StreamType, payload []byte) { - fmt.Printf("packet from track %d, type %v, size %d\n", trackID, streamType, len(payload)) - }) + err = c.ReadFrames() panic(err) } diff --git a/examples/server-tls/main.go b/examples/server-tls/main.go index f8678d19..cd484f42 100644 --- a/examples/server-tls/main.go +++ b/examples/server-tls/main.go @@ -129,7 +129,7 @@ func (sh *serverHandler) OnPacketRTP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx) sh.mutex.Lock() defer sh.mutex.Unlock() - // if we are the publisher, route packet to readers + // if we are the publisher, route the RTP packet to readers if ctx.Session == sh.publisher { sh.stream.WritePacketRTP(ctx.TrackID, ctx.Payload) } @@ -140,7 +140,7 @@ func (sh *serverHandler) OnPacketRTCP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx sh.mutex.Lock() defer sh.mutex.Unlock() - // if we are the publisher, route packet to readers + // if we are the publisher, route the RTCP packet to readers if ctx.Session == sh.publisher { sh.stream.WritePacketRTCP(ctx.TrackID, ctx.Payload) } diff --git a/examples/server/main.go b/examples/server/main.go index 1e8eee79..8867e6bc 100644 --- a/examples/server/main.go +++ b/examples/server/main.go @@ -128,7 +128,7 @@ func (sh *serverHandler) OnPacketRTP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx) sh.mutex.Lock() defer sh.mutex.Unlock() - // if we are the publisher, route packet to readers + // if we are the publisher, route the RTP packet to readers if ctx.Session == sh.publisher { sh.stream.WritePacketRTP(ctx.TrackID, ctx.Payload) } @@ -139,7 +139,7 @@ func (sh *serverHandler) OnPacketRTCP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx sh.mutex.Lock() defer sh.mutex.Unlock() - // if we are the publisher, route packet to readers + // if we are the publisher, route the RTCP packet to readers if ctx.Session == sh.publisher { sh.stream.WritePacketRTCP(ctx.TrackID, ctx.Payload) }