prevent decoding formats with zero clock rate or channels (#770)

This commit is contained in:
Alessandro Ros
2025-05-01 17:06:55 +02:00
committed by GitHub
parent 18307140ec
commit bfacd10d35
16 changed files with 76 additions and 30 deletions

View File

@@ -1,6 +1,7 @@
package format package format
import ( import (
"fmt"
"strconv" "strconv"
"strings" "strings"
@@ -23,15 +24,15 @@ func (f *AC3) unmarshal(ctx *unmarshalContext) error {
tmp := strings.SplitN(ctx.clock, "/", 2) tmp := strings.SplitN(ctx.clock, "/", 2)
tmp1, err := strconv.ParseUint(tmp[0], 10, 31) tmp1, err := strconv.ParseUint(tmp[0], 10, 31)
if err != nil { if err != nil || tmp1 == 0 {
return err return fmt.Errorf("invalid sample rate: '%s'", tmp[0])
} }
f.SampleRate = int(tmp1) f.SampleRate = int(tmp1)
if len(tmp) >= 2 { if len(tmp) >= 2 {
tmp1, err := strconv.ParseUint(tmp[1], 10, 31) tmp1, err := strconv.ParseUint(tmp[1], 10, 31)
if err != nil { if err != nil || tmp1 == 0 {
return err return fmt.Errorf("invalid channel count: '%s'", tmp[1])
} }
f.ChannelCount = int(tmp1) f.ChannelCount = int(tmp1)
} else { } else {

View File

@@ -1267,19 +1267,45 @@ func FuzzUnmarshal(f *testing.F) {
f.Add(ca.in) f.Add(ca.in)
} }
f.Fuzz(func(_ *testing.T, in string) { f.Fuzz(func(t *testing.T, in string) {
var desc sdp.SessionDescription var desc sdp.SessionDescription
err := desc.Unmarshal([]byte(in)) err := desc.Unmarshal([]byte(in))
if err != nil || len(desc.MediaDescriptions) == 0 || len(desc.MediaDescriptions[0].MediaName.Formats) == 0 {
return
}
if err == nil && len(desc.MediaDescriptions) >= 1 && len(desc.MediaDescriptions[0].MediaName.Formats) >= 1 { f, err := Unmarshal(desc.MediaDescriptions[0], desc.MediaDescriptions[0].MediaName.Formats[0])
f, err := Unmarshal(desc.MediaDescriptions[0], desc.MediaDescriptions[0].MediaName.Formats[0]) if err != nil {
if err == nil { return
f.Codec() }
f.ClockRate()
f.PayloadType() // only Generic can return zero ClockRate
f.RTPMap() if _, ok := f.(*Generic); !ok {
f.FMTP() require.NotZero(t, f.ClockRate())
} } else {
f.ClockRate()
}
f.Codec()
f.PayloadType()
f.RTPMap()
f.FMTP()
switch f := f.(type) {
case *AC3:
require.NotZero(t, f.ChannelCount)
case *G711:
require.NotZero(t, f.ChannelCount)
case *LPCM:
require.NotZero(t, f.ChannelCount)
case *Opus:
require.NotZero(t, f.ChannelCount)
case *Vorbis:
require.NotZero(t, f.ChannelCount)
} }
}) })
} }

View File

@@ -1,6 +1,7 @@
package format package format
import ( import (
"fmt"
"strconv" "strconv"
"strings" "strings"
@@ -40,15 +41,15 @@ func (f *G711) unmarshal(ctx *unmarshalContext) error {
tmp := strings.SplitN(ctx.clock, "/", 2) tmp := strings.SplitN(ctx.clock, "/", 2)
tmp1, err := strconv.ParseUint(tmp[0], 10, 31) tmp1, err := strconv.ParseUint(tmp[0], 10, 31)
if err != nil { if err != nil || tmp1 == 0 {
return err return fmt.Errorf("invalid sample rate: '%s'", tmp[0])
} }
f.SampleRate = int(tmp1) f.SampleRate = int(tmp1)
if len(tmp) >= 2 { if len(tmp) >= 2 {
tmp1, err := strconv.ParseUint(tmp[1], 10, 31) tmp1, err := strconv.ParseUint(tmp[1], 10, 31)
if err != nil { if err != nil || tmp1 == 0 {
return err return fmt.Errorf("invalid channel count: '%s'", tmp[1])
} }
f.ChannelCount = int(tmp1) f.ChannelCount = int(tmp1)
} else { } else {

View File

@@ -1,6 +1,7 @@
package format package format
import ( import (
"fmt"
"strconv" "strconv"
"strings" "strings"
@@ -50,15 +51,15 @@ func (f *LPCM) unmarshal(ctx *unmarshalContext) error {
tmp := strings.SplitN(ctx.clock, "/", 2) tmp := strings.SplitN(ctx.clock, "/", 2)
tmp1, err := strconv.ParseUint(tmp[0], 10, 31) tmp1, err := strconv.ParseUint(tmp[0], 10, 31)
if err != nil { if err != nil || tmp1 == 0 {
return err return fmt.Errorf("invalid sample rate: '%s'", tmp[0])
} }
f.SampleRate = int(tmp1) f.SampleRate = int(tmp1)
if len(tmp) >= 2 { if len(tmp) >= 2 {
tmp1, err := strconv.ParseUint(tmp[1], 10, 31) tmp1, err := strconv.ParseUint(tmp[1], 10, 31)
if err != nil { if err != nil || tmp1 == 0 {
return err return fmt.Errorf("invalid channel count: '%s'", tmp[1])
} }
f.ChannelCount = int(tmp1) f.ChannelCount = int(tmp1)
} else { } else {

View File

@@ -65,10 +65,9 @@ func (f *Opus) unmarshal(ctx *unmarshalContext) error {
} }
channelCount, err := strconv.ParseUint(tmp[1], 10, 31) channelCount, err := strconv.ParseUint(tmp[1], 10, 31)
if err != nil { if err != nil || channelCount == 0 {
return fmt.Errorf("invalid channel count: '%s'", tmp[1]) return fmt.Errorf("invalid channel count: '%s'", tmp[1])
} }
f.ChannelCount = int(channelCount) f.ChannelCount = int(channelCount)
} }

View File

@@ -19,8 +19,8 @@ func (f *Speex) unmarshal(ctx *unmarshalContext) error {
f.PayloadTyp = ctx.payloadType f.PayloadTyp = ctx.payloadType
sampleRate, err := strconv.ParseUint(ctx.clock, 10, 31) sampleRate, err := strconv.ParseUint(ctx.clock, 10, 31)
if err != nil { if err != nil || sampleRate == 0 {
return err return fmt.Errorf("invalid sample rate: '%s'", ctx.clock)
} }
f.SampleRate = int(sampleRate) f.SampleRate = int(sampleRate)

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("m=audio 0 AVP 96\na=rtpmap:96 VORBIS/0/0\na=fmtp:96 ConfigurAtion=")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("m=audio 0 AVP 97\na=rtpmap:97 L8/1/0")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("m=audio 0 AVP 96\na=rtpmap:96 PCMA/1/0")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("m=audio 0 AVP 97\na=rtpmap:97 AC3/0")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("m=audio 0 AVP 96\na=rtpmap:96 AC3/1/0")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("m=audio 0 AVP 96\na=rtpmap:96 multiopus/48000/0")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("m=audio 0 AVP 96\na=rtpmap:96 VORBIS/1/0\na=fmtp:96 ConfigurAtion=")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("m=audio 0 AVP 96\na=rtpmap:96 speeX/0")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("m=audio 0 AVP 97\na=rtpmap:97 L16/0")

View File

@@ -27,14 +27,14 @@ func (f *Vorbis) unmarshal(ctx *unmarshalContext) error {
} }
sampleRate, err := strconv.ParseUint(tmp[0], 10, 31) sampleRate, err := strconv.ParseUint(tmp[0], 10, 31)
if err != nil { if err != nil || sampleRate == 0 {
return err return fmt.Errorf("invalid sample rate: '%s'", tmp[0])
} }
f.SampleRate = int(sampleRate) f.SampleRate = int(sampleRate)
channelCount, err := strconv.ParseUint(tmp[1], 10, 31) channelCount, err := strconv.ParseUint(tmp[1], 10, 31)
if err != nil { if err != nil || channelCount == 0 {
return err return fmt.Errorf("invalid channel count: '%s'", tmp[1])
} }
f.ChannelCount = int(channelCount) f.ChannelCount = int(channelCount)