Files
rtsp-simple-server/internal/protocols/rtmp/reader.go
Alessandro Ros c21c969a8c revert #4583 (#4606)
This reverts commit 500d18b6c6.
2025-06-03 20:27:53 +02:00

804 lines
19 KiB
Go

package rtmp
import (
"errors"
"fmt"
"sort"
"time"
"github.com/abema/go-mp4"
"github.com/bluenviron/gortsplib/v4/pkg/format"
"github.com/bluenviron/mediacommon/v2/pkg/codecs/ac3"
"github.com/bluenviron/mediacommon/v2/pkg/codecs/av1"
"github.com/bluenviron/mediacommon/v2/pkg/codecs/h264"
"github.com/bluenviron/mediacommon/v2/pkg/codecs/h265"
"github.com/bluenviron/mediacommon/v2/pkg/codecs/mpeg4audio"
"github.com/bluenviron/mediamtx/internal/protocols/rtmp/h264conf"
"github.com/bluenviron/mediamtx/internal/protocols/rtmp/message"
)
const (
analyzePeriod = 2 * time.Second
)
// OnDataAV1Func is the prototype of the callback passed to OnDataAV1().
type OnDataAV1Func func(pts time.Duration, tu [][]byte)
// OnDataVP9Func is the prototype of the callback passed to OnDataVP9().
type OnDataVP9Func func(pts time.Duration, frame []byte)
// OnDataH26xFunc is the prototype of the callback passed to OnDataH26x().
type OnDataH26xFunc func(pts time.Duration, au [][]byte)
// OnDataOpusFunc is the prototype of the callback passed to OnDataOpus().
type OnDataOpusFunc func(pts time.Duration, packet []byte)
// OnDataMPEG4AudioFunc is the prototype of the callback passed to OnDataMPEG4Audio().
type OnDataMPEG4AudioFunc func(pts time.Duration, au []byte)
// OnDataMPEG1AudioFunc is the prototype of the callback passed to OnDataMPEG1Audio().
type OnDataMPEG1AudioFunc func(pts time.Duration, frame []byte)
// OnDataAC3Func is the prototype of the callback passed to OnDataAC3().
type OnDataAC3Func func(pts time.Duration, frame []byte)
// OnDataG711Func is the prototype of the callback passed to OnDataG711().
type OnDataG711Func func(pts time.Duration, samples []byte)
// OnDataLPCMFunc is the prototype of the callback passed to OnDataLPCM().
type OnDataLPCMFunc func(pts time.Duration, samples []byte)
func h265FindNALU(array []mp4.HEVCNaluArray, typ h265.NALUType) []byte {
for _, entry := range array {
if entry.NaluType == byte(typ) && entry.NumNalus == 1 &&
h265.NALUType((entry.Nalus[0].NALUnit[0]>>1)&0b111111) == typ {
return entry.Nalus[0].NALUnit
}
}
return nil
}
func h264TrackFromConfig(data []byte) (*format.H264, error) {
var conf h264conf.Conf
err := conf.Unmarshal(data)
if err != nil {
return nil, fmt.Errorf("unable to parse H264 config: %w", err)
}
return &format.H264{
PayloadTyp: 96,
SPS: conf.SPS,
PPS: conf.PPS,
PacketizationMode: 1,
}, nil
}
func mpeg4AudioTrackFromConfig(data []byte) (*format.MPEG4Audio, error) {
var mpegConf mpeg4audio.Config
err := mpegConf.Unmarshal(data)
if err != nil {
return nil, err
}
return &format.MPEG4Audio{
PayloadTyp: 96,
Config: &mpegConf,
SizeLength: 13,
IndexLength: 3,
IndexDeltaLength: 3,
}, nil
}
func audioTrackFromData(msg *message.Audio) (format.Format, error) {
switch msg.Codec {
case message.CodecMPEG1Audio:
return &format.MPEG1Audio{}, nil
case message.CodecPCMA:
return &format.G711{
PayloadTyp: 8,
MULaw: false,
SampleRate: 8000,
ChannelCount: func() int {
if msg.IsStereo {
return 2
}
return 1
}(),
}, nil
case message.CodecPCMU:
return &format.G711{
PayloadTyp: 0,
MULaw: true,
SampleRate: 8000,
ChannelCount: func() int {
if msg.IsStereo {
return 2
}
return 1
}(),
}, nil
case message.CodecLPCM:
return &format.LPCM{
PayloadTyp: 96,
BitDepth: func() int {
if msg.Depth == message.Depth16 {
return 16
}
return 8
}(),
SampleRate: audioRateRTMPToInt(msg.Rate),
ChannelCount: func() int {
if msg.IsStereo {
return 2
}
return 1
}(),
}, nil
default:
panic("should not happen")
}
}
func videoTrackFromSequenceStart(msg *message.VideoExSequenceStart) (format.Format, error) {
switch msg.FourCC {
case message.FourCCAV1:
// parse sequence header and metadata contained in ConfigOBUs, but do not use them
var tu av1.Bitstream
err := tu.Unmarshal(msg.AV1Header.ConfigOBUs)
if err != nil {
return nil, fmt.Errorf("invalid AV1 configuration: %w", err)
}
return &format.AV1{
PayloadTyp: 96,
}, nil
case message.FourCCVP9:
return &format.VP9{
PayloadTyp: 96,
}, nil
case message.FourCCHEVC:
vps := h265FindNALU(msg.HEVCHeader.NaluArrays, h265.NALUType_VPS_NUT)
sps := h265FindNALU(msg.HEVCHeader.NaluArrays, h265.NALUType_SPS_NUT)
pps := h265FindNALU(msg.HEVCHeader.NaluArrays, h265.NALUType_PPS_NUT)
if vps == nil || sps == nil || pps == nil {
return nil, fmt.Errorf("H265 parameters are missing")
}
return &format.H265{
PayloadTyp: 96,
VPS: vps,
SPS: sps,
PPS: pps,
}, nil
case message.FourCCAVC:
if len(msg.AVCHeader.SequenceParameterSets) != 1 || len(msg.AVCHeader.PictureParameterSets) != 1 {
return nil, fmt.Errorf("H264 parameters are missing")
}
return &format.H264{
PayloadTyp: 96,
SPS: msg.AVCHeader.SequenceParameterSets[0].NALUnit,
PPS: msg.AVCHeader.PictureParameterSets[0].NALUnit,
PacketizationMode: 1,
}, nil
default:
panic("should not happen")
}
}
func audioTrackFromExtendedMessages(
sequenceStart *message.AudioExSequenceStart,
frames *message.AudioExCodedFrames,
) (format.Format, error) {
if frames.FourCC != message.FourCCMP3 {
if sequenceStart == nil {
return nil, fmt.Errorf("sequence start not received")
}
if sequenceStart.FourCC != frames.FourCC {
return nil, fmt.Errorf("AudioExSequenceStart FourCC and AudioExCodedFrames are different")
}
}
switch frames.FourCC {
case message.FourCCOpus:
if len(frames.Payload) < 1 {
return nil, fmt.Errorf("invalid Opus frame")
}
return &format.Opus{
PayloadTyp: 96,
ChannelCount: int(sequenceStart.OpusHeader.ChannelCount),
}, nil
case message.FourCCAC3:
if len(frames.Payload) < 6 {
return nil, fmt.Errorf("invalid AC-3 frame")
}
var syncInfo ac3.SyncInfo
err := syncInfo.Unmarshal(frames.Payload)
if err != nil {
return nil, fmt.Errorf("invalid AC-3 frame: %w", err)
}
var bsi ac3.BSI
err = bsi.Unmarshal(frames.Payload[5:])
if err != nil {
return nil, fmt.Errorf("invalid AC-3 frame: %w", err)
}
return &format.AC3{
PayloadTyp: 96,
SampleRate: syncInfo.SampleRate(),
ChannelCount: bsi.ChannelCount(),
}, nil
case message.FourCCMP4A:
return &format.MPEG4Audio{
PayloadTyp: 96,
Config: sequenceStart.AACHeader,
SizeLength: 13,
IndexLength: 3,
IndexDeltaLength: 3,
}, nil
case message.FourCCMP3:
return &format.MPEG1Audio{}, nil
default:
panic("should not happen")
}
}
func sortedKeys(m map[uint8]format.Format) []int {
ret := make([]int, len(m))
i := 0
for k := range m {
ret[i] = int(k)
i++
}
sort.Ints(ret)
return ret
}
// Reader provides functions to read incoming data.
type Reader struct {
Conn Conn
videoTracks map[uint8]format.Format
audioTracks map[uint8]format.Format
onVideoData map[uint8]func(message.Message) error
onAudioData map[uint8]func(message.Message) error
}
// Initialize initializes Reader.
func (r *Reader) Initialize() error {
var err error
r.videoTracks, r.audioTracks, err = r.readTracks()
if err != nil {
return err
}
r.onVideoData = make(map[uint8]func(message.Message) error)
r.onAudioData = make(map[uint8]func(message.Message) error)
return nil
}
func (r *Reader) readTracks() (map[uint8]format.Format, map[uint8]format.Format, error) {
firstReceived := false
var startTime time.Duration
var curTime time.Duration
videoTracks := make(map[uint8]format.Format)
audioTracks := make(map[uint8]format.Format)
handleVideoSequenceStart := func(trackID uint8, msg *message.VideoExSequenceStart) error {
if videoTracks[trackID] != nil {
return fmt.Errorf("video track %d already setupped", trackID)
}
var err error
videoTracks[trackID], err = videoTrackFromSequenceStart(msg)
if err != nil {
return err
}
return nil
}
handleVideoExCodedFrames := func(_ uint8, msg *message.VideoExCodedFrames) error {
if !firstReceived {
firstReceived = true
startTime = msg.DTS
}
curTime = msg.DTS
return nil
}
handleVideoExFramesX := func(_ uint8, msg *message.VideoExFramesX) error {
if !firstReceived {
firstReceived = true
startTime = msg.DTS
}
curTime = msg.DTS
return nil
}
audioSequenceStarts := make(map[uint8]*message.AudioExSequenceStart)
handleAudioSequenceStart := func(trackID uint8, msg *message.AudioExSequenceStart) error {
if audioSequenceStarts[trackID] != nil {
return fmt.Errorf("audio track %d already setupped", trackID)
}
audioSequenceStarts[trackID] = msg
return nil
}
handleAudioCodedFrames := func(trackID uint8, msg *message.AudioExCodedFrames) error {
if !firstReceived {
firstReceived = true
startTime = msg.DTS
}
curTime = msg.DTS
if audioTracks[trackID] != nil {
return nil
}
var err error
audioTracks[trackID], err = audioTrackFromExtendedMessages(audioSequenceStarts[trackID], msg)
if err != nil {
return err
}
return nil
}
for {
msg, err := r.Conn.Read()
if err != nil {
return nil, nil, err
}
switch msg := msg.(type) {
case *message.Video:
if !firstReceived {
firstReceived = true
startTime = msg.DTS
}
curTime = msg.DTS
if msg.Type == message.VideoTypeConfig && videoTracks[0] == nil {
videoTracks[0], err = h264TrackFromConfig(msg.Payload)
if err != nil {
return nil, nil, err
}
}
case *message.VideoExSequenceStart:
err = handleVideoSequenceStart(0, msg)
if err != nil {
return nil, nil, err
}
case *message.VideoExCodedFrames:
err = handleVideoExCodedFrames(0, msg)
if err != nil {
return nil, nil, err
}
case *message.VideoExFramesX:
err = handleVideoExFramesX(0, msg)
if err != nil {
return nil, nil, err
}
case *message.VideoExMultitrack:
if _, ok := videoTracks[msg.TrackID]; !ok {
videoTracks[msg.TrackID] = nil
}
switch wmsg := msg.Wrapped.(type) {
case *message.VideoExSequenceStart:
err = handleVideoSequenceStart(msg.TrackID, wmsg)
if err != nil {
return nil, nil, err
}
case *message.VideoExCodedFrames:
err = handleVideoExCodedFrames(msg.TrackID, wmsg)
if err != nil {
return nil, nil, err
}
case *message.VideoExFramesX:
err = handleVideoExFramesX(msg.TrackID, wmsg)
if err != nil {
return nil, nil, err
}
}
case *message.Audio:
if !firstReceived {
firstReceived = true
startTime = msg.DTS
}
curTime = msg.DTS
if audioTracks[0] == nil && len(msg.Payload) != 0 {
if msg.Codec == message.CodecMPEG4Audio {
if msg.AACType == message.AudioAACTypeConfig {
audioTracks[0], err = mpeg4AudioTrackFromConfig(msg.Payload)
if err != nil {
return nil, nil, err
}
}
} else {
audioTracks[0], err = audioTrackFromData(msg)
if err != nil {
return nil, nil, err
}
}
}
case *message.AudioExSequenceStart:
err := handleAudioSequenceStart(0, msg)
if err != nil {
return nil, nil, err
}
case *message.AudioExCodedFrames:
err := handleAudioCodedFrames(0, msg)
if err != nil {
return nil, nil, err
}
case *message.AudioExMultitrack:
if _, ok := audioTracks[msg.TrackID]; !ok {
audioTracks[msg.TrackID] = nil
}
switch wmsg := msg.Wrapped.(type) {
case *message.AudioExSequenceStart:
err := handleAudioSequenceStart(msg.TrackID, wmsg)
if err != nil {
return nil, nil, err
}
case *message.AudioExCodedFrames:
err := handleAudioCodedFrames(msg.TrackID, wmsg)
if err != nil {
return nil, nil, err
}
}
}
if (curTime - startTime) >= analyzePeriod {
break
}
}
if len(videoTracks) == 0 && len(audioTracks) == 0 {
return nil, nil, fmt.Errorf("no tracks found")
}
return videoTracks, audioTracks, nil
}
// Tracks returns detected tracks
func (r *Reader) Tracks() []format.Format {
ret := make([]format.Format, len(r.videoTracks)+len(r.audioTracks))
i := 0
for _, k := range sortedKeys(r.videoTracks) {
ret[i] = r.videoTracks[uint8(k)]
i++
}
for _, k := range sortedKeys(r.audioTracks) {
ret[i] = r.audioTracks[uint8(k)]
i++
}
return ret
}
func (r *Reader) videoTrackID(t format.Format) uint8 {
for id, track := range r.videoTracks {
if track == t {
return id
}
}
return 255
}
func (r *Reader) audioTrackID(t format.Format) uint8 {
for id, track := range r.audioTracks {
if track == t {
return id
}
}
return 255
}
// OnDataAV1 sets a callback that is called when AV1 data is received.
func (r *Reader) OnDataAV1(track *format.AV1, cb OnDataAV1Func) {
r.onVideoData[r.videoTrackID(track)] = func(msg message.Message) error {
switch msg := msg.(type) {
case *message.VideoExFramesX:
var tu av1.Bitstream
err := tu.Unmarshal(msg.Payload)
if err != nil {
return fmt.Errorf("unable to decode bitstream: %w", err)
}
cb(msg.DTS, tu)
case *message.VideoExCodedFrames:
var tu av1.Bitstream
err := tu.Unmarshal(msg.Payload)
if err != nil {
return fmt.Errorf("unable to decode bitstream: %w", err)
}
cb(msg.DTS+msg.PTSDelta, tu)
}
return nil
}
}
// OnDataVP9 sets a callback that is called when VP9 data is received.
func (r *Reader) OnDataVP9(track *format.VP9, cb OnDataVP9Func) {
r.onVideoData[r.videoTrackID(track)] = func(msg message.Message) error {
switch msg := msg.(type) {
case *message.VideoExFramesX:
cb(msg.DTS, msg.Payload)
case *message.VideoExCodedFrames:
cb(msg.DTS+msg.PTSDelta, msg.Payload)
}
return nil
}
}
// OnDataH265 sets a callback that is called when H265 data is received.
func (r *Reader) OnDataH265(track *format.H265, cb OnDataH26xFunc) {
r.onVideoData[r.videoTrackID(track)] = func(msg message.Message) error {
switch msg := msg.(type) {
case *message.VideoExFramesX:
var au h264.AVCC
err := au.Unmarshal(msg.Payload)
if err != nil {
if errors.Is(err, h264.ErrAVCCNoNALUs) {
return nil
}
return fmt.Errorf("unable to decode AVCC: %w", err)
}
cb(msg.DTS, au)
case *message.VideoExCodedFrames:
var au h264.AVCC
err := au.Unmarshal(msg.Payload)
if err != nil {
if errors.Is(err, h264.ErrAVCCNoNALUs) {
return nil
}
return fmt.Errorf("unable to decode AVCC: %w", err)
}
cb(msg.DTS+msg.PTSDelta, au)
}
return nil
}
}
// OnDataH264 sets a callback that is called when H264 data is received.
func (r *Reader) OnDataH264(track *format.H264, cb OnDataH26xFunc) {
r.onVideoData[r.videoTrackID(track)] = func(msg message.Message) error {
switch msg := msg.(type) {
case *message.Video:
switch msg.Type {
case message.VideoTypeConfig:
var conf h264conf.Conf
err := conf.Unmarshal(msg.Payload)
if err != nil {
return fmt.Errorf("unable to parse H264 config: %w", err)
}
au := [][]byte{
conf.SPS,
conf.PPS,
}
cb(msg.DTS+msg.PTSDelta, au)
case message.VideoTypeAU:
var au h264.AVCC
err := au.Unmarshal(msg.Payload)
if err != nil {
if errors.Is(err, h264.ErrAVCCNoNALUs) {
return nil
}
return fmt.Errorf("unable to decode AVCC: %w", err)
}
cb(msg.DTS+msg.PTSDelta, au)
}
return nil
case *message.VideoExFramesX:
var au h264.AVCC
err := au.Unmarshal(msg.Payload)
if err != nil {
if errors.Is(err, h264.ErrAVCCNoNALUs) {
return nil
}
return fmt.Errorf("unable to decode AVCC: %w", err)
}
cb(msg.DTS, au)
case *message.VideoExCodedFrames:
var au h264.AVCC
err := au.Unmarshal(msg.Payload)
if err != nil {
if errors.Is(err, h264.ErrAVCCNoNALUs) {
return nil
}
return fmt.Errorf("unable to decode AVCC: %w", err)
}
cb(msg.DTS+msg.PTSDelta, au)
}
return nil
}
}
// OnDataOpus sets a callback that is called when Opus data is received.
func (r *Reader) OnDataOpus(track *format.Opus, cb OnDataOpusFunc) {
r.onAudioData[r.audioTrackID(track)] = func(msg message.Message) error {
if msg, ok := msg.(*message.AudioExCodedFrames); ok {
cb(msg.DTS, msg.Payload)
}
return nil
}
}
// OnDataMPEG4Audio sets a callback that is called when MPEG-4 Audio data is received.
func (r *Reader) OnDataMPEG4Audio(track *format.MPEG4Audio, cb OnDataMPEG4AudioFunc) {
r.onAudioData[r.audioTrackID(track)] = func(msg message.Message) error {
switch msg := msg.(type) {
case *message.Audio:
if msg.AACType == message.AudioAACTypeAU {
cb(msg.DTS, msg.Payload)
}
case *message.AudioExCodedFrames:
cb(msg.DTS, msg.Payload)
}
return nil
}
}
// OnDataMPEG1Audio sets a callback that is called when MPEG-1 Audio data is received.
func (r *Reader) OnDataMPEG1Audio(track *format.MPEG1Audio, cb OnDataMPEG1AudioFunc) {
r.onAudioData[r.audioTrackID(track)] = func(msg message.Message) error {
switch msg := msg.(type) {
case *message.Audio:
cb(msg.DTS, msg.Payload)
case *message.AudioExCodedFrames:
cb(msg.DTS, msg.Payload)
}
return nil
}
}
// OnDataAC3 sets a callback that is called when AC-3 data is received.
func (r *Reader) OnDataAC3(track *format.AC3, cb OnDataAC3Func) {
r.onAudioData[r.audioTrackID(track)] = func(msg message.Message) error {
if msg, ok := msg.(*message.AudioExCodedFrames); ok {
cb(msg.DTS, msg.Payload)
}
return nil
}
}
// OnDataG711 sets a callback that is called when G711 data is received.
func (r *Reader) OnDataG711(track *format.G711, cb OnDataG711Func) {
r.onAudioData[r.audioTrackID(track)] = func(msg message.Message) error {
if msg, ok := msg.(*message.Audio); ok {
cb(msg.DTS, msg.Payload)
}
return nil
}
}
// OnDataLPCM sets a callback that is called when LPCM data is received.
func (r *Reader) OnDataLPCM(track *format.LPCM, cb OnDataLPCMFunc) {
bitDepth := track.BitDepth
if bitDepth == 16 {
r.onAudioData[r.audioTrackID(track)] = func(msg message.Message) error {
if msg, ok := msg.(*message.Audio); ok {
le := len(msg.Payload)
if le%2 != 0 {
return fmt.Errorf("invalid payload length: %d", le)
}
// convert from little endian to big endian
for i := 0; i < le; i += 2 {
msg.Payload[i], msg.Payload[i+1] = msg.Payload[i+1], msg.Payload[i]
}
cb(msg.DTS, msg.Payload)
}
return nil
}
} else {
r.onAudioData[r.audioTrackID(track)] = func(msg message.Message) error {
if msg, ok := msg.(*message.Audio); ok {
cb(msg.DTS, msg.Payload)
}
return nil
}
}
}
// Read reads data.
func (r *Reader) Read() error {
msg, err := r.Conn.Read()
if err != nil {
return err
}
switch msg := msg.(type) {
case *message.Video, *message.VideoExCodedFrames, *message.VideoExFramesX:
if r.videoTracks[0] == nil {
return fmt.Errorf("received a packet for video track 0, but track is not set up")
}
return r.onVideoData[0](msg)
case *message.Audio, *message.AudioExCodedFrames:
if r.audioTracks[0] == nil {
return fmt.Errorf("received a packet for audio track 0, but track is not set up")
}
return r.onAudioData[0](msg)
case *message.VideoExMultitrack:
switch wmsg := msg.Wrapped.(type) {
case *message.VideoExCodedFrames, *message.VideoExFramesX:
if r.videoTracks[msg.TrackID] == nil {
return fmt.Errorf("received a packet for video track %d, but track is not set up", msg.TrackID)
}
return r.onVideoData[msg.TrackID](wmsg)
}
case *message.AudioExMultitrack:
if wmsg, ok := msg.Wrapped.(*message.AudioExCodedFrames); ok {
if r.audioTracks[msg.TrackID] == nil {
return fmt.Errorf("received a packet for audio track %d, but track is not set up", msg.TrackID)
}
return r.onAudioData[msg.TrackID](wmsg)
}
}
return nil
}