client: turn ReadFrames into OnPacketRTP, OnPacketRTCP

This commit is contained in:
aler9
2021-11-04 11:05:51 +01:00
committed by Alessandro Ros
parent b4aadd8e4c
commit a22116e66e
12 changed files with 353 additions and 318 deletions

View File

@@ -132,10 +132,16 @@ type Client struct {
// //
// callbacks // callbacks
// //
// callback called before every request. // called before every request.
OnRequest func(*base.Request) OnRequest func(*base.Request)
// callback called after every response. // called after every response.
OnResponse func(*base.Response) OnResponse func(*base.Response)
// called before sending a PLAY request.
OnPlay func(*Client)
// called when a RTP packet arrives.
OnPacketRTP func(*Client, int, []byte)
// called when a RTCP packet arrives.
OnPacketRTCP func(*Client, int, []byte)
// //
// RTSP parameters // RTSP parameters
@@ -210,12 +216,10 @@ type Client struct {
lastRange *headers.Range lastRange *headers.Range
backgroundRunning bool backgroundRunning bool
backgroundErr error backgroundErr error
tcpFrameBuffer *multibuffer.MultiBuffer // tcp tcpFrameBuffer *multibuffer.MultiBuffer // tcp
tcpWriteMutex sync.Mutex // tcp tcpWriteMutex sync.Mutex // tcp
readCBMutex sync.RWMutex // read writeMutex sync.RWMutex // write
readCB func(int, StreamType, []byte) // read writeFrameAllowed bool // write
writeMutex sync.RWMutex // write
writeFrameAllowed bool // write
// in // in
options chan optionsReq options chan optionsReq
@@ -230,12 +234,21 @@ type Client struct {
// out // out
backgroundInnerDone chan error backgroundInnerDone chan error
backgroundDone chan struct{} backgroundDone chan struct{}
readCBSet chan struct{}
done chan struct{} done chan struct{}
} }
// Dial connects to a server. // Dial connects to a server.
func (c *Client) Dial(scheme string, host string) error { func (c *Client) Dial(scheme string, host string) error {
// callbacks
if c.OnPacketRTP == nil {
c.OnPacketRTP = func(c *Client, trackID int, payload []byte) {
}
}
if c.OnPacketRTCP == nil {
c.OnPacketRTCP = func(c *Client, trackID int, payload []byte) {
}
}
// RTSP parameters // RTSP parameters
if c.ReadTimeout == 0 { if c.ReadTimeout == 0 {
c.ReadTimeout = 10 * time.Second c.ReadTimeout = 10 * time.Second
@@ -414,14 +427,14 @@ func (c *Client) DialPublishContext(ctx context.Context, address string, tracks
return nil return nil
} }
// Close closes the connection and waits for all its resources to exit. // Close closes all the client resources and waits for them to exit.
func (c *Client) Close() error { func (c *Client) Close() error {
c.ctxCancel() c.ctxCancel()
<-c.done <-c.done
return nil return nil
} }
// Tracks returns all the tracks that the connection is reading or publishing. // Tracks returns all the tracks that the client is reading or publishing.
func (c *Client) Tracks() Tracks { func (c *Client) Tracks() Tracks {
ids := make([]int, len(c.tracks)) ids := make([]int, len(c.tracks))
pos := 0 pos := 0
@@ -534,10 +547,6 @@ func (c *Client) reset(isSwitchingProtocol bool) {
c.tracks = nil c.tracks = nil
c.tracksByChannel = nil c.tracksByChannel = nil
c.tcpFrameBuffer = nil c.tcpFrameBuffer = nil
if !isSwitchingProtocol {
c.readCB = nil
}
} }
func (c *Client) checkState(allowed map[clientState]struct{}) error { func (c *Client) checkState(allowed map[clientState]struct{}) error {
@@ -590,12 +599,6 @@ func (c *Client) switchProtocolIfTimeout(err error) error {
return nil return nil
} }
func (c *Client) pullReadCB() func(int, StreamType, []byte) {
c.readCBMutex.RLock()
defer c.readCBMutex.RUnlock()
return c.readCB
}
func (c *Client) backgroundStart(isSwitchingProtocol bool) { func (c *Client) backgroundStart(isSwitchingProtocol bool) {
c.writeMutex.Lock() c.writeMutex.Lock()
c.writeFrameAllowed = true c.writeFrameAllowed = true
@@ -791,10 +794,10 @@ func (c *Client) runBackgroundPlayTCP() error {
} }
channel := frame.Channel channel := frame.Channel
streamType := StreamTypeRTP isRTP := true
if (channel % 2) != 0 { if (channel % 2) != 0 {
channel-- channel--
streamType = StreamTypeRTCP isRTP = false
} }
trackID, ok := c.tracksByChannel[channel] trackID, ok := c.tracksByChannel[channel]
@@ -805,13 +808,13 @@ func (c *Client) runBackgroundPlayTCP() error {
now := time.Now() now := time.Now()
atomic.StoreInt64(&lastFrameTime, now.Unix()) atomic.StoreInt64(&lastFrameTime, now.Unix())
if streamType == StreamTypeRTP { if isRTP {
c.tracks[trackID].rtcpReceiver.ProcessPacketRTP(now, frame.Payload) c.tracks[trackID].rtcpReceiver.ProcessPacketRTP(now, frame.Payload)
c.OnPacketRTP(c, trackID, frame.Payload)
} else { } else {
c.tracks[trackID].rtcpReceiver.ProcessPacketRTCP(now, frame.Payload) c.tracks[trackID].rtcpReceiver.ProcessPacketRTCP(now, frame.Payload)
c.OnPacketRTCP(c, trackID, frame.Payload)
} }
c.pullReadCB()(trackID, streamType, frame.Payload)
} }
}() }()
@@ -923,10 +926,10 @@ func (c *Client) runBackgroundRecordTCP() error {
} }
channel := frame.Channel channel := frame.Channel
streamType := StreamTypeRTP isRTP := true
if (channel % 2) != 0 { if (channel % 2) != 0 {
channel-- channel--
streamType = StreamTypeRTCP isRTP = false
} }
trackID, ok := c.tracksByChannel[channel] trackID, ok := c.tracksByChannel[channel]
@@ -934,7 +937,9 @@ func (c *Client) runBackgroundRecordTCP() error {
continue continue
} }
c.pullReadCB()(trackID, streamType, frame.Payload) if !isRTP {
c.OnPacketRTCP(c, trackID, frame.Payload)
}
} }
}() }()
@@ -1677,6 +1682,10 @@ func (c *Client) doPlay(ra *headers.Range, isSwitchingProtocol bool) (*base.Resp
} }
} }
if c.OnPlay != nil {
c.OnPlay(c)
}
header := make(base.Header) header := make(base.Header)
// Range is mandatory in Parrot Streaming Server // Range is mandatory in Parrot Streaming Server
@@ -1707,21 +1716,6 @@ func (c *Client) doPlay(ra *headers.Range, isSwitchingProtocol bool) (*base.Resp
c.state = clientStatePlay c.state = clientStatePlay
c.lastRange = ra c.lastRange = ra
if !isSwitchingProtocol {
// use a temporary callback that is replaces as soon as
// the user calls ReadFrames()
c.readCBSet = make(chan struct{})
copy := c.readCBSet
c.readCB = func(trackID int, streamType StreamType, payload []byte) {
select {
case <-copy:
case <-c.ctx.Done():
return
}
c.pullReadCB()(trackID, streamType, payload)
}
}
c.backgroundStart(isSwitchingProtocol) c.backgroundStart(isSwitchingProtocol)
return res, nil return res, nil
@@ -1765,11 +1759,6 @@ func (c *Client) doRecord() (*base.Response, error) {
c.state = clientStateRecord c.state = clientStateRecord
// when publishing, calling ReadFrames() is not mandatory
// use an empty callback
c.readCB = func(trackID int, streamType StreamType, payload []byte) {
}
c.backgroundStart(false) c.backgroundStart(false)
return nil, nil return nil, nil
@@ -1849,17 +1838,7 @@ func (c *Client) Seek(ra *headers.Range) (*base.Response, error) {
} }
// ReadFrames starts reading frames. // ReadFrames starts reading frames.
func (c *Client) ReadFrames(onFrame func(int, StreamType, []byte)) error { func (c *Client) ReadFrames() error {
c.readCBMutex.Lock()
c.readCB = onFrame
c.readCBMutex.Unlock()
// replace temporary callback with final callback
if c.readCBSet != nil {
close(c.readCBSet)
c.readCBSet = nil
}
<-c.backgroundDone <-c.backgroundDone
return c.backgroundErr return c.backgroundErr
} }

View File

@@ -158,6 +158,8 @@ func TestClientPublishSerial(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
}() }()
recvDone := make(chan struct{})
c := &Client{ c := &Client{
Transport: func() *Transport { Transport: func() *Transport {
if transport == "udp" { if transport == "udp" {
@@ -167,6 +169,11 @@ func TestClientPublishSerial(t *testing.T) {
v := TransportTCP v := TransportTCP
return &v return &v
}(), }(),
OnPacketRTCP: func(c *Client, trackID int, payload []byte) {
require.Equal(t, 0, trackID)
require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, payload)
close(recvDone)
},
} }
track, err := NewTrackH264(96, &TrackConfigH264{[]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}}) track, err := NewTrackH264(96, &TrackConfigH264{[]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}})
@@ -176,16 +183,10 @@ func TestClientPublishSerial(t *testing.T) {
Tracks{track}) Tracks{track})
require.NoError(t, err) require.NoError(t, err)
recvDone := make(chan struct{})
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer close(done) defer close(done)
c.ReadFrames(func(trackID int, streamType StreamType, payload []byte) { c.ReadFrames()
require.Equal(t, 0, trackID)
require.Equal(t, StreamTypeRTCP, streamType)
require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, payload)
close(recvDone)
})
}() }()
err = c.WritePacketRTP(0, err = c.WritePacketRTP(0,

View File

@@ -394,6 +394,8 @@ func TestClientRead(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
}() }()
counter := uint64(0)
c := &Client{ c := &Client{
Transport: func() *Transport { Transport: func() *Transport {
switch transport { switch transport {
@@ -410,17 +412,8 @@ func TestClientRead(t *testing.T) {
return &v return &v
} }
}(), }(),
} OnPacketRTP: func(c *Client, trackID int, payload []byte) {
// ignore multicast loopback
err = c.DialRead(scheme + "://" + listenIP + ":8554/test/stream?param=value")
require.NoError(t, err)
done := make(chan struct{})
counter := uint64(0)
go func() {
defer close(done)
c.ReadFrames(func(id int, streamType StreamType, payload []byte) {
// skip multicast loopback
if transport == "multicast" { if transport == "multicast" {
add := atomic.AddUint64(&counter, 1) add := atomic.AddUint64(&counter, 1)
if add >= 2 { if add >= 2 {
@@ -428,21 +421,29 @@ func TestClientRead(t *testing.T) {
} }
} }
require.Equal(t, 0, id) require.Equal(t, 0, trackID)
require.Equal(t, StreamTypeRTP, streamType)
require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, payload) require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, payload)
err = c.WritePacketRTCP(0, []byte{0x05, 0x06, 0x07, 0x08}) err = c.WritePacketRTCP(0, []byte{0x05, 0x06, 0x07, 0x08})
require.NoError(t, err) require.NoError(t, err)
}) },
}
err = c.DialRead(scheme + "://" + listenIP + ":8554/test/stream?param=value")
require.NoError(t, err)
done := make(chan struct{})
go func() {
defer close(done)
c.ReadFrames()
}() }()
<-frameRecv <-frameRecv
c.Close() c.Close()
<-done <-done
c.ReadFrames(func(id int, typ StreamType, payload []byte) { c.ReadFrames()
})
}) })
} }
} }
@@ -665,11 +666,18 @@ func TestClientReadPartial(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
}() }()
frameRecv := make(chan struct{})
c := &Client{ c := &Client{
Transport: func() *Transport { Transport: func() *Transport {
v := TransportTCP v := TransportTCP
return &v return &v
}(), }(),
OnPacketRTP: func(c *Client, trackID int, payload []byte) {
require.Equal(t, 0, trackID)
require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, payload)
close(frameRecv)
},
} }
u, err := base.ParseURL("rtsp://" + listenIP + ":8554/teststream") u, err := base.ParseURL("rtsp://" + listenIP + ":8554/teststream")
@@ -689,15 +697,9 @@ func TestClientReadPartial(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
done := make(chan struct{}) done := make(chan struct{})
frameRecv := make(chan struct{})
go func() { go func() {
defer close(done) defer close(done)
c.ReadFrames(func(id int, streamType StreamType, payload []byte) { c.ReadFrames()
require.Equal(t, 0, id)
require.Equal(t, StreamTypeRTP, streamType)
require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, payload)
close(frameRecv)
})
}() }()
<-frameRecv <-frameRecv
@@ -920,20 +922,22 @@ func TestClientReadAnyPort(t *testing.T) {
}) })
}() }()
frameRecv := make(chan struct{})
c := &Client{ c := &Client{
AnyPortEnable: true, AnyPortEnable: true,
OnPacketRTP: func(c *Client, trackID int, payload []byte) {
close(frameRecv)
},
} }
err = c.DialRead("rtsp://localhost:8554/teststream") err = c.DialRead("rtsp://localhost:8554/teststream")
require.NoError(t, err) require.NoError(t, err)
frameRecv := make(chan struct{})
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer close(done) defer close(done)
c.ReadFrames(func(id int, typ StreamType, payload []byte) { c.ReadFrames()
close(frameRecv)
})
}() }()
<-frameRecv <-frameRecv
@@ -1043,18 +1047,21 @@ func TestClientReadAutomaticProtocol(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
}() }()
c := Client{} frameRecv := make(chan struct{})
c := Client{
OnPacketRTP: func(c *Client, trackID int, payload []byte) {
close(frameRecv)
},
}
err = c.DialRead("rtsp://localhost:8554/teststream") err = c.DialRead("rtsp://localhost:8554/teststream")
require.NoError(t, err) require.NoError(t, err)
frameRecv := make(chan struct{})
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer close(done) defer close(done)
c.ReadFrames(func(id int, typ StreamType, payload []byte) { c.ReadFrames()
close(frameRecv)
})
}() }()
<-frameRecv <-frameRecv
@@ -1248,20 +1255,22 @@ func TestClientReadAutomaticProtocol(t *testing.T) {
conn.Close() conn.Close()
}() }()
frameRecv := make(chan struct{})
c := &Client{ c := &Client{
ReadTimeout: 1 * time.Second, ReadTimeout: 1 * time.Second,
OnPacketRTP: func(c *Client, trackID int, payload []byte) {
close(frameRecv)
},
} }
err = c.DialRead("rtsp://myuser:mypass@localhost:8554/teststream") err = c.DialRead("rtsp://myuser:mypass@localhost:8554/teststream")
require.NoError(t, err) require.NoError(t, err)
frameRecv := make(chan struct{})
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer close(done) defer close(done)
c.ReadFrames(func(id int, typ StreamType, payload []byte) { c.ReadFrames()
close(frameRecv)
})
}() }()
<-frameRecv <-frameRecv
@@ -1374,24 +1383,26 @@ func TestClientReadDifferentInterleavedIDs(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
}() }()
frameRecv := make(chan struct{})
c := &Client{ c := &Client{
Transport: func() *Transport { Transport: func() *Transport {
v := TransportTCP v := TransportTCP
return &v return &v
}(), }(),
OnPacketRTP: func(c *Client, trackID int, payload []byte) {
require.Equal(t, 0, trackID)
close(frameRecv)
},
} }
err = c.DialRead("rtsp://localhost:8554/teststream") err = c.DialRead("rtsp://localhost:8554/teststream")
require.NoError(t, err) require.NoError(t, err)
frameRecv := make(chan struct{})
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer close(done) defer close(done)
c.ReadFrames(func(id int, typ StreamType, payload []byte) { c.ReadFrames()
require.Equal(t, 0, id)
close(frameRecv)
})
}() }()
<-frameRecv <-frameRecv
@@ -1528,18 +1539,21 @@ func TestClientReadRedirect(t *testing.T) {
}) })
}() }()
c := Client{} frameRecv := make(chan struct{})
c := Client{
OnPacketRTP: func(c *Client, trackID int, payload []byte) {
close(frameRecv)
},
}
err = c.DialRead("rtsp://localhost:8554/path1") err = c.DialRead("rtsp://localhost:8554/path1")
require.NoError(t, err) require.NoError(t, err)
frameRecv := make(chan struct{})
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer close(done) defer close(done)
c.ReadFrames(func(id int, typ StreamType, payload []byte) { c.ReadFrames()
close(frameRecv)
})
}() }()
<-frameRecv <-frameRecv
@@ -1723,6 +1737,9 @@ func TestClientReadPause(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
}() }()
firstFrame := int32(0)
frameRecv := make(chan struct{})
c := &Client{ c := &Client{
Transport: func() *Transport { Transport: func() *Transport {
if transport == "udp" { if transport == "udp" {
@@ -1732,21 +1749,20 @@ func TestClientReadPause(t *testing.T) {
v := TransportTCP v := TransportTCP
return &v return &v
}(), }(),
OnPacketRTP: func(c *Client, trackID int, payload []byte) {
if atomic.SwapInt32(&firstFrame, 1) == 0 {
close(frameRecv)
}
},
} }
err = c.DialRead("rtsp://localhost:8554/teststream") err = c.DialRead("rtsp://localhost:8554/teststream")
require.NoError(t, err) require.NoError(t, err)
firstFrame := int32(0)
frameRecv := make(chan struct{})
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer close(done) defer close(done)
c.ReadFrames(func(id int, typ StreamType, payload []byte) { c.ReadFrames()
if atomic.SwapInt32(&firstFrame, 1) == 0 {
close(frameRecv)
}
})
}() }()
<-frameRecv <-frameRecv
@@ -1754,22 +1770,18 @@ func TestClientReadPause(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
<-done <-done
c.ReadFrames(func(id int, typ StreamType, payload []byte) { c.ReadFrames()
})
firstFrame = int32(0)
frameRecv = make(chan struct{})
_, err = c.Play(nil) _, err = c.Play(nil)
require.NoError(t, err) require.NoError(t, err)
firstFrame = int32(0)
frameRecv = make(chan struct{})
done = make(chan struct{}) done = make(chan struct{})
go func() { go func() {
defer close(done) defer close(done)
c.ReadFrames(func(id int, typ StreamType, payload []byte) { c.ReadFrames()
if atomic.SwapInt32(&firstFrame, 1) == 0 {
close(frameRecv)
}
})
}() }()
<-frameRecv <-frameRecv
@@ -1917,28 +1929,36 @@ func TestClientReadRTCPReport(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
}() }()
recv := 0
recvDone := make(chan struct{})
c := &Client{ c := &Client{
Transport: func() *Transport { Transport: func() *Transport {
v := TransportTCP v := TransportTCP
return &v return &v
}(), }(),
OnPacketRTP: func(c *Client, trackID int, payload []byte) {
recv++
if recv >= 3 {
close(recvDone)
}
},
OnPacketRTCP: func(c *Client, trackID int, payload []byte) {
recv++
if recv >= 3 {
close(recvDone)
}
},
receiverReportPeriod: 1 * time.Second, receiverReportPeriod: 1 * time.Second,
} }
err = c.DialRead("rtsp://localhost:8554/teststream") err = c.DialRead("rtsp://localhost:8554/teststream")
require.NoError(t, err) require.NoError(t, err)
recv := 0
recvDone := make(chan struct{})
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer close(done) defer close(done)
c.ReadFrames(func(id int, typ StreamType, payload []byte) { c.ReadFrames()
recv++
if recv >= 3 {
close(recvDone)
}
})
}() }()
time.Sleep(1300 * time.Millisecond) time.Sleep(1300 * time.Millisecond)
@@ -2096,8 +2116,7 @@ func TestClientReadErrorTimeout(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer c.Close() defer c.Close()
err = c.ReadFrames(func(trackID int, streamType StreamType, payload []byte) { err = c.ReadFrames()
})
switch transport { switch transport {
case "udp", "auto": case "udp", "auto":
@@ -2216,23 +2235,25 @@ func TestClientReadIgnoreTCPInvalidTrack(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
}() }()
recv := make(chan struct{})
c := &Client{ c := &Client{
Transport: func() *Transport { Transport: func() *Transport {
v := TransportTCP v := TransportTCP
return &v return &v
}(), }(),
OnPacketRTP: func(c *Client, trackID int, payload []byte) {
close(recv)
},
} }
err = c.DialRead("rtsp://localhost:8554/teststream") err = c.DialRead("rtsp://localhost:8554/teststream")
require.NoError(t, err) require.NoError(t, err)
recv := make(chan struct{})
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer close(done) defer close(done)
c.ReadFrames(func(trackID int, streamType StreamType, payload []byte) { c.ReadFrames()
close(recv)
})
}() }()
<-recv <-recv

View File

@@ -169,11 +169,11 @@ func (l *clientUDPListener) run() {
if l.streamType == StreamTypeRTP { if l.streamType == StreamTypeRTP {
l.c.tracks[l.trackID].rtcpReceiver.ProcessPacketRTP(now, buf[:n]) l.c.tracks[l.trackID].rtcpReceiver.ProcessPacketRTP(now, buf[:n])
l.c.OnPacketRTP(l.c, l.trackID, buf[:n])
} else { } else {
l.c.tracks[l.trackID].rtcpReceiver.ProcessPacketRTCP(now, buf[:n]) l.c.tracks[l.trackID].rtcpReceiver.ProcessPacketRTCP(now, buf[:n])
l.c.OnPacketRTCP(l.c, l.trackID, buf[:n])
} }
l.c.pullReadCB()(l.trackID, l.streamType, buf[:n])
} }
} else { // record } else { // record
for { for {
@@ -191,7 +191,7 @@ func (l *clientUDPListener) run() {
now := time.Now() now := time.Now()
atomic.StoreInt64(l.lastFrameTime, now.Unix()) atomic.StoreInt64(l.lastFrameTime, now.Unix())
l.c.pullReadCB()(l.trackID, l.streamType, buf[:n]) l.c.OnPacketRTCP(l.c, l.trackID, buf[:n])
} }
} }
} }

View File

@@ -25,31 +25,6 @@ const (
) )
func main() { func main() {
c := gortsplib.Client{}
// connect to the server and start reading all tracks
err := c.DialRead(inputStream)
if err != nil {
panic(err)
}
defer c.Close()
// find the H264 track
var h264TrackID int = -1
var h264Conf *gortsplib.TrackConfigH264
for i, track := range c.Tracks() {
if track.IsH264() {
h264TrackID = i
h264Conf, err = track.ExtractConfigH264()
if err != nil {
panic(err)
}
}
}
if h264TrackID < 0 {
panic(fmt.Errorf("H264 track not found"))
}
// open output file // open output file
f, err := os.Create(outputFile) f, err := os.Create(outputFile)
if err != nil { if err != nil {
@@ -73,100 +48,128 @@ func main() {
}) })
mux.SetPCRPID(256) mux.SetPCRPID(256)
// read packets var h264TrackID int = -1
err = c.ReadFrames(func(trackID int, streamType gortsplib.StreamType, payload []byte) { var h264Conf *gortsplib.TrackConfigH264
if trackID != h264TrackID {
return
}
if streamType != gortsplib.StreamTypeRTP { c := gortsplib.Client{
return // called before sending a PLAY request
} OnPlay: func(c *gortsplib.Client) {
// find the H264 track
// parse RTP packets for i, track := range c.Tracks() {
var pkt rtp.Packet if track.IsH264() {
err := pkt.Unmarshal(payload) h264TrackID = i
if err != nil { var err error
return h264Conf, err = track.ExtractConfigH264()
} if err != nil {
panic(err)
// decode H264 NALUs from RTP packets }
nalus, pts, err := dec.DecodeUntilMarker(&pkt)
if err != nil {
return
}
if !firstPacketWritten {
firstPacketWritten = true
startPTS = pts
}
// check whether there's an IDR
idrPresent := func() bool {
for _, nalu := range nalus {
typ := h264.NALUType(nalu[0] & 0x1F)
if typ == h264.NALUTypeIDR {
return true
} }
} }
return false if h264TrackID < 0 {
}() panic(fmt.Errorf("H264 track not found"))
}
// prepend an AUD. This is required by some players },
filteredNALUs := [][]byte{ // called when a RTP packet arrives
{byte(h264.NALUTypeAccessUnitDelimiter), 240}, OnPacketRTP: func(c *gortsplib.Client, trackID int, payload []byte) {
} if trackID != h264TrackID {
return
for _, nalu := range nalus {
// remove existing SPS, PPS, AUD
typ := h264.NALUType(nalu[0] & 0x1F)
switch typ {
case h264.NALUTypeSPS, h264.NALUTypePPS, h264.NALUTypeAccessUnitDelimiter:
continue
} }
// add SPS and PPS before every IDR // parse RTP packets
if typ == h264.NALUTypeIDR { var pkt rtp.Packet
filteredNALUs = append(filteredNALUs, h264Conf.SPS) err := pkt.Unmarshal(payload)
filteredNALUs = append(filteredNALUs, h264Conf.PPS) if err != nil {
return
} }
filteredNALUs = append(filteredNALUs, nalu) // decode H264 NALUs from RTP packets
} nalus, pts, err := dec.DecodeUntilMarker(&pkt)
if err != nil {
return
}
// encode into Annex-B if !firstPacketWritten {
enc, err := h264.EncodeAnnexB(filteredNALUs) firstPacketWritten = true
if err != nil { startPTS = pts
panic(err) }
}
dts := dtsEst.Feed(pts - startPTS) // check whether there's an IDR
pts = pts - startPTS idrPresent := func() bool {
for _, nalu := range nalus {
typ := h264.NALUType(nalu[0] & 0x1F)
if typ == h264.NALUTypeIDR {
return true
}
}
return false
}()
// write TS packet // prepend an AUD. This is required by some players
_, err = mux.WriteData(&astits.MuxerData{ filteredNALUs := [][]byte{
PID: 256, {byte(h264.NALUTypeAccessUnitDelimiter), 240},
AdaptationField: &astits.PacketAdaptationField{ }
RandomAccessIndicator: idrPresent,
}, for _, nalu := range nalus {
PES: &astits.PESData{ // remove existing SPS, PPS, AUD
Header: &astits.PESHeader{ typ := h264.NALUType(nalu[0] & 0x1F)
OptionalHeader: &astits.PESOptionalHeader{ switch typ {
MarkerBits: 2, case h264.NALUTypeSPS, h264.NALUTypePPS, h264.NALUTypeAccessUnitDelimiter:
PTSDTSIndicator: astits.PTSDTSIndicatorBothPresent, continue
DTS: &astits.ClockReference{Base: int64(dts.Seconds() * 90000)}, }
PTS: &astits.ClockReference{Base: int64(pts.Seconds() * 90000)},
}, // add SPS and PPS before every IDR
StreamID: 224, // video if typ == h264.NALUTypeIDR {
filteredNALUs = append(filteredNALUs, h264Conf.SPS)
filteredNALUs = append(filteredNALUs, h264Conf.PPS)
}
filteredNALUs = append(filteredNALUs, nalu)
}
// encode into Annex-B
enc, err := h264.EncodeAnnexB(filteredNALUs)
if err != nil {
panic(err)
}
dts := dtsEst.Feed(pts - startPTS)
pts = pts - startPTS
// write TS packet
_, err = mux.WriteData(&astits.MuxerData{
PID: 256,
AdaptationField: &astits.PacketAdaptationField{
RandomAccessIndicator: idrPresent,
}, },
Data: enc, PES: &astits.PESData{
}, Header: &astits.PESHeader{
}) OptionalHeader: &astits.PESOptionalHeader{
if err != nil { MarkerBits: 2,
panic(err) PTSDTSIndicator: astits.PTSDTSIndicatorBothPresent,
} DTS: &astits.ClockReference{Base: int64(dts.Seconds() * 90000)},
PTS: &astits.ClockReference{Base: int64(pts.Seconds() * 90000)},
},
StreamID: 224, // video
},
Data: enc,
},
})
if err != nil {
panic(err)
}
fmt.Println("wrote ts packet") fmt.Println("wrote ts packet")
}) },
}
// connect to the server and start reading all tracks
err = c.DialRead(inputStream)
if err != nil {
panic(err)
}
defer c.Close()
// read packets
err = c.ReadFrames()
panic(err) panic(err)
} }

View File

@@ -14,7 +14,54 @@ import (
// 3. get H264 NALUs of that track // 3. get H264 NALUs of that track
func main() { func main() {
c := gortsplib.Client{} var h264Track int
var dec *rtph264.Decoder
c := gortsplib.Client{
// called before sending a PLAY request
OnPlay: func(c *gortsplib.Client) {
// find the H264 track
h264Track = func() int {
for i, track := range c.Tracks() {
if track.IsH264() {
return i
}
}
return -1
}()
if h264Track < 0 {
panic(fmt.Errorf("H264 track not found"))
}
fmt.Printf("H264 track is number %d\n", h264Track+1)
// instantiate a RTP/H264 decoder
dec = rtph264.NewDecoder()
},
// called when a RTP packet arrives
OnPacketRTP: func(c *gortsplib.Client, trackID int, payload []byte) {
if trackID != h264Track {
return
}
// parse RTP packets
var pkt rtp.Packet
err := pkt.Unmarshal(payload)
if err != nil {
return
}
// decode H264 NALUs from RTP packets
nalus, _, err := dec.Decode(&pkt)
if err != nil {
return
}
// print NALUs
for _, nalu := range nalus {
fmt.Printf("received H264 NALU of size %d\n", len(nalu))
}
},
}
// connect to the server and start reading all tracks // connect to the server and start reading all tracks
err := c.DialRead("rtsp://localhost:8554/mystream") err := c.DialRead("rtsp://localhost:8554/mystream")
@@ -23,50 +70,7 @@ func main() {
} }
defer c.Close() defer c.Close()
// find the H264 track
h264Track := func() int {
for i, track := range c.Tracks() {
if track.IsH264() {
return i
}
}
return -1
}()
if h264Track < 0 {
panic(fmt.Errorf("H264 track not found"))
}
fmt.Printf("H264 track is number %d\n", h264Track+1)
// instantiate a RTP/H264 decoder
dec := rtph264.NewDecoder()
// read packets // read packets
err = c.ReadFrames(func(trackID int, streamType gortsplib.StreamType, payload []byte) { err = c.ReadFrames()
if streamType != gortsplib.StreamTypeRTP {
return
}
if trackID != h264Track {
return
}
// parse RTP packets
var pkt rtp.Packet
err := pkt.Unmarshal(payload)
if err != nil {
return
}
// decode H264 NALUs from RTP packets
nalus, _, err := dec.Decode(&pkt)
if err != nil {
return
}
// print NALUs
for _, nalu := range nalus {
fmt.Printf("received H264 NALU of size %d\n", len(nalu))
}
})
panic(err) panic(err)
} }

View File

@@ -20,6 +20,14 @@ func main() {
ReadTimeout: 10 * time.Second, ReadTimeout: 10 * time.Second,
// timeout of write operations // timeout of write operations
WriteTimeout: 10 * time.Second, WriteTimeout: 10 * time.Second,
// called when a RTP packet arrives
OnPacketRTP: func(c *gortsplib.Client, trackID int, payload []byte) {
fmt.Printf("RTP packet from track %d, size %d\n", trackID, len(payload))
},
// called when a RTCP packet arrives
OnPacketRTCP: func(c *gortsplib.Client, trackID int, payload []byte) {
fmt.Printf("RTCP packet from track %d, size %d\n", trackID, len(payload))
},
} }
// connect to the server and start reading all tracks // connect to the server and start reading all tracks
@@ -30,8 +38,6 @@ func main() {
defer c.Close() defer c.Close()
// read packets // read packets
err = c.ReadFrames(func(trackID int, streamType gortsplib.StreamType, payload []byte) { err = c.ReadFrames()
fmt.Printf("packet from track %d, type %v, size %d\n", trackID, streamType, len(payload))
})
panic(err) panic(err)
} }

View File

@@ -19,7 +19,16 @@ func main() {
panic(err) panic(err)
} }
c := gortsplib.Client{} c := gortsplib.Client{
// called when a RTP packet arrives
OnPacketRTP: func(c *gortsplib.Client, trackID int, payload []byte) {
fmt.Printf("RTP packet from track %d, size %d\n", trackID, len(payload))
},
// called when a RTCP packet arrives
OnPacketRTCP: func(c *gortsplib.Client, trackID int, payload []byte) {
fmt.Printf("RTCP packet from track %d, size %d\n", trackID, len(payload))
},
}
err = c.Dial(u.Scheme, u.Host) err = c.Dial(u.Scheme, u.Host)
if err != nil { if err != nil {
@@ -54,8 +63,6 @@ func main() {
} }
// read packets // read packets
err = c.ReadFrames(func(trackID int, streamType gortsplib.StreamType, payload []byte) { err = c.ReadFrames()
fmt.Printf("packet from track %d, type %v, size %d\n", trackID, streamType, len(payload))
})
panic(err) panic(err)
} }

View File

@@ -14,7 +14,16 @@ import (
// 4. repeat // 4. repeat
func main() { func main() {
c := gortsplib.Client{} c := gortsplib.Client{
// called when a RTP packet arrives
OnPacketRTP: func(c *gortsplib.Client, trackID int, payload []byte) {
fmt.Printf("RTP packet from track %d, size %d\n", trackID, len(payload))
},
// called when a RTCP packet arrives
OnPacketRTCP: func(c *gortsplib.Client, trackID int, payload []byte) {
fmt.Printf("RTCP packet from track %d, size %d\n", trackID, len(payload))
},
}
// connect to the server and start reading all tracks // connect to the server and start reading all tracks
err := c.DialRead("rtsp://localhost:8554/mystream") err := c.DialRead("rtsp://localhost:8554/mystream")
@@ -28,9 +37,7 @@ func main() {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer close(done) defer close(done)
c.ReadFrames(func(trackID int, streamType gortsplib.StreamType, payload []byte) { c.ReadFrames()
fmt.Printf("packet from track %d, type %v, size %d\n", trackID, streamType, len(payload))
})
}() }()
// wait // wait

View File

@@ -10,7 +10,16 @@ import (
// 1. connect to a RTSP server and read all tracks on a path // 1. connect to a RTSP server and read all tracks on a path
func main() { func main() {
c := gortsplib.Client{} c := gortsplib.Client{
// called when a RTP packet arrives
OnPacketRTP: func(c *gortsplib.Client, trackID int, payload []byte) {
fmt.Printf("RTP packet from track %d, size %d\n", trackID, len(payload))
},
// called when a RTCP packet arrives
OnPacketRTCP: func(c *gortsplib.Client, trackID int, payload []byte) {
fmt.Printf("RTCP packet from track %d, size %d\n", trackID, len(payload))
},
}
// connect to the server and start reading all tracks // connect to the server and start reading all tracks
err := c.DialRead("rtsp://localhost:8554/mystream") err := c.DialRead("rtsp://localhost:8554/mystream")
@@ -20,8 +29,6 @@ func main() {
defer c.Close() defer c.Close()
// read packets // read packets
err = c.ReadFrames(func(trackID int, streamType gortsplib.StreamType, payload []byte) { err = c.ReadFrames()
fmt.Printf("packet from track %d, type %v, size %d\n", trackID, streamType, len(payload))
})
panic(err) panic(err)
} }

View File

@@ -129,7 +129,7 @@ func (sh *serverHandler) OnPacketRTP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx)
sh.mutex.Lock() sh.mutex.Lock()
defer sh.mutex.Unlock() defer sh.mutex.Unlock()
// if we are the publisher, route packet to readers // if we are the publisher, route the RTP packet to readers
if ctx.Session == sh.publisher { if ctx.Session == sh.publisher {
sh.stream.WritePacketRTP(ctx.TrackID, ctx.Payload) sh.stream.WritePacketRTP(ctx.TrackID, ctx.Payload)
} }
@@ -140,7 +140,7 @@ func (sh *serverHandler) OnPacketRTCP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx
sh.mutex.Lock() sh.mutex.Lock()
defer sh.mutex.Unlock() defer sh.mutex.Unlock()
// if we are the publisher, route packet to readers // if we are the publisher, route the RTCP packet to readers
if ctx.Session == sh.publisher { if ctx.Session == sh.publisher {
sh.stream.WritePacketRTCP(ctx.TrackID, ctx.Payload) sh.stream.WritePacketRTCP(ctx.TrackID, ctx.Payload)
} }

View File

@@ -128,7 +128,7 @@ func (sh *serverHandler) OnPacketRTP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx)
sh.mutex.Lock() sh.mutex.Lock()
defer sh.mutex.Unlock() defer sh.mutex.Unlock()
// if we are the publisher, route packet to readers // if we are the publisher, route the RTP packet to readers
if ctx.Session == sh.publisher { if ctx.Session == sh.publisher {
sh.stream.WritePacketRTP(ctx.TrackID, ctx.Payload) sh.stream.WritePacketRTP(ctx.TrackID, ctx.Payload)
} }
@@ -139,7 +139,7 @@ func (sh *serverHandler) OnPacketRTCP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx
sh.mutex.Lock() sh.mutex.Lock()
defer sh.mutex.Unlock() defer sh.mutex.Unlock()
// if we are the publisher, route packet to readers // if we are the publisher, route the RTCP packet to readers
if ctx.Session == sh.publisher { if ctx.Session == sh.publisher {
sh.stream.WritePacketRTCP(ctx.TrackID, ctx.Payload) sh.stream.WritePacketRTCP(ctx.TrackID, ctx.Payload)
} }