propagate H264 packets throughout the server

This commit is contained in:
aler9
2022-03-02 22:52:05 +01:00
committed by Alessandro Ros
parent a59ddf7176
commit d929197b21
15 changed files with 160 additions and 255 deletions

2
go.mod
View File

@@ -4,7 +4,7 @@ go 1.17
require ( require (
code.cloudfoundry.org/bytefmt v0.0.0-20211005130812-5bb3c17173e5 code.cloudfoundry.org/bytefmt v0.0.0-20211005130812-5bb3c17173e5
github.com/aler9/gortsplib v0.0.0-20220401091943-cec5326ccfed github.com/aler9/gortsplib v0.0.0-20220408160915-2d2e62f55bae
github.com/asticode/go-astits v1.10.0 github.com/asticode/go-astits v1.10.0
github.com/fsnotify/fsnotify v1.4.9 github.com/fsnotify/fsnotify v1.4.9
github.com/gin-gonic/gin v1.7.2 github.com/gin-gonic/gin v1.7.2

4
go.sum
View File

@@ -4,8 +4,8 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafo
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d h1:UQZhZ2O0vMHr2cI+DC1Mbh0TJxzA3RcLoMsFw+aXw7E= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d h1:UQZhZ2O0vMHr2cI+DC1Mbh0TJxzA3RcLoMsFw+aXw7E=
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho=
github.com/aler9/gortsplib v0.0.0-20220401091943-cec5326ccfed h1:lA/dMUwmQcBCeuFRYkPr3qnGPKgrFdGC0RZzB8tRKXw= github.com/aler9/gortsplib v0.0.0-20220408160915-2d2e62f55bae h1:BGe90r+y1BRvSz1b1OIbee0q9c2MdI2GUhnzVm0XoSU=
github.com/aler9/gortsplib v0.0.0-20220401091943-cec5326ccfed/go.mod h1:4mWq8mM6v8KrSQG4sEdnvM6+ZVKPPgKtf75TYR+jsKQ= github.com/aler9/gortsplib v0.0.0-20220408160915-2d2e62f55bae/go.mod h1:Mezkz7Jb5zrIWP6MxJ2uBgt5xwywZkcdmuQZ2QrFYsM=
github.com/aler9/rtmp v0.0.0-20210403095203-3be4a5535927 h1:95mXJ5fUCYpBRdSOnLAQAdJHHKxxxJrVCiaqDi965YQ= github.com/aler9/rtmp v0.0.0-20210403095203-3be4a5535927 h1:95mXJ5fUCYpBRdSOnLAQAdJHHKxxxJrVCiaqDi965YQ=
github.com/aler9/rtmp v0.0.0-20210403095203-3be4a5535927/go.mod h1:vzuE21rowz+lT1NGsWbreIvYulgBpCGnQyeTyFblUHc= github.com/aler9/rtmp v0.0.0-20210403095203-3be4a5535927/go.mod h1:vzuE21rowz+lT1NGsWbreIvYulgBpCGnQyeTyFblUHc=
github.com/asticode/go-astikit v0.20.0 h1:+7N+J4E4lWx2QOkRdOf6DafWJMv6O4RRfgClwQokrH8= github.com/asticode/go-astikit v0.20.0 h1:+7N+J4E4lWx2QOkRdOf6DafWJMv6O4RRfgClwQokrH8=

14
internal/core/data.go Normal file
View File

@@ -0,0 +1,14 @@
package core
import (
"time"
"github.com/pion/rtp"
)
type data struct {
rtp *rtp.Packet
ptsEqualsDTS bool
h264NALUs [][]byte
h264PTS time.Duration
}

View File

@@ -16,8 +16,6 @@ import (
"github.com/aler9/gortsplib" "github.com/aler9/gortsplib"
"github.com/aler9/gortsplib/pkg/ringbuffer" "github.com/aler9/gortsplib/pkg/ringbuffer"
"github.com/aler9/gortsplib/pkg/rtpaac" "github.com/aler9/gortsplib/pkg/rtpaac"
"github.com/aler9/gortsplib/pkg/rtph264"
"github.com/pion/rtp"
"github.com/aler9/rtsp-simple-server/internal/conf" "github.com/aler9/rtsp-simple-server/internal/conf"
"github.com/aler9/rtsp-simple-server/internal/hls" "github.com/aler9/rtsp-simple-server/internal/hls"
@@ -107,9 +105,9 @@ type hlsMuxerRequest struct {
res chan hlsMuxerResponse res chan hlsMuxerResponse
} }
type hlsMuxerTrackIDPayloadPair struct { type hlsMuxerTrackIDDataPair struct {
trackID int trackID int
packet *rtp.Packet data *data
} }
type hlsMuxerPathManager interface { type hlsMuxerPathManager interface {
@@ -282,7 +280,6 @@ func (m *hlsMuxer) runInner(innerCtx context.Context, innerReady chan struct{})
var videoTrack *gortsplib.TrackH264 var videoTrack *gortsplib.TrackH264
videoTrackID := -1 videoTrackID := -1
var h264Decoder *rtph264.Decoder
var audioTrack *gortsplib.TrackAAC var audioTrack *gortsplib.TrackAAC
audioTrackID := -1 audioTrackID := -1
var aacDecoder *rtpaac.Decoder var aacDecoder *rtpaac.Decoder
@@ -296,8 +293,6 @@ func (m *hlsMuxer) runInner(innerCtx context.Context, innerReady chan struct{})
videoTrack = tt videoTrack = tt
videoTrackID = i videoTrackID = i
h264Decoder = &rtph264.Decoder{}
h264Decoder.Init()
case *gortsplib.TrackAAC: case *gortsplib.TrackAAC:
if audioTrack != nil { if audioTrack != nil {
@@ -342,25 +337,20 @@ func (m *hlsMuxer) runInner(innerCtx context.Context, innerReady chan struct{})
if !ok { if !ok {
return fmt.Errorf("terminated") return fmt.Errorf("terminated")
} }
pair := data.(hlsMuxerTrackIDPayloadPair) pair := data.(hlsMuxerTrackIDDataPair)
if videoTrack != nil && pair.trackID == videoTrackID { if videoTrack != nil && pair.trackID == videoTrackID {
nalus, pts, err := h264Decoder.DecodeUntilMarker(pair.packet) if pair.data.h264NALUs == nil {
if err != nil {
if err != rtph264.ErrMorePacketsNeeded &&
err != rtph264.ErrNonStartingPacketAndNoPrevious {
m.log(logger.Warn, "unable to decode video track: %v", err)
}
continue continue
} }
err = m.muxer.WriteH264(pts, nalus) err = m.muxer.WriteH264(pair.data.h264PTS, pair.data.h264NALUs)
if err != nil { if err != nil {
m.log(logger.Warn, "unable to write segment: %v", err) m.log(logger.Warn, "unable to write segment: %v", err)
continue continue
} }
} else if audioTrack != nil && pair.trackID == audioTrackID { } else if audioTrack != nil && pair.trackID == audioTrackID {
aus, pts, err := aacDecoder.Decode(pair.packet) aus, pts, err := aacDecoder.Decode(pair.data.rtp)
if err != nil { if err != nil {
if err != rtpaac.ErrMorePacketsNeeded { if err != rtpaac.ErrMorePacketsNeeded {
m.log(logger.Warn, "unable to decode audio track: %v", err) m.log(logger.Warn, "unable to decode audio track: %v", err)
@@ -536,9 +526,9 @@ func (m *hlsMuxer) onReaderAccepted() {
m.log(logger.Info, "is converting into HLS") m.log(logger.Info, "is converting into HLS")
} }
// onReaderPacketRTP implements reader. // onReaderData implements reader.
func (m *hlsMuxer) onReaderPacketRTP(trackID int, pkt *rtp.Packet) { func (m *hlsMuxer) onReaderData(trackID int, data *data) {
m.ringBuffer.Push(hlsMuxerTrackIDPayloadPair{trackID, pkt}) m.ringBuffer.Push(hlsMuxerTrackIDDataPair{trackID, data})
} }
// onReaderAPIDescribe implements reader. // onReaderAPIDescribe implements reader.

View File

@@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/aler9/gortsplib" "github.com/aler9/gortsplib"
"github.com/aler9/gortsplib/pkg/h264"
"github.com/aler9/gortsplib/pkg/rtpaac" "github.com/aler9/gortsplib/pkg/rtpaac"
"github.com/aler9/gortsplib/pkg/rtph264" "github.com/aler9/gortsplib/pkg/rtph264"
@@ -146,8 +147,21 @@ func (s *hlsSource) runInner() bool {
return return
} }
for _, pkt := range pkts { lastPkt := len(pkts) - 1
stream.writePacketRTP(videoTrackID, pkt) for i, pkt := range pkts {
if i != lastPkt {
stream.writeData(videoTrackID, &data{
rtp: pkt,
ptsEqualsDTS: false,
})
} else {
stream.writeData(videoTrackID, &data{
rtp: pkt,
ptsEqualsDTS: h264.IDRPresent(nalus),
h264NALUs: nalus,
h264PTS: pts,
})
}
} }
} }
@@ -162,7 +176,10 @@ func (s *hlsSource) runInner() bool {
} }
for _, pkt := range pkts { for _, pkt := range pkts {
stream.writePacketRTP(audioTrackID, pkt) stream.writeData(audioTrackID, &data{
rtp: pkt,
ptsEqualsDTS: true,
})
} }
} }

View File

@@ -13,7 +13,6 @@ import (
"github.com/aler9/gortsplib/pkg/h264" "github.com/aler9/gortsplib/pkg/h264"
"github.com/asticode/go-astits" "github.com/asticode/go-astits"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pion/rtp"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -134,8 +133,8 @@ func TestHLSSource(t *testing.T) {
frameRecv := make(chan struct{}) frameRecv := make(chan struct{})
c := gortsplib.Client{ c := gortsplib.Client{
OnPacketRTP: func(trackID int, pkt *rtp.Packet) { OnPacketRTP: func(ctx *gortsplib.ClientOnPacketRTPCtx) {
require.Equal(t, []byte{0x05}, pkt.Payload) require.Equal(t, []byte{0x05}, ctx.Packet.Payload)
close(frameRecv) close(frameRecv)
}, },
} }

View File

@@ -1,13 +1,9 @@
package core package core
import (
"github.com/pion/rtp"
)
// reader is an entity that can read a stream. // reader is an entity that can read a stream.
type reader interface { type reader interface {
close() close()
onReaderAccepted() onReaderAccepted()
onReaderPacketRTP(int, *rtp.Packet) onReaderData(int, *data)
onReaderAPIDescribe() interface{} onReaderAPIDescribe() interface{}
} }

View File

@@ -16,7 +16,6 @@ import (
"github.com/aler9/gortsplib/pkg/rtpaac" "github.com/aler9/gortsplib/pkg/rtpaac"
"github.com/aler9/gortsplib/pkg/rtph264" "github.com/aler9/gortsplib/pkg/rtph264"
"github.com/notedit/rtmp/av" "github.com/notedit/rtmp/av"
"github.com/pion/rtp"
"github.com/aler9/rtsp-simple-server/internal/conf" "github.com/aler9/rtsp-simple-server/internal/conf"
"github.com/aler9/rtsp-simple-server/internal/externalcmd" "github.com/aler9/rtsp-simple-server/internal/externalcmd"
@@ -44,9 +43,9 @@ const (
rtmpConnStatePublish rtmpConnStatePublish
) )
type rtmpConnTrackIDPayloadPair struct { type rtmpConnTrackIDDataPair struct {
trackID int trackID int
packet *rtp.Packet data *data
} }
type rtmpConnPathManager interface { type rtmpConnPathManager interface {
@@ -260,7 +259,6 @@ func (c *rtmpConn) runRead(ctx context.Context) error {
var videoTrack *gortsplib.TrackH264 var videoTrack *gortsplib.TrackH264
videoTrackID := -1 videoTrackID := -1
var h264Decoder *rtph264.Decoder
var audioTrack *gortsplib.TrackAAC var audioTrack *gortsplib.TrackAAC
audioTrackID := -1 audioTrackID := -1
var aacDecoder *rtpaac.Decoder var aacDecoder *rtpaac.Decoder
@@ -274,8 +272,6 @@ func (c *rtmpConn) runRead(ctx context.Context) error {
videoTrack = tt videoTrack = tt
videoTrackID = i videoTrackID = i
h264Decoder = &rtph264.Decoder{}
h264Decoder.Init()
case *gortsplib.TrackAAC: case *gortsplib.TrackAAC:
if audioTrack != nil { if audioTrack != nil {
@@ -338,20 +334,16 @@ func (c *rtmpConn) runRead(ctx context.Context) error {
if !ok { if !ok {
return fmt.Errorf("terminated") return fmt.Errorf("terminated")
} }
pair := data.(rtmpConnTrackIDPayloadPair) pair := data.(rtmpConnTrackIDDataPair)
if videoTrack != nil && pair.trackID == videoTrackID { if videoTrack != nil && pair.trackID == videoTrackID {
nalus, pts, err := h264Decoder.DecodeUntilMarker(pair.packet) if pair.data.h264NALUs == nil {
if err != nil {
if err != rtph264.ErrMorePacketsNeeded && err != rtph264.ErrNonStartingPacketAndNoPrevious {
c.log(logger.Warn, "unable to decode video track: %v", err)
}
continue continue
} }
var nalusFiltered [][]byte var nalusFiltered [][]byte
for _, nalu := range nalus { for _, nalu := range pair.data.h264NALUs {
// remove SPS, PPS and AUD, not needed by RTMP // remove SPS, PPS and AUD, not needed by RTMP
typ := h264.NALUType(nalu[0] & 0x1F) typ := h264.NALUType(nalu[0] & 0x1F)
switch typ { switch typ {
@@ -362,24 +354,14 @@ func (c *rtmpConn) runRead(ctx context.Context) error {
nalusFiltered = append(nalusFiltered, nalu) nalusFiltered = append(nalusFiltered, nalu)
} }
idrPresent := func() bool {
for _, nalu := range nalus {
typ := h264.NALUType(nalu[0] & 0x1F)
if typ == h264.NALUTypeIDR {
return true
}
}
return false
}()
// wait until we receive an IDR // wait until we receive an IDR
if !videoFirstIDRFound { if !videoFirstIDRFound {
if !idrPresent { if !h264.IDRPresent(nalusFiltered) {
continue continue
} }
videoFirstIDRFound = true videoFirstIDRFound = true
videoStartPTS = pts videoStartPTS = pair.data.h264PTS
videoDTSEst = h264.NewDTSEstimator() videoDTSEst = h264.NewDTSEstimator()
} }
@@ -388,7 +370,7 @@ func (c *rtmpConn) runRead(ctx context.Context) error {
return err return err
} }
pts -= videoStartPTS pts := pair.data.h264PTS - videoStartPTS
dts := videoDTSEst.Feed(pts) dts := videoDTSEst.Feed(pts)
c.conn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout))) c.conn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout)))
@@ -402,7 +384,7 @@ func (c *rtmpConn) runRead(ctx context.Context) error {
return err return err
} }
} else if audioTrack != nil && pair.trackID == audioTrackID { } else if audioTrack != nil && pair.trackID == audioTrackID {
aus, pts, err := aacDecoder.Decode(pair.packet) aus, pts, err := aacDecoder.Decode(pair.data.rtp)
if err != nil { if err != nil {
if err != rtpaac.ErrMorePacketsNeeded { if err != rtpaac.ErrMorePacketsNeeded {
c.log(logger.Warn, "unable to decode audio track: %v", err) c.log(logger.Warn, "unable to decode audio track: %v", err)
@@ -545,13 +527,28 @@ func (c *rtmpConn) runPublish(ctx context.Context) error {
continue continue
} }
pkts, err := h264Encoder.Encode(outNALUs, pkt.Time+pkt.CTime) pts := pkt.Time + pkt.CTime
pkts, err := h264Encoder.Encode(outNALUs, pts)
if err != nil { if err != nil {
return fmt.Errorf("error while encoding H264: %v", err) return fmt.Errorf("error while encoding H264: %v", err)
} }
for _, pkt := range pkts { lastPkt := len(pkts) - 1
rres.stream.writePacketRTP(videoTrackID, pkt) for i, pkt := range pkts {
if i != lastPkt {
rres.stream.writeData(videoTrackID, &data{
rtp: pkt,
ptsEqualsDTS: false,
})
} else {
rres.stream.writeData(videoTrackID, &data{
rtp: pkt,
ptsEqualsDTS: h264.IDRPresent(outNALUs),
h264NALUs: outNALUs,
h264PTS: pts,
})
}
} }
case av.AAC: case av.AAC:
@@ -565,7 +562,10 @@ func (c *rtmpConn) runPublish(ctx context.Context) error {
} }
for _, pkt := range pkts { for _, pkt := range pkts {
rres.stream.writePacketRTP(audioTrackID, pkt) rres.stream.writeData(audioTrackID, &data{
rtp: pkt,
ptsEqualsDTS: true,
})
} }
} }
} }
@@ -622,9 +622,9 @@ func (c *rtmpConn) onReaderAccepted() {
c.log(logger.Info, "is reading from path '%s'", c.path.Name()) c.log(logger.Info, "is reading from path '%s'", c.path.Name())
} }
// onReaderPacketRTP implements reader. // onReaderData implements reader.
func (c *rtmpConn) onReaderPacketRTP(trackID int, pkt *rtp.Packet) { func (c *rtmpConn) onReaderData(trackID int, data *data) {
c.ringBuffer.Push(rtmpConnTrackIDPayloadPair{trackID, pkt}) c.ringBuffer.Push(rtmpConnTrackIDDataPair{trackID, data})
} }
// onReaderAPIDescribe implements reader. // onReaderAPIDescribe implements reader.

View File

@@ -165,6 +165,7 @@ func (s *rtmpSource) runInner() bool {
defer func() { defer func() {
s.parent.onSourceStaticSetNotReady(pathSourceStaticSetNotReadyReq{source: s}) s.parent.onSourceStaticSetNotReady(pathSourceStaticSetNotReadyReq{source: s})
}() }()
for { for {
conn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout))) conn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout)))
pkt, err := conn.ReadPacket() pkt, err := conn.ReadPacket()
@@ -195,13 +196,28 @@ func (s *rtmpSource) runInner() bool {
outNALUs = append(outNALUs, nalu) outNALUs = append(outNALUs, nalu)
} }
pkts, err := h264Encoder.Encode(outNALUs, pkt.Time+pkt.CTime) pts := pkt.Time + pkt.CTime
pkts, err := h264Encoder.Encode(outNALUs, pts)
if err != nil { if err != nil {
return fmt.Errorf("error while encoding H264: %v", err) return fmt.Errorf("error while encoding H264: %v", err)
} }
for _, pkt := range pkts { lastPkt := len(pkts) - 1
res.stream.writePacketRTP(videoTrackID, pkt) for i, pkt := range pkts {
if i != lastPkt {
res.stream.writeData(videoTrackID, &data{
rtp: pkt,
ptsEqualsDTS: false,
})
} else {
res.stream.writeData(videoTrackID, &data{
rtp: pkt,
ptsEqualsDTS: h264.IDRPresent(outNALUs),
h264NALUs: outNALUs,
h264PTS: pts,
})
}
} }
case av.AAC: case av.AAC:
@@ -215,7 +231,10 @@ func (s *rtmpSource) runInner() bool {
} }
for _, pkt := range pkts { for _, pkt := range pkts {
res.stream.writePacketRTP(audioTrackID, pkt) res.stream.writeData(audioTrackID, &data{
rtp: pkt,
ptsEqualsDTS: true,
})
} }
} }
} }

View File

@@ -459,11 +459,11 @@ func TestRTSPServerPublisherOverride(t *testing.T) {
frameRecv := make(chan struct{}) frameRecv := make(chan struct{})
c := gortsplib.Client{ c := gortsplib.Client{
OnPacketRTP: func(trackID int, pkt *rtp.Packet) { OnPacketRTP: func(ctx *gortsplib.ClientOnPacketRTPCtx) {
if ca == "enabled" { if ca == "enabled" {
require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, pkt.Payload) require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, ctx.Packet.Payload)
} else { } else {
require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, pkt.Payload) require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, ctx.Packet.Payload)
} }
close(frameRecv) close(frameRecv)
}, },
@@ -483,7 +483,7 @@ func TestRTSPServerPublisherOverride(t *testing.T) {
Marker: true, Marker: true,
}, },
Payload: []byte{0x01, 0x02, 0x03, 0x04}, Payload: []byte{0x01, 0x02, 0x03, 0x04},
}) }, true)
if ca == "enabled" { if ca == "enabled" {
require.Error(t, err) require.Error(t, err)
} else { } else {
@@ -501,7 +501,7 @@ func TestRTSPServerPublisherOverride(t *testing.T) {
Marker: true, Marker: true,
}, },
Payload: []byte{0x05, 0x06, 0x07, 0x08}, Payload: []byte{0x05, 0x06, 0x07, 0x08},
}) }, true)
require.NoError(t, err) require.NoError(t, err)
} }

View File

@@ -9,7 +9,6 @@ import (
"github.com/aler9/gortsplib" "github.com/aler9/gortsplib"
"github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/base"
"github.com/pion/rtp"
"github.com/aler9/rtsp-simple-server/internal/conf" "github.com/aler9/rtsp-simple-server/internal/conf"
"github.com/aler9/rtsp-simple-server/internal/externalcmd" "github.com/aler9/rtsp-simple-server/internal/externalcmd"
@@ -42,10 +41,9 @@ type rtspSession struct {
path *path path *path
state gortsplib.ServerSessionState state gortsplib.ServerSessionState
stateMutex sync.Mutex stateMutex sync.Mutex
setuppedTracks map[int]gortsplib.Track // read onReadCmd *externalcmd.Cmd // read
onReadCmd *externalcmd.Cmd // read announcedTracks gortsplib.Tracks // publish
announcedTracks gortsplib.Tracks // publish stream *stream // publish
stream *stream // publish
} }
func newRTSPSession( func newRTSPSession(
@@ -231,11 +229,6 @@ func (s *rtspSession) onSetup(c *rtspConn, ctx *gortsplib.ServerHandlerOnSetupCt
}, nil, fmt.Errorf("track %d does not exist", ctx.TrackID) }, nil, fmt.Errorf("track %d does not exist", ctx.TrackID)
} }
if s.setuppedTracks == nil {
s.setuppedTracks = make(map[int]gortsplib.Track)
}
s.setuppedTracks[ctx.TrackID] = res.stream.tracks()[ctx.TrackID]
s.stateMutex.Lock() s.stateMutex.Lock()
s.state = gortsplib.ServerSessionStatePrePlay s.state = gortsplib.ServerSessionStatePrePlay
s.stateMutex.Unlock() s.stateMutex.Unlock()
@@ -348,8 +341,8 @@ func (s *rtspSession) onReaderAccepted() {
s.ss.SetuppedTransport()) s.ss.SetuppedTransport())
} }
// onReaderPacketRTP implements reader. // onReaderData implements reader.
func (s *rtspSession) onReaderPacketRTP(trackID int, pkt *rtp.Packet) { func (s *rtspSession) onReaderData(trackID int, data *data) {
// packets are routed to the session by gortsplib.ServerStream. // packets are routed to the session by gortsplib.ServerStream.
} }
@@ -399,5 +392,17 @@ func (s *rtspSession) onPublisherAccepted(tracksLen int) {
// onPacketRTP is called by rtspServer. // onPacketRTP is called by rtspServer.
func (s *rtspSession) onPacketRTP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx) { func (s *rtspSession) onPacketRTP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx) {
s.stream.writePacketRTP(ctx.TrackID, ctx.Packet) if ctx.H264NALUs != nil {
s.stream.writeData(ctx.TrackID, &data{
rtp: ctx.Packet,
ptsEqualsDTS: ctx.PTSEqualsDTS,
h264NALUs: append([][]byte(nil), ctx.H264NALUs...),
h264PTS: ctx.H264PTS,
})
} else {
s.stream.writeData(ctx.TrackID, &data{
rtp: ctx.Packet,
ptsEqualsDTS: ctx.PTSEqualsDTS,
})
}
} }

View File

@@ -12,9 +12,6 @@ import (
"github.com/aler9/gortsplib" "github.com/aler9/gortsplib"
"github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/h264"
"github.com/aler9/gortsplib/pkg/rtph264"
"github.com/pion/rtp"
"github.com/aler9/rtsp-simple-server/internal/conf" "github.com/aler9/rtsp-simple-server/internal/conf"
"github.com/aler9/rtsp-simple-server/internal/logger" "github.com/aler9/rtsp-simple-server/internal/logger"
@@ -186,11 +183,6 @@ func (s *rtspSource) runInner() bool {
} }
} }
err = s.handleMissingH264Params(c, tracks)
if err != nil {
return err
}
res := s.parent.onSourceStaticSetReady(pathSourceStaticSetReadyReq{ res := s.parent.onSourceStaticSetReady(pathSourceStaticSetReadyReq{
source: s, source: s,
tracks: c.Tracks(), tracks: c.Tracks(),
@@ -205,8 +197,20 @@ func (s *rtspSource) runInner() bool {
s.parent.onSourceStaticSetNotReady(pathSourceStaticSetNotReadyReq{source: s}) s.parent.onSourceStaticSetNotReady(pathSourceStaticSetNotReadyReq{source: s})
}() }()
c.OnPacketRTP = func(trackID int, pkt *rtp.Packet) { c.OnPacketRTP = func(ctx *gortsplib.ClientOnPacketRTPCtx) {
res.stream.writePacketRTP(trackID, pkt) if ctx.H264NALUs != nil {
res.stream.writeData(ctx.TrackID, &data{
rtp: ctx.Packet,
ptsEqualsDTS: ctx.PTSEqualsDTS,
h264NALUs: append([][]byte(nil), ctx.H264NALUs...),
h264PTS: ctx.H264PTS,
})
} else {
res.stream.writeData(ctx.TrackID, &data{
rtp: ctx.Packet,
ptsEqualsDTS: ctx.PTSEqualsDTS,
})
}
} }
_, err = c.Play(nil) _, err = c.Play(nil)
@@ -230,131 +234,6 @@ func (s *rtspSource) runInner() bool {
} }
} }
func (s *rtspSource) handleMissingH264Params(c *gortsplib.Client, tracks gortsplib.Tracks) error {
h264Track, h264TrackID := func() (*gortsplib.TrackH264, int) {
for i, t := range tracks {
if th264, ok := t.(*gortsplib.TrackH264); ok {
if th264.SPS() == nil {
return th264, i
}
}
}
return nil, -1
}()
if h264TrackID < 0 {
return nil
}
if h264Track.SPS() != nil && h264Track.PPS() != nil {
return nil
}
s.log(logger.Info, "source has not provided H264 parameters (SPS and PPS)"+
" inside the SDP; extracting them from the stream...")
var streamMutex sync.RWMutex
var stream *stream
decoder := &rtph264.Decoder{}
decoder.Init()
var sps []byte
var pps []byte
paramsReceived := make(chan struct{})
c.OnPacketRTP = func(trackID int, pkt *rtp.Packet) {
streamMutex.RLock()
defer streamMutex.RUnlock()
if stream == nil {
if trackID != h264TrackID {
return
}
select {
case <-paramsReceived:
return
default:
}
nalus, _, err := decoder.Decode(pkt)
if err != nil {
return
}
for _, nalu := range nalus {
typ := h264.NALUType(nalu[0] & 0x1F)
switch typ {
case h264.NALUTypeSPS:
sps = nalu
if sps != nil && pps != nil {
close(paramsReceived)
}
case h264.NALUTypePPS:
pps = nalu
if sps != nil && pps != nil {
close(paramsReceived)
}
}
}
} else {
stream.writePacketRTP(trackID, pkt)
}
}
_, err := c.Play(nil)
if err != nil {
return err
}
readErr := make(chan error)
go func() {
readErr <- c.Wait()
}()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
select {
case err := <-readErr:
return err
case <-timeout.C:
c.Close()
<-readErr
return fmt.Errorf("source did not send H264 parameters in time")
case <-paramsReceived:
s.log(logger.Info, "H264 parameters extracted")
h264Track.SetSPS(sps)
h264Track.SetPPS(pps)
res := s.parent.onSourceStaticSetReady(pathSourceStaticSetReadyReq{
source: s,
tracks: tracks,
})
if res.err != nil {
c.Close()
<-readErr
return res.err
}
func() {
streamMutex.Lock()
defer streamMutex.Unlock()
stream = res.stream
}()
s.log(logger.Info, "ready")
defer func() {
s.parent.onSourceStaticSetNotReady(pathSourceStaticSetNotReadyReq{source: s})
}()
return <-readErr
}
}
// onSourceAPIDescribe implements source. // onSourceAPIDescribe implements source.
func (*rtspSource) onSourceAPIDescribe() interface{} { func (*rtspSource) onSourceAPIDescribe() interface{} {
return struct { return struct {

View File

@@ -84,7 +84,7 @@ func TestRTSPSource(t *testing.T) {
Marker: true, Marker: true,
}, },
Payload: []byte{0x01, 0x02, 0x03, 0x04}, Payload: []byte{0x01, 0x02, 0x03, 0x04},
}) }, true)
}() }()
return &base.Response{ return &base.Response{
@@ -143,8 +143,8 @@ func TestRTSPSource(t *testing.T) {
received := make(chan struct{}) received := make(chan struct{})
c := gortsplib.Client{ c := gortsplib.Client{
OnPacketRTP: func(trackID int, pkt *rtp.Packet) { OnPacketRTP: func(ctx *gortsplib.ClientOnPacketRTPCtx) {
require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, pkt.Payload) require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, ctx.Packet.Payload)
close(received) close(received)
}, },
} }
@@ -243,25 +243,25 @@ func TestRTSPSourceMissingH264Params(t *testing.T) {
pkts, err := enc.Encode([][]byte{{5}}, 0) // IDR pkts, err := enc.Encode([][]byte{{5}}, 0) // IDR
require.NoError(t, err) require.NoError(t, err)
stream.WritePacketRTP(0, pkts[0]) stream.WritePacketRTP(0, pkts[0], true)
pkts, err = enc.Encode([][]byte{{7, 1, 2, 3}}, 0) // SPS pkts, err = enc.Encode([][]byte{{7, 1, 2, 3}}, 0) // SPS
require.NoError(t, err) require.NoError(t, err)
stream.WritePacketRTP(0, pkts[0]) stream.WritePacketRTP(0, pkts[0], true)
pkts, err = enc.Encode([][]byte{{8}}, 0) // PPS pkts, err = enc.Encode([][]byte{{8}}, 0) // PPS
require.NoError(t, err) require.NoError(t, err)
stream.WritePacketRTP(0, pkts[0]) stream.WritePacketRTP(0, pkts[0], true)
pkts, err = enc.Encode([][]byte{{5, 1}}, 0) // IDR pkts, err = enc.Encode([][]byte{{5, 1}}, 0) // IDR
require.NoError(t, err) require.NoError(t, err)
stream.WritePacketRTP(0, pkts[0]) stream.WritePacketRTP(0, pkts[0], true)
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
pkts, err = enc.Encode([][]byte{{5, 2}}, 0) // IDR pkts, err = enc.Encode([][]byte{{5, 2}}, 0) // IDR
require.NoError(t, err) require.NoError(t, err)
stream.WritePacketRTP(0, pkts[0]) stream.WritePacketRTP(0, pkts[0], true)
}() }()
return &base.Response{ return &base.Response{
@@ -286,17 +286,14 @@ func TestRTSPSourceMissingH264Params(t *testing.T) {
defer p.close() defer p.close()
received := make(chan struct{}) received := make(chan struct{})
decoder := &rtph264.Decoder{}
decoder.Init()
c := gortsplib.Client{ c := gortsplib.Client{
OnPacketRTP: func(trackID int, pkt *rtp.Packet) { OnPacketRTP: func(ctx *gortsplib.ClientOnPacketRTPCtx) {
nalus, _, err := decoder.Decode(pkt) if ctx.H264NALUs == nil {
if err != nil {
return return
} }
require.Equal(t, [][]byte{{0x05, 0x02}}, nalus) require.Equal(t, [][]byte{{0x05, 0x02}}, ctx.H264NALUs)
close(received) close(received)
}, },
} }

View File

@@ -4,7 +4,6 @@ import (
"sync" "sync"
"github.com/aler9/gortsplib" "github.com/aler9/gortsplib"
"github.com/pion/rtp"
) )
type streamNonRTSPReadersMap struct { type streamNonRTSPReadersMap struct {
@@ -36,12 +35,12 @@ func (m *streamNonRTSPReadersMap) remove(r reader) {
delete(m.ma, r) delete(m.ma, r)
} }
func (m *streamNonRTSPReadersMap) forwardPacketRTP(trackID int, pkt *rtp.Packet) { func (m *streamNonRTSPReadersMap) forwardPacketRTP(trackID int, data *data) {
m.mutex.RLock() m.mutex.RLock()
defer m.mutex.RUnlock() defer m.mutex.RUnlock()
for c := range m.ma { for c := range m.ma {
c.onReaderPacketRTP(trackID, pkt) c.onReaderData(trackID, data)
} }
} }
@@ -79,10 +78,10 @@ func (s *stream) readerRemove(r reader) {
} }
} }
func (s *stream) writePacketRTP(trackID int, pkt *rtp.Packet) { func (s *stream) writeData(trackID int, data *data) {
// forward to RTSP readers // forward to RTSP readers
s.rtspStream.WritePacketRTP(trackID, pkt) s.rtspStream.WritePacketRTP(trackID, data.rtp, data.ptsEqualsDTS)
// forward to non-RTSP readers // forward to non-RTSP readers
s.nonRTSPReaders.forwardPacketRTP(trackID, pkt) s.nonRTSPReaders.forwardPacketRTP(trackID, data)
} }

View File

@@ -14,16 +14,6 @@ const (
segmentMinAUCount = 100 segmentMinAUCount = 100
) )
func idrPresent(nalus [][]byte) bool {
for _, nalu := range nalus {
typ := h264.NALUType(nalu[0] & 0x1F)
if typ == h264.NALUTypeIDR {
return true
}
}
return false
}
type writerFunc func(p []byte) (int, error) type writerFunc func(p []byte) (int, error)
func (f writerFunc) Write(p []byte) (int, error) { func (f writerFunc) Write(p []byte) (int, error) {
@@ -93,7 +83,7 @@ func newMuxerTSGenerator(
func (m *muxerTSGenerator) writeH264(pts time.Duration, nalus [][]byte) error { func (m *muxerTSGenerator) writeH264(pts time.Duration, nalus [][]byte) error {
now := time.Now() now := time.Now()
idrPresent := idrPresent(nalus) idrPresent := h264.IDRPresent(nalus)
if m.currentSegment == nil { if m.currentSegment == nil {
// skip groups silently until we find one with a IDR // skip groups silently until we find one with a IDR