improve AV1 decoder efficiency (#744)

This commit is contained in:
Alessandro Ros
2025-03-27 13:41:18 +01:00
committed by GitHub
parent 3414414c02
commit c9b91e629c
6 changed files with 57 additions and 43 deletions

View File

@@ -6,7 +6,6 @@ import (
"github.com/bluenviron/mediacommon/v2/pkg/codecs/av1" "github.com/bluenviron/mediacommon/v2/pkg/codecs/av1"
"github.com/pion/rtp" "github.com/pion/rtp"
"github.com/pion/rtp/codecs"
) )
// ErrMorePacketsNeeded is returned when more packets are needed. // ErrMorePacketsNeeded is returned when more packets are needed.
@@ -61,20 +60,35 @@ func (d *Decoder) resetFragments() {
} }
func (d *Decoder) decodeOBUs(pkt *rtp.Packet) ([][]byte, error) { func (d *Decoder) decodeOBUs(pkt *rtp.Packet) ([][]byte, error) {
var av1header codecs.AV1Packet if len(pkt.Payload) < 2 {
_, err := av1header.Unmarshal(pkt.Payload) return nil, fmt.Errorf("invalid payload size")
if err != nil {
d.resetFragments()
return nil, fmt.Errorf("invalid header: %w", err)
} }
for _, obu := range av1header.OBUElements { z := (pkt.Payload[0] & 0b10000000) != 0
if len(obu) == 0 { y := (pkt.Payload[0] & 0b01000000) != 0
return nil, fmt.Errorf("invalid OBU size") payload := pkt.Payload[1:]
var obus [][]byte
for len(payload) > 0 {
var size av1.LEB128
n, err := size.Unmarshal(payload)
if err != nil {
d.resetFragments()
return nil, err
} }
payload = payload[n:]
if size == 0 || len(payload) < int(size) {
return nil, fmt.Errorf("invalid fragmented OBU (invalid size)")
}
var obu []byte
obu, payload = payload[:size], payload[size:]
obus = append(obus, obu)
} }
if av1header.Z { // first OBU is continuation of previous one
if z {
if d.fragmentsSize == 0 { if d.fragmentsSize == 0 {
if !d.firstPacketReceived { if !d.firstPacketReceived {
return nil, ErrNonStartingPacketAndNoPrevious return nil, ErrNonStartingPacketAndNoPrevious
@@ -83,12 +97,14 @@ func (d *Decoder) decodeOBUs(pkt *rtp.Packet) ([][]byte, error) {
return nil, fmt.Errorf("received a subsequent fragment without previous fragments") return nil, fmt.Errorf("received a subsequent fragment without previous fragments")
} }
d.firstPacketReceived = true
if pkt.SequenceNumber != d.fragmentNextSeqNum { if pkt.SequenceNumber != d.fragmentNextSeqNum {
d.resetFragments() d.resetFragments()
return nil, fmt.Errorf("discarding frame since a RTP packet is missing") return nil, fmt.Errorf("discarding frame since a RTP packet is missing")
} }
d.fragmentsSize += len(av1header.OBUElements[0]) d.fragmentsSize += len(obus[0])
if d.fragmentsSize > av1.MaxTemporalUnitSize { if d.fragmentsSize > av1.MaxTemporalUnitSize {
errSize := d.fragmentsSize errSize := d.fragmentsSize
@@ -97,38 +113,31 @@ func (d *Decoder) decodeOBUs(pkt *rtp.Packet) ([][]byte, error) {
errSize, av1.MaxTemporalUnitSize) errSize, av1.MaxTemporalUnitSize)
} }
d.fragments = append(d.fragments, av1header.OBUElements[0]) d.fragments = append(d.fragments, obus[0])
av1header.OBUElements = av1header.OBUElements[1:]
d.fragmentNextSeqNum++ d.fragmentNextSeqNum++
}
d.firstPacketReceived = true if len(obus) == 1 && y {
return nil, ErrMorePacketsNeeded
var obus [][]byte
if len(av1header.OBUElements) > 0 {
if d.fragmentsSize != 0 {
obus = append(obus, joinFragments(d.fragments, d.fragmentsSize))
d.resetFragments()
} }
if av1header.Y { obus[0] = joinFragments(d.fragments, d.fragmentsSize)
elementCount := len(av1header.OBUElements)
d.fragmentsSize = len(av1header.OBUElements[elementCount-1])
d.fragments = append(d.fragments, av1header.OBUElements[elementCount-1])
av1header.OBUElements = av1header.OBUElements[:elementCount-1]
d.fragmentNextSeqNum = pkt.SequenceNumber + 1
}
obus = append(obus, av1header.OBUElements...)
} else if !av1header.Y {
obus = append(obus, joinFragments(d.fragments, d.fragmentsSize))
d.resetFragments() d.resetFragments()
} else {
d.firstPacketReceived = true
} }
if len(obus) == 0 { // last OBU will continue in next packet
return nil, ErrMorePacketsNeeded if y {
var obu []byte
obu, obus = obus[len(obus)-1], obus[:len(obus)-1]
d.fragmentsSize = len(obu)
d.fragments = append(d.fragments, obu)
d.fragmentNextSeqNum = pkt.SequenceNumber + 1
if len(obus) == 0 {
return nil, ErrMorePacketsNeeded
}
} }
return obus, nil return obus, nil

View File

@@ -1,5 +0,0 @@
go test fuzz v1
[]byte("\x180")
bool(false)
[]byte("\xd00")
bool(false)

View File

@@ -2,4 +2,4 @@ go test fuzz v1
[]byte("0\x00") []byte("0\x00")
bool(true) bool(true)
[]byte("0") []byte("0")
bool(true) bool(false)

View File

@@ -1,5 +1,5 @@
go test fuzz v1 go test fuzz v1
[]byte("\xd00") []byte("\xb0\x010")
bool(false) bool(false)
[]byte("0") []byte("0")
bool(false) bool(false)

View File

@@ -0,0 +1,5 @@
go test fuzz v1
[]byte("0\xbb")
bool(true)
[]byte("0")
bool(false)

View File

@@ -0,0 +1,5 @@
go test fuzz v1
[]byte("00")
bool(true)
[]byte("0")
bool(true)