hls client: move RTP packet generation outside client

This commit is contained in:
aler9
2022-02-19 12:25:23 +01:00
parent 4b0d33e309
commit fe32022edf
5 changed files with 97 additions and 102 deletions

View File

@@ -6,7 +6,8 @@ import (
"time" "time"
"github.com/aler9/gortsplib" "github.com/aler9/gortsplib"
"github.com/pion/rtp" "github.com/aler9/gortsplib/pkg/rtpaac"
"github.com/aler9/gortsplib/pkg/rtph264"
"github.com/aler9/rtsp-simple-server/internal/hls" "github.com/aler9/rtsp-simple-server/internal/hls"
"github.com/aler9/rtsp-simple-server/internal/logger" "github.com/aler9/rtsp-simple-server/internal/logger"
@@ -92,6 +93,8 @@ func (s *hlsSource) runInner() bool {
var rtcpSenders *rtcpsenderset.RTCPSenderSet var rtcpSenders *rtcpsenderset.RTCPSenderSet
var videoTrackID int var videoTrackID int
var audioTrackID int var audioTrackID int
var videoEnc *rtph264.Encoder
var audioEnc *rtpaac.Encoder
defer func() { defer func() {
if stream != nil { if stream != nil {
@@ -105,11 +108,13 @@ func (s *hlsSource) runInner() bool {
if videoTrack != nil { if videoTrack != nil {
videoTrackID = len(tracks) videoTrackID = len(tracks)
videoEnc = rtph264.NewEncoder(96, nil, nil, nil)
tracks = append(tracks, videoTrack) tracks = append(tracks, videoTrack)
} }
if audioTrack != nil { if audioTrack != nil {
audioTrackID = len(tracks) audioTrackID = len(tracks)
audioEnc = rtpaac.NewEncoder(97, audioTrack.ClockRate(), nil, nil, nil)
tracks = append(tracks, audioTrack) tracks = append(tracks, audioTrack)
} }
@@ -129,17 +134,35 @@ func (s *hlsSource) runInner() bool {
return nil return nil
} }
onPacket := func(isVideo bool, pkt *rtp.Packet) { onVideoData := func(pts time.Duration, nalus [][]byte) {
var trackID int if stream == nil {
if isVideo { return
trackID = videoTrackID
} else {
trackID = audioTrackID
} }
if stream != nil { pkts, err := videoEnc.Encode(nalus, pts)
rtcpSenders.OnPacketRTP(trackID, pkt) if err != nil {
stream.onPacketRTP(trackID, pkt) return
}
for _, pkt := range pkts {
rtcpSenders.OnPacketRTP(videoTrackID, pkt)
stream.onPacketRTP(videoTrackID, pkt)
}
}
onAudioData := func(pts time.Duration, aus [][]byte) {
if stream == nil {
return
}
pkts, err := audioEnc.Encode(aus, pts)
if err != nil {
return
}
for _, pkt := range pkts {
rtcpSenders.OnPacketRTP(audioTrackID, pkt)
stream.onPacketRTP(audioTrackID, pkt)
} }
} }
@@ -147,7 +170,8 @@ func (s *hlsSource) runInner() bool {
s.ur, s.ur,
s.fingerprint, s.fingerprint,
onTracks, onTracks,
onPacket, onVideoData,
onAudioData,
s, s,
) )
if err != nil { if err != nil {

View File

@@ -18,7 +18,6 @@ import (
"github.com/aler9/gortsplib" "github.com/aler9/gortsplib"
"github.com/asticode/go-astits" "github.com/asticode/go-astits"
"github.com/grafov/m3u8" "github.com/grafov/m3u8"
"github.com/pion/rtp"
"github.com/aler9/rtsp-simple-server/internal/logger" "github.com/aler9/rtsp-simple-server/internal/logger"
) )
@@ -58,9 +57,10 @@ type ClientParent interface {
// Client is a HLS client. // Client is a HLS client.
type Client struct { type Client struct {
onTracks func(gortsplib.Track, gortsplib.Track) error onTracks func(gortsplib.Track, gortsplib.Track) error
onPacket func(bool, *rtp.Packet) onVideoData func(time.Duration, [][]byte)
parent ClientParent onAudioData func(time.Duration, [][]byte)
parent ClientParent
ctx context.Context ctx context.Context
ctxCancel func() ctxCancel func()
@@ -95,7 +95,8 @@ func NewClient(
primaryPlaylistURLStr string, primaryPlaylistURLStr string,
fingerprint string, fingerprint string,
onTracks func(gortsplib.Track, gortsplib.Track) error, onTracks func(gortsplib.Track, gortsplib.Track) error,
onPacket func(bool, *rtp.Packet), onVideoData func(time.Duration, [][]byte),
onAudioData func(time.Duration, [][]byte),
parent ClientParent, parent ClientParent,
) (*Client, error) { ) (*Client, error) {
primaryPlaylistURL, err := url.Parse(primaryPlaylistURLStr) primaryPlaylistURL, err := url.Parse(primaryPlaylistURLStr)
@@ -128,7 +129,8 @@ func NewClient(
c := &Client{ c := &Client{
onTracks: onTracks, onTracks: onTracks,
onPacket: onPacket, onVideoData: onVideoData,
onAudioData: onAudioData,
parent: parent, parent: parent,
ctx: ctx, ctx: ctx,
ctxCancel: ctxCancel, ctxCancel: ctxCancel,
@@ -181,7 +183,11 @@ func (c *Client) runInner() error {
c.videoProc = newClientVideoProcessor( c.videoProc = newClientVideoProcessor(
innerCtx, innerCtx,
c.onVideoTrack, c.onVideoTrack,
c.onVideoPacket) func(pts time.Duration, nalus [][]byte) {
c.tracksMutex.RLock()
defer c.tracksMutex.RUnlock()
c.onVideoData(pts, nalus)
})
go func() { errChan <- c.videoProc.run() }() go func() { errChan <- c.videoProc.run() }()
} }
@@ -190,7 +196,11 @@ func (c *Client) runInner() error {
c.audioProc = newClientAudioProcessor( c.audioProc = newClientAudioProcessor(
innerCtx, innerCtx,
c.onAudioTrack, c.onAudioTrack,
c.onAudioPacket) func(pts time.Duration, aus [][]byte) {
c.tracksMutex.RLock()
defer c.tracksMutex.RUnlock()
c.onAudioData(pts, aus)
})
go func() { errChan <- c.audioProc.run() }() go func() { errChan <- c.audioProc.run() }()
} }
@@ -536,7 +546,7 @@ func (c *Client) onVideoTrack(track gortsplib.Track) error {
c.videoTrack = track c.videoTrack = track
if c.audioPID == nil || c.audioTrack != nil { if c.audioPID == nil || c.audioTrack != nil {
return c.initializeEncoders() return c.onTracks(c.videoTrack, c.audioTrack)
} }
return nil return nil
@@ -549,26 +559,8 @@ func (c *Client) onAudioTrack(track gortsplib.Track) error {
c.audioTrack = track c.audioTrack = track
if c.videoPID == nil || c.videoTrack != nil { if c.videoPID == nil || c.videoTrack != nil {
return c.initializeEncoders() return c.onTracks(c.videoTrack, c.audioTrack)
} }
return nil return nil
} }
func (c *Client) initializeEncoders() error {
return c.onTracks(c.videoTrack, c.audioTrack)
}
func (c *Client) onVideoPacket(pkt *rtp.Packet) {
c.tracksMutex.RLock()
defer c.tracksMutex.RUnlock()
c.onPacket(true, pkt)
}
func (c *Client) onAudioPacket(pkt *rtp.Packet) {
c.tracksMutex.RLock()
defer c.tracksMutex.RUnlock()
c.onPacket(false, pkt)
}

View File

@@ -7,8 +7,6 @@ import (
"github.com/aler9/gortsplib" "github.com/aler9/gortsplib"
"github.com/aler9/gortsplib/pkg/aac" "github.com/aler9/gortsplib/pkg/aac"
"github.com/aler9/gortsplib/pkg/rtpaac"
"github.com/pion/rtp"
) )
type clientAudioProcessorData struct { type clientAudioProcessorData struct {
@@ -17,25 +15,25 @@ type clientAudioProcessorData struct {
} }
type clientAudioProcessor struct { type clientAudioProcessor struct {
ctx context.Context ctx context.Context
onTrack func(gortsplib.Track) error onTrack func(gortsplib.Track) error
onPacket func(*rtp.Packet) onData func(time.Duration, [][]byte)
queue chan clientAudioProcessorData trackInitialized bool
encoder *rtpaac.Encoder queue chan clientAudioProcessorData
clockStartRTC time.Time clockStartRTC time.Time
} }
func newClientAudioProcessor( func newClientAudioProcessor(
ctx context.Context, ctx context.Context,
onTrack func(gortsplib.Track) error, onTrack func(gortsplib.Track) error,
onPacket func(*rtp.Packet), onData func(time.Duration, [][]byte),
) *clientAudioProcessor { ) *clientAudioProcessor {
p := &clientAudioProcessor{ p := &clientAudioProcessor{
ctx: ctx, ctx: ctx,
onTrack: onTrack, onTrack: onTrack,
onPacket: onPacket, onData: onData,
queue: make(chan clientAudioProcessorData, clientQueueSize), queue: make(chan clientAudioProcessorData, clientQueueSize),
} }
return p return p
@@ -65,9 +63,7 @@ func (p *clientAudioProcessor) doProcess(
} }
aus := make([][]byte, 0, len(adtsPkts)) aus := make([][]byte, 0, len(adtsPkts))
pktPts := pts pktPts := pts
now := time.Now() now := time.Now()
for _, pkt := range adtsPkts { for _, pkt := range adtsPkts {
@@ -81,14 +77,14 @@ func (p *clientAudioProcessor) doProcess(
} }
} }
if p.encoder == nil { if !p.trackInitialized {
p.trackInitialized = true
track, err := gortsplib.NewTrackAAC(97, pkt.Type, pkt.SampleRate, pkt.ChannelCount, nil) track, err := gortsplib.NewTrackAAC(97, pkt.Type, pkt.SampleRate, pkt.ChannelCount, nil)
if err != nil { if err != nil {
return err return err
} }
p.encoder = rtpaac.NewEncoder(97, track.ClockRate(), nil, nil, nil)
err = p.onTrack(track) err = p.onTrack(track)
if err != nil { if err != nil {
return err return err
@@ -99,15 +95,7 @@ func (p *clientAudioProcessor) doProcess(
pktPts += 1000 * time.Second / time.Duration(pkt.SampleRate) pktPts += 1000 * time.Second / time.Duration(pkt.SampleRate)
} }
pkts, err := p.encoder.Encode(aus, pts) p.onData(pts, aus)
if err != nil {
return fmt.Errorf("error while encoding AAC: %v", err)
}
for _, pkt := range pkts {
p.onPacket(pkt)
}
return nil return nil
} }

View File

@@ -10,12 +10,12 @@ import (
"net/http" "net/http"
"os" "os"
"testing" "testing"
"time"
"github.com/aler9/gortsplib" "github.com/aler9/gortsplib"
"github.com/aler9/gortsplib/pkg/h264" "github.com/aler9/gortsplib/pkg/h264"
"github.com/asticode/go-astits" "github.com/asticode/go-astits"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pion/rtp"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/aler9/rtsp-simple-server/internal/logger" "github.com/aler9/rtsp-simple-server/internal/logger"
@@ -207,11 +207,12 @@ func TestClient(t *testing.T) {
func(gortsplib.Track, gortsplib.Track) error { func(gortsplib.Track, gortsplib.Track) error {
return nil return nil
}, },
func(isVideo bool, pkt *rtp.Packet) { func(pts time.Duration, nalus [][]byte) {
require.Equal(t, true, isVideo) require.Equal(t, [][]byte{{0x05}}, nalus)
require.Equal(t, []byte{0x05}, pkt.Payload)
close(packetRecv) close(packetRecv)
}, },
func(pts time.Duration, aus [][]byte) {
},
testLogger{}, testLogger{},
) )
require.NoError(t, err) require.NoError(t, err)

View File

@@ -7,8 +7,6 @@ import (
"github.com/aler9/gortsplib" "github.com/aler9/gortsplib"
"github.com/aler9/gortsplib/pkg/h264" "github.com/aler9/gortsplib/pkg/h264"
"github.com/aler9/gortsplib/pkg/rtph264"
"github.com/pion/rtp"
) )
type clientVideoProcessorData struct { type clientVideoProcessorData struct {
@@ -18,27 +16,27 @@ type clientVideoProcessorData struct {
} }
type clientVideoProcessor struct { type clientVideoProcessor struct {
ctx context.Context ctx context.Context
onTrack func(gortsplib.Track) error onTrack func(gortsplib.Track) error
onPacket func(*rtp.Packet) onData func(time.Duration, [][]byte)
queue chan clientVideoProcessorData trackInitialized bool
sps []byte queue chan clientVideoProcessorData
pps []byte sps []byte
encoder *rtph264.Encoder pps []byte
clockStartRTC time.Time clockStartRTC time.Time
} }
func newClientVideoProcessor( func newClientVideoProcessor(
ctx context.Context, ctx context.Context,
onTrack func(gortsplib.Track) error, onTrack func(gortsplib.Track) error,
onPacket func(*rtp.Packet), onData func(time.Duration, [][]byte),
) *clientVideoProcessor { ) *clientVideoProcessor {
p := &clientVideoProcessor{ p := &clientVideoProcessor{
ctx: ctx, ctx: ctx,
onTrack: onTrack, onTrack: onTrack,
onPacket: onPacket, onData: onData,
queue: make(chan clientVideoProcessorData, clientQueueSize), queue: make(chan clientVideoProcessorData, clientQueueSize),
} }
return p return p
@@ -87,8 +85,9 @@ func (p *clientVideoProcessor) doProcess(
if p.sps == nil { if p.sps == nil {
p.sps = append([]byte(nil), nalu...) p.sps = append([]byte(nil), nalu...)
if p.encoder == nil && p.pps != nil { if !p.trackInitialized && p.pps != nil {
err := p.initializeEncoder() p.trackInitialized = true
err := p.initializeTrack()
if err != nil { if err != nil {
return err return err
} }
@@ -102,8 +101,9 @@ func (p *clientVideoProcessor) doProcess(
if p.pps == nil { if p.pps == nil {
p.pps = append([]byte(nil), nalu...) p.pps = append([]byte(nil), nalu...)
if p.encoder == nil && p.sps != nil { if !p.trackInitialized && p.sps != nil {
err := p.initializeEncoder() p.trackInitialized = true
err := p.initializeTrack()
if err != nil { if err != nil {
return err return err
} }
@@ -125,19 +125,11 @@ func (p *clientVideoProcessor) doProcess(
return nil return nil
} }
if p.encoder == nil { if !p.trackInitialized {
return nil return nil
} }
pkts, err := p.encoder.Encode(outNALUs, pts) p.onData(pts, outNALUs)
if err != nil {
return fmt.Errorf("error while encoding H264: %v", err)
}
for _, pkt := range pkts {
p.onPacket(pkt)
}
return nil return nil
} }
@@ -148,13 +140,11 @@ func (p *clientVideoProcessor) process(
p.queue <- clientVideoProcessorData{data, pts, dts} p.queue <- clientVideoProcessorData{data, pts, dts}
} }
func (p *clientVideoProcessor) initializeEncoder() error { func (p *clientVideoProcessor) initializeTrack() error {
track, err := gortsplib.NewTrackH264(96, p.sps, p.pps, nil) track, err := gortsplib.NewTrackH264(96, p.sps, p.pps, nil)
if err != nil { if err != nil {
return err return err
} }
p.encoder = rtph264.NewEncoder(96, nil, nil, nil)
return p.onTrack(track) return p.onTrack(track)
} }