diff --git a/connclientdial_test.go b/connclientdial_test.go index c7625d4a..795c42fe 100644 --- a/connclientdial_test.go +++ b/connclientdial_test.go @@ -1,7 +1,6 @@ package gortsplib import ( - "fmt" "net" "os" "os/exec" @@ -9,8 +8,9 @@ import ( "testing" "time" - "github.com/pion/rtp" "github.com/stretchr/testify/require" + + "github.com/aler9/gortsplib/rtph264" ) type container struct { @@ -58,46 +58,6 @@ func (c *container) wait() int { return int(code) } -func getH264SPSandPPS(pc net.PacketConn) ([]byte, []byte, error) { - var sps []byte - var pps []byte - - buf := make([]byte, 2048) - for { - n, _, err := pc.ReadFrom(buf) - if err != nil { - return nil, nil, err - } - - packet := &rtp.Packet{} - err = packet.Unmarshal(buf[:n]) - if err != nil { - return nil, nil, err - } - - // require h264 - if packet.PayloadType != 96 { - return nil, nil, fmt.Errorf("wrong payload type '%d', expected 96", - packet.PayloadType) - } - - // switch by NALU type - switch packet.Payload[0] & 0x1F { - case 0x07: // sps - sps = append([]byte(nil), packet.Payload...) - if sps != nil && pps != nil { - return sps, pps, nil - } - - case 0x08: // pps - pps = append([]byte(nil), packet.Payload...) - if sps != nil && pps != nil { - return sps, pps, nil - } - } - } -} - func TestConnClientDialReadUDP(t *testing.T) { cnt1, err := newContainer("rtsp-simple-server", "server", []string{}) require.NoError(t, err) @@ -198,7 +158,8 @@ func TestConnClientDialPublishUDP(t *testing.T) { require.NoError(t, err) defer cnt2.close() - sps, pps, err := getH264SPSandPPS(pc) + decoder := rtph264.NewDecoderFromPacketConn(pc) + sps, pps, err := decoder.ReadSPSPPS() require.NoError(t, err) track, err := NewTrackH264(0, sps, pps) @@ -267,7 +228,8 @@ func TestConnClientDialPublishTCP(t *testing.T) { require.NoError(t, err) defer cnt2.close() - sps, pps, err := getH264SPSandPPS(pc) + decoder := rtph264.NewDecoderFromPacketConn(pc) + sps, pps, err := decoder.ReadSPSPPS() require.NoError(t, err) track, err := NewTrackH264(0, sps, pps) diff --git a/examples/client-publish-tcp.go b/examples/client-publish-tcp.go index f9429939..7213ec59 100644 --- a/examples/client-publish-tcp.go +++ b/examples/client-publish-tcp.go @@ -7,53 +7,13 @@ import ( "net" "github.com/aler9/gortsplib" - "github.com/pion/rtp" + "github.com/aler9/gortsplib/rtph264" ) // This example shows how to generate RTP/H264 frames from a file with Gstreamer, // create a RTSP client, connect to a server, announce a H264 track and write // the frames with the TCP protocol. -func getRtpH264SPSandPPS(pc net.PacketConn) ([]byte, []byte, error) { - var sps []byte - var pps []byte - - buf := make([]byte, 2048) - for { - n, _, err := pc.ReadFrom(buf) - if err != nil { - return nil, nil, err - } - - pkt := &rtp.Packet{} - err = pkt.Unmarshal(buf[:n]) - if err != nil { - return nil, nil, err - } - - // require h264 - if pkt.PayloadType != 96 { - return nil, nil, fmt.Errorf("wrong payload type '%d', expected 96", - pkt.PayloadType) - } - - // switch by NALU type - switch pkt.Payload[0] & 0x1F { - case 0x07: // sps - sps = append([]byte(nil), pkt.Payload...) - if sps != nil && pps != nil { - return sps, pps, nil - } - - case 0x08: // pps - pps = append([]byte(nil), pkt.Payload...) - if sps != nil && pps != nil { - return sps, pps, nil - } - } - } -} - func main() { // open a listener to receive RTP/H264 frames pc, err := net.ListenPacket("udp4", "127.0.0.1:9000") @@ -67,7 +27,8 @@ func main() { " ! h264parse config-interval=1 ! rtph264pay ! udpsink host=127.0.0.1 port=9000") // wait for RTP/H264 frames - sps, pps, err := getRtpH264SPSandPPS(pc) + decoder := rtph264.NewDecoderFromPacketConn(pc) + sps, pps, err := decoder.ReadSPSPPS() if err != nil { panic(err) } diff --git a/examples/client-publish-udp.go b/examples/client-publish-udp.go index 01c6c4e9..4d792f1b 100644 --- a/examples/client-publish-udp.go +++ b/examples/client-publish-udp.go @@ -7,53 +7,13 @@ import ( "net" "github.com/aler9/gortsplib" - "github.com/pion/rtp" + "github.com/aler9/gortsplib/rtph264" ) // This example shows how to generate RTP/H264 frames from a file with Gstreamer, // create a RTSP client, connect to a server, announce a H264 track and write // the frames with the UDP protocol. -func getRtpH264SPSandPPS(pc net.PacketConn) ([]byte, []byte, error) { - var sps []byte - var pps []byte - - buf := make([]byte, 2048) - for { - n, _, err := pc.ReadFrom(buf) - if err != nil { - return nil, nil, err - } - - pkt := &rtp.Packet{} - err = pkt.Unmarshal(buf[:n]) - if err != nil { - return nil, nil, err - } - - // require h264 - if pkt.PayloadType != 96 { - return nil, nil, fmt.Errorf("wrong payload type '%d', expected 96", - pkt.PayloadType) - } - - // switch by NALU type - switch pkt.Payload[0] & 0x1F { - case 0x07: // sps - sps = append([]byte(nil), pkt.Payload...) - if sps != nil && pps != nil { - return sps, pps, nil - } - - case 0x08: // pps - pps = append([]byte(nil), pkt.Payload...) - if sps != nil && pps != nil { - return sps, pps, nil - } - } - } -} - func main() { // open a listener to receive RTP/H264 frames pc, err := net.ListenPacket("udp4", "127.0.0.1:9000") @@ -67,7 +27,8 @@ func main() { " ! h264parse config-interval=1 ! rtph264pay ! udpsink host=127.0.0.1 port=9000") // wait for RTP/H264 frames - sps, pps, err := getRtpH264SPSandPPS(pc) + decoder := rtph264.NewDecoderFromPacketConn(pc) + sps, pps, err := decoder.ReadSPSPPS() if err != nil { panic(err) } diff --git a/rtph264/decoder.go b/rtph264/decoder.go new file mode 100644 index 00000000..39c1aedc --- /dev/null +++ b/rtph264/decoder.go @@ -0,0 +1,141 @@ +// Package rtph264 contains a RTP/H264 decoder and encoder. +package rtph264 + +import ( + "fmt" + "io" + "net" + + "github.com/pion/rtp" +) + +type packetConnReader struct { + inner net.PacketConn +} + +func (r packetConnReader) Read(p []byte) (int, error) { + n, _, err := r.inner.ReadFrom(p) + return n, err +} + +// Decoder is a RTP/H264 decoder. +type Decoder struct { + r io.Reader + buf []byte +} + +// NewDecoderFromPacketConn creates a decoder around a Reader. +func NewDecoder(r io.Reader) *Decoder { + return &Decoder{ + r: r, + buf: make([]byte, 2048), + } +} + +// NewDecoderFromPacketConn creates a decoder around a net.PacketConn. +func NewDecoderFromPacketConn(pc net.PacketConn) *Decoder { + return NewDecoder(packetConnReader{pc}) +} + +// Read decodes NALUs from RTP/H264 packets. +func (d *Decoder) Read() ([][]byte, error) { + n, err := d.r.Read(d.buf) + if err != nil { + return nil, err + } + + pkt := rtp.Packet{} + err = pkt.Unmarshal(d.buf[:n]) + if err != nil { + return nil, err + } + payload := pkt.Payload + + typ := naluType(payload[0] & 0x1F) + + if typ >= naluTypeFirstSingle && typ <= naluTypeLastSingle { + return [][]byte{payload}, nil + } + + switch typ { + case naluTypeFuA: + return d.readFragmented(payload) + + case naluTypeStapA, naluTypeStapB, naluTypeMtap16, naluTypeMtap24, naluTypeFuB: + return nil, fmt.Errorf("NALU type not supported (%d)", typ) + } + + return nil, fmt.Errorf("invalid NALU type (%d)", typ) +} + +func (d *Decoder) readFragmented(payload []byte) ([][]byte, error) { + // A NALU can have any size; we can't preallocate it + var ret []byte + + // process first nalu + nri := (payload[0] >> 5) & 0x03 + start := payload[1] >> 7 + if start != 1 { + return nil, fmt.Errorf("first NALU does not contain the start bit") + } + typ := payload[1] & 0x1F + ret = append([]byte{(nri << 5) | typ}, payload[2:]...) + + // process other nalus + for { + n, err := d.r.Read(d.buf) + if err != nil { + return nil, err + } + + pkt := rtp.Packet{} + err = pkt.Unmarshal(d.buf[:n]) + if err != nil { + return nil, err + } + payload := pkt.Payload + + typ := naluType(payload[0] & 0x1F) + if typ != naluTypeFuA { + return nil, fmt.Errorf("non-starting NALU is not FU-A") + } + end := (payload[1] >> 6) & 0x01 + + ret = append(ret, payload[2:]...) + + if end == 1 { + break + } + } + + return [][]byte{ret}, nil +} + +// ReadSPSPPS decodes NALUs until SPS and PPS are found. +func (d *Decoder) ReadSPSPPS() ([]byte, []byte, error) { + var sps []byte + var pps []byte + + for { + nalus, err := d.Read() + if err != nil { + return nil, nil, err + } + + for _, nalu := range nalus { + switch naluType(nalu[0] & 0x1F) { + case naluTypeSPS: + sps = append([]byte(nil), nalu...) + if sps != nil && pps != nil { + return sps, pps, nil + } + + case naluTypePPS: + pps = append([]byte(nil), nalu...) + if sps != nil && pps != nil { + return sps, pps, nil + } + } + } + } +} diff --git a/rtph264/defs.go b/rtph264/defs.go new file mode 100644 index 00000000..a5492d50 --- /dev/null +++ b/rtph264/defs.go @@ -0,0 +1,16 @@ +package rtph264 + +type naluType uint8 + +const ( + naluTypeFirstSingle naluType = 1 + naluTypeSPS naluType = 7 + naluTypePPS naluType = 8 + naluTypeLastSingle naluType = 23 + naluTypeStapA naluType = 24 + naluTypeStapB naluType = 25 + naluTypeMtap16 naluType = 26 + naluTypeMtap24 naluType = 27 + naluTypeFuA naluType = 28 + naluTypeFuB naluType = 29 +) diff --git a/rtph264/encoder.go b/rtph264/encoder.go new file mode 100644 index 00000000..b31f1956 --- /dev/null +++ b/rtph264/encoder.go @@ -0,0 +1,148 @@ +package rtph264 + +import ( + "math/rand" + "time" + + "github.com/pion/rtp" +) + +const ( + rtpVersion = 0x02 + rtpPayloadMaxSize = 1460 // 1500 - ip header - udp header - rtp header +) + +// Encoder is a RTP/H264 encoder. +type Encoder struct { + payloadType uint8 + sequenceNumber uint16 + ssrc uint32 + initialTs uint32 + started time.Duration +} + +// NewEncoder allocates an Encoder. +func NewEncoder(relativeType uint8) *Encoder { + return &Encoder{ + payloadType: 96 + relativeType, + sequenceNumber: uint16(0), + ssrc: rand.Uint32(), + initialTs: rand.Uint32(), + } +} + +// Write encodes NALUs into RTP/H264 packets. +func (e *Encoder) Write(nalus [][]byte, timestamp time.Duration) ([][]byte, error) { + var frames [][]byte + + if e.started == time.Duration(0) { + e.started = timestamp + } + + // rtp/h264 uses a 90khz clock + rtpTs := e.initialTs + uint32((timestamp-e.started).Seconds()*90000) + + for i, nalu := range nalus { + naluFrames, err := e.writeNalu(nalu, rtpTs, (i == len(nalus)-1)) + if err != nil { + return nil, err + } + frames = append(frames, naluFrames...) + } + + return frames, nil +} + +func (e *Encoder) writeNalu(nalu []byte, rtpTs uint32, isFinal bool) ([][]byte, error) { + // if the NALU fits into a single RTP packet, use a single NALU payload + if len(nalu) < rtpPayloadMaxSize { + return e.writeSingle(nalu, rtpTs, isFinal) + } + + // otherwise, split the NALU into multiple fragmentation payloads + return e.writeFragmented(nalu, rtpTs, isFinal) +} + +func (e *Encoder) writeSingle(nalu []byte, rtpTs uint32, isFinal bool) ([][]byte, error) { + rpkt := rtp.Packet{ + Header: rtp.Header{ + Version: rtpVersion, + PayloadType: e.payloadType, + SequenceNumber: e.sequenceNumber, + Timestamp: rtpTs, + SSRC: e.ssrc, + }, + Payload: nalu, + } + e.sequenceNumber++ + + if isFinal { + rpkt.Header.Marker = true + } + + frame, err := rpkt.Marshal() + if err != nil { + return nil, err + } + + return [][]byte{frame}, nil +} + +func (e *Encoder) writeFragmented(nalu []byte, rtpTs uint32, isFinal bool) ([][]byte, error) { + // use only FU-A, not FU-B, since we always use non-interleaved mode + // (packetization-mode=1) + frameCount := (len(nalu) - 1) / (rtpPayloadMaxSize - 2) + lastFrameSize := (len(nalu) - 1) % (rtpPayloadMaxSize - 2) + if lastFrameSize > 0 { + frameCount++ + } + frames := make([][]byte, frameCount) + + nri := (nalu[0] >> 5) & 0x03 + typ := nalu[0] & 0x1F + nalu = nalu[1:] // remove header + + for i := 0; i < frameCount; i++ { + indicator := (nri << 5) | uint8(naluTypeFuA) + + start := uint8(0) + if i == 0 { + start = 1 + } + end := uint8(0) + le := rtpPayloadMaxSize - 2 + if i == (len(frames) - 1) { + end = 1 + le = lastFrameSize + } + header := (start << 7) | (end << 6) | typ + + data := append([]byte{indicator, header}, nalu[:le]...) + nalu = nalu[le:] + + rpkt := rtp.Packet{ + Header: rtp.Header{ + Version: rtpVersion, + PayloadType: e.payloadType, + SequenceNumber: e.sequenceNumber, + Timestamp: rtpTs, + SSRC: e.ssrc, + }, + Payload: data, + } + e.sequenceNumber++ + + if isFinal && i == (len(frames)-1) { + rpkt.Header.Marker = true + } + + frame, err := rpkt.Marshal() + if err != nil { + return nil, err + } + + frames[i] = frame + } + + return frames, nil +}