diff --git a/pkg/media/oggreader/oggreader.go b/pkg/media/oggreader/oggreader.go index 67451f41..e3cce6a5 100644 --- a/pkg/media/oggreader/oggreader.go +++ b/pkg/media/oggreader/oggreader.go @@ -15,7 +15,6 @@ const ( pageHeaderTypeBeginningOfStream = 0x02 pageHeaderSignature = "OggS" - idPageSignature = "OpusHead" idPageBasePayloadLength = 19 pageHeaderLen = 27 ) @@ -39,8 +38,12 @@ type OggReader struct { doChecksum bool } -// OggHeader is the metadata from the first two pages -// in the file (ID and Comment) +// OggHeader contains Opus codec metadata parsed from an Opus ID page. +// This header is extracted from an Ogg page payload that starts with the OpusHead +// signature (the first page of an Opus stream in an Ogg container). +// +// Use OggPageHeader.OpusPacketType() to classify a page payload as OpusHead, +// and OggPageHeader.ParseOpusHeader() to parse the OpusHead payload. // // https://tools.ietf.org/html/rfc7845.html#section-3 type OggHeader struct { @@ -57,6 +60,16 @@ type OggHeader struct { ChannelMapping string } +// ParseOpusHead parses an Opus head from the page payload. +func ParseOpusHead(payload []byte) (*OggHeader, error) { + header := parseBasicHeaderFields(payload) + if err := parseChannelMapping(header, payload); err != nil { + return nil, err + } + + return header, nil +} + // OggPageHeader is the metadata for a Page // Pages are the fundamental unit of multiplexing in an Ogg stream // @@ -67,17 +80,78 @@ type OggPageHeader struct { sig [4]byte version uint8 headerType uint8 - serial uint32 + Serial uint32 index uint32 segmentsCount uint8 } +type HeaderType string + +const ( + headerUnknown HeaderType = "" + HeaderOpusID HeaderType = "OpusHead" +) + +func opusPayloadSignature(payload []byte) (HeaderType, bool) { + if len(payload) < 8 { + return headerUnknown, false + } + + sig := HeaderType(payload[:8]) + if sig == HeaderOpusID { + return sig, true + } + + return headerUnknown, false +} + +// HeaderType classifies the page. +func (p *OggPageHeader) HeaderType(payload []byte) (HeaderType, bool) { + sig, ok := opusPayloadSignature(payload) + + if !ok || (sig == HeaderOpusID && p.headerType != pageHeaderTypeBeginningOfStream) { + return headerUnknown, false + } + + return sig, true +} + +type Option func(*OggReader) error + // NewWith returns a new Ogg reader and Ogg header // with an io.Reader input. +// Deprecated: Use NewWithOptions instead. func NewWith(in io.Reader) (*OggReader, *OggHeader, error) { return newWith(in /* doChecksum */, true) } +// NewWithOptions returns a new Ogg reader. +func NewWithOptions(in io.Reader, options ...Option) (*OggReader, error) { + reader := &OggReader{ + stream: in, + checksumTable: generateChecksumTable(), + doChecksum: true, + } + + for _, option := range options { + if err := option(reader); err != nil { + return nil, err + } + } + + return reader, nil +} + +// WithDoChecksum is an option to set the doChecksum flag +// Default is true. +func WithDoChecksum(doChecksum bool) Option { + return func(o *OggReader) error { + o.doChecksum = doChecksum + + return nil + } +} + func newWith(in io.Reader, doChecksum bool) (*OggReader, *OggHeader, error) { if in == nil { return nil, nil, errNilStream @@ -89,7 +163,7 @@ func newWith(in io.Reader, doChecksum bool) (*OggReader, *OggHeader, error) { doChecksum: doChecksum, } - header, err := reader.readHeaders() + header, err := reader.readOpusHeader() if err != nil { return nil, nil, err } @@ -97,13 +171,13 @@ func newWith(in io.Reader, doChecksum bool) (*OggReader, *OggHeader, error) { return reader, header, nil } -func (o *OggReader) readHeaders() (*OggHeader, error) { +func (o *OggReader) readOpusHeader() (*OggHeader, error) { payload, pageHeader, err := o.ParseNextPage() if err != nil { return nil, err } - if err := validatePageHeader(pageHeader, payload); err != nil { + if err := validateOpusPageHeader(pageHeader, payload); err != nil { return nil, err } @@ -115,7 +189,7 @@ func (o *OggReader) readHeaders() (*OggHeader, error) { return header, nil } -func validatePageHeader(pageHeader *OggPageHeader, payload []byte) error { +func validateOpusPageHeader(pageHeader *OggPageHeader, payload []byte) error { if string(pageHeader.sig[:]) != pageHeaderSignature { return errBadIDPageSignature } @@ -128,8 +202,8 @@ func validatePageHeader(pageHeader *OggPageHeader, payload []byte) error { return errBadIDPageLength } - if s := string(payload[:8]); s != idPageSignature { - return errBadIDPagePayloadSignature + if sig, ok := opusPayloadSignature(payload); !ok || sig != HeaderOpusID { + return fmt.Errorf("%w: expected OpusHead, got %s", errBadIDPagePayloadSignature, sig) } return nil @@ -203,7 +277,7 @@ func (o *OggReader) ParseNextPage() ([]byte, *OggPageHeader, error) { //nolint:c pageHeader.version = header[4] pageHeader.headerType = header[5] pageHeader.GranulePosition = binary.LittleEndian.Uint64(header[6 : 6+8]) - pageHeader.serial = binary.LittleEndian.Uint32(header[14 : 14+4]) + pageHeader.Serial = binary.LittleEndian.Uint32(header[14 : 14+4]) pageHeader.index = binary.LittleEndian.Uint32(header[18 : 18+4]) pageHeader.segmentsCount = header[26] diff --git a/pkg/media/oggreader/oggreader_test.go b/pkg/media/oggreader/oggreader_test.go index 6d876346..52010f43 100644 --- a/pkg/media/oggreader/oggreader_test.go +++ b/pkg/media/oggreader/oggreader_test.go @@ -5,6 +5,8 @@ package oggreader import ( "bytes" + "encoding/binary" + "errors" "io" "testing" @@ -152,7 +154,7 @@ func TestOggReader_ParseErrors(t *testing.T) { ogg[0] = 0 _, _, err := newWith(bytes.NewReader(ogg), false) - assert.Equal(t, err, errBadIDPageSignature) + assert.ErrorIs(t, err, errBadIDPageSignature) }) t.Run("Invalid ID Page Header Type", func(t *testing.T) { @@ -160,7 +162,7 @@ func TestOggReader_ParseErrors(t *testing.T) { ogg[5] = 0 _, _, err := newWith(bytes.NewReader(ogg), false) - assert.Equal(t, err, errBadIDPageType) + assert.ErrorIs(t, err, errBadIDPageType) }) t.Run("Invalid ID Page Payload Length", func(t *testing.T) { @@ -168,7 +170,7 @@ func TestOggReader_ParseErrors(t *testing.T) { ogg[27] = 0 _, _, err := newWith(bytes.NewReader(ogg), false) - assert.Equal(t, err, errBadIDPageLength) + assert.ErrorIs(t, err, errBadIDPageLength) }) t.Run("Invalid ID Page Payload Length", func(t *testing.T) { @@ -176,7 +178,7 @@ func TestOggReader_ParseErrors(t *testing.T) { ogg[35] = 0 _, _, err := newWith(bytes.NewReader(ogg), false) - assert.Equal(t, err, errBadIDPagePayloadSignature) + assert.ErrorIs(t, err, errBadIDPagePayloadSignature) }) t.Run("Invalid Page Checksum", func(t *testing.T) { @@ -184,17 +186,17 @@ func TestOggReader_ParseErrors(t *testing.T) { ogg[22] = 0 _, _, err := NewWith(bytes.NewReader(ogg)) - assert.Equal(t, err, errChecksumMismatch) + assert.ErrorIs(t, err, errChecksumMismatch) }) t.Run("Invalid Multichannel ID Page Payload Length", func(t *testing.T) { _, _, err := newWith(bytes.NewReader(buildSurroundOggContainerShort()), false) - assert.Equal(t, err, errBadIDPageLength) + assert.ErrorIs(t, err, errBadIDPageLength) }) t.Run("Unsupported Channel Mapping Family", func(t *testing.T) { _, _, err := newWith(bytes.NewReader(buildUnknownMappingFamilyContainer(4, 2)), false) - assert.Equal(t, err, errUnsupportedChannelMappingFamily) + assert.ErrorIs(t, err, errUnsupportedChannelMappingFamily) }) } @@ -221,11 +223,21 @@ func TestOggReader_ChannelMappingFamily1(t *testing.T) { for _, tc := range cases { tc := tc t.Run(tc.name, func(t *testing.T) { - reader, header, err := newWith(bytes.NewReader( - buildChannelMappingFamilyContainer(1, tc.channels, tc.streams, tc.coupled, tc.channelMap), - ), false) + reader, err := NewWithOptions( + bytes.NewReader(buildChannelMappingFamilyContainer(1, tc.channels, tc.streams, tc.coupled, tc.channelMap)), + WithDoChecksum(false), + ) assert.NoError(t, err) assert.NotNil(t, reader) + + payload, pageHeader, err := reader.ParseNextPage() + assert.NoError(t, err) + sig, ok := pageHeader.HeaderType(payload) + assert.True(t, ok) + assert.Equal(t, HeaderOpusID, sig) + + header, err := ParseOpusHead(payload) + assert.NoError(t, err) assert.NotNil(t, header) assert.EqualValues(t, 1, header.Version) @@ -257,11 +269,21 @@ func TestOggReader_KnownChannelMappingFamilies(t *testing.T) { for _, tc := range cases { tc := tc t.Run(tc.name, func(t *testing.T) { - reader, header, err := newWith(bytes.NewReader( - buildChannelMappingFamilyContainer(tc.mappingFamily, tc.channels, tc.streams, tc.coupled, tc.channelMap), - ), false) + container := buildChannelMappingFamilyContainer( + tc.mappingFamily, tc.channels, tc.streams, tc.coupled, tc.channelMap, + ) + reader, err := NewWithOptions(bytes.NewReader(container), WithDoChecksum(false)) assert.NoError(t, err) assert.NotNil(t, reader) + + payload, pageHeader, err := reader.ParseNextPage() + assert.NoError(t, err) + sig, ok := pageHeader.HeaderType(payload) + assert.True(t, ok) + assert.Equal(t, HeaderOpusID, sig) + + header, err := ParseOpusHead(payload) + assert.NoError(t, err) assert.NotNil(t, header) assert.EqualValues(t, tc.mappingFamily, header.ChannelMap) @@ -299,11 +321,21 @@ func TestOggReader_ParseExtraFieldsForNonZeroMappingFamily(t *testing.T) { for _, tc := range cases { tc := tc t.Run(tc.name, func(t *testing.T) { - reader, header, err := newWith(bytes.NewReader( - buildChannelMappingFamilyContainer(tc.mappingFamily, tc.channels, tc.streams, tc.coupled, tc.channelMap), - ), false) + container := buildChannelMappingFamilyContainer( + tc.mappingFamily, tc.channels, tc.streams, tc.coupled, tc.channelMap, + ) + reader, err := NewWithOptions(bytes.NewReader(container), WithDoChecksum(false)) assert.NoError(t, err) assert.NotNil(t, reader) + + payload, pageHeader, err := reader.ParseNextPage() + assert.NoError(t, err) + sig, ok := pageHeader.HeaderType(payload) + assert.True(t, ok) + assert.Equal(t, HeaderOpusID, sig) + + header, err := ParseOpusHead(payload) + assert.NoError(t, err) assert.NotNil(t, header) assert.EqualValues(t, tc.mappingFamily, header.ChannelMap) @@ -314,3 +346,194 @@ func TestOggReader_ParseExtraFieldsForNonZeroMappingFamily(t *testing.T) { }) } } + +func TestOggReader_NewWithOptions(t *testing.T) { + t.Run("With checksum enabled (default)", func(t *testing.T) { + reader, err := NewWithOptions(bytes.NewReader(buildOggContainer())) + assert.NoError(t, err) + assert.NotNil(t, reader) + assert.True(t, reader.doChecksum) + + payload, pageHeader, err := reader.ParseNextPage() + assert.NoError(t, err) + assert.NotNil(t, payload) + assert.NotNil(t, pageHeader) + assert.Equal(t, string(HeaderOpusID), string(payload[:8])) + }) + + t.Run("With checksum enabled explicitly", func(t *testing.T) { + reader, err := NewWithOptions(bytes.NewReader(buildOggContainer()), WithDoChecksum(true)) + assert.NoError(t, err) + assert.NotNil(t, reader) + assert.True(t, reader.doChecksum) + + ogg := buildOggContainer() + ogg[22] = 0 + reader2, err := NewWithOptions(bytes.NewReader(ogg), WithDoChecksum(true)) + assert.NoError(t, err) + assert.NotNil(t, reader2) + + _, _, err = reader2.ParseNextPage() + assert.Equal(t, errChecksumMismatch, err) + }) + + t.Run("With checksum disabled", func(t *testing.T) { + reader, err := NewWithOptions(bytes.NewReader(buildOggContainer()), WithDoChecksum(false)) + assert.NoError(t, err) + assert.NotNil(t, reader) + assert.False(t, reader.doChecksum) + + ogg := buildOggContainer() + ogg[22] = 0 + reader2, err := NewWithOptions(bytes.NewReader(ogg), WithDoChecksum(false)) + assert.NoError(t, err) + assert.NotNil(t, reader2) + + payload, pageHeader, err := reader2.ParseNextPage() + assert.NoError(t, err) + assert.NotNil(t, payload) + assert.NotNil(t, pageHeader) + }) +} + +// buildMultiTrackOggContainer generates a minimal two-track Ogg file +// with two Opus ID header pages (one for each track). +func buildMultiTrackOggContainer( + firstSerial, secondSerial uint32, + channels uint8, + sampleRate uint32, + preskip uint16, + version uint8, + channelMap uint8, + outputGain uint16, +) []byte { + firstSerialBytes := make([]byte, 4) + binary.LittleEndian.PutUint32(firstSerialBytes, firstSerial) + secondSerialBytes := make([]byte, 4) + binary.LittleEndian.PutUint32(secondSerialBytes, secondSerial) + + preskipBytes := make([]byte, 2) + binary.LittleEndian.PutUint16(preskipBytes, preskip) + + sampleRateBytes := make([]byte, 4) + binary.LittleEndian.PutUint32(sampleRateBytes, sampleRate) + + outputGainBytes := make([]byte, 2) + binary.LittleEndian.PutUint16(outputGainBytes, outputGain) + + firstPageHeader := []byte{ + 0x4f, 0x67, 0x67, 0x53, // "OggS" + 0x00, // version + 0x02, // header type (beginning of stream) + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // granule position + firstSerialBytes[0], firstSerialBytes[1], firstSerialBytes[2], firstSerialBytes[3], // serial number + 0x00, 0x00, 0x00, 0x00, // page sequence number + 0xd7, 0xb7, 0x51, 0x4a, // checksum + 0x01, // page segments + 0x13, // segment size (19 bytes) + } + + firstPayload := []byte{ + 0x4f, 0x70, 0x75, 0x73, 0x48, 0x65, 0x61, 0x64, // "OpusHead" + version, // version + channels, // channels + preskipBytes[0], preskipBytes[1], // preskip + sampleRateBytes[0], sampleRateBytes[1], sampleRateBytes[2], sampleRateBytes[3], // sample rate + outputGainBytes[0], outputGainBytes[1], // output gain + channelMap, // channel mapping family + } + + // Second track: Opus ID page + // Ogg page header (27 bytes) + secondPageHeader := []byte{ + 0x4f, 0x67, 0x67, 0x53, // "OggS" + 0x00, // version + 0x02, // header type (beginning of stream) + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // granule position + secondSerialBytes[0], secondSerialBytes[1], secondSerialBytes[2], secondSerialBytes[3], // serial number + 0x00, 0x00, 0x00, 0x00, // page sequence number + 0xaf, 0xaa, 0x01, 0x8b, // checksum + 0x01, // page segments + 0x13, // segment size (19 bytes) + } + + // Second track: OpusHead payload (19 bytes) + secondPayload := []byte{ + 0x4f, 0x70, 0x75, 0x73, 0x48, 0x65, 0x61, 0x64, // "OpusHead" + version, // version + channels, // channels + preskipBytes[0], preskipBytes[1], // preskip + sampleRateBytes[0], sampleRateBytes[1], sampleRateBytes[2], sampleRateBytes[3], // sample rate + outputGainBytes[0], outputGainBytes[1], // output gain + channelMap, // channel mapping family + } + + container := make([]byte, 0, len(firstPageHeader)+len(firstPayload)+len(secondPageHeader)+len(secondPayload)) + container = append(container, firstPageHeader...) + container = append(container, firstPayload...) + container = append(container, secondPageHeader...) + container = append(container, secondPayload...) + + return container +} + +func TestOggReader_MultiTrackFile(t *testing.T) { + firstSerial := uint32(0xd03ed35d) + secondSerial := uint32(0xfa6e13f0) + channels := uint8(1) + sampleRate := uint32(48000) + preskip := uint16(0x0138) + version := uint8(1) + channelMap := uint8(0) + outputGain := uint16(0) + + data := buildMultiTrackOggContainer( + firstSerial, secondSerial, + channels, sampleRate, preskip, + version, channelMap, outputGain, + ) + + reader, err := NewWithOptions(bytes.NewReader(data), WithDoChecksum(false)) + assert.NoError(t, err) + assert.NotNil(t, reader) + + var headers []*OggHeader + var pageHeaders []*OggPageHeader + + for { + payload, pageHeader, err := reader.ParseNextPage() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + assert.NoError(t, err, "Error reading page") + + break + } + + sig, ok := pageHeader.HeaderType(payload) + assert.True(t, ok) + assert.Equal(t, HeaderOpusID, sig) + + header, err2 := ParseOpusHead(payload) + assert.NoError(t, err2) + assert.NotNil(t, header) + headers = append(headers, header) + pageHeaders = append(pageHeaders, pageHeader) + + t.Logf("Found header %d: Channels=%d, SampleRate=%d, Serial=%d", + len(headers), header.Channels, header.SampleRate, pageHeader.Serial) + } + + assert.Equal(t, 2, len(headers), "Should find exactly 2 headers") + assert.Equal(t, channels, headers[0].Channels, "First track should be mono") + assert.Equal(t, channels, headers[1].Channels, "Second track should be mono") + assert.Equal(t, sampleRate, headers[0].SampleRate, "First track should be 48kHz") + assert.Equal(t, sampleRate, headers[1].SampleRate, "Second track should be 48kHz") + + assert.Equal(t, firstSerial, pageHeaders[0].Serial, "First track serial should match") + assert.Equal(t, secondSerial, pageHeaders[1].Serial, "Second track serial should match") + assert.NotEqual(t, pageHeaders[0].Serial, pageHeaders[1].Serial, "Serial numbers should be different") + + t.Logf("Multi-track file: found %d headers", len(headers)) +}