diff --git a/clientconf_test.go b/clientconf_test.go index fd8e3d8c..a9a6a263 100644 --- a/clientconf_test.go +++ b/clientconf_test.go @@ -449,8 +449,8 @@ func TestClientDialPublishSerial(t *testing.T) { require.NoError(t, err) defer cnt2.close() - decoder := rtph264.NewDecoderFromPacketConn(pc) - sps, pps, err := decoder.ReadSPSPPS() + decoder := rtph264.NewDecoder() + sps, pps, err := decoder.ReadSPSPPS(rtph264.PacketConnReader{pc}) //nolint:govet require.NoError(t, err) track, err := NewTrackH264(96, sps, pps) @@ -529,8 +529,8 @@ func TestClientDialPublishParallel(t *testing.T) { require.NoError(t, err) defer cnt2.close() - decoder := rtph264.NewDecoderFromPacketConn(pc) - sps, pps, err := decoder.ReadSPSPPS() + decoder := rtph264.NewDecoder() + sps, pps, err := decoder.ReadSPSPPS(rtph264.PacketConnReader{pc}) //nolint:govet require.NoError(t, err) track, err := NewTrackH264(96, sps, pps) @@ -623,8 +623,8 @@ func TestClientDialPublishPauseSerial(t *testing.T) { require.NoError(t, err) defer cnt2.close() - decoder := rtph264.NewDecoderFromPacketConn(pc) - sps, pps, err := decoder.ReadSPSPPS() + decoder := rtph264.NewDecoder() + sps, pps, err := decoder.ReadSPSPPS(rtph264.PacketConnReader{pc}) //nolint:govet require.NoError(t, err) track, err := NewTrackH264(96, sps, pps) @@ -695,8 +695,8 @@ func TestClientDialPublishPauseParallel(t *testing.T) { require.NoError(t, err) defer cnt2.close() - decoder := rtph264.NewDecoderFromPacketConn(pc) - sps, pps, err := decoder.ReadSPSPPS() + decoder := rtph264.NewDecoder() + sps, pps, err := decoder.ReadSPSPPS(rtph264.PacketConnReader{pc}) //nolint:govet require.NoError(t, err) track, err := NewTrackH264(96, sps, pps) diff --git a/examples/client-publish-options/main.go b/examples/client-publish-options/main.go index 810d98c7..9844f241 100644 --- a/examples/client-publish-options/main.go +++ b/examples/client-publish-options/main.go @@ -27,9 +27,9 @@ func main() { "gst-launch-1.0 filesrc location=video.mp4 ! qtdemux ! video/x-h264" + " ! h264parse config-interval=1 ! rtph264pay ! udpsink host=127.0.0.1 port=9000") - // wait for RTP/H264 frames - decoder := rtph264.NewDecoderFromPacketConn(pc) - sps, pps, err := decoder.ReadSPSPPS() + // get SPS and PPS + decoder := rtph264.NewDecoder() + sps, pps, err := decoder.ReadSPSPPS(rtph264.PacketConnReader{pc}) if err != nil { panic(err) } diff --git a/examples/client-publish-pause/main.go b/examples/client-publish-pause/main.go index ddb6c2a2..46b64bb9 100644 --- a/examples/client-publish-pause/main.go +++ b/examples/client-publish-pause/main.go @@ -28,9 +28,9 @@ func main() { "gst-launch-1.0 filesrc location=video.mp4 ! qtdemux ! video/x-h264" + " ! h264parse config-interval=1 ! rtph264pay ! udpsink host=127.0.0.1 port=9000") - // wait for RTP/H264 frames - decoder := rtph264.NewDecoderFromPacketConn(pc) - sps, pps, err := decoder.ReadSPSPPS() + // get SPS and PPS + decoder := rtph264.NewDecoder() + sps, pps, err := decoder.ReadSPSPPS(rtph264.PacketConnReader{pc}) if err != nil { panic(err) } diff --git a/examples/client-publish/main.go b/examples/client-publish/main.go index 041eb7a5..11a32cdf 100644 --- a/examples/client-publish/main.go +++ b/examples/client-publish/main.go @@ -25,9 +25,9 @@ func main() { "gst-launch-1.0 filesrc location=video.mp4 ! qtdemux ! video/x-h264" + " ! h264parse config-interval=1 ! rtph264pay ! udpsink host=127.0.0.1 port=9000") - // wait for RTP/H264 frames - decoder := rtph264.NewDecoderFromPacketConn(pc) - sps, pps, err := decoder.ReadSPSPPS() + // get SPS and PPS + decoder := rtph264.NewDecoder() + sps, pps, err := decoder.ReadSPSPPS(rtph264.PacketConnReader{pc}) if err != nil { panic(err) } diff --git a/pkg/rtph264/decoder.go b/pkg/rtph264/decoder.go index 30dce358..8e2fe0f0 100644 --- a/pkg/rtph264/decoder.go +++ b/pkg/rtph264/decoder.go @@ -1,146 +1,170 @@ -// Package rtph264 contains a RTP/H264 decoder and encoder. package rtph264 import ( + "errors" "fmt" "io" "net" + "time" "github.com/pion/rtp" ) -type packetConnReader struct { - inner net.PacketConn +// ErrMorePacketsNeeded is returned by Decoder.Read when more packets are needed. +var ErrMorePacketsNeeded = errors.New("need more packets") + +// PacketConnReader creates a io.Reader around a net.PacketConn. +type PacketConnReader struct { + net.PacketConn } -func (r packetConnReader) Read(p []byte) (int, error) { - n, _, err := r.inner.ReadFrom(p) +// Read implements io.Reader. +func (r PacketConnReader) Read(p []byte) (int, error) { + n, _, err := r.PacketConn.ReadFrom(p) return n, err } +type decoderState int + +const ( + decoderStateInitial decoderState = iota + decoderStateReadingFragmented +) + // Decoder is a RTP/H264 decoder. type Decoder struct { - r io.Reader - buf []byte + state decoderState + initialTs uint32 + initialTsSet bool + fragmentedBuf []byte } // NewDecoder creates a decoder around a Reader. -func NewDecoder(r io.Reader) *Decoder { - return &Decoder{ - r: r, - buf: make([]byte, 2048), - } +func NewDecoder() *Decoder { + return &Decoder{} } -// NewDecoderFromPacketConn creates a decoder around a net.PacketConn. -func NewDecoderFromPacketConn(pc net.PacketConn) *Decoder { - return NewDecoder(packetConnReader{pc}) -} - -// Read decodes NALUs from RTP/H264 packets. -func (d *Decoder) Read() ([][]byte, error) { - n, err := d.r.Read(d.buf) - if err != nil { - return nil, err - } - - pkt := rtp.Packet{} - err = pkt.Unmarshal(d.buf[:n]) - if err != nil { - return nil, err - } - payload := pkt.Payload - - typ := NALUType(payload[0] & 0x1F) - - switch typ { - case NALUTypeNonIDR, NALUTypeDataPartitionA, NALUTypeDataPartitionB, - NALUTypeDataPartitionC, NALUTypeIDR, NALUTypeSei, NALUTypeSPS, - NALUTypePPS, NALUTypeAccessUnitDelimiter, NALUTypeEndOfSequence, - NALUTypeEndOfStream, NALUTypeFillerData, NALUTypeSPSExtension, - NALUTypePrefix, NALUTypeSubsetSPS, NALUTypeReserved16, NALUTypeReserved17, - NALUTypeReserved18, NALUTypeSliceLayerWithoutPartitioning, - NALUTypeSliceExtension, NALUTypeSliceExtensionDepth, NALUTypeReserved22, - NALUTypeReserved23: - return [][]byte{payload}, nil - - case NALUTypeFuA: - return d.readFragmented(payload) - - case NALUTypeStapA, NALUTypeStapB, NALUTypeMtap16, NALUTypeMtap24, NALUTypeFuB: - return nil, fmt.Errorf("NALU type not supported (%d)", typ) - } - - return nil, fmt.Errorf("invalid NALU type (%d)", typ) -} - -func (d *Decoder) readFragmented(payload []byte) ([][]byte, error) { - // A NALU can have any size; we can't preallocate it - var ret []byte - - // process first nalu - nri := (payload[0] >> 5) & 0x03 - start := payload[1] >> 7 - if start != 1 { - return nil, fmt.Errorf("first NALU does not contain the start bit") - } - typ := payload[1] & 0x1F - ret = append([]byte{(nri << 5) | typ}, payload[2:]...) - - // process other nalus - for { - n, err := d.r.Read(d.buf) - if err != nil { - return nil, err - } - +// Decode decodes a NALU from RTP/H264 packets. +// Since a NALU can require multiple RTP/H264 packets, it returns +// one packet, or no packets with ErrMorePacketsNeeded. +func (d *Decoder) Decode(byts []byte) (*NALUAndTimestamp, error) { + switch d.state { + case decoderStateInitial: pkt := rtp.Packet{} - err = pkt.Unmarshal(d.buf[:n]) + err := pkt.Unmarshal(byts) if err != nil { return nil, err } - payload := pkt.Payload - typ := NALUType(payload[0] & 0x1F) + if !d.initialTsSet { + d.initialTsSet = true + d.initialTs = pkt.Timestamp + } + + typ := NALUType(pkt.Payload[0] & 0x1F) + + switch typ { + case NALUTypeNonIDR, NALUTypeDataPartitionA, NALUTypeDataPartitionB, + NALUTypeDataPartitionC, NALUTypeIDR, NALUTypeSei, NALUTypeSPS, + NALUTypePPS, NALUTypeAccessUnitDelimiter, NALUTypeEndOfSequence, + NALUTypeEndOfStream, NALUTypeFillerData, NALUTypeSPSExtension, + NALUTypePrefix, NALUTypeSubsetSPS, NALUTypeReserved16, NALUTypeReserved17, + NALUTypeReserved18, NALUTypeSliceLayerWithoutPartitioning, + NALUTypeSliceExtension, NALUTypeSliceExtensionDepth, NALUTypeReserved22, + NALUTypeReserved23: + return &NALUAndTimestamp{ + NALU: pkt.Payload, + Timestamp: time.Duration(pkt.Timestamp-d.initialTs) * time.Second / rtpClockRate, + }, nil + + case NALUTypeFuA: // first packet of a fragmented NALU + nri := (pkt.Payload[0] >> 5) & 0x03 + start := pkt.Payload[1] >> 7 + if start != 1 { + return nil, fmt.Errorf("first NALU does not contain the start bit") + } + typ := pkt.Payload[1] & 0x1F + d.fragmentedBuf = append([]byte{(nri << 5) | typ}, pkt.Payload[2:]...) + + d.state = decoderStateReadingFragmented + return nil, ErrMorePacketsNeeded + + case NALUTypeStapA, NALUTypeStapB, NALUTypeMtap16, NALUTypeMtap24, NALUTypeFuB: + return nil, fmt.Errorf("NALU type not supported (%d)", typ) + } + + return nil, fmt.Errorf("invalid NALU type (%d)", typ) + + default: // decoderStateReadingFragmented + pkt := rtp.Packet{} + err := pkt.Unmarshal(byts) + if err != nil { + return nil, err + } + + typ := NALUType(pkt.Payload[0] & 0x1F) if typ != NALUTypeFuA { return nil, fmt.Errorf("non-starting NALU is not FU-A") } - end := (payload[1] >> 6) & 0x01 + end := (pkt.Payload[1] >> 6) & 0x01 - ret = append(ret, payload[2:]...) + d.fragmentedBuf = append(d.fragmentedBuf, pkt.Payload[2:]...) - if end == 1 { - break + if end != 1 { + return nil, ErrMorePacketsNeeded } - } - return [][]byte{ret}, nil + d.state = decoderStateInitial + return &NALUAndTimestamp{ + NALU: d.fragmentedBuf, + Timestamp: time.Duration(pkt.Timestamp-d.initialTs) * time.Second / rtpClockRate, + }, nil + } } -// ReadSPSPPS decodes NALUs until SPS and PPS are found. -func (d *Decoder) ReadSPSPPS() ([]byte, []byte, error) { +// Read reads RTP/H264 packets from a reader until a NALU is decoded. +func (d *Decoder) Read(r io.Reader) (*NALUAndTimestamp, error) { + buf := make([]byte, 2048) + for { + n, err := r.Read(buf) + if err != nil { + return nil, err + } + + nalu, err := d.Decode(buf[:n]) + if err != nil { + if err == ErrMorePacketsNeeded { + continue + } + return nil, err + } + return nalu, nil + } +} + +// ReadSPSPPS reads RTP/H264 packets from a reader until SPS and PPS are +// found, and returns them. +func (d *Decoder) ReadSPSPPS(r io.Reader) ([]byte, []byte, error) { var sps []byte var pps []byte for { - nalus, err := d.Read() + nt, err := d.Read(r) if err != nil { return nil, nil, err } - for _, nalu := range nalus { - switch NALUType(nalu[0] & 0x1F) { - case NALUTypeSPS: - sps = append([]byte(nil), nalu...) - if sps != nil && pps != nil { - return sps, pps, nil - } + switch NALUType(nt.NALU[0] & 0x1F) { + case NALUTypeSPS: + sps = append([]byte(nil), nt.NALU...) + if sps != nil && pps != nil { + return sps, pps, nil + } - case NALUTypePPS: - pps = append([]byte(nil), nalu...) - if sps != nil && pps != nil { - return sps, pps, nil - } + case NALUTypePPS: + pps = append([]byte(nil), nt.NALU...) + if sps != nil && pps != nil { + return sps, pps, nil } } } diff --git a/pkg/rtph264/encoder.go b/pkg/rtph264/encoder.go index 3cbb9465..7b14fb07 100644 --- a/pkg/rtph264/encoder.go +++ b/pkg/rtph264/encoder.go @@ -2,14 +2,14 @@ package rtph264 import ( "math/rand" - "time" "github.com/pion/rtp" ) const ( rtpVersion = 0x02 - rtpPayloadMaxSize = 1460 // 1500 (mtu) - 20 (ip header) - 8 (udp header) - 12 (rtp header) + rtpPayloadMaxSize = 1460 // 1500 (mtu) - 20 (ip header) - 8 (udp header) - 12 (rtp header) + rtpClockRate = 90000 // h264 always uses 90khz ) // Encoder is a RTP/H264 encoder. @@ -18,52 +18,49 @@ type Encoder struct { sequenceNumber uint16 ssrc uint32 initialTs uint32 - started time.Duration } // NewEncoder allocates an Encoder. -func NewEncoder(payloadType uint8) (*Encoder, error) { +func NewEncoder(payloadType uint8, sequenceNumber *uint16, + ssrc *uint32, initialTs *uint32) *Encoder { return &Encoder{ - payloadType: payloadType, - sequenceNumber: uint16(rand.Uint32()), - ssrc: rand.Uint32(), - initialTs: rand.Uint32(), - }, nil + payloadType: payloadType, + sequenceNumber: func() uint16 { + if sequenceNumber != nil { + return *sequenceNumber + } + return uint16(rand.Uint32()) + }(), + ssrc: func() uint32 { + if ssrc != nil { + return *ssrc + } + return rand.Uint32() + }(), + initialTs: func() uint32 { + if initialTs != nil { + return *initialTs + } + return rand.Uint32() + }(), + } } -// Write encodes NALUs into RTP/H264 packets. -func (e *Encoder) Write(ts time.Duration, nalus [][]byte) ([][]byte, error) { - if e.started == 0 { - e.started = ts - } +// Encode encodes a NALU into RTP/H264 packets. +// It always returns at least one RTP/H264 packet. +func (e *Encoder) Encode(nt *NALUAndTimestamp) ([][]byte, error) { + rtpTime := e.initialTs + uint32((nt.Timestamp).Seconds()*rtpClockRate) - // rtp/h264 uses a 90khz clock - rtpTime := e.initialTs + uint32((ts-e.started).Seconds()*90000) - - var frames [][]byte - - for i, nalu := range nalus { - naluFrames, err := e.writeNALU(rtpTime, nalu, (i == len(nalus)-1)) - if err != nil { - return nil, err - } - frames = append(frames, naluFrames...) - } - - return frames, nil -} - -func (e *Encoder) writeNALU(rtpTime uint32, nalu []byte, isFinal bool) ([][]byte, error) { - // if the NALU fits into a single RTP packet, use a single NALU payload - if len(nalu) < rtpPayloadMaxSize { - return e.writeSingle(rtpTime, nalu, isFinal) + // if the NALU fits into a single RTP packet, use a single payload + if len(nt.NALU) < rtpPayloadMaxSize { + return e.writeSingle(rtpTime, nt.NALU) } // otherwise, split the NALU into multiple fragmentation payloads - return e.writeFragmented(rtpTime, nalu, isFinal) + return e.writeFragmented(rtpTime, nt.NALU) } -func (e *Encoder) writeSingle(rtpTime uint32, nalu []byte, isFinal bool) ([][]byte, error) { +func (e *Encoder) writeSingle(rtpTime uint32, nalu []byte) ([][]byte, error) { rpkt := rtp.Packet{ Header: rtp.Header{ Version: rtpVersion, @@ -76,9 +73,7 @@ func (e *Encoder) writeSingle(rtpTime uint32, nalu []byte, isFinal bool) ([][]by } e.sequenceNumber++ - if isFinal { - rpkt.Header.Marker = true - } + rpkt.Header.Marker = true frame, err := rpkt.Marshal() if err != nil { @@ -88,7 +83,7 @@ func (e *Encoder) writeSingle(rtpTime uint32, nalu []byte, isFinal bool) ([][]by return [][]byte{frame}, nil } -func (e *Encoder) writeFragmented(rtpTime uint32, nalu []byte, isFinal bool) ([][]byte, error) { +func (e *Encoder) writeFragmented(rtpTime uint32, nalu []byte) ([][]byte, error) { // use only FU-A, not FU-B, since we always use non-interleaved mode // (packetization-mode=1) frameCount := (len(nalu) - 1) / (rtpPayloadMaxSize - 2) @@ -96,7 +91,7 @@ func (e *Encoder) writeFragmented(rtpTime uint32, nalu []byte, isFinal bool) ([] if lastFrameSize > 0 { frameCount++ } - frames := make([][]byte, frameCount) + ret := make([][]byte, frameCount) nri := (nalu[0] >> 5) & 0x03 typ := nalu[0] & 0x1F @@ -111,7 +106,7 @@ func (e *Encoder) writeFragmented(rtpTime uint32, nalu []byte, isFinal bool) ([] } end := uint8(0) le := rtpPayloadMaxSize - 2 - if i == (len(frames) - 1) { + if i == (frameCount - 1) { end = 1 le = lastFrameSize } @@ -132,7 +127,7 @@ func (e *Encoder) writeFragmented(rtpTime uint32, nalu []byte, isFinal bool) ([] } e.sequenceNumber++ - if isFinal && i == (len(frames)-1) { + if i == (frameCount - 1) { rpkt.Header.Marker = true } @@ -141,8 +136,8 @@ func (e *Encoder) writeFragmented(rtpTime uint32, nalu []byte, isFinal bool) ([] return nil, err } - frames[i] = frame + ret[i] = frame } - return frames, nil + return ret, nil } diff --git a/pkg/rtph264/defs.go b/pkg/rtph264/rtph264.go similarity index 88% rename from pkg/rtph264/defs.go rename to pkg/rtph264/rtph264.go index 0680fa32..9ccedd60 100644 --- a/pkg/rtph264/defs.go +++ b/pkg/rtph264/rtph264.go @@ -1,5 +1,16 @@ +// Package rtph264 contains a RTP/H264 decoder and encoder. package rtph264 +import ( + "time" +) + +// NALUAndTimestamp is a NALU and an associated timestamp. +type NALUAndTimestamp struct { + Timestamp time.Duration + NALU []byte +} + // NALUType is the type of a NALU. type NALUType uint8 diff --git a/pkg/rtph264/rtph264_test.go b/pkg/rtph264/rtph264_test.go new file mode 100644 index 00000000..fed2f252 --- /dev/null +++ b/pkg/rtph264/rtph264_test.go @@ -0,0 +1,133 @@ +package rtph264 + +import ( + "bytes" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func mergeBytes(vals ...[]byte) []byte { + size := 0 + for _, v := range vals { + size += len(v) + } + res := make([]byte, size) + + pos := 0 + for _, v := range vals { + n := copy(res[pos:], v) + pos += n + } + + return res +} + +type readerFunc func(p []byte) (int, error) + +func (f readerFunc) Read(p []byte) (int, error) { + return f(p) +} + +var cases = []struct { + name string + dec *NALUAndTimestamp + enc [][]byte +}{ + { + "single", + &NALUAndTimestamp{ + Timestamp: 25 * time.Millisecond, + NALU: mergeBytes( + []byte{0x05}, + bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 8), + ), + }, + [][]byte{ + mergeBytes( + []byte{ + 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6f, 0x1f, + 0x9d, 0xbb, 0x78, 0x12, 0x05, + }, + bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 8), + ), + }, + }, + { + "fragmented", + &NALUAndTimestamp{ + Timestamp: 55 * time.Millisecond, + NALU: mergeBytes( + []byte{0x05}, + bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 256), + ), + }, + [][]byte{ + mergeBytes( + []byte{ + 0x80, 0x60, 0x44, 0xed, 0x88, 0x77, 0x79, 0xab, + 0x9d, 0xbb, 0x78, 0x12, 0x1c, 0x85, 0x00, 0x01, + 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + }, + bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 181), + []byte{0x00, 0x01}, + ), + mergeBytes( + []byte{ + 0x80, 0xe0, 0x44, 0xee, 0x88, 0x77, 0x79, 0xab, + 0x9d, 0xbb, 0x78, 0x12, 0x1c, 0x45, 0x02, 0x03, + 0x04, 0x05, 0x06, 0x07, + }, + bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 73), + ), + }, + }, +} + +func TestEncode(t *testing.T) { + for _, ca := range cases { + t.Run(ca.name, func(t *testing.T) { + sequenceNumber := uint16(0x44ed) + ssrc := uint32(0x9dbb7812) + initialTs := uint32(0x88776655) + e := NewEncoder(96, &sequenceNumber, &ssrc, &initialTs) + enc, err := e.Encode(ca.dec) + require.NoError(t, err) + require.Equal(t, ca.enc, enc) + }) + } +} + +func TestDecode(t *testing.T) { + for _, ca := range cases { + t.Run(ca.name, func(t *testing.T) { + i := 0 + r := readerFunc(func(p []byte) (int, error) { + if i == 0 { + // send an initial packet downstream + // in order to correctly compute the timestamp + n := copy(p, []byte{ + 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x66, 0x55, + 0x9d, 0xbb, 0x78, 0x12, 0x06, 0x00, + }) + i++ + return n, nil + } + + n := copy(p, ca.enc[i-1]) + i++ + return n, nil + }) + + d := NewDecoder() + + _, err := d.Read(r) + require.NoError(t, err) + + dec, err := d.Read(r) + require.NoError(t, err) + require.Equal(t, ca.dec, dec) + }) + } +}