diff --git a/client_read_test.go b/client_read_test.go index edf2ee82..7403539d 100644 --- a/client_read_test.go +++ b/client_read_test.go @@ -41,10 +41,10 @@ func TestClientReadTracks(t *testing.T) { track1, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil) require.NoError(t, err) - track2, err := NewTrackAAC(96, 2, 44100, 2, nil) + track2, err := NewTrackAAC(96, 2, 44100, 2, nil, 13, 3, 3) require.NoError(t, err) - track3, err := NewTrackAAC(96, 2, 96000, 2, nil) + track3, err := NewTrackAAC(96, 2, 96000, 2, nil, 13, 3, 3) require.NoError(t, err) l, err := net.Listen("tcp", "localhost:8554") diff --git a/examples/client-publish-aac/main.go b/examples/client-publish-aac/main.go index 29828440..707d9e9a 100644 --- a/examples/client-publish-aac/main.go +++ b/examples/client-publish-aac/main.go @@ -34,7 +34,7 @@ func main() { log.Println("stream connected") // create an AAC track - track, err := gortsplib.NewTrackAAC(96, 2, 48000, 2, nil) + track, err := gortsplib.NewTrackAAC(96, 2, 48000, 2, nil, 13, 3, 3) if err != nil { panic(err) } diff --git a/examples/client-read-aac/main.go b/examples/client-read-aac/main.go index f987242d..9e3375e4 100644 --- a/examples/client-read-aac/main.go +++ b/examples/client-read-aac/main.go @@ -37,10 +37,16 @@ func main() { // find the AAC track var clockRate int + var sizeLength int + var indexLength int + var indexDeltaLength int aacTrack := func() int { for i, track := range tracks { - if _, ok := track.(*gortsplib.TrackAAC); ok { + if tt, ok := track.(*gortsplib.TrackAAC); ok { clockRate = track.ClockRate() + sizeLength = tt.SizeLength() + indexLength = tt.IndexLength() + indexDeltaLength = tt.IndexDeltaLength() return i } } @@ -52,7 +58,10 @@ func main() { // setup decoder dec := &rtpaac.Decoder{ - SampleRate: clockRate, + SampleRate: clockRate, + SizeLength: sizeLength, + IndexLength: indexLength, + IndexDeltaLength: indexDeltaLength, } dec.Init() diff --git a/pkg/rtpaac/decoder.go b/pkg/rtpaac/decoder.go index b539c261..5660d4dc 100644 --- a/pkg/rtpaac/decoder.go +++ b/pkg/rtpaac/decoder.go @@ -1,11 +1,13 @@ package rtpaac import ( + "bytes" "encoding/binary" "errors" "fmt" "time" + "github.com/icza/bitio" "github.com/pion/rtp" "github.com/aler9/gortsplib/pkg/rtptimedec" @@ -23,6 +25,13 @@ type Decoder struct { fragmentedMode bool fragmentedParts [][]byte fragmentedSize int + + // The number of bits on which the AU-size field is encoded in the AU-header. + SizeLength int + // The number of bits on which the AU-Index is encoded in the first AU-header. + IndexLength int + // The number of bits on which the AU-Index-delta field is encoded in any non-first AU-header. + IndexDeltaLength int } // Init initializes the decoder @@ -41,12 +50,21 @@ func (d *Decoder) Decode(pkt *rtp.Packet) ([][]byte, time.Duration, error) { } // AU-headers-length - headersLen := binary.BigEndian.Uint16(pkt.Payload) - if (headersLen % 16) != 0 { + headersLen := int(binary.BigEndian.Uint16(pkt.Payload)) + + auHeaderSize := d.SizeLength + d.IndexLength + if auHeaderSize <= 0 { d.fragmentedParts = d.fragmentedParts[:0] d.fragmentedMode = false - return nil, 0, fmt.Errorf("invalid AU-headers-length (%d)", headersLen) + return nil, 0, fmt.Errorf("invalid AU-header-size (%d)", auHeaderSize) } + + if (headersLen % auHeaderSize) != 0 { + d.fragmentedParts = d.fragmentedParts[:0] + d.fragmentedMode = false + return nil, 0, fmt.Errorf("invalid AU-headers-length (%d) with AU-header-size (%d)", headersLen, auHeaderSize) + } + headersLenBytes := (headersLen + 7) / 8 payload := pkt.Payload[2:] if !d.fragmentedMode { @@ -55,23 +73,12 @@ func (d *Decoder) Decode(pkt *rtp.Packet) ([][]byte, time.Duration, error) { // AAC headers are 16 bits, where // * 13 bits are data size // * 3 bits are AU index - headerCount := headersLen / 16 - var dataLens []uint16 - for i := 0; i < int(headerCount); i++ { - if len(payload[i*2:]) < 2 { - return nil, 0, fmt.Errorf("payload is too short") - } - - header := binary.BigEndian.Uint16(payload[i*2:]) - dataLen := header >> 3 - auIndex := header & 0x03 - if auIndex != 0 { - return nil, 0, fmt.Errorf("AU-index field is not zero") - } - - dataLens = append(dataLens, dataLen) + headerCount := headersLen / auHeaderSize + dataLens, err := d.parseAuData(payload, headersLenBytes, headerCount) + if err != nil { + return nil, 0, err } - payload = payload[headerCount*2:] + payload = payload[headersLenBytes:] // AUs aus := make([][]byte, len(dataLens)) @@ -87,20 +94,21 @@ func (d *Decoder) Decode(pkt *rtp.Packet) ([][]byte, time.Duration, error) { return aus, d.timeDecoder.Decode(pkt.Timestamp), nil } - if headersLen != 16 { + if headersLen != auHeaderSize { return nil, 0, fmt.Errorf("a fragmented packet can only contain one AU") } // AU-header - header := binary.BigEndian.Uint16(payload) - dataLen := header >> 3 - auIndex := header & 0x03 - if auIndex != 0 { - return nil, 0, fmt.Errorf("AU-index field is not zero") + dataLens, err := d.parseAuData(payload, headersLenBytes, 1) + if err != nil { + return nil, 0, err } - payload = payload[2:] + if len(dataLens) != 1 { + return nil, 0, fmt.Errorf("a fragmented packet can only contain one AU") + } + payload = payload[headersLenBytes:] - if len(payload) < int(dataLen) { + if len(payload) < int(dataLens[0]) { return nil, 0, fmt.Errorf("payload is too short") } @@ -112,24 +120,27 @@ func (d *Decoder) Decode(pkt *rtp.Packet) ([][]byte, time.Duration, error) { // we are decoding a fragmented AU - if headersLen != 16 { + if headersLen != auHeaderSize { d.fragmentedParts = d.fragmentedParts[:0] d.fragmentedMode = false return nil, 0, fmt.Errorf("a fragmented packet can only contain one AU") } // AU-header - header := binary.BigEndian.Uint16(payload) - dataLen := header >> 3 - auIndex := header & 0x03 - if auIndex != 0 { + dataLens, err := d.parseAuData(payload, headersLenBytes, 1) + if err != nil { d.fragmentedParts = d.fragmentedParts[:0] d.fragmentedMode = false - return nil, 0, fmt.Errorf("AU-index field is not zero") + return nil, 0, err } - payload = payload[2:] + if len(dataLens) != 1 { + d.fragmentedParts = d.fragmentedParts[:0] + d.fragmentedMode = false + return nil, 0, fmt.Errorf("a fragmented packet can only contain one AU") + } + payload = payload[headersLenBytes:] - if len(payload) < int(dataLen) { + if len(payload) < int(dataLens[0]) { return nil, 0, fmt.Errorf("payload is too short") } @@ -156,3 +167,47 @@ func (d *Decoder) Decode(pkt *rtp.Packet) ([][]byte, time.Duration, error) { d.fragmentedMode = false return [][]byte{ret}, d.timeDecoder.Decode(pkt.Timestamp), nil } + +func (d *Decoder) parseAuData(payload []byte, + headersLenBytes int, + headerCount int, +) (dataLens []uint64, err error) { + if len(payload) < headersLenBytes { + return nil, fmt.Errorf("payload is too short") + } + + br := bitio.NewReader(bytes.NewBuffer(payload[:headersLenBytes])) + readAUIndex := func(index int) error { + auIndex, err := br.ReadBits(uint8(index)) + if err != nil { + return fmt.Errorf("payload is too short") + } + + if auIndex != 0 { + return fmt.Errorf("AU-index field is not zero") + } + + return nil + } + for i := 0; i < headerCount; i++ { + dataLen, err := br.ReadBits(uint8(d.SizeLength)) + if err != nil { + return nil, fmt.Errorf("payload is too short") + } + switch { + case i == 0 && d.IndexLength > 0: + err := readAUIndex(d.IndexLength) + if err != nil { + return nil, err + } + case d.IndexDeltaLength > 0: + err := readAUIndex(d.IndexDeltaLength) + if err != nil { + return nil, err + } + } + + dataLens = append(dataLens, dataLen) + } + return dataLens, nil +} diff --git a/pkg/rtpaac/rtpaac_test.go b/pkg/rtpaac/rtpaac_test.go index 46212b65..41780955 100644 --- a/pkg/rtpaac/rtpaac_test.go +++ b/pkg/rtpaac/rtpaac_test.go @@ -280,7 +280,10 @@ func TestDecode(t *testing.T) { for _, ca := range cases { t.Run(ca.name, func(t *testing.T) { d := &Decoder{ - SampleRate: 48000, + SampleRate: 48000, + IndexLength: 3, + SizeLength: 13, + IndexDeltaLength: 3, } d.Init() @@ -397,7 +400,7 @@ func TestDecodeErrors(t *testing.T) { Payload: []byte{0x00, 0x09}, }, }, - "invalid AU-headers-length (9)", + "invalid AU-headers-length (9) with AU-header-size (16)", }, { "au index not zero", @@ -566,7 +569,9 @@ func TestDecodeErrors(t *testing.T) { } { t.Run(ca.name, func(t *testing.T) { d := &Decoder{ - SampleRate: 48000, + SampleRate: 48000, + IndexLength: 3, + SizeLength: 13, } d.Init() diff --git a/track_aac.go b/track_aac.go index bfb759aa..eff1d059 100644 --- a/track_aac.go +++ b/track_aac.go @@ -20,11 +20,26 @@ type TrackAAC struct { channelCount int aotSpecificConfig []byte mpegConf []byte + + // The number of bits on which the AU-size field is encoded in the AU-header. + sizeLength int + // The number of bits on which the AU-Index is encoded in the first AU-header. + // The default value of zero indicates the absence of the AU-Index field in each first AU-header. + indexLength int + // The number of bits on which the AU-Index-delta field is encoded in any non-first AU-header. + // The default value of zero indicates the absence of the AU-Index-delta field in each non-first AU-header. + indexDeltaLength int } // NewTrackAAC allocates a TrackAAC. -func NewTrackAAC(payloadType uint8, typ int, sampleRate int, - channelCount int, aotSpecificConfig []byte, +func NewTrackAAC(payloadType uint8, + typ int, + sampleRate int, + channelCount int, + aotSpecificConfig []byte, + sizeLength int, + indexLength int, + indexDeltaLength int, ) (*TrackAAC, error) { mpegConf, err := aac.MPEG4AudioConfig{ Type: aac.MPEG4AudioType(typ), @@ -43,6 +58,9 @@ func NewTrackAAC(payloadType uint8, typ int, sampleRate int, channelCount: channelCount, aotSpecificConfig: aotSpecificConfig, mpegConf: mpegConf, + sizeLength: sizeLength, + indexLength: indexLength, + indexDeltaLength: indexDeltaLength, }, nil } @@ -61,6 +79,13 @@ func newTrackAACFromMediaDescription( return nil, fmt.Errorf("invalid fmtp (%v)", v) } + track := &TrackAAC{ + trackBase: trackBase{ + control: control, + }, + payloadType: payloadType, + } + for _, kv := range strings.Split(tmp[1], ";") { kv = strings.Trim(kv, " ") @@ -73,7 +98,8 @@ func newTrackAACFromMediaDescription( return nil, fmt.Errorf("invalid fmtp (%v)", v) } - if tmp[0] == "config" { + switch strings.ToLower(tmp[0]) { + case "config": enc, err := hex.DecodeString(tmp[1]) if err != nil { return nil, fmt.Errorf("invalid AAC config (%v)", tmp[1]) @@ -88,21 +114,37 @@ func newTrackAACFromMediaDescription( // re-encode the conf to normalize it enc, _ = mpegConf.Encode() - return &TrackAAC{ - trackBase: trackBase{ - control: control, - }, - payloadType: payloadType, - typ: int(mpegConf.Type), - sampleRate: mpegConf.SampleRate, - channelCount: mpegConf.ChannelCount, - aotSpecificConfig: mpegConf.AOTSpecificConfig, - mpegConf: enc, - }, nil + track.typ = int(mpegConf.Type) + track.sampleRate = mpegConf.SampleRate + track.channelCount = mpegConf.ChannelCount + track.aotSpecificConfig = mpegConf.AOTSpecificConfig + track.mpegConf = enc + case "sizelength": + val, err := strconv.ParseUint(tmp[1], 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid AAC sizeLength (%v)", tmp[1]) + } + track.sizeLength = int(val) + case "indexlength": + val, err := strconv.ParseUint(tmp[1], 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid AAC indexLength (%v)", tmp[1]) + } + track.indexLength = int(val) + case "indexdeltalength": + val, err := strconv.ParseUint(tmp[1], 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid AAC indexDeltaLength (%v)", tmp[1]) + } + track.indexDeltaLength = int(val) } } - return nil, fmt.Errorf("config is missing (%v)", v) + if len(track.mpegConf) == 0 { + return nil, fmt.Errorf("config is missing (%v)", v) + } + + return track, nil } // ClockRate returns the track clock rate. @@ -125,6 +167,21 @@ func (t *TrackAAC) AOTSpecificConfig() []byte { return t.aotSpecificConfig } +// SizeLength returns the track sizeLength. +func (t *TrackAAC) SizeLength() int { + return t.sizeLength +} + +// IndexLength returns the track indexLength. +func (t *TrackAAC) IndexLength() int { + return t.indexLength +} + +// IndexDeltaLength returns the track indexDeltaLength. +func (t *TrackAAC) IndexDeltaLength() int { + return t.indexDeltaLength +} + func (t *TrackAAC) clone() Track { return &TrackAAC{ trackBase: t.trackBase, @@ -134,6 +191,9 @@ func (t *TrackAAC) clone() Track { channelCount: t.channelCount, aotSpecificConfig: t.aotSpecificConfig, mpegConf: t.mpegConf, + sizeLength: t.sizeLength, + indexLength: t.indexLength, + indexDeltaLength: t.indexDeltaLength, } } @@ -157,9 +217,24 @@ func (t *TrackAAC) MediaDescription() *psdp.MediaDescription { Key: "fmtp", Value: typ + " profile-level-id=1; " + "mode=AAC-hbr; " + - "sizelength=13; " + - "indexlength=3; " + - "indexdeltalength=3; " + + func() string { + if t.sizeLength > 0 { + return fmt.Sprintf("sizelength=%d; ", t.sizeLength) + } + return "" + }() + + func() string { + if t.indexLength > 0 { + return fmt.Sprintf("indexlength=%d; ", t.indexLength) + } + return "" + }() + + func() string { + if t.indexDeltaLength > 0 { + return fmt.Sprintf("indexdeltalength=%d; ", t.indexDeltaLength) + } + return "" + }() + "config=" + hex.EncodeToString(t.mpegConf), }, { diff --git a/track_aac_test.go b/track_aac_test.go index d6ed670c..408c8e41 100644 --- a/track_aac_test.go +++ b/track_aac_test.go @@ -8,22 +8,25 @@ import ( ) func TestTrackAACNew(t *testing.T) { - track, err := NewTrackAAC(96, 2, 48000, 4, []byte{0x01, 0x02}) + track, err := NewTrackAAC(96, 2, 48000, 4, []byte{0x01, 0x02}, 13, 3, 3) require.NoError(t, err) require.Equal(t, "", track.GetControl()) require.Equal(t, 2, track.Type()) require.Equal(t, 48000, track.ClockRate()) require.Equal(t, 4, track.ChannelCount()) require.Equal(t, []byte{0x01, 0x02}, track.AOTSpecificConfig()) + require.Equal(t, 13, track.SizeLength()) + require.Equal(t, 3, track.IndexLength()) + require.Equal(t, 3, track.IndexDeltaLength()) } func TestTrackAACNewErrors(t *testing.T) { - _, err := NewTrackAAC(96, 2, 48000, 10, nil) + _, err := NewTrackAAC(96, 2, 48000, 10, nil, 13, 3, 3) require.EqualError(t, err, "invalid configuration: invalid channel count (10)") } func TestTrackAACClone(t *testing.T) { - track, err := NewTrackAAC(96, 2, 48000, 2, []byte{0x01, 0x02}) + track, err := NewTrackAAC(96, 2, 48000, 2, []byte{0x01, 0x02}, 13, 3, 3) require.NoError(t, err) clone := track.clone() @@ -32,7 +35,7 @@ func TestTrackAACClone(t *testing.T) { } func TestTrackAACMediaDescription(t *testing.T) { - track, err := NewTrackAAC(96, 2, 48000, 2, nil) + track, err := NewTrackAAC(96, 2, 48000, 2, nil, 13, 3, 3) require.NoError(t, err) require.Equal(t, &psdp.MediaDescription{ @@ -57,3 +60,69 @@ func TestTrackAACMediaDescription(t *testing.T) { }, }, track.MediaDescription()) } + +func TestNewTrackAACFromMediaDescription(t *testing.T) { + track, err := newTrackAACFromMediaDescription("", 2, &psdp.MediaDescription{ + MediaName: psdp.MediaName{ + Media: "audio", + Protos: []string{"RTP", "AVP"}, + Formats: []string{"96"}, + }, + Attributes: []psdp.Attribute{ + { + Key: "rtpmap", + Value: "96 mpeg4-generic/48000/2", + }, + { + Key: "fmtp", + Value: "96 profile-level-id=1; mode=AAC-hbr; sizelength=13; indexlength=3; indexdeltalength=3; config=11900810", + }, + { + Key: "control", + Value: "", + }, + }, + }) + require.NoError(t, err) + require.Equal(t, "", track.GetControl()) + require.Equal(t, 2, track.Type()) + require.Equal(t, 48000, track.ClockRate()) + require.Equal(t, 2, track.ChannelCount()) + require.Equal(t, []byte{0x01, 0x02}, track.AOTSpecificConfig()) + require.Equal(t, 13, track.SizeLength()) + require.Equal(t, 3, track.IndexLength()) + require.Equal(t, 3, track.IndexDeltaLength()) +} + +func TestNewTrackAACFromMediaDescriptionWithoutIndex(t *testing.T) { + track, err := newTrackAACFromMediaDescription("", 2, &psdp.MediaDescription{ + MediaName: psdp.MediaName{ + Media: "audio", + Protos: []string{"RTP", "AVP"}, + Formats: []string{"96"}, + }, + Attributes: []psdp.Attribute{ + { + Key: "rtpmap", + Value: "96 mpeg4-generic/48000/2", + }, + { + Key: "fmtp", + Value: "96 streamtype=3;profile-level-id=14;mode=AAC-hbr;config=1190;sizeLength=13", + }, + { + Key: "control", + Value: "", + }, + }, + }) + require.NoError(t, err) + require.Equal(t, "", track.GetControl()) + require.Equal(t, 2, track.Type()) + require.Equal(t, 48000, track.ClockRate()) + require.Equal(t, 2, track.ChannelCount()) + require.Equal(t, []byte(nil), track.AOTSpecificConfig()) + require.Equal(t, 13, track.SizeLength()) + require.Equal(t, 0, track.IndexLength()) + require.Equal(t, 0, track.IndexDeltaLength()) +} diff --git a/track_test.go b/track_test.go index bfb714a3..e72d66f0 100644 --- a/track_test.go +++ b/track_test.go @@ -62,11 +62,14 @@ func TestTrackNewFromMediaDescription(t *testing.T) { }, }, &TrackAAC{ - payloadType: 96, - typ: 2, - sampleRate: 48000, - channelCount: 2, - mpegConf: []byte{0x11, 0x90}, + payloadType: 96, + typ: 2, + sampleRate: 48000, + channelCount: 2, + mpegConf: []byte{0x11, 0x90}, + sizeLength: 13, + indexLength: 3, + indexDeltaLength: 3, }, }, { @@ -89,11 +92,14 @@ func TestTrackNewFromMediaDescription(t *testing.T) { }, }, &TrackAAC{ - payloadType: 96, - typ: 2, - sampleRate: 48000, - channelCount: 2, - mpegConf: []byte{0x11, 0x90}, + payloadType: 96, + typ: 2, + sampleRate: 48000, + channelCount: 2, + mpegConf: []byte{0x11, 0x90}, + sizeLength: 13, + indexLength: 3, + indexDeltaLength: 3, }, }, { @@ -116,11 +122,14 @@ func TestTrackNewFromMediaDescription(t *testing.T) { }, }, &TrackAAC{ - payloadType: 96, - typ: 2, - sampleRate: 48000, - channelCount: 2, - mpegConf: []byte{0x11, 0x90}, + payloadType: 96, + typ: 2, + sampleRate: 48000, + channelCount: 2, + mpegConf: []byte{0x11, 0x90}, + indexLength: 3, + sizeLength: 13, + indexDeltaLength: 3, }, }, {