mirror of
				https://github.com/aler9/gortsplib
				synced 2025-10-31 02:26:57 +08:00 
			
		
		
		
	improve AV1 decoder efficiency (#744)
This commit is contained in:
		| @@ -6,7 +6,6 @@ import ( | |||||||
|  |  | ||||||
| 	"github.com/bluenviron/mediacommon/v2/pkg/codecs/av1" | 	"github.com/bluenviron/mediacommon/v2/pkg/codecs/av1" | ||||||
| 	"github.com/pion/rtp" | 	"github.com/pion/rtp" | ||||||
| 	"github.com/pion/rtp/codecs" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // ErrMorePacketsNeeded is returned when more packets are needed. | // ErrMorePacketsNeeded is returned when more packets are needed. | ||||||
| @@ -61,20 +60,35 @@ func (d *Decoder) resetFragments() { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (d *Decoder) decodeOBUs(pkt *rtp.Packet) ([][]byte, error) { | func (d *Decoder) decodeOBUs(pkt *rtp.Packet) ([][]byte, error) { | ||||||
| 	var av1header codecs.AV1Packet | 	if len(pkt.Payload) < 2 { | ||||||
| 	_, err := av1header.Unmarshal(pkt.Payload) | 		return nil, fmt.Errorf("invalid payload size") | ||||||
| 	if err != nil { |  | ||||||
| 		d.resetFragments() |  | ||||||
| 		return nil, fmt.Errorf("invalid header: %w", err) |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	for _, obu := range av1header.OBUElements { | 	z := (pkt.Payload[0] & 0b10000000) != 0 | ||||||
| 		if len(obu) == 0 { | 	y := (pkt.Payload[0] & 0b01000000) != 0 | ||||||
| 			return nil, fmt.Errorf("invalid OBU size") | 	payload := pkt.Payload[1:] | ||||||
|  | 	var obus [][]byte | ||||||
|  |  | ||||||
|  | 	for len(payload) > 0 { | ||||||
|  | 		var size av1.LEB128 | ||||||
|  | 		n, err := size.Unmarshal(payload) | ||||||
|  | 		if err != nil { | ||||||
|  | 			d.resetFragments() | ||||||
|  | 			return nil, err | ||||||
| 		} | 		} | ||||||
|  | 		payload = payload[n:] | ||||||
|  |  | ||||||
|  | 		if size == 0 || len(payload) < int(size) { | ||||||
|  | 			return nil, fmt.Errorf("invalid fragmented OBU (invalid size)") | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		var obu []byte | ||||||
|  | 		obu, payload = payload[:size], payload[size:] | ||||||
|  | 		obus = append(obus, obu) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if av1header.Z { | 	// first OBU is continuation of previous one | ||||||
|  | 	if z { | ||||||
| 		if d.fragmentsSize == 0 { | 		if d.fragmentsSize == 0 { | ||||||
| 			if !d.firstPacketReceived { | 			if !d.firstPacketReceived { | ||||||
| 				return nil, ErrNonStartingPacketAndNoPrevious | 				return nil, ErrNonStartingPacketAndNoPrevious | ||||||
| @@ -83,12 +97,14 @@ func (d *Decoder) decodeOBUs(pkt *rtp.Packet) ([][]byte, error) { | |||||||
| 			return nil, fmt.Errorf("received a subsequent fragment without previous fragments") | 			return nil, fmt.Errorf("received a subsequent fragment without previous fragments") | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		d.firstPacketReceived = true | ||||||
|  |  | ||||||
| 		if pkt.SequenceNumber != d.fragmentNextSeqNum { | 		if pkt.SequenceNumber != d.fragmentNextSeqNum { | ||||||
| 			d.resetFragments() | 			d.resetFragments() | ||||||
| 			return nil, fmt.Errorf("discarding frame since a RTP packet is missing") | 			return nil, fmt.Errorf("discarding frame since a RTP packet is missing") | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		d.fragmentsSize += len(av1header.OBUElements[0]) | 		d.fragmentsSize += len(obus[0]) | ||||||
|  |  | ||||||
| 		if d.fragmentsSize > av1.MaxTemporalUnitSize { | 		if d.fragmentsSize > av1.MaxTemporalUnitSize { | ||||||
| 			errSize := d.fragmentsSize | 			errSize := d.fragmentsSize | ||||||
| @@ -97,38 +113,31 @@ func (d *Decoder) decodeOBUs(pkt *rtp.Packet) ([][]byte, error) { | |||||||
| 				errSize, av1.MaxTemporalUnitSize) | 				errSize, av1.MaxTemporalUnitSize) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		d.fragments = append(d.fragments, av1header.OBUElements[0]) | 		d.fragments = append(d.fragments, obus[0]) | ||||||
| 		av1header.OBUElements = av1header.OBUElements[1:] |  | ||||||
| 		d.fragmentNextSeqNum++ | 		d.fragmentNextSeqNum++ | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	d.firstPacketReceived = true | 		if len(obus) == 1 && y { | ||||||
|  | 			return nil, ErrMorePacketsNeeded | ||||||
| 	var obus [][]byte |  | ||||||
|  |  | ||||||
| 	if len(av1header.OBUElements) > 0 { |  | ||||||
| 		if d.fragmentsSize != 0 { |  | ||||||
| 			obus = append(obus, joinFragments(d.fragments, d.fragmentsSize)) |  | ||||||
| 			d.resetFragments() |  | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		if av1header.Y { | 		obus[0] = joinFragments(d.fragments, d.fragmentsSize) | ||||||
| 			elementCount := len(av1header.OBUElements) |  | ||||||
|  |  | ||||||
| 			d.fragmentsSize = len(av1header.OBUElements[elementCount-1]) |  | ||||||
| 			d.fragments = append(d.fragments, av1header.OBUElements[elementCount-1]) |  | ||||||
| 			av1header.OBUElements = av1header.OBUElements[:elementCount-1] |  | ||||||
| 			d.fragmentNextSeqNum = pkt.SequenceNumber + 1 |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		obus = append(obus, av1header.OBUElements...) |  | ||||||
| 	} else if !av1header.Y { |  | ||||||
| 		obus = append(obus, joinFragments(d.fragments, d.fragmentsSize)) |  | ||||||
| 		d.resetFragments() | 		d.resetFragments() | ||||||
|  | 	} else { | ||||||
|  | 		d.firstPacketReceived = true | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if len(obus) == 0 { | 	// last OBU will continue in next packet | ||||||
| 		return nil, ErrMorePacketsNeeded | 	if y { | ||||||
|  | 		var obu []byte | ||||||
|  | 		obu, obus = obus[len(obus)-1], obus[:len(obus)-1] | ||||||
|  |  | ||||||
|  | 		d.fragmentsSize = len(obu) | ||||||
|  | 		d.fragments = append(d.fragments, obu) | ||||||
|  | 		d.fragmentNextSeqNum = pkt.SequenceNumber + 1 | ||||||
|  |  | ||||||
|  | 		if len(obus) == 0 { | ||||||
|  | 			return nil, ErrMorePacketsNeeded | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return obus, nil | 	return obus, nil | ||||||
|   | |||||||
| @@ -1,5 +0,0 @@ | |||||||
| go test fuzz v1 |  | ||||||
| []byte("\x180") |  | ||||||
| bool(false) |  | ||||||
| []byte("\xd00") |  | ||||||
| bool(false) |  | ||||||
| @@ -2,4 +2,4 @@ go test fuzz v1 | |||||||
| []byte("0\x00") | []byte("0\x00") | ||||||
| bool(true) | bool(true) | ||||||
| []byte("0") | []byte("0") | ||||||
| bool(true) | bool(false) | ||||||
| @@ -1,5 +1,5 @@ | |||||||
| go test fuzz v1 | go test fuzz v1 | ||||||
| []byte("\xd00") | []byte("\xb0\x010") | ||||||
| bool(false) | bool(false) | ||||||
| []byte("0") | []byte("0") | ||||||
| bool(false) | bool(false) | ||||||
							
								
								
									
										5
									
								
								pkg/format/rtpav1/testdata/fuzz/FuzzDecoder/b24758197029d487
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								pkg/format/rtpav1/testdata/fuzz/FuzzDecoder/b24758197029d487
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | go test fuzz v1 | ||||||
|  | []byte("0\xbb") | ||||||
|  | bool(true) | ||||||
|  | []byte("0") | ||||||
|  | bool(false) | ||||||
							
								
								
									
										5
									
								
								pkg/format/rtpav1/testdata/fuzz/FuzzDecoder/bf22798e3d9a8f55
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								pkg/format/rtpav1/testdata/fuzz/FuzzDecoder/bf22798e3d9a8f55
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | go test fuzz v1 | ||||||
|  | []byte("00") | ||||||
|  | bool(true) | ||||||
|  | []byte("0") | ||||||
|  | bool(true) | ||||||
		Reference in New Issue
	
	Block a user
	 Alessandro Ros
					Alessandro Ros