Files
monibuca/plugin/rtp/pkg/video.go
cto-new[bot] de348725b7 feat(codec): Add unified AV1 raw format and protocol mux/demux support (#354)
* cherry-pick 95191a3: AV1 raw format support and protocol mux/demux integration

* feat(rtp/av1): 完善 AV1 RTP 封装分片及关键帧检测

- Implements RTP packetization for AV1 with OBU fragmentation per RFC9304
- Adds accurate detection of AV1 keyframes using OBU inspection
- Updates AV1 RTP demuxing to reconstruct fragmented OBUs
- Ensures keyframe (IDR) flag is set correctly throughout mux/demux pipeline

---------

Co-authored-by: engine-labs-app[bot] <140088366+engine-labs-app[bot]@users.noreply.github.com>
2025-10-21 09:38:00 +08:00

615 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package rtp
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"slices"
"time"
"unsafe"
"github.com/bluenviron/mediacommon/pkg/bits"
"github.com/bluenviron/mediacommon/pkg/codecs/av1"
"github.com/deepch/vdk/codec/h264parser"
"github.com/deepch/vdk/codec/h265parser"
"github.com/langhuihui/gomem"
"github.com/pion/rtp"
"github.com/pion/webrtc/v4"
. "m7s.live/v5/pkg"
"m7s.live/v5/pkg/codec"
"m7s.live/v5/pkg/util"
)
type (
H26xCtx struct {
RTPCtx
seq uint16
dtsEst util.DTSEstimator
}
H264Ctx struct {
H26xCtx
*codec.H264Ctx
}
H265Ctx struct {
H26xCtx
*codec.H265Ctx
DONL bool
}
AV1Ctx struct {
RTPCtx
seq uint16
*codec.AV1Ctx
}
VP9Ctx struct {
RTPCtx
}
VideoFrame struct {
RTPData
}
)
var (
_ IAVFrame = (*VideoFrame)(nil)
_ IVideoCodecCtx = (*H264Ctx)(nil)
_ IVideoCodecCtx = (*H265Ctx)(nil)
_ IVideoCodecCtx = (*AV1Ctx)(nil)
)
const (
H265_NALU_AP = h265parser.NAL_UNIT_UNSPECIFIED_48
H265_NALU_FU = h265parser.NAL_UNIT_UNSPECIFIED_49
startBit = 1 << 7
endBit = 1 << 6
MTUSize = 1460
ReceiveMTU = 1500
// AV1 RTP payload descriptor bits (subset used)
av1ZBit = 1 << 7 // start of OBU
av1YBit = 1 << 6 // end of OBU
)
func (r *VideoFrame) Recycle() {
r.RecyclableMemory.Recycle()
r.Packets.Reset()
}
func (r *VideoFrame) CheckCodecChange() (err error) {
if len(r.Packets) == 0 {
return ErrSkip
}
old := r.ICodecCtx
// 解复用数据
if err = r.Demux(); err != nil {
return
}
// 处理时间戳和序列号
pts := r.Packets[0].Timestamp
switch ctx := old.(type) {
case *H264Ctx:
nalus := r.Raw.(*Nalus)
dts := ctx.dtsEst.Feed(pts)
r.SetDTS(time.Duration(dts))
r.SetPTS(time.Duration(pts))
// 检查 SPS、PPS 和 IDR 帧
var sps, pps []byte
var hasSPSPPS bool
for nalu := range nalus.RangePoint {
nalType := codec.ParseH264NALUType(nalu.Buffers[0][0])
switch nalType {
case h264parser.NALU_SPS:
sps = nalu.ToBytes()
defer nalus.Remove(nalu)
case h264parser.NALU_PPS:
pps = nalu.ToBytes()
defer nalus.Remove(nalu)
case codec.NALU_IDR_Picture:
r.IDR = true
}
}
// 如果发现新的 SPS/PPS更新编解码器上下文
if hasSPSPPS = sps != nil && pps != nil; hasSPSPPS && (len(ctx.Record) == 0 || !bytes.Equal(sps, ctx.SPS()) || !bytes.Equal(pps, ctx.PPS())) {
var newCodecData h264parser.CodecData
if newCodecData, err = h264parser.NewCodecDataFromSPSAndPPS(sps, pps); err != nil {
return
}
newCtx := &H264Ctx{
H26xCtx: ctx.H26xCtx,
H264Ctx: &codec.H264Ctx{
CodecData: newCodecData,
},
}
// 保持原有的 RTP 参数
if oldCtx, ok := old.(*H264Ctx); ok {
newCtx.RTPCtx = oldCtx.RTPCtx
}
r.ICodecCtx = newCtx
} else {
// 如果是 IDR 帧但没有 SPS/PPS需要插入
if r.IDR && len(ctx.SPS()) > 0 && len(ctx.PPS()) > 0 {
spsRTP := rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: ctx.SequenceNumber,
Timestamp: pts,
SSRC: ctx.SSRC,
PayloadType: uint8(ctx.PayloadType),
},
Payload: ctx.SPS(),
}
ppsRTP := rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: ctx.SequenceNumber,
Timestamp: pts,
SSRC: ctx.SSRC,
PayloadType: uint8(ctx.PayloadType),
},
Payload: ctx.PPS(),
}
r.Packets = slices.Insert(r.Packets, 0, spsRTP, ppsRTP)
}
}
// 更新序列号
for p := range r.Packets.RangePoint {
p.SequenceNumber = ctx.seq
ctx.seq++
}
case *H265Ctx:
nalus := r.Raw.(*Nalus)
dts := ctx.dtsEst.Feed(pts)
r.SetDTS(time.Duration(dts))
r.SetPTS(time.Duration(pts))
var vps, sps, pps []byte
var hasVPSSPSPPS bool
for nalu := range nalus.RangePoint {
switch codec.ParseH265NALUType(nalu.Buffers[0][0]) {
case h265parser.NAL_UNIT_VPS:
vps = nalu.ToBytes()
defer nalus.Remove(nalu)
case h265parser.NAL_UNIT_SPS:
sps = nalu.ToBytes()
defer nalus.Remove(nalu)
case h265parser.NAL_UNIT_PPS:
pps = nalu.ToBytes()
defer nalus.Remove(nalu)
case h265parser.NAL_UNIT_CODED_SLICE_BLA_W_LP,
h265parser.NAL_UNIT_CODED_SLICE_BLA_W_RADL,
h265parser.NAL_UNIT_CODED_SLICE_BLA_N_LP,
h265parser.NAL_UNIT_CODED_SLICE_IDR_W_RADL,
h265parser.NAL_UNIT_CODED_SLICE_IDR_N_LP,
h265parser.NAL_UNIT_CODED_SLICE_CRA:
r.IDR = true
}
}
if hasVPSSPSPPS = vps != nil && sps != nil && pps != nil; hasVPSSPSPPS && (len(ctx.Record) == 0 || !bytes.Equal(vps, ctx.VPS()) || !bytes.Equal(sps, ctx.SPS()) || !bytes.Equal(pps, ctx.PPS())) {
var newCodecData h265parser.CodecData
if newCodecData, err = h265parser.NewCodecDataFromVPSAndSPSAndPPS(vps, sps, pps); err != nil {
return
}
newCtx := &H265Ctx{
H26xCtx: ctx.H26xCtx,
H265Ctx: &codec.H265Ctx{
CodecData: newCodecData,
},
}
if oldCtx, ok := old.(*H265Ctx); ok {
newCtx.RTPCtx = oldCtx.RTPCtx
}
r.ICodecCtx = newCtx
} else {
if r.IDR && len(ctx.VPS()) > 0 && len(ctx.SPS()) > 0 && len(ctx.PPS()) > 0 {
vpsRTP := rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: ctx.SequenceNumber,
Timestamp: pts,
SSRC: ctx.SSRC,
PayloadType: uint8(ctx.PayloadType),
},
Payload: ctx.VPS(),
}
spsRTP := rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: ctx.SequenceNumber,
Timestamp: pts,
SSRC: ctx.SSRC,
PayloadType: uint8(ctx.PayloadType),
},
Payload: ctx.SPS(),
}
ppsRTP := rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: ctx.SequenceNumber,
Timestamp: pts,
SSRC: ctx.SSRC,
PayloadType: uint8(ctx.PayloadType),
},
Payload: ctx.PPS(),
}
r.Packets = slices.Insert(r.Packets, 0, vpsRTP, spsRTP, ppsRTP)
}
}
for p := range r.Packets.RangePoint {
p.SequenceNumber = ctx.seq
ctx.seq++
}
case *AV1Ctx:
r.SetPTS(time.Duration(pts))
r.SetDTS(time.Duration(pts))
// detect keyframe from OBUs
if obus, ok := r.Raw.(*OBUs); ok {
r.IDR = ctx.IsKeyFrame(obus)
}
// 更新序列号
for p := range r.Packets.RangePoint {
p.SequenceNumber = ctx.seq
ctx.seq++
}
}
return
}
// AV1 helper to detect keyframe (KEY_FRAME or INTRA_ONLY)
func (av1Ctx *AV1Ctx) IsKeyFrame(obus *OBUs) bool {
for o := range obus.RangePoint {
reader := o.NewReader()
if reader.Length < 2 { // need at least header + leb
continue
}
var first byte
if b, err := reader.ReadByte(); err == nil {
first = b
} else {
continue
}
var header av1.OBUHeader
if err := header.Unmarshal([]byte{first}); err != nil {
continue
}
// read leb128 size to move to payload start
_, _, _ = reader.LEB128Unmarshal()
// only inspect frame header or frame obu
// OBU_FRAME_HEADER = 3, OBU_FRAME = 6
switch header.Type {
case 3, 6:
// try parse a minimal frame header: show_existing_frame (1), frame_type (2)
payload := reader
var pos int
// read show_existing_frame
showExisting, ok := utilReadBits(&payload, &pos, 1)
if !ok {
continue
}
if showExisting == 1 {
return false
}
// attempt to read frame_type (2 bits)
ft, ok := utilReadBits(&payload, &pos, 2)
if !ok {
continue
}
if ft == 0 || ft == 2 { // KEY_FRAME(0) or INTRA_ONLY(2)
return true
}
case av1.OBUTypeSequenceHeader:
// sequence header often precedes keyframes; treat as keyframe
return true
}
}
return false
}
// utilReadBits reads nbits from MemoryReader, returns value and ok
func utilReadBits(r *gomem.MemoryReader, pos *int, nbits int) (uint64, bool) {
// use mediacommon bits reader on a copy of remaining bytes
data, err := r.ReadBytes(r.Length)
if err != nil {
return 0, false
}
v, err2 := av1ReadBits(data, pos, nbits)
return v, err2 == nil
}
// av1ReadBits uses mediacommon bits helper
func av1ReadBits(buf []byte, pos *int, nbits int) (uint64, error) {
return bits.ReadBits(buf, pos, nbits)
}
func (h264 *H264Ctx) GetInfo() string {
return h264.SDPFmtpLine
}
func (h265 *H265Ctx) GetInfo() string {
return h265.SDPFmtpLine
}
func (av1 *AV1Ctx) GetInfo() string {
return av1.SDPFmtpLine
}
func (r *VideoFrame) Mux(baseFrame *Sample) error {
// 获取编解码器上下文
codecCtx := r.ICodecCtx
if baseCtx := baseFrame.GetBase(); codecCtx == nil || codecCtx.GetBase() != baseCtx {
switch base := baseCtx.(type) {
case *codec.H264Ctx:
var ctx H264Ctx
ctx.H264Ctx = base
ctx.PayloadType = 96
ctx.MimeType = webrtc.MimeTypeH264
ctx.ClockRate = 90000
spsInfo := ctx.SPSInfo
ctx.SDPFmtpLine = fmt.Sprintf("sprop-parameter-sets=%s,%s;profile-level-id=%02x%02x%02x;level-asymmetry-allowed=1;packetization-mode=1", base64.StdEncoding.EncodeToString(ctx.SPS()), base64.StdEncoding.EncodeToString(ctx.PPS()), spsInfo.ProfileIdc, spsInfo.ConstraintSetFlag, spsInfo.LevelIdc)
ctx.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
codecCtx = &ctx
case *codec.H265Ctx:
var ctx H265Ctx
ctx.H265Ctx = base
ctx.PayloadType = 98
ctx.MimeType = webrtc.MimeTypeH265
ctx.SDPFmtpLine = fmt.Sprintf("profile-id=1;sprop-sps=%s;sprop-pps=%s;sprop-vps=%s", base64.StdEncoding.EncodeToString(ctx.SPS()), base64.StdEncoding.EncodeToString(ctx.PPS()), base64.StdEncoding.EncodeToString(ctx.VPS()))
ctx.ClockRate = 90000
ctx.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
codecCtx = &ctx
}
r.ICodecCtx = codecCtx
}
// 获取时间戳信息
pts := uint32(baseFrame.GetPTS())
switch c := codecCtx.(type) {
case *H264Ctx:
ctx := &c.RTPCtx
var lastPacket *rtp.Packet
if baseFrame.IDR && len(c.RecordInfo.SPS) > 0 && len(c.RecordInfo.PPS) > 0 {
r.Append(ctx, pts, c.SPS())
r.Append(ctx, pts, c.PPS())
}
for nalu := range baseFrame.Raw.(*Nalus).RangePoint {
if reader := nalu.NewReader(); reader.Length > MTUSize {
payloadLen := MTUSize
if reader.Length+1 < payloadLen {
payloadLen = reader.Length + 1
}
//fu-a
mem := r.NextN(payloadLen)
reader.Read(mem[1:])
fuaHead, naluType := codec.NALU_FUA.Or(mem[1]&0x60), mem[1]&0x1f
mem[0], mem[1] = fuaHead, naluType|startBit
lastPacket = r.Append(ctx, pts, mem)
for payloadLen = MTUSize; reader.Length > 0; lastPacket = r.Append(ctx, pts, mem) {
if reader.Length+2 < payloadLen {
payloadLen = reader.Length + 2
}
mem = r.NextN(payloadLen)
reader.Read(mem[2:])
mem[0], mem[1] = fuaHead, naluType
}
lastPacket.Payload[1] |= endBit
} else {
mem := r.NextN(reader.Length)
reader.Read(mem)
lastPacket = r.Append(ctx, pts, mem)
}
}
lastPacket.Header.Marker = true
case *H265Ctx:
ctx := &c.RTPCtx
var lastPacket *rtp.Packet
if baseFrame.IDR && len(c.RecordInfo.SPS) > 0 && len(c.RecordInfo.PPS) > 0 && len(c.RecordInfo.VPS) > 0 {
r.Append(ctx, pts, c.VPS())
r.Append(ctx, pts, c.SPS())
r.Append(ctx, pts, c.PPS())
}
for nalu := range baseFrame.Raw.(*Nalus).RangePoint {
if reader := nalu.NewReader(); reader.Length > MTUSize {
var b0, b1 byte
_ = reader.ReadByteTo(&b0, &b1)
//fu
naluType := byte(codec.ParseH265NALUType(b0))
b0 = (byte(H265_NALU_FU) << 1) | (b0 & 0b10000001)
payloadLen := MTUSize
if reader.Length+3 < payloadLen {
payloadLen = reader.Length + 3
}
mem := r.NextN(payloadLen)
reader.Read(mem[3:])
mem[0], mem[1], mem[2] = b0, b1, naluType|startBit
lastPacket = r.Append(ctx, pts, mem)
for payloadLen = MTUSize; reader.Length > 0; lastPacket = r.Append(ctx, pts, mem) {
if reader.Length+3 < payloadLen {
payloadLen = reader.Length + 3
}
mem = r.NextN(payloadLen)
reader.Read(mem[3:])
mem[0], mem[1], mem[2] = b0, b1, naluType
}
lastPacket.Payload[2] |= endBit
} else {
mem := r.NextN(reader.Length)
reader.Read(mem)
lastPacket = r.Append(ctx, pts, mem)
}
}
lastPacket.Header.Marker = true
case *AV1Ctx:
ctx := &c.RTPCtx
var lastPacket *rtp.Packet
for obu := range baseFrame.Raw.(*OBUs).RangePoint {
reader := obu.NewReader()
payloadCap := MTUSize - 1
if reader.Length+1 <= MTUSize {
mem := r.NextN(reader.Length + 1)
mem[0] = av1ZBit | av1YBit
reader.Read(mem[1:])
lastPacket = r.Append(ctx, pts, mem)
continue
}
// fragmented OBU
first := true
for reader.Length > 0 {
chunk := payloadCap
if reader.Length < chunk {
chunk = reader.Length
}
mem := r.NextN(chunk + 1)
head := byte(0)
if first {
head |= av1ZBit
first = false
}
reader.Read(mem[1:])
if reader.Length == 0 {
head |= av1YBit
}
mem[0] = head
lastPacket = r.Append(ctx, pts, mem)
}
}
if lastPacket != nil {
lastPacket.Header.Marker = true
}
}
return nil
}
func (r *VideoFrame) Demux() (err error) {
if len(r.Packets) == 0 {
return ErrSkip
}
switch c := r.ICodecCtx.(type) {
case *H264Ctx:
nalus := r.GetNalus()
var nalu *gomem.Memory
var naluType codec.H264NALUType
for packet := range r.Packets.RangePoint {
if len(packet.Payload) < 2 {
continue
}
if packet.Padding {
packet.Padding = false
}
b0 := packet.Payload[0]
if t := codec.ParseH264NALUType(b0); t < 24 {
nalus.GetNextPointer().PushOne(packet.Payload)
} else {
offset := t.Offset()
switch t {
case codec.NALU_STAPA, codec.NALU_STAPB:
if len(packet.Payload) <= offset {
return fmt.Errorf("invalid nalu size %d", len(packet.Payload))
}
for buffer := util.Buffer(packet.Payload[offset:]); buffer.CanRead(); {
if nextSize := int(buffer.ReadUint16()); buffer.Len() >= nextSize {
nalus.GetNextPointer().PushOne(buffer.ReadN(nextSize))
} else {
return fmt.Errorf("invalid nalu size %d", nextSize)
}
}
case codec.NALU_FUA, codec.NALU_FUB:
b1 := packet.Payload[1]
if util.Bit1(b1, 0) {
nalu = nalus.GetNextPointer()
naluType.Parse(b1)
nalu.PushOne([]byte{naluType.Or(b0 & 0x60)})
}
if nalu != nil && nalu.Size > 0 {
nalu.PushOne(packet.Payload[offset:])
if util.Bit1(b1, 1) {
nalu = nil
}
} else {
continue
}
default:
return fmt.Errorf("unsupported nalu type %d", t)
}
}
}
return nil
case *H265Ctx:
nalus := r.GetNalus()
var nalu *gomem.Memory
for _, packet := range r.Packets {
if len(packet.Payload) == 0 {
continue
}
b0 := packet.Payload[0]
if t := codec.ParseH265NALUType(b0); t < H265_NALU_AP {
nalus.GetNextPointer().PushOne(packet.Payload)
} else {
var buffer = util.Buffer(packet.Payload)
switch t {
case H265_NALU_AP:
buffer.ReadUint16()
if c.DONL {
buffer.ReadUint16()
}
for buffer.CanRead() {
nalus.GetNextPointer().PushOne(buffer.ReadN(int(buffer.ReadUint16())))
}
if c.DONL {
buffer.ReadByte()
}
case H265_NALU_FU:
if buffer.Len() < 3 {
return io.ErrShortBuffer
}
first3 := buffer.ReadN(3)
fuHeader := first3[2]
if c.DONL {
buffer.ReadUint16()
}
if naluType := fuHeader & 0b00111111; util.Bit1(fuHeader, 0) {
nalu = nalus.GetNextPointer()
nalu.PushOne([]byte{first3[0]&0b10000001 | (naluType << 1), first3[1]})
}
if nalu != nil && nalu.Size > 0 {
nalu.PushOne(buffer)
if util.Bit1(fuHeader, 1) {
nalu = nil
}
} else {
continue
}
default:
return fmt.Errorf("unsupported nalu type %d", t)
}
}
}
return nil
case *AV1Ctx:
obus := r.GetOBUs()
obus.Reset()
var cur *gomem.Memory
for _, packet := range r.Packets {
if len(packet.Payload) <= 1 {
continue
}
desc := packet.Payload[0]
payload := packet.Payload[1:]
if desc&av1ZBit != 0 {
// start of OBU
cur = obus.GetNextPointer()
}
if cur != nil {
cur.PushOne(payload)
if desc&av1YBit != 0 {
cur = nil
}
}
}
return nil
}
return ErrUnsupportCodec
}