diff --git a/pkg/avframe.go b/pkg/avframe.go index e43d2b8..d814c69 100644 --- a/pkg/avframe.go +++ b/pkg/avframe.go @@ -51,7 +51,7 @@ type ( AudioData = gomem.Memory - OBUs AudioData + OBUs = util.ReuseArray[gomem.Memory] AVFrame struct { DataFrame @@ -148,6 +148,13 @@ func (b *BaseSample) GetNalus() *Nalus { return b.Raw.(*Nalus) } +func (b *BaseSample) GetOBUs() *OBUs { + if b.Raw == nil { + b.Raw = &OBUs{} + } + return b.Raw.(*OBUs) +} + func (b *BaseSample) GetAudioData() *AudioData { if b.Raw == nil { b.Raw = &AudioData{} @@ -203,21 +210,21 @@ func (df *DataFrame) Ready() { df.Unlock() } -func (obus *OBUs) ParseAVCC(reader *gomem.MemoryReader) error { +func (b *BaseSample) ParseAV1OBUs(reader *gomem.MemoryReader) error { var obuHeader av1.OBUHeader startLen := reader.Length for reader.Length > 0 { offset := reader.Size - reader.Length - b, err := reader.ReadByte() + b0, err := reader.ReadByte() if err != nil { return err } - err = obuHeader.Unmarshal([]byte{b}) + err = obuHeader.Unmarshal([]byte{b0}) if err != nil { return err } // if log.Trace { - // vt.Trace("obu", zap.Any("type", obuHeader.Type), zap.Bool("iframe", vt.Value.IFrame)) + // vt.Trace("obu", zap.Any("type", obuHeader.Type), zap.Bool("iframe", vt.Value.IFrame)) // } obuSize, _, _ := reader.LEB128Unmarshal() end := reader.Size - reader.Length @@ -227,15 +234,7 @@ func (obus *OBUs) ParseAVCC(reader *gomem.MemoryReader) error { if err != nil { return err } - (*AudioData)(obus).PushOne(obu) + b.GetNalus().GetNextPointer().PushOne(obu) } return nil } - -func (obus *OBUs) Reset() { - ((*gomem.Memory)(obus)).Reset() -} - -func (obus *OBUs) Count() int { - return (*gomem.Memory)(obus).Count() -} diff --git a/pkg/format/raw.go b/pkg/format/raw.go index 536c28c..c60b8e5 100644 --- a/pkg/format/raw.go +++ b/pkg/format/raw.go @@ -129,3 +129,44 @@ func (r *H26xFrame) GetSize() (ret int) { func (h *H26xFrame) String() string { return fmt.Sprintf("H26xFrame{FourCC: %s, Timestamp: %s, CTS: %s}", h.FourCC, h.Timestamp, h.CTS) } + +var _ pkg.IAVFrame = (*AV1Frame)(nil) + +type AV1Frame struct { + pkg.Sample +} + +func (a *AV1Frame) CheckCodecChange() (err error) { + if a.ICodecCtx == nil { + return pkg.ErrUnsupportCodec + } + return nil +} + +func (a *AV1Frame) GetSize() (ret int) { + if obus, ok := a.Raw.(*pkg.OBUs); ok { + for obu := range obus.RangePoint { + ret += obu.Size + } + } + return +} + +func (a *AV1Frame) Demux() error { + a.Raw = &a.Memory + return nil +} + +func (a *AV1Frame) Mux(from *pkg.Sample) (err error) { + a.InitRecycleIndexes(0) + obus := from.Raw.(*pkg.OBUs) + for obu := range obus.RangePoint { + a.Push(obu.Buffers...) + } + a.ICodecCtx = from.GetBase() + return +} + +func (a *AV1Frame) String() string { + return fmt.Sprintf("AV1Frame{FourCC: %s, Timestamp: %s, CTS: %s}", a.FourCC, a.Timestamp, a.CTS) +} diff --git a/plugin/mp4/pkg/video.go b/plugin/mp4/pkg/video.go index a3e2c7e..f4fbbf0 100644 --- a/plugin/mp4/pkg/video.go +++ b/plugin/mp4/pkg/video.go @@ -32,6 +32,10 @@ func (v *VideoFrame) Demux() (err error) { if err := v.ParseAVCC(&reader, int(ctx.RecordInfo.LengthSizeMinusOne)+1); err != nil { return fmt.Errorf("failed to parse H.265 AVCC: %w", err) } + case *codec.AV1Ctx: + if err := v.ParseAV1OBUs(&reader); err != nil { + return fmt.Errorf("failed to parse AV1 OBUs: %w", err) + } default: // 对于其他格式,尝试默认的 AVCC 解析(4字节长度前缀) if err := v.ParseAVCC(&reader, 4); err != nil { @@ -48,17 +52,21 @@ func (v *VideoFrame) Mux(sample *pkg.Sample) (err error) { v.ICodecCtx = sample.GetBase() switch rawData := sample.Raw.(type) { case *pkg.Nalus: - // 根据编解码器类型确定 NALU 长度字段的大小 - var naluSizeLen int = 4 // 默认使用 4 字节 + var naluSizeLen int = 4 switch ctx := sample.ICodecCtx.(type) { + case *codec.AV1Ctx: + for obu := range rawData.RangePoint { + util.PutBE(v.NextN(4), obu.Size) + v.Push(obu.Buffers...) + } + return case *codec.H264Ctx: naluSizeLen = int(ctx.RecordInfo.LengthSizeMinusOne) + 1 case *codec.H265Ctx: naluSizeLen = int(ctx.RecordInfo.LengthSizeMinusOne) + 1 } - // 为每个 NALU 添加长度前缀 for nalu := range rawData.RangePoint { - util.PutBE(v.NextN(naluSizeLen), nalu.Size) // 写入 NALU 长度 + util.PutBE(v.NextN(naluSizeLen), nalu.Size) v.Push(nalu.Buffers...) } } diff --git a/plugin/rtmp/pkg/video.go b/plugin/rtmp/pkg/video.go index b1eee91..f65ebc6 100644 --- a/plugin/rtmp/pkg/video.go +++ b/plugin/rtmp/pkg/video.go @@ -187,12 +187,12 @@ func (avcc *VideoFrame) CheckCodecChange() (err error) { } else { // switch ctx := old.(type) { // case *codec.H264Ctx: - // avcc.filterH264(int(ctx.RecordInfo.LengthSizeMinusOne) + 1) + // avcc.filterH264(int(ctx.RecordInfo.LengthSizeMinusOne) + 1) // case *H265Ctx: - // avcc.filterH265(int(ctx.RecordInfo.LengthSizeMinusOne) + 1) + // avcc.filterH265(int(ctx.RecordInfo.LengthSizeMinusOne) + 1) // } // if avcc.Size <= 5 { - // return old, ErrSkip + // return old, ErrSkip // } } } @@ -208,12 +208,7 @@ func (avcc *VideoFrame) parseH265(ctx *H265Ctx, reader *gomem.MemoryReader) (err } func (avcc *VideoFrame) parseAV1(reader *gomem.MemoryReader) error { - var obus OBUs - if err := obus.ParseAVCC(reader); err != nil { - return err - } - avcc.Raw = &obus - return nil + return avcc.ParseAV1OBUs(reader) } func (avcc *VideoFrame) Demux() error { @@ -298,7 +293,7 @@ func (avcc *VideoFrame) muxOld26x(codecID VideoCodecID, fromBase *Sample) { naluLen := uint32(nalu.Size) binary.BigEndian.PutUint32(naluLenM, naluLen) // if nalu.Size != len(util.ConcatBuffers(nalu.Buffers)) { - // panic("nalu size mismatch") + // panic("nalu size mismatch") // } avcc.Push(nalu.Buffers...) } @@ -306,8 +301,29 @@ func (avcc *VideoFrame) muxOld26x(codecID VideoCodecID, fromBase *Sample) { func (avcc *VideoFrame) Mux(fromBase *Sample) (err error) { switch c := fromBase.GetBase().(type) { - case *AV1Ctx: - panic(c) + case *codec.AV1Ctx: + if avcc.ICodecCtx == nil || avcc.GetBase() != c { + ctx := &AV1Ctx{AV1Ctx: c} + configBytes := make([]byte, 5+len(c.ConfigOBUs)) + configBytes[0] = 0b1001_0000 | byte(PacketTypeSequenceStart) + copy(configBytes[1:], codec.FourCC_AV1[:]) + copy(configBytes[5:], c.ConfigOBUs) + ctx.SequenceFrame.PushOne(configBytes) + ctx.SequenceFrame.BaseSample = &BaseSample{} + avcc.ICodecCtx = ctx + } + obus := fromBase.Raw.(*OBUs) + avcc.InitRecycleIndexes(obus.Count()) + head := avcc.NextN(5) + if fromBase.IDR { + head[0] = 0b1001_0000 | byte(PacketTypeCodedFrames) + } else { + head[0] = 0b1010_0000 | byte(PacketTypeCodedFrames) + } + copy(head[1:], codec.FourCC_AV1[:]) + for obu := range obus.RangePoint { + avcc.Push(obu.Buffers...) + } case *codec.H264Ctx: if avcc.ICodecCtx == nil || avcc.GetBase() != c { ctx := &H264Ctx{H264Ctx: c} diff --git a/plugin/rtp/pkg/video.go b/plugin/rtp/pkg/video.go index 8760af6..e1d3e46 100644 --- a/plugin/rtp/pkg/video.go +++ b/plugin/rtp/pkg/video.go @@ -9,6 +9,8 @@ import ( "time" "unsafe" + "github.com/bluenviron/mediacommon/pkg/bits" + "github.com/bluenviron/mediacommon/pkg/codecs/av1" "github.com/deepch/vdk/codec/h264parser" "github.com/deepch/vdk/codec/h265parser" "github.com/langhuihui/gomem" @@ -37,6 +39,7 @@ type ( } AV1Ctx struct { RTPCtx + seq uint16 *codec.AV1Ctx } VP9Ctx struct { @@ -61,6 +64,10 @@ const ( endBit = 1 << 6 MTUSize = 1460 ReceiveMTU = 1500 + + // AV1 RTP payload descriptor bits (subset used) + av1ZBit = 1 << 7 // start of OBU + av1YBit = 1 << 6 // end of OBU ) func (r *VideoFrame) Recycle() { @@ -79,9 +86,9 @@ func (r *VideoFrame) CheckCodecChange() (err error) { } // 处理时间戳和序列号 pts := r.Packets[0].Timestamp - nalus := r.Raw.(*Nalus) switch ctx := old.(type) { case *H264Ctx: + nalus := r.Raw.(*Nalus) dts := ctx.dtsEst.Feed(pts) r.SetDTS(time.Duration(dts)) r.SetPTS(time.Duration(pts)) @@ -153,10 +160,10 @@ func (r *VideoFrame) CheckCodecChange() (err error) { ctx.seq++ } case *H265Ctx: + nalus := r.Raw.(*Nalus) dts := ctx.dtsEst.Feed(pts) r.SetDTS(time.Duration(dts)) r.SetPTS(time.Duration(pts)) - // 检查 VPS、SPS、PPS 和 IDR 帧 var vps, sps, pps []byte var hasVPSSPSPPS bool for nalu := range nalus.RangePoint { @@ -179,8 +186,6 @@ func (r *VideoFrame) CheckCodecChange() (err error) { r.IDR = true } } - - // 如果发现新的 VPS/SPS/PPS,更新编解码器上下文 if hasVPSSPSPPS = vps != nil && sps != nil && pps != nil; hasVPSSPSPPS && (len(ctx.Record) == 0 || !bytes.Equal(vps, ctx.VPS()) || !bytes.Equal(sps, ctx.SPS()) || !bytes.Equal(pps, ctx.PPS())) { var newCodecData h265parser.CodecData if newCodecData, err = h265parser.NewCodecDataFromVPSAndSPSAndPPS(vps, sps, pps); err != nil { @@ -192,13 +197,11 @@ func (r *VideoFrame) CheckCodecChange() (err error) { CodecData: newCodecData, }, } - // 保持原有的 RTP 参数 if oldCtx, ok := old.(*H265Ctx); ok { newCtx.RTPCtx = oldCtx.RTPCtx } r.ICodecCtx = newCtx } else { - // 如果是 IDR 帧但没有 VPS/SPS/PPS,需要插入 if r.IDR && len(ctx.VPS()) > 0 && len(ctx.SPS()) > 0 && len(ctx.PPS()) > 0 { vpsRTP := rtp.Packet{ Header: rtp.Header{ @@ -233,7 +236,17 @@ func (r *VideoFrame) CheckCodecChange() (err error) { r.Packets = slices.Insert(r.Packets, 0, vpsRTP, spsRTP, ppsRTP) } } - + for p := range r.Packets.RangePoint { + p.SequenceNumber = ctx.seq + ctx.seq++ + } + case *AV1Ctx: + r.SetPTS(time.Duration(pts)) + r.SetDTS(time.Duration(pts)) + // detect keyframe from OBUs + if obus, ok := r.Raw.(*OBUs); ok { + r.IDR = ctx.IsKeyFrame(obus) + } // 更新序列号 for p := range r.Packets.RangePoint { p.SequenceNumber = ctx.seq @@ -243,6 +256,72 @@ func (r *VideoFrame) CheckCodecChange() (err error) { return } +// AV1 helper to detect keyframe (KEY_FRAME or INTRA_ONLY) +func (av1Ctx *AV1Ctx) IsKeyFrame(obus *OBUs) bool { + for o := range obus.RangePoint { + reader := o.NewReader() + if reader.Length < 2 { // need at least header + leb + continue + } + var first byte + if b, err := reader.ReadByte(); err == nil { + first = b + } else { + continue + } + var header av1.OBUHeader + if err := header.Unmarshal([]byte{first}); err != nil { + continue + } + // read leb128 size to move to payload start + _, _, _ = reader.LEB128Unmarshal() + // only inspect frame header or frame obu + // OBU_FRAME_HEADER = 3, OBU_FRAME = 6 + switch header.Type { + case 3, 6: + // try parse a minimal frame header: show_existing_frame (1), frame_type (2) + payload := reader + var pos int + // read show_existing_frame + showExisting, ok := utilReadBits(&payload, &pos, 1) + if !ok { + continue + } + if showExisting == 1 { + return false + } + // attempt to read frame_type (2 bits) + ft, ok := utilReadBits(&payload, &pos, 2) + if !ok { + continue + } + if ft == 0 || ft == 2 { // KEY_FRAME(0) or INTRA_ONLY(2) + return true + } + case av1.OBUTypeSequenceHeader: + // sequence header often precedes keyframes; treat as keyframe + return true + } + } + return false +} + +// utilReadBits reads nbits from MemoryReader, returns value and ok +func utilReadBits(r *gomem.MemoryReader, pos *int, nbits int) (uint64, bool) { + // use mediacommon bits reader on a copy of remaining bytes + data, err := r.ReadBytes(r.Length) + if err != nil { + return 0, false + } + v, err2 := av1ReadBits(data, pos, nbits) + return v, err2 == nil +} + +// av1ReadBits uses mediacommon bits helper +func av1ReadBits(buf []byte, pos *int, nbits int) (uint64, error) { + return bits.ReadBits(buf, pos, nbits) +} + func (h264 *H264Ctx) GetInfo() string { return h264.SDPFmtpLine } @@ -362,6 +441,43 @@ func (r *VideoFrame) Mux(baseFrame *Sample) error { } } lastPacket.Header.Marker = true + case *AV1Ctx: + ctx := &c.RTPCtx + var lastPacket *rtp.Packet + for obu := range baseFrame.Raw.(*OBUs).RangePoint { + reader := obu.NewReader() + payloadCap := MTUSize - 1 + if reader.Length+1 <= MTUSize { + mem := r.NextN(reader.Length + 1) + mem[0] = av1ZBit | av1YBit + reader.Read(mem[1:]) + lastPacket = r.Append(ctx, pts, mem) + continue + } + // fragmented OBU + first := true + for reader.Length > 0 { + chunk := payloadCap + if reader.Length < chunk { + chunk = reader.Length + } + mem := r.NextN(chunk + 1) + head := byte(0) + if first { + head |= av1ZBit + first = false + } + reader.Read(mem[1:]) + if reader.Length == 0 { + head |= av1YBit + } + mem[0] = head + lastPacket = r.Append(ctx, pts, mem) + } + } + if lastPacket != nil { + lastPacket.Header.Marker = true + } } return nil } @@ -471,6 +587,28 @@ func (r *VideoFrame) Demux() (err error) { } } return nil + case *AV1Ctx: + obus := r.GetOBUs() + obus.Reset() + var cur *gomem.Memory + for _, packet := range r.Packets { + if len(packet.Payload) <= 1 { + continue + } + desc := packet.Payload[0] + payload := packet.Payload[1:] + if desc&av1ZBit != 0 { + // start of OBU + cur = obus.GetNextPointer() + } + if cur != nil { + cur.PushOne(payload) + if desc&av1YBit != 0 { + cur = nil + } + } + } + return nil } return ErrUnsupportCodec }