diff --git a/internal/core/hls_source.go b/internal/core/hls_source.go index 5f98bd65..7a8d33cf 100644 --- a/internal/core/hls_source.go +++ b/internal/core/hls_source.go @@ -6,7 +6,8 @@ import ( "time" "github.com/aler9/gortsplib" - "github.com/pion/rtp" + "github.com/aler9/gortsplib/pkg/rtpaac" + "github.com/aler9/gortsplib/pkg/rtph264" "github.com/aler9/rtsp-simple-server/internal/hls" "github.com/aler9/rtsp-simple-server/internal/logger" @@ -92,6 +93,8 @@ func (s *hlsSource) runInner() bool { var rtcpSenders *rtcpsenderset.RTCPSenderSet var videoTrackID int var audioTrackID int + var videoEnc *rtph264.Encoder + var audioEnc *rtpaac.Encoder defer func() { if stream != nil { @@ -105,11 +108,13 @@ func (s *hlsSource) runInner() bool { if videoTrack != nil { videoTrackID = len(tracks) + videoEnc = rtph264.NewEncoder(96, nil, nil, nil) tracks = append(tracks, videoTrack) } if audioTrack != nil { audioTrackID = len(tracks) + audioEnc = rtpaac.NewEncoder(97, audioTrack.ClockRate(), nil, nil, nil) tracks = append(tracks, audioTrack) } @@ -129,17 +134,35 @@ func (s *hlsSource) runInner() bool { return nil } - onPacket := func(isVideo bool, pkt *rtp.Packet) { - var trackID int - if isVideo { - trackID = videoTrackID - } else { - trackID = audioTrackID + onVideoData := func(pts time.Duration, nalus [][]byte) { + if stream == nil { + return } - if stream != nil { - rtcpSenders.OnPacketRTP(trackID, pkt) - stream.onPacketRTP(trackID, pkt) + pkts, err := videoEnc.Encode(nalus, pts) + if err != nil { + return + } + + for _, pkt := range pkts { + rtcpSenders.OnPacketRTP(videoTrackID, pkt) + stream.onPacketRTP(videoTrackID, pkt) + } + } + + onAudioData := func(pts time.Duration, aus [][]byte) { + if stream == nil { + return + } + + pkts, err := audioEnc.Encode(aus, pts) + if err != nil { + return + } + + for _, pkt := range pkts { + rtcpSenders.OnPacketRTP(audioTrackID, pkt) + stream.onPacketRTP(audioTrackID, pkt) } } @@ -147,7 +170,8 @@ func (s *hlsSource) runInner() bool { s.ur, s.fingerprint, onTracks, - onPacket, + onVideoData, + onAudioData, s, ) if err != nil { diff --git a/internal/hls/client.go b/internal/hls/client.go index 25de2cf4..2156e7e5 100644 --- a/internal/hls/client.go +++ b/internal/hls/client.go @@ -18,7 +18,6 @@ import ( "github.com/aler9/gortsplib" "github.com/asticode/go-astits" "github.com/grafov/m3u8" - "github.com/pion/rtp" "github.com/aler9/rtsp-simple-server/internal/logger" ) @@ -58,9 +57,10 @@ type ClientParent interface { // Client is a HLS client. type Client struct { - onTracks func(gortsplib.Track, gortsplib.Track) error - onPacket func(bool, *rtp.Packet) - parent ClientParent + onTracks func(gortsplib.Track, gortsplib.Track) error + onVideoData func(time.Duration, [][]byte) + onAudioData func(time.Duration, [][]byte) + parent ClientParent ctx context.Context ctxCancel func() @@ -95,7 +95,8 @@ func NewClient( primaryPlaylistURLStr string, fingerprint string, onTracks func(gortsplib.Track, gortsplib.Track) error, - onPacket func(bool, *rtp.Packet), + onVideoData func(time.Duration, [][]byte), + onAudioData func(time.Duration, [][]byte), parent ClientParent, ) (*Client, error) { primaryPlaylistURL, err := url.Parse(primaryPlaylistURLStr) @@ -128,7 +129,8 @@ func NewClient( c := &Client{ onTracks: onTracks, - onPacket: onPacket, + onVideoData: onVideoData, + onAudioData: onAudioData, parent: parent, ctx: ctx, ctxCancel: ctxCancel, @@ -181,7 +183,11 @@ func (c *Client) runInner() error { c.videoProc = newClientVideoProcessor( innerCtx, c.onVideoTrack, - c.onVideoPacket) + func(pts time.Duration, nalus [][]byte) { + c.tracksMutex.RLock() + defer c.tracksMutex.RUnlock() + c.onVideoData(pts, nalus) + }) go func() { errChan <- c.videoProc.run() }() } @@ -190,7 +196,11 @@ func (c *Client) runInner() error { c.audioProc = newClientAudioProcessor( innerCtx, c.onAudioTrack, - c.onAudioPacket) + func(pts time.Duration, aus [][]byte) { + c.tracksMutex.RLock() + defer c.tracksMutex.RUnlock() + c.onAudioData(pts, aus) + }) go func() { errChan <- c.audioProc.run() }() } @@ -536,7 +546,7 @@ func (c *Client) onVideoTrack(track gortsplib.Track) error { c.videoTrack = track if c.audioPID == nil || c.audioTrack != nil { - return c.initializeEncoders() + return c.onTracks(c.videoTrack, c.audioTrack) } return nil @@ -549,26 +559,8 @@ func (c *Client) onAudioTrack(track gortsplib.Track) error { c.audioTrack = track if c.videoPID == nil || c.videoTrack != nil { - return c.initializeEncoders() + return c.onTracks(c.videoTrack, c.audioTrack) } return nil } - -func (c *Client) initializeEncoders() error { - return c.onTracks(c.videoTrack, c.audioTrack) -} - -func (c *Client) onVideoPacket(pkt *rtp.Packet) { - c.tracksMutex.RLock() - defer c.tracksMutex.RUnlock() - - c.onPacket(true, pkt) -} - -func (c *Client) onAudioPacket(pkt *rtp.Packet) { - c.tracksMutex.RLock() - defer c.tracksMutex.RUnlock() - - c.onPacket(false, pkt) -} diff --git a/internal/hls/client_audio_processor.go b/internal/hls/client_audio_processor.go index 28dea1f3..eac6f860 100644 --- a/internal/hls/client_audio_processor.go +++ b/internal/hls/client_audio_processor.go @@ -7,8 +7,6 @@ import ( "github.com/aler9/gortsplib" "github.com/aler9/gortsplib/pkg/aac" - "github.com/aler9/gortsplib/pkg/rtpaac" - "github.com/pion/rtp" ) type clientAudioProcessorData struct { @@ -17,25 +15,25 @@ type clientAudioProcessorData struct { } type clientAudioProcessor struct { - ctx context.Context - onTrack func(gortsplib.Track) error - onPacket func(*rtp.Packet) + ctx context.Context + onTrack func(gortsplib.Track) error + onData func(time.Duration, [][]byte) - queue chan clientAudioProcessorData - encoder *rtpaac.Encoder - clockStartRTC time.Time + trackInitialized bool + queue chan clientAudioProcessorData + clockStartRTC time.Time } func newClientAudioProcessor( ctx context.Context, onTrack func(gortsplib.Track) error, - onPacket func(*rtp.Packet), + onData func(time.Duration, [][]byte), ) *clientAudioProcessor { p := &clientAudioProcessor{ - ctx: ctx, - onTrack: onTrack, - onPacket: onPacket, - queue: make(chan clientAudioProcessorData, clientQueueSize), + ctx: ctx, + onTrack: onTrack, + onData: onData, + queue: make(chan clientAudioProcessorData, clientQueueSize), } return p @@ -65,9 +63,7 @@ func (p *clientAudioProcessor) doProcess( } aus := make([][]byte, 0, len(adtsPkts)) - pktPts := pts - now := time.Now() for _, pkt := range adtsPkts { @@ -81,14 +77,14 @@ func (p *clientAudioProcessor) doProcess( } } - if p.encoder == nil { + if !p.trackInitialized { + p.trackInitialized = true + track, err := gortsplib.NewTrackAAC(97, pkt.Type, pkt.SampleRate, pkt.ChannelCount, nil) if err != nil { return err } - p.encoder = rtpaac.NewEncoder(97, track.ClockRate(), nil, nil, nil) - err = p.onTrack(track) if err != nil { return err @@ -99,15 +95,7 @@ func (p *clientAudioProcessor) doProcess( pktPts += 1000 * time.Second / time.Duration(pkt.SampleRate) } - pkts, err := p.encoder.Encode(aus, pts) - if err != nil { - return fmt.Errorf("error while encoding AAC: %v", err) - } - - for _, pkt := range pkts { - p.onPacket(pkt) - } - + p.onData(pts, aus) return nil } diff --git a/internal/hls/client_test.go b/internal/hls/client_test.go index d22659c3..839af537 100644 --- a/internal/hls/client_test.go +++ b/internal/hls/client_test.go @@ -10,12 +10,12 @@ import ( "net/http" "os" "testing" + "time" "github.com/aler9/gortsplib" "github.com/aler9/gortsplib/pkg/h264" "github.com/asticode/go-astits" "github.com/gin-gonic/gin" - "github.com/pion/rtp" "github.com/stretchr/testify/require" "github.com/aler9/rtsp-simple-server/internal/logger" @@ -207,11 +207,12 @@ func TestClient(t *testing.T) { func(gortsplib.Track, gortsplib.Track) error { return nil }, - func(isVideo bool, pkt *rtp.Packet) { - require.Equal(t, true, isVideo) - require.Equal(t, []byte{0x05}, pkt.Payload) + func(pts time.Duration, nalus [][]byte) { + require.Equal(t, [][]byte{{0x05}}, nalus) close(packetRecv) }, + func(pts time.Duration, aus [][]byte) { + }, testLogger{}, ) require.NoError(t, err) diff --git a/internal/hls/client_video_processor.go b/internal/hls/client_video_processor.go index 360dc2f0..2b3f9ad8 100644 --- a/internal/hls/client_video_processor.go +++ b/internal/hls/client_video_processor.go @@ -7,8 +7,6 @@ import ( "github.com/aler9/gortsplib" "github.com/aler9/gortsplib/pkg/h264" - "github.com/aler9/gortsplib/pkg/rtph264" - "github.com/pion/rtp" ) type clientVideoProcessorData struct { @@ -18,27 +16,27 @@ type clientVideoProcessorData struct { } type clientVideoProcessor struct { - ctx context.Context - onTrack func(gortsplib.Track) error - onPacket func(*rtp.Packet) + ctx context.Context + onTrack func(gortsplib.Track) error + onData func(time.Duration, [][]byte) - queue chan clientVideoProcessorData - sps []byte - pps []byte - encoder *rtph264.Encoder - clockStartRTC time.Time + trackInitialized bool + queue chan clientVideoProcessorData + sps []byte + pps []byte + clockStartRTC time.Time } func newClientVideoProcessor( ctx context.Context, onTrack func(gortsplib.Track) error, - onPacket func(*rtp.Packet), + onData func(time.Duration, [][]byte), ) *clientVideoProcessor { p := &clientVideoProcessor{ - ctx: ctx, - onTrack: onTrack, - onPacket: onPacket, - queue: make(chan clientVideoProcessorData, clientQueueSize), + ctx: ctx, + onTrack: onTrack, + onData: onData, + queue: make(chan clientVideoProcessorData, clientQueueSize), } return p @@ -87,8 +85,9 @@ func (p *clientVideoProcessor) doProcess( if p.sps == nil { p.sps = append([]byte(nil), nalu...) - if p.encoder == nil && p.pps != nil { - err := p.initializeEncoder() + if !p.trackInitialized && p.pps != nil { + p.trackInitialized = true + err := p.initializeTrack() if err != nil { return err } @@ -102,8 +101,9 @@ func (p *clientVideoProcessor) doProcess( if p.pps == nil { p.pps = append([]byte(nil), nalu...) - if p.encoder == nil && p.sps != nil { - err := p.initializeEncoder() + if !p.trackInitialized && p.sps != nil { + p.trackInitialized = true + err := p.initializeTrack() if err != nil { return err } @@ -125,19 +125,11 @@ func (p *clientVideoProcessor) doProcess( return nil } - if p.encoder == nil { + if !p.trackInitialized { return nil } - pkts, err := p.encoder.Encode(outNALUs, pts) - if err != nil { - return fmt.Errorf("error while encoding H264: %v", err) - } - - for _, pkt := range pkts { - p.onPacket(pkt) - } - + p.onData(pts, outNALUs) return nil } @@ -148,13 +140,11 @@ func (p *clientVideoProcessor) process( p.queue <- clientVideoProcessorData{data, pts, dts} } -func (p *clientVideoProcessor) initializeEncoder() error { +func (p *clientVideoProcessor) initializeTrack() error { track, err := gortsplib.NewTrackH264(96, p.sps, p.pps, nil) if err != nil { return err } - p.encoder = rtph264.NewEncoder(96, nil, nil, nil) - return p.onTrack(track) }