rtmp: rewrite implementation of rtmp connection (#1047)

* rtmp: improve MsgCommandAMF0

* rtmp: fix MsgSetPeerBandwidth

* rtmp: add message tests

* rtmp: replace implementation with new one

* rtmp: rename handshake functions

* rtmp: avoid calling useless function

* rtmp: use time.Duration for PTSDelta

* rtmp: fix decoding chunks with relevant size

* rtmp: rewrite implementation of rtmp connection

* rtmp: fix tests

* rtmp: improve error message

* rtmp: replace h264 config implementation

* link against github.com/notedit/rtmp

* normalize MessageStreamID

* rtmp: make acknowledge optional

* rtmp: fix decoding of chunk2 + chunk3

* avoid using encoding/binary
This commit is contained in:
Alessandro Ros
2022-07-17 15:17:18 +02:00
committed by GitHub
parent 50d205274f
commit 9e6abc6e9f
45 changed files with 2045 additions and 1064 deletions

6
go.mod
View File

@@ -5,14 +5,14 @@ go 1.17
require (
code.cloudfoundry.org/bytefmt v0.0.0-20211005130812-5bb3c17173e5
github.com/abema/go-mp4 v0.7.2
github.com/aler9/gortsplib v0.0.0-20220709151311-234e4f4f8d6f
github.com/aler9/gortsplib v0.0.0-20220717125404-c6972424d6b8
github.com/asticode/go-astits v1.10.1-0.20220319093903-4abe66a9b757
github.com/fsnotify/fsnotify v1.4.9
github.com/gin-gonic/gin v1.8.1
github.com/gookit/color v1.4.2
github.com/grafov/m3u8 v0.11.1
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51
github.com/notedit/rtmp v0.0.0
github.com/notedit/rtmp v0.0.2
github.com/orcaman/writerseeker v0.0.0
github.com/pion/rtp v1.7.13
github.com/stretchr/testify v1.7.1
@@ -51,6 +51,4 @@ require (
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
)
replace github.com/notedit/rtmp => github.com/aler9/rtmp v0.0.0-20210403095203-3be4a5535927
replace github.com/orcaman/writerseeker => github.com/aler9/writerseeker v0.0.0-20220601075008-6f0e685b9c82

8
go.sum
View File

@@ -6,10 +6,8 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafo
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d h1:UQZhZ2O0vMHr2cI+DC1Mbh0TJxzA3RcLoMsFw+aXw7E=
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho=
github.com/aler9/gortsplib v0.0.0-20220709151311-234e4f4f8d6f h1:EC+MOSv3e8ZEvtdHoL1++HahNoiVIkvu2Ygjrx6LyOg=
github.com/aler9/gortsplib v0.0.0-20220709151311-234e4f4f8d6f/go.mod h1:WI3nMhY2mM6nfoeW9uyk7TyG5Qr6YnYxmFoCply0sbo=
github.com/aler9/rtmp v0.0.0-20210403095203-3be4a5535927 h1:95mXJ5fUCYpBRdSOnLAQAdJHHKxxxJrVCiaqDi965YQ=
github.com/aler9/rtmp v0.0.0-20210403095203-3be4a5535927/go.mod h1:vzuE21rowz+lT1NGsWbreIvYulgBpCGnQyeTyFblUHc=
github.com/aler9/gortsplib v0.0.0-20220717125404-c6972424d6b8 h1:GdQOJFYbcrw8bXGClhroHTBIEJAb/jPCIV33Q966rms=
github.com/aler9/gortsplib v0.0.0-20220717125404-c6972424d6b8/go.mod h1:WI3nMhY2mM6nfoeW9uyk7TyG5Qr6YnYxmFoCply0sbo=
github.com/aler9/writerseeker v0.0.0-20220601075008-6f0e685b9c82 h1:9WgSzBLo3a9ToSVV7sRTBYZ1GGOZUpq4+5H3SN0UZq4=
github.com/aler9/writerseeker v0.0.0-20220601075008-6f0e685b9c82/go.mod h1:qsMrZCbeBf/mCLOeF16KDkPu4gktn/pOWyaq1aYQE7U=
github.com/asticode/go-astikit v0.20.0 h1:+7N+J4E4lWx2QOkRdOf6DafWJMv6O4RRfgClwQokrH8=
@@ -83,6 +81,8 @@ github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OH
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/notedit/rtmp v0.0.2 h1:5+to4yezKATiJgnrcETu9LbV5G/QsWkOV9Ts2M/p33w=
github.com/notedit/rtmp v0.0.2/go.mod h1:vzuE21rowz+lT1NGsWbreIvYulgBpCGnQyeTyFblUHc=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=

View File

@@ -16,13 +16,14 @@ import (
"github.com/aler9/gortsplib/pkg/ringbuffer"
"github.com/aler9/gortsplib/pkg/rtpaac"
"github.com/aler9/gortsplib/pkg/rtph264"
"github.com/notedit/rtmp/av"
nh264 "github.com/notedit/rtmp/codec/h264"
"github.com/notedit/rtmp/format/flv/flvio"
"github.com/aler9/rtsp-simple-server/internal/conf"
"github.com/aler9/rtsp-simple-server/internal/externalcmd"
"github.com/aler9/rtsp-simple-server/internal/logger"
"github.com/aler9/rtsp-simple-server/internal/rtmp"
"github.com/aler9/rtsp-simple-server/internal/rtmp/h264conf"
"github.com/aler9/rtsp-simple-server/internal/rtmp/message"
)
const (
@@ -107,7 +108,7 @@ func newRTMPConn(
runOnConnect: runOnConnect,
runOnConnectRestart: runOnConnectRestart,
wg: wg,
conn: rtmp.NewServerConn(nconn),
conn: rtmp.NewConn(nconn),
nconn: nconn,
externalCmdPool: externalCmdPool,
pathManager: pathManager,
@@ -211,19 +212,19 @@ func (c *rtmpConn) runInner(ctx context.Context) error {
c.nconn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout)))
c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout)))
err := c.conn.ServerHandshake()
u, isReading, err := c.conn.InitializeServer()
if err != nil {
return err
}
if c.conn.IsPublishing() {
return c.runPublish(ctx)
if isReading {
return c.runRead(ctx, u)
}
return c.runRead(ctx)
return c.runPublish(ctx, u)
}
func (c *rtmpConn) runRead(ctx context.Context) error {
pathName, query, rawQuery := pathNameAndQuery(c.conn.URL())
func (c *rtmpConn) runRead(ctx context.Context, u *url.URL) error {
pathName, query, rawQuery := pathNameAndQuery(u)
res := c.pathManager.onReaderSetupPlay(pathReaderSetupPlayReq{
author: c,
@@ -410,22 +411,17 @@ func (c *rtmpConn) runRead(ctx context.Context) error {
sps := videoTrack.SafeSPS()
pps := videoTrack.SafePPS()
codec := nh264.Codec{
SPS: map[int][]byte{
0: sps,
},
PPS: map[int][]byte{
0: pps,
},
}
b := make([]byte, 128)
var n int
codec.ToConfig(b, &n)
b = b[:n]
buf, _ := h264conf.Conf{
SPS: sps,
PPS: pps,
}.Marshal()
err = c.conn.WritePacket(av.Packet{
Type: av.H264DecoderConfig,
Data: b,
err = c.conn.WriteMessage(&message.MsgVideo{
ChunkStreamID: 6,
MessageStreamID: 1,
IsKeyFrame: true,
H264Type: flvio.AVC_SEQHDR,
Payload: buf,
})
if err != nil {
return err
@@ -438,11 +434,14 @@ func (c *rtmpConn) runRead(ctx context.Context) error {
}
c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout)))
err = c.conn.WritePacket(av.Packet{
Type: av.H264,
Data: avcc,
Time: dts,
CTime: pts - dts,
err = c.conn.WriteMessage(&message.MsgVideo{
ChunkStreamID: 6,
MessageStreamID: 1,
IsKeyFrame: idrPresent,
H264Type: flvio.AVC_NALU,
Payload: avcc,
DTS: dts,
PTSDelta: pts - dts,
})
if err != nil {
return err
@@ -467,10 +466,15 @@ func (c *rtmpConn) runRead(ctx context.Context) error {
for i, au := range aus {
c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout)))
err := c.conn.WritePacket(av.Packet{
Type: av.AAC,
Data: au,
Time: pts + time.Duration(i)*aac.SamplesPerAccessUnit*time.Second/time.Duration(audioTrack.ClockRate()),
err := c.conn.WriteMessage(&message.MsgAudio{
ChunkStreamID: 4,
MessageStreamID: 1,
Rate: flvio.SOUND_44Khz,
Depth: flvio.SOUND_16BIT,
Channels: flvio.SOUND_STEREO,
AACType: flvio.AAC_RAW,
Payload: au,
DTS: pts + time.Duration(i)*aac.SamplesPerAccessUnit*time.Second/time.Duration(audioTrack.ClockRate()),
})
if err != nil {
return err
@@ -480,7 +484,7 @@ func (c *rtmpConn) runRead(ctx context.Context) error {
}
}
func (c *rtmpConn) runPublish(ctx context.Context) error {
func (c *rtmpConn) runPublish(ctx context.Context, u *url.URL) error {
c.nconn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout)))
videoTrack, audioTrack, err := c.conn.ReadTracks()
if err != nil {
@@ -513,7 +517,7 @@ func (c *rtmpConn) runPublish(ctx context.Context) error {
tracks = append(tracks, audioTrack)
}
pathName, query, rawQuery := pathNameAndQuery(c.conn.URL())
pathName, query, rawQuery := pathNameAndQuery(u)
res := c.pathManager.onPublisherAnnounce(pathPublisherAnnounceReq{
author: c,
@@ -559,22 +563,24 @@ func (c *rtmpConn) runPublish(ctx context.Context) error {
for {
c.nconn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout)))
pkt, err := c.conn.ReadPacket()
msg, err := c.conn.ReadMessage()
if err != nil {
return err
}
switch pkt.Type {
case av.H264DecoderConfig:
codec, err := nh264.FromDecoderConfig(pkt.Data)
switch tmsg := msg.(type) {
case *message.MsgVideo:
if tmsg.H264Type == flvio.AVC_SEQHDR {
var conf h264conf.Conf
err = conf.Unmarshal(tmsg.Payload)
if err != nil {
return err
return fmt.Errorf("unable to parse H264 config: %v", err)
}
pts := pkt.Time + pkt.CTime
pts := tmsg.DTS + tmsg.PTSDelta
nalus := [][]byte{
codec.SPS[0],
codec.PPS[0],
conf.SPS,
conf.PPS,
}
pkts, err := h264Encoder.Encode(nalus, pts)
@@ -600,15 +606,14 @@ func (c *rtmpConn) runPublish(ctx context.Context) error {
})
}
}
case av.H264:
} else if tmsg.H264Type == flvio.AVC_NALU {
if videoTrack == nil {
return fmt.Errorf("received an H264 packet, but track is not set up")
}
nalus, err := h264.AVCCUnmarshal(pkt.Data)
nalus, err := h264.AVCCUnmarshal(tmsg.Payload)
if err != nil {
return err
return fmt.Errorf("unable to decode AVCC: %v", err)
}
// skip invalid NALUs sent by DJI
@@ -631,7 +636,7 @@ func (c *rtmpConn) runPublish(ctx context.Context) error {
}
}
pts := pkt.Time + pkt.CTime
pts := tmsg.DTS + tmsg.PTSDelta
pkts, err := h264Encoder.Encode(validNALUs, pts)
if err != nil {
@@ -656,13 +661,15 @@ func (c *rtmpConn) runPublish(ctx context.Context) error {
})
}
}
}
case av.AAC:
case *message.MsgAudio:
if tmsg.AACType == flvio.AAC_RAW {
if audioTrack == nil {
return fmt.Errorf("received an AAC packet, but track is not set up")
}
pkts, err := aacEncoder.Encode([][]byte{pkt.Data}, pkt.Time+pkt.CTime)
pkts, err := aacEncoder.Encode([][]byte{tmsg.Payload}, tmsg.DTS)
if err != nil {
return fmt.Errorf("error while encoding AAC: %v", err)
}
@@ -676,6 +683,7 @@ func (c *rtmpConn) runPublish(ctx context.Context) error {
}
}
}
}
}
func (c *rtmpConn) authenticate(

View File

@@ -3,7 +3,6 @@ package core
import (
"context"
"crypto/rand"
"encoding/binary"
"fmt"
"net"
"strconv"
@@ -259,7 +258,7 @@ func (s *rtmpServer) newConnID() (string, error) {
return "", err
}
u := binary.LittleEndian.Uint32(b)
u := uint32(b[3])<<24 | uint32(b[2])<<16 | uint32(b[1])<<8 | uint32(b[0])
u %= 899999999
u += 100000000

View File

@@ -141,9 +141,9 @@ func TestRTMPServerAuth(t *testing.T) {
nconn, err := net.Dial("tcp", u.Host)
require.NoError(t, err)
defer nconn.Close()
conn := rtmp.NewClientConn(nconn, u)
conn := rtmp.NewConn(nconn)
err = conn.ClientHandshake(true)
err = conn.InitializeClient(u, true)
require.NoError(t, err)
_, _, err = conn.ReadTracks()
@@ -229,9 +229,17 @@ func TestRTMPServerAuthFail(t *testing.T) {
nconn, err := net.Dial("tcp", u.Host)
require.NoError(t, err)
defer nconn.Close()
conn := rtmp.NewClientConn(nconn, u)
conn := rtmp.NewConn(nconn)
err = conn.ClientHandshake(true)
err = conn.InitializeClient(u, true)
require.NoError(t, err)
for i := 0; i < 3; i++ {
_, err := conn.ReadMessage()
require.NoError(t, err)
}
_, err = conn.ReadMessage()
require.Equal(t, err, io.EOF)
})
}

View File

@@ -12,11 +12,12 @@ import (
"github.com/aler9/gortsplib/pkg/h264"
"github.com/aler9/gortsplib/pkg/rtpaac"
"github.com/aler9/gortsplib/pkg/rtph264"
"github.com/notedit/rtmp/av"
"github.com/notedit/rtmp/format/flv/flvio"
"github.com/aler9/rtsp-simple-server/internal/conf"
"github.com/aler9/rtsp-simple-server/internal/logger"
"github.com/aler9/rtsp-simple-server/internal/rtmp"
"github.com/aler9/rtsp-simple-server/internal/rtmp/message"
)
const (
@@ -126,14 +127,14 @@ func (s *rtmpSource) runInner() bool {
return err
}
conn := rtmp.NewClientConn(nconn, u)
conn := rtmp.NewConn(nconn)
readDone := make(chan error)
go func() {
readDone <- func() error {
nconn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout)))
nconn.SetWriteDeadline(time.Now().Add(time.Duration(s.writeTimeout)))
err = conn.ClientHandshake(true)
err = conn.InitializeClient(u, true)
if err != nil {
return err
}
@@ -187,23 +188,24 @@ func (s *rtmpSource) runInner() bool {
for {
nconn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout)))
pkt, err := conn.ReadPacket()
msg, err := conn.ReadMessage()
if err != nil {
return err
}
switch pkt.Type {
case av.H264:
switch tmsg := msg.(type) {
case *message.MsgVideo:
if tmsg.H264Type == flvio.AVC_NALU {
if videoTrack == nil {
return fmt.Errorf("received an H264 packet, but track is not set up")
}
nalus, err := h264.AVCCUnmarshal(pkt.Data)
nalus, err := h264.AVCCUnmarshal(tmsg.Payload)
if err != nil {
return err
return fmt.Errorf("unable to decode AVCC: %v", err)
}
pts := pkt.Time + pkt.CTime
pts := tmsg.DTS + tmsg.PTSDelta
pkts, err := h264Encoder.Encode(nalus, pts)
if err != nil {
@@ -228,13 +230,15 @@ func (s *rtmpSource) runInner() bool {
})
}
}
}
case av.AAC:
case *message.MsgAudio:
if tmsg.AACType == flvio.AAC_RAW {
if audioTrack == nil {
return fmt.Errorf("received an AAC packet, but track is not set up")
}
pkts, err := aacEncoder.Encode([][]byte{pkt.Data}, pkt.Time+pkt.CTime)
pkts, err := aacEncoder.Encode([][]byte{tmsg.Payload}, tmsg.DTS)
if err != nil {
return fmt.Errorf("error while encoding AAC: %v", err)
}
@@ -248,6 +252,7 @@ func (s *rtmpSource) runInner() bool {
}
}
}
}
}()
}()

View File

@@ -4,7 +4,6 @@ import (
"context"
"crypto/rand"
"crypto/tls"
"encoding/binary"
"fmt"
"strconv"
"strings"
@@ -235,7 +234,7 @@ func (s *rtspServer) newSessionID() (string, error) {
return "", err
}
u := binary.LittleEndian.Uint32(b)
u := uint32(b[3])<<24 | uint32(b[2])<<16 | uint32(b[1])<<8 | uint32(b[0])
u %= 899999999
u += 100000000

View File

@@ -20,7 +20,7 @@ type Chunk0 struct {
// Read reads the chunk.
func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen uint32) error {
header := make([]byte, 12)
_, err := r.Read(header)
_, err := io.ReadFull(r, header)
if err != nil {
return err
}
@@ -37,7 +37,7 @@ func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen uint32) error {
}
c.Body = make([]byte, chunkBodyLen)
_, err = r.Read(c.Body)
_, err = io.ReadFull(r, c.Body)
return err
}

View File

@@ -21,7 +21,7 @@ type Chunk1 struct {
// Read reads the chunk.
func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen uint32) error {
header := make([]byte, 8)
_, err := r.Read(header)
_, err := io.ReadFull(r, header)
if err != nil {
return err
}
@@ -37,7 +37,7 @@ func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen uint32) error {
}
c.Body = make([]byte, chunkBodyLen)
_, err = r.Read(c.Body)
_, err = io.ReadFull(r, c.Body)
return err
}

View File

@@ -17,7 +17,7 @@ type Chunk2 struct {
// Read reads the chunk.
func (c *Chunk2) Read(r io.Reader, chunkBodyLen uint32) error {
header := make([]byte, 4)
_, err := r.Read(header)
_, err := io.ReadFull(r, header)
if err != nil {
return err
}
@@ -26,7 +26,7 @@ func (c *Chunk2) Read(r io.Reader, chunkBodyLen uint32) error {
c.TimestampDelta = uint32(header[1])<<16 | uint32(header[2])<<8 | uint32(header[3])
c.Body = make([]byte, chunkBodyLen)
_, err = r.Read(c.Body)
_, err = io.ReadFull(r, c.Body)
return err
}

View File

@@ -18,7 +18,7 @@ type Chunk3 struct {
// Read reads the chunk.
func (c *Chunk3) Read(r io.Reader, chunkBodyLen uint32) error {
header := make([]byte, 1)
_, err := r.Read(header)
_, err := io.ReadFull(r, header)
if err != nil {
return err
}
@@ -26,7 +26,7 @@ func (c *Chunk3) Read(r io.Reader, chunkBodyLen uint32) error {
c.ChunkStreamID = header[0] & 0x3F
c.Body = make([]byte, chunkBodyLen)
_, err = r.Read(c.Body)
_, err = io.ReadFull(r, c.Body)
return err
}

View File

@@ -1,111 +1,590 @@
package rtmp
import (
"bufio"
"errors"
"fmt"
"net"
"net/url"
"strings"
"time"
"github.com/aler9/gortsplib"
"github.com/aler9/gortsplib/pkg/aac"
"github.com/notedit/rtmp/av"
nhaac "github.com/notedit/rtmp/codec/aac"
nh264 "github.com/notedit/rtmp/codec/h264"
"github.com/notedit/rtmp/format/flv/flvio"
"github.com/notedit/rtmp/format/rtmp"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/h264conf"
"github.com/aler9/rtsp-simple-server/internal/rtmp/handshake"
"github.com/aler9/rtsp-simple-server/internal/rtmp/message"
)
const (
readBufferSize = 4096
writeBufferSize = 4096
codecH264 = 7
codecAAC = 10
)
// Conn is a RTMP connection.
type Conn struct {
rconn *rtmp.Conn
}
// NewClientConn initializes a client-side connection.
func NewClientConn(nconn net.Conn, u *url.URL) *Conn {
c := rtmp.NewConn(&bufio.ReadWriter{
Reader: bufio.NewReaderSize(nconn, readBufferSize),
Writer: bufio.NewWriterSize(nconn, writeBufferSize),
})
c.URL = u
return &Conn{
rconn: c,
func resultIsOK1(res *message.MsgCommandAMF0) bool {
if len(res.Arguments) < 2 {
return false
}
}
// NewServerConn initializes a server-side connection.
func NewServerConn(nconn net.Conn) *Conn {
// https://github.com/aler9/rtmp/blob/master/format/rtmp/server.go#L46
c := rtmp.NewConn(&bufio.ReadWriter{
Reader: bufio.NewReaderSize(nconn, readBufferSize),
Writer: bufio.NewWriterSize(nconn, writeBufferSize),
})
c.IsServer = true
return &Conn{
rconn: c,
ma, ok := res.Arguments[1].(flvio.AMFMap)
if !ok {
return false
}
}
// ClientHandshake performs the handshake of a client-side connection.
func (c *Conn) ClientHandshake(isPlaying bool) error {
var flag int
if isPlaying {
flag = rtmp.PrepareReading
} else {
flag = rtmp.PrepareWriting
v, ok := ma.GetString("level")
if !ok {
return false
}
return c.rconn.Prepare(rtmp.StageGotPublishOrPlayCommand, flag)
return v == "status"
}
// ServerHandshake performs the handshake of a server-side connection.
func (c *Conn) ServerHandshake() error {
return c.rconn.Prepare(rtmp.StageGotPublishOrPlayCommand, 0)
}
// IsPublishing returns whether the connection is publishing.
func (c *Conn) IsPublishing() bool {
return c.rconn.Publishing
}
// URL returns the URL requested by the connection.
func (c *Conn) URL() *url.URL {
return c.rconn.URL
}
// ReadPacket reads a packet.
func (c *Conn) ReadPacket() (av.Packet, error) {
return c.rconn.ReadPacket()
}
// WritePacket writes a packet.
func (c *Conn) WritePacket(pkt av.Packet) error {
err := c.rconn.WritePacket(pkt)
if err != nil {
return err
func resultIsOK2(res *message.MsgCommandAMF0) bool {
if len(res.Arguments) < 2 {
return false
}
return c.rconn.FlushWrite()
v, ok := res.Arguments[1].(float64)
if !ok {
return false
}
return v == 1
}
func trackFromH264DecoderConfig(data []byte) (*gortsplib.TrackH264, error) {
codec, err := nh264.FromDecoderConfig(data)
func splitPath(u *url.URL) (app, stream string) {
nu := *u
nu.ForceQuery = false
pathsegs := strings.Split(nu.RequestURI(), "/")
if len(pathsegs) == 2 {
app = pathsegs[1]
}
if len(pathsegs) == 3 {
app = pathsegs[1]
stream = pathsegs[2]
}
if len(pathsegs) > 3 {
app = strings.Join(pathsegs[1:3], "/")
stream = strings.Join(pathsegs[3:], "/")
}
return
}
func getTcURL(u *url.URL) string {
app, _ := splitPath(u)
nu, _ := url.Parse(u.String()) // perform a deep copy
nu.RawQuery = ""
nu.Path = "/"
return nu.String() + app
}
func createURL(tcurl, app, play string) (*url.URL, error) {
u, err := url.ParseRequestURI("/" + app + "/" + play)
if err != nil {
return nil, err
}
tu, err := url.Parse(tcurl)
if err != nil {
return nil, err
}
if tu.Host == "" {
return nil, fmt.Errorf("invalid host")
}
u.Host = tu.Host
if tu.Scheme == "" {
return nil, fmt.Errorf("invalid scheme")
}
u.Scheme = tu.Scheme
return u, nil
}
// Conn is a RTMP connection.
type Conn struct {
bc *bytecounter.ReadWriter
mrw *message.ReadWriter
}
// NewConn initializes a connection.
func NewConn(nconn net.Conn) *Conn {
c := &Conn{}
c.bc = bytecounter.NewReadWriter(nconn)
c.mrw = message.NewReadWriter(c.bc, false)
return c
}
func (c *Conn) readCommand() (*message.MsgCommandAMF0, error) {
for {
msg, err := c.mrw.Read()
if err != nil {
return nil, err
}
if cmd, ok := msg.(*message.MsgCommandAMF0); ok {
return cmd, nil
}
}
}
func (c *Conn) readCommandResult(commandName string, isValid func(*message.MsgCommandAMF0) bool) error {
for {
msg, err := c.mrw.Read()
if err != nil {
return err
}
if cmd, ok := msg.(*message.MsgCommandAMF0); ok {
if cmd.Name == commandName {
if !isValid(cmd) {
return fmt.Errorf("server refused connect request")
}
return nil
}
}
}
}
// InitializeClient performs the initialization of a client-side connection.
func (c *Conn) InitializeClient(u *url.URL, isPlaying bool) error {
connectpath, actionpath := splitPath(u)
err := handshake.DoClient(c.bc, false)
if err != nil {
return err
}
err = c.mrw.Write(&message.MsgSetWindowAckSize{
Value: 2500000,
})
if err != nil {
return err
}
err = c.mrw.Write(&message.MsgSetPeerBandwidth{
Value: 2500000,
Type: 2,
})
if err != nil {
return err
}
err = c.mrw.Write(&message.MsgSetChunkSize{
Value: 65536,
})
if err != nil {
return err
}
err = c.mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Name: "connect",
CommandID: 1,
Arguments: []interface{}{
flvio.AMFMap{
{K: "app", V: connectpath},
{K: "flashVer", V: "LNX 9,0,124,2"},
{K: "tcUrl", V: getTcURL(u)},
{K: "fpad", V: false},
{K: "capabilities", V: 15},
{K: "audioCodecs", V: 4071},
{K: "videoCodecs", V: 252},
{K: "videoFunction", V: 1},
},
},
})
if err != nil {
return err
}
err = c.readCommandResult("_result", resultIsOK1)
if err != nil {
return err
}
if isPlaying {
err = c.mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Name: "createStream",
CommandID: 2,
Arguments: []interface{}{
nil,
},
})
if err != nil {
return err
}
err = c.readCommandResult("_result", resultIsOK2)
if err != nil {
return err
}
err = c.mrw.Write(&message.MsgUserControlSetBufferLength{
BufferLength: 0x64,
})
if err != nil {
return err
}
err = c.mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 4,
MessageStreamID: 0x1000000,
Name: "play",
CommandID: 0,
Arguments: []interface{}{
nil,
actionpath,
},
})
if err != nil {
return err
}
err = c.readCommandResult("onStatus", resultIsOK1)
if err != nil {
return err
}
} else {
err = c.mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Name: "releaseStream",
CommandID: 2,
Arguments: []interface{}{
nil,
actionpath,
},
})
if err != nil {
return err
}
err = c.mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Name: "FCPublish",
CommandID: 3,
Arguments: []interface{}{
nil,
actionpath,
},
})
if err != nil {
return err
}
err = c.mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Name: "createStream",
CommandID: 4,
Arguments: []interface{}{
nil,
},
})
if err != nil {
return err
}
err = c.readCommandResult("_result", resultIsOK2)
if err != nil {
return err
}
err = c.mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 4,
MessageStreamID: 0x1000000,
Name: "publish",
CommandID: 5,
Arguments: []interface{}{
nil,
actionpath,
connectpath,
},
})
if err != nil {
return err
}
err = c.readCommandResult("onStatus", resultIsOK1)
if err != nil {
return err
}
}
return nil
}
// InitializeServer performs the initialization of a server-side connection.
func (c *Conn) InitializeServer() (*url.URL, bool, error) {
err := handshake.DoServer(c.bc, false)
if err != nil {
return nil, false, err
}
cmd, err := c.readCommand()
if err != nil {
return nil, false, err
}
if cmd.Name != "connect" {
return nil, false, fmt.Errorf("unexpected command: %+v", cmd)
}
if len(cmd.Arguments) < 1 {
return nil, false, fmt.Errorf("invalid connect command: %+v", cmd)
}
ma, ok := cmd.Arguments[0].(flvio.AMFMap)
if !ok {
return nil, false, fmt.Errorf("invalid connect command: %+v", cmd)
}
connectpath, ok := ma.GetString("app")
if !ok {
return nil, false, fmt.Errorf("invalid connect command: %+v", cmd)
}
tcURL, ok := ma.GetString("tcUrl")
if !ok {
tcURL, ok = ma.GetString("tcurl")
if !ok {
return nil, false, fmt.Errorf("invalid connect command: %+v", cmd)
}
}
err = c.mrw.Write(&message.MsgSetWindowAckSize{
Value: 2500000,
})
if err != nil {
return nil, false, err
}
err = c.mrw.Write(&message.MsgSetPeerBandwidth{
Value: 2500000,
Type: 2,
})
if err != nil {
return nil, false, err
}
err = c.mrw.Write(&message.MsgSetChunkSize{
Value: 65536,
})
if err != nil {
return nil, false, err
}
oe, _ := ma.GetFloat64("objectEncoding")
err = c.mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: cmd.ChunkStreamID,
Name: "_result",
CommandID: cmd.CommandID,
Arguments: []interface{}{
flvio.AMFMap{
{K: "fmsVer", V: "LNX 9,0,124,2"},
{K: "capabilities", V: float64(31)},
},
flvio.AMFMap{
{K: "level", V: "status"},
{K: "code", V: "NetConnection.Connect.Success"},
{K: "description", V: "Connection succeeded."},
{K: "objectEncoding", V: oe},
},
},
})
if err != nil {
return nil, false, err
}
for {
cmd, err := c.readCommand()
if err != nil {
return nil, false, err
}
switch cmd.Name {
case "createStream":
err = c.mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: cmd.ChunkStreamID,
Name: "_result",
CommandID: cmd.CommandID,
Arguments: []interface{}{
nil,
float64(1),
},
})
if err != nil {
return nil, false, err
}
case "play":
if len(cmd.Arguments) < 2 {
return nil, false, fmt.Errorf("invalid play command arguments")
}
actionpath, ok := cmd.Arguments[1].(string)
if !ok {
return nil, false, fmt.Errorf("invalid play command arguments")
}
u, err := createURL(tcURL, connectpath, actionpath)
if err != nil {
return nil, false, err
}
err = c.mrw.Write(&message.MsgUserControlStreamIsRecorded{
StreamID: 1,
})
if err != nil {
return nil, false, err
}
err = c.mrw.Write(&message.MsgUserControlStreamBegin{
StreamID: 1,
})
if err != nil {
return nil, false, err
}
err = c.mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 5,
MessageStreamID: 0x1000000,
Name: "onStatus",
CommandID: cmd.CommandID,
Arguments: []interface{}{
nil,
flvio.AMFMap{
{K: "level", V: "status"},
{K: "code", V: "NetStream.Play.Reset"},
{K: "description", V: "play reset"},
},
},
})
if err != nil {
return nil, false, err
}
err = c.mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 5,
MessageStreamID: 0x1000000,
Name: "onStatus",
CommandID: cmd.CommandID,
Arguments: []interface{}{
nil,
flvio.AMFMap{
{K: "level", V: "status"},
{K: "code", V: "NetStream.Play.Start"},
{K: "description", V: "play start"},
},
},
})
if err != nil {
return nil, false, err
}
err = c.mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 5,
MessageStreamID: 0x1000000,
Name: "onStatus",
CommandID: cmd.CommandID,
Arguments: []interface{}{
nil,
flvio.AMFMap{
{K: "level", V: "status"},
{K: "code", V: "NetStream.Data.Start"},
{K: "description", V: "data start"},
},
},
})
if err != nil {
return nil, false, err
}
err = c.mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 5,
MessageStreamID: 0x1000000,
Name: "onStatus",
CommandID: 4,
Arguments: []interface{}{
nil,
flvio.AMFMap{
{K: "level", V: "status"},
{K: "code", V: "NetStream.Play.PublishNotify"},
{K: "description", V: "publish notify"},
},
},
})
if err != nil {
return nil, false, err
}
return u, true, nil
case "publish":
if len(cmd.Arguments) < 2 {
return nil, false, fmt.Errorf("invalid publish command arguments")
}
actionpath, ok := cmd.Arguments[1].(string)
if !ok {
return nil, false, fmt.Errorf("invalid publish command arguments")
}
u, err := createURL(tcURL, connectpath, actionpath)
if err != nil {
return nil, false, err
}
err = c.mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 5,
Name: "onStatus",
CommandID: cmd.CommandID,
MessageStreamID: 0x1000000,
Arguments: []interface{}{
nil,
flvio.AMFMap{
{K: "level", V: "status"},
{K: "code", V: "NetStream.Publish.Start"},
{K: "description", V: "publish start"},
},
},
})
if err != nil {
return nil, false, err
}
return u, false, nil
}
}
}
// ReadMessage reads a message.
func (c *Conn) ReadMessage() (message.Message, error) {
return c.mrw.Read()
}
// WriteMessage writes a message.
func (c *Conn) WriteMessage(msg message.Message) error {
return c.mrw.Write(msg)
}
func trackFromH264DecoderConfig(data []byte) (*gortsplib.TrackH264, error) {
var conf h264conf.Conf
err := conf.Unmarshal(data)
if err != nil {
return nil, fmt.Errorf("unable to parse H264 config: %v", err)
}
return &gortsplib.TrackH264{
PayloadType: 96,
SPS: codec.SPS[0],
PPS: codec.PPS[0],
SPS: conf.SPS,
PPS: conf.PPS,
}, nil
}
@@ -127,17 +606,12 @@ func trackFromAACDecoderConfig(data []byte) (*gortsplib.TrackAAC, error) {
var errEmptyMetadata = errors.New("metadata is empty")
func (c *Conn) readTracksFromMetadata(pkt av.Packet) (*gortsplib.TrackH264, *gortsplib.TrackAAC, error) {
arr, err := flvio.ParseAMFVals(pkt.Data, false)
if err != nil {
return nil, nil, err
}
if len(arr) != 1 {
func (c *Conn) readTracksFromMetadata(payload []interface{}) (*gortsplib.TrackH264, *gortsplib.TrackAAC, error) {
if len(payload) != 1 {
return nil, nil, fmt.Errorf("invalid metadata")
}
md, ok := arr[0].(flvio.AMFMap)
md, ok := payload[0].(flvio.AMFMap)
if !ok {
return nil, nil, fmt.Errorf("invalid metadata")
}
@@ -206,14 +680,14 @@ func (c *Conn) readTracksFromMetadata(pkt av.Packet) (*gortsplib.TrackH264, *gor
var audioTrack *gortsplib.TrackAAC
for {
var pkt av.Packet
pkt, err = c.ReadPacket()
msg, err := c.ReadMessage()
if err != nil {
return nil, nil, err
}
switch pkt.Type {
case av.H264DecoderConfig:
switch tmsg := msg.(type) {
case *message.MsgVideo:
if tmsg.H264Type == flvio.AVC_SEQHDR {
if !hasVideo {
return nil, nil, fmt.Errorf("unexpected video packet")
}
@@ -222,12 +696,14 @@ func (c *Conn) readTracksFromMetadata(pkt av.Packet) (*gortsplib.TrackH264, *gor
return nil, nil, fmt.Errorf("video track setupped twice")
}
videoTrack, err = trackFromH264DecoderConfig(pkt.Data)
videoTrack, err = trackFromH264DecoderConfig(tmsg.Payload)
if err != nil {
return nil, nil, err
}
}
case av.AACDecoderConfig:
case *message.MsgAudio:
if tmsg.AACType == flvio.AVC_SEQHDR {
if !hasAudio {
return nil, nil, fmt.Errorf("unexpected audio packet")
}
@@ -236,11 +712,12 @@ func (c *Conn) readTracksFromMetadata(pkt av.Packet) (*gortsplib.TrackH264, *gor
return nil, nil, fmt.Errorf("audio track setupped twice")
}
audioTrack, err = trackFromAACDecoderConfig(pkt.Data)
audioTrack, err = trackFromAACDecoderConfig(tmsg.Payload)
if err != nil {
return nil, nil, err
}
}
}
if (!hasVideo || videoTrack != nil) &&
(!hasAudio || audioTrack != nil) {
@@ -249,18 +726,25 @@ func (c *Conn) readTracksFromMetadata(pkt av.Packet) (*gortsplib.TrackH264, *gor
}
}
func (c *Conn) readTracksFromPackets(pkt av.Packet) (*gortsplib.TrackH264, *gortsplib.TrackAAC, error) {
startTime := pkt.Time
func (c *Conn) readTracksFromMessages(msg message.Message) (*gortsplib.TrackH264, *gortsplib.TrackAAC, error) {
var startTime *time.Duration
var videoTrack *gortsplib.TrackH264
var audioTrack *gortsplib.TrackAAC
// analyze 1 second of packets
outer:
for {
switch pkt.Type {
case av.H264DecoderConfig:
switch tmsg := msg.(type) {
case *message.MsgVideo:
if startTime == nil {
v := tmsg.DTS
startTime = &v
}
if tmsg.H264Type == flvio.AVC_SEQHDR {
if videoTrack == nil {
var err error
videoTrack, err = trackFromH264DecoderConfig(pkt.Data)
videoTrack, err = trackFromH264DecoderConfig(tmsg.Payload)
if err != nil {
return nil, nil, err
}
@@ -270,11 +754,22 @@ func (c *Conn) readTracksFromPackets(pkt av.Packet) (*gortsplib.TrackH264, *gort
return videoTrack, audioTrack, nil
}
}
}
case av.AACDecoderConfig:
if (tmsg.DTS - *startTime) >= 1*time.Second {
break outer
}
case *message.MsgAudio:
if startTime == nil {
v := tmsg.DTS
startTime = &v
}
if tmsg.AACType == flvio.AVC_SEQHDR {
if audioTrack == nil {
var err error
audioTrack, err = trackFromAACDecoderConfig(pkt.Data)
audioTrack, err = trackFromAACDecoderConfig(tmsg.Payload)
if err != nil {
return nil, nil, err
}
@@ -286,12 +781,13 @@ func (c *Conn) readTracksFromPackets(pkt av.Packet) (*gortsplib.TrackH264, *gort
}
}
if (pkt.Time - startTime) >= 1*time.Second {
break
if (tmsg.DTS - *startTime) >= 1*time.Second {
break outer
}
}
var err error
pkt, err = c.ReadPacket()
msg, err = c.ReadMessage()
if err != nil {
return nil, nil, err
}
@@ -306,26 +802,34 @@ func (c *Conn) readTracksFromPackets(pkt av.Packet) (*gortsplib.TrackH264, *gort
// ReadTracks reads track informations.
func (c *Conn) ReadTracks() (*gortsplib.TrackH264, *gortsplib.TrackAAC, error) {
pkt, err := c.ReadPacket()
msg, err := c.ReadMessage()
if err != nil {
return nil, nil, err
}
if pkt.Type == av.Metadata {
videoTrack, audioTrack, err := c.readTracksFromMetadata(pkt)
if data, ok := msg.(*message.MsgDataAMF0); ok && len(data.Payload) >= 1 {
payload := data.Payload
// skip packet
if s, ok := payload[0].(string); ok && s == "|RtmpSampleAccess" {
return c.ReadTracks()
}
if s, ok := payload[0].(string); ok && s == "@setDataFrame" {
payload = payload[1:]
}
if len(payload) >= 1 {
if s, ok := payload[0].(string); ok && s == "onMetaData" {
videoTrack, audioTrack, err := c.readTracksFromMetadata(payload[1:])
if err != nil {
if err == errEmptyMetadata {
pkt, err := c.ReadPacket()
msg, err := c.ReadMessage()
if err != nil {
return nil, nil, err
}
videoTrack, audioTrack, err := c.readTracksFromPackets(pkt)
if err != nil {
return nil, nil, err
}
return videoTrack, audioTrack, nil
return c.readTracksFromMessages(msg)
}
return nil, nil, err
@@ -333,20 +837,21 @@ func (c *Conn) ReadTracks() (*gortsplib.TrackH264, *gortsplib.TrackAAC, error) {
return videoTrack, audioTrack, nil
}
videoTrack, audioTrack, err := c.readTracksFromPackets(pkt)
if err != nil {
return nil, nil, err
}
}
return videoTrack, audioTrack, nil
return c.readTracksFromMessages(msg)
}
// WriteTracks writes track informations.
func (c *Conn) WriteTracks(videoTrack *gortsplib.TrackH264, audioTrack *gortsplib.TrackAAC) error {
err := c.WritePacket(av.Packet{
Type: av.Metadata,
Data: flvio.FillAMF0ValMalloc(flvio.AMFMap{
err := c.WriteMessage(&message.MsgDataAMF0{
ChunkStreamID: 4,
MessageStreamID: 0x1000000,
Payload: []interface{}{
"@setDataFrame",
"onMetaData",
flvio.AMFMap{
{
K: "videodatarate",
V: float64(0),
@@ -373,31 +878,27 @@ func (c *Conn) WriteTracks(videoTrack *gortsplib.TrackH264, audioTrack *gortspli
return 0
}(),
},
}),
},
},
})
if err != nil {
return err
}
// write decoder config only if SPS and PPS are available.
// if they're not available yet, they're sent later as H264 NALUs.
// if they're not available yet, they're sent later.
if videoTrack != nil && videoTrack.SafeSPS() != nil && videoTrack.SafePPS() != nil {
codec := nh264.Codec{
SPS: map[int][]byte{
0: videoTrack.SafeSPS(),
},
PPS: map[int][]byte{
0: videoTrack.SafePPS(),
},
}
b := make([]byte, 128)
var n int
codec.ToConfig(b, &n)
b = b[:n]
buf, _ := h264conf.Conf{
SPS: videoTrack.SafeSPS(),
PPS: videoTrack.SafePPS(),
}.Marshal()
err = c.WritePacket(av.Packet{
Type: av.H264DecoderConfig,
Data: b,
err = c.WriteMessage(&message.MsgVideo{
ChunkStreamID: 6,
MessageStreamID: 0x1000000,
IsKeyFrame: true,
H264Type: flvio.AVC_SEQHDR,
Payload: buf,
})
if err != nil {
return err
@@ -410,14 +911,14 @@ func (c *Conn) WriteTracks(videoTrack *gortsplib.TrackH264, audioTrack *gortspli
return err
}
err = c.WritePacket(av.Packet{
Type: av.AACDecoderConfig,
AAC: &nhaac.Codec{
Config: nhaac.MPEG4AudioConfig{
ChannelLayout: nhaac.CH_STEREO,
},
},
Data: enc,
err = c.WriteMessage(&message.MsgAudio{
ChunkStreamID: 4,
MessageStreamID: 0x1000000,
Rate: flvio.SOUND_44Khz,
Depth: flvio.SOUND_16BIT,
Channels: flvio.SOUND_STEREO,
AACType: flvio.AAC_SEQHDR,
Payload: enc,
})
if err != nil {
return err

View File

@@ -3,52 +3,20 @@ package rtmp
import (
"net"
"net/url"
"strings"
"testing"
"github.com/aler9/gortsplib"
"github.com/aler9/gortsplib/pkg/aac"
nh264 "github.com/notedit/rtmp/codec/h264"
"github.com/notedit/rtmp/format/flv/flvio"
"github.com/stretchr/testify/require"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/h264conf"
"github.com/aler9/rtsp-simple-server/internal/rtmp/handshake"
"github.com/aler9/rtsp-simple-server/internal/rtmp/message"
)
func splitPath(u *url.URL) (app, stream string) {
nu := *u
nu.ForceQuery = false
pathsegs := strings.Split(nu.RequestURI(), "/")
if len(pathsegs) == 2 {
app = pathsegs[1]
}
if len(pathsegs) == 3 {
app = pathsegs[1]
stream = pathsegs[2]
}
if len(pathsegs) > 3 {
app = strings.Join(pathsegs[1:3], "/")
stream = strings.Join(pathsegs[3:], "/")
}
return
}
func getTcURL(u string) string {
ur, err := url.Parse(u)
if err != nil {
panic(err)
}
app, _ := splitPath(ur)
nu := *ur
nu.RawQuery = ""
nu.Path = "/"
return nu.String() + app
}
func TestClientHandshake(t *testing.T) {
func TestInitializeClient(t *testing.T) {
for _, ca := range []string{"read", "publish"} {
t.Run(ca, func(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:9121")
@@ -63,10 +31,10 @@ func TestClientHandshake(t *testing.T) {
defer conn.Close()
bc := bytecounter.NewReadWriter(conn)
err = handshake.DoServer(bc)
err = handshake.DoServer(bc, true)
require.NoError(t, err)
mrw := message.NewReadWriter(bc)
mrw := message.NewReadWriter(bc, true)
// C->S set window ack size
msg, err := mrw.Read()
@@ -79,7 +47,7 @@ func TestClientHandshake(t *testing.T) {
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgSetPeerBandwidth{
Value: 0x2625a0,
Value: 2500000,
Type: 2,
}, msg)
@@ -95,13 +63,13 @@ func TestClientHandshake(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"connect",
float64(1),
Name: "connect",
CommandID: 1,
Arguments: []interface{}{
flvio.AMFMap{
{K: "app", V: "stream"},
{K: "flashVer", V: "LNX 9,0,124,2"},
{K: "tcUrl", V: getTcURL("rtmp://127.0.0.1:9121/stream")},
{K: "tcUrl", V: "rtmp://127.0.0.1:9121/stream"},
{K: "fpad", V: false},
{K: "capabilities", V: float64(15)},
{K: "audioCodecs", V: float64(4071)},
@@ -114,9 +82,9 @@ func TestClientHandshake(t *testing.T) {
// S->C result
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"_result",
float64(1),
Name: "_result",
CommandID: 1,
Arguments: []interface{}{
flvio.AMFMap{
{K: "fmsVer", V: "LNX 9,0,124,2"},
{K: "capabilities", V: float64(31)},
@@ -137,9 +105,9 @@ func TestClientHandshake(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"createStream",
float64(2),
Name: "createStream",
CommandID: 2,
Arguments: []interface{}{
nil,
},
}, msg)
@@ -147,9 +115,9 @@ func TestClientHandshake(t *testing.T) {
// S->C result
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"_result",
float64(2),
Name: "_result",
CommandID: 2,
Arguments: []interface{}{
nil,
float64(1),
},
@@ -168,10 +136,10 @@ func TestClientHandshake(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 4,
MessageStreamID: 16777216,
Payload: []interface{}{
"play",
float64(0),
MessageStreamID: 0x1000000,
Name: "play",
CommandID: 0,
Arguments: []interface{}{
nil,
"",
},
@@ -180,10 +148,10 @@ func TestClientHandshake(t *testing.T) {
// S->C onStatus
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 5,
MessageStreamID: 16777216,
Payload: []interface{}{
"onStatus",
float64(4),
MessageStreamID: 0x1000000,
Name: "onStatus",
CommandID: 4,
Arguments: []interface{}{
nil,
flvio.AMFMap{
{K: "level", V: "status"},
@@ -199,9 +167,9 @@ func TestClientHandshake(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"releaseStream",
float64(2),
Name: "releaseStream",
CommandID: 2,
Arguments: []interface{}{
nil,
"",
},
@@ -212,9 +180,9 @@ func TestClientHandshake(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"FCPublish",
float64(3),
Name: "FCPublish",
CommandID: 3,
Arguments: []interface{}{
nil,
"",
},
@@ -225,9 +193,9 @@ func TestClientHandshake(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"createStream",
float64(4),
Name: "createStream",
CommandID: 4,
Arguments: []interface{}{
nil,
},
}, msg)
@@ -235,9 +203,9 @@ func TestClientHandshake(t *testing.T) {
// S->C result
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"_result",
float64(4),
Name: "_result",
CommandID: 4,
Arguments: []interface{}{
nil,
float64(1),
},
@@ -249,10 +217,10 @@ func TestClientHandshake(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 4,
MessageStreamID: 16777216,
Payload: []interface{}{
"publish",
float64(5),
MessageStreamID: 0x1000000,
Name: "publish",
CommandID: 5,
Arguments: []interface{}{
nil,
"",
"stream",
@@ -262,10 +230,10 @@ func TestClientHandshake(t *testing.T) {
// S->C onStatus
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 5,
MessageStreamID: 16777216,
Payload: []interface{}{
"onStatus",
float64(5),
MessageStreamID: 0x1000000,
Name: "onStatus",
CommandID: 5,
Arguments: []interface{}{
nil,
flvio.AMFMap{
{K: "level", V: "status"},
@@ -286,9 +254,9 @@ func TestClientHandshake(t *testing.T) {
nconn, err := net.Dial("tcp", u.Host)
require.NoError(t, err)
defer nconn.Close()
conn := NewClientConn(nconn, u)
conn := NewConn(nconn)
err = conn.ClientHandshake(ca == "read")
err = conn.InitializeClient(u, ca == "read")
require.NoError(t, err)
<-done
@@ -296,7 +264,7 @@ func TestClientHandshake(t *testing.T) {
}
}
func TestServerHandshake(t *testing.T) {
func TestInitializeServer(t *testing.T) {
for _, ca := range []string{"read", "publish"} {
t.Run(ca, func(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:9121")
@@ -310,9 +278,15 @@ func TestServerHandshake(t *testing.T) {
require.NoError(t, err)
defer nconn.Close()
conn := NewServerConn(nconn)
err = conn.ServerHandshake()
conn := NewConn(nconn)
u, isReading, err := conn.InitializeServer()
require.NoError(t, err)
require.Equal(t, &url.URL{
Scheme: "rtmp",
Host: "127.0.0.1:9121",
Path: "//stream/",
}, u)
require.Equal(t, ca == "read", isReading)
close(done)
}()
@@ -322,21 +296,21 @@ func TestServerHandshake(t *testing.T) {
defer conn.Close()
bc := bytecounter.NewReadWriter(conn)
err = handshake.DoClient(bc)
err = handshake.DoClient(bc, true)
require.NoError(t, err)
mrw := message.NewReadWriter(bc)
mrw := message.NewReadWriter(bc, true)
// C->S connect
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"connect",
1,
Name: "connect",
CommandID: 1,
Arguments: []interface{}{
flvio.AMFMap{
{K: "app", V: "/stream"},
{K: "flashVer", V: "LNX 9,0,124,2"},
{K: "tcUrl", V: getTcURL("rtmp://127.0.0.1:9121/stream")},
{K: "tcUrl", V: "rtmp://127.0.0.1:9121/stream"},
{K: "fpad", V: false},
{K: "capabilities", V: 15},
{K: "audioCodecs", V: 4071},
@@ -374,9 +348,9 @@ func TestServerHandshake(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"_result",
float64(1),
Name: "_result",
CommandID: 1,
Arguments: []interface{}{
flvio.AMFMap{
{K: "fmsVer", V: "LNX 9,0,124,2"},
{K: "capabilities", V: float64(31)},
@@ -400,9 +374,9 @@ func TestServerHandshake(t *testing.T) {
// C->S createStream
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"createStream",
float64(2),
Name: "createStream",
CommandID: 2,
Arguments: []interface{}{
nil,
},
})
@@ -413,9 +387,9 @@ func TestServerHandshake(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"_result",
float64(2),
Name: "_result",
CommandID: 2,
Arguments: []interface{}{
nil,
float64(1),
},
@@ -430,10 +404,10 @@ func TestServerHandshake(t *testing.T) {
// C->S play
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 4,
MessageStreamID: 16777216,
Payload: []interface{}{
"play",
float64(0),
MessageStreamID: 0x1000000,
Name: "play",
CommandID: 0,
Arguments: []interface{}{
nil,
"",
},
@@ -443,9 +417,9 @@ func TestServerHandshake(t *testing.T) {
// C->S releaseStream
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"releaseStream",
float64(2),
Name: "releaseStream",
CommandID: 2,
Arguments: []interface{}{
nil,
"",
},
@@ -455,9 +429,9 @@ func TestServerHandshake(t *testing.T) {
// C->S FCPublish
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"FCPublish",
float64(3),
Name: "FCPublish",
CommandID: 3,
Arguments: []interface{}{
nil,
"",
},
@@ -467,9 +441,9 @@ func TestServerHandshake(t *testing.T) {
// C->S createStream
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"createStream",
float64(4),
Name: "createStream",
CommandID: 4,
Arguments: []interface{}{
nil,
},
})
@@ -480,9 +454,9 @@ func TestServerHandshake(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"_result",
float64(4),
Name: "_result",
CommandID: 4,
Arguments: []interface{}{
nil,
float64(1),
},
@@ -491,10 +465,10 @@ func TestServerHandshake(t *testing.T) {
// C->S publish
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 4,
MessageStreamID: 16777216,
Payload: []interface{}{
"publish",
float64(5),
MessageStreamID: 0x1000000,
Name: "publish",
CommandID: 5,
Arguments: []interface{}{
nil,
"",
"stream",
@@ -536,8 +510,8 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
defer conn.Close()
rconn := NewServerConn(conn)
err = rconn.ServerHandshake()
rconn := NewConn(conn)
_, _, err = rconn.InitializeServer()
require.NoError(t, err)
videoTrack, audioTrack, err := rconn.ReadTracks()
@@ -610,21 +584,21 @@ func TestReadTracks(t *testing.T) {
defer conn.Close()
bc := bytecounter.NewReadWriter(conn)
err = handshake.DoClient(bc)
err = handshake.DoClient(bc, true)
require.NoError(t, err)
mrw := message.NewReadWriter(bc)
mrw := message.NewReadWriter(bc, true)
// C->S connect
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"connect",
1,
Name: "connect",
CommandID: 1,
Arguments: []interface{}{
flvio.AMFMap{
{K: "app", V: "/stream"},
{K: "flashVer", V: "LNX 9,0,124,2"},
{K: "tcUrl", V: getTcURL("rtmp://127.0.0.1:9121/stream")},
{K: "tcUrl", V: "rtmp://127.0.0.1:9121/stream"},
{K: "fpad", V: false},
{K: "capabilities", V: 15},
{K: "audioCodecs", V: 4071},
@@ -662,9 +636,9 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"_result",
float64(1),
Name: "_result",
CommandID: 1,
Arguments: []interface{}{
flvio.AMFMap{
{K: "fmsVer", V: "LNX 9,0,124,2"},
{K: "capabilities", V: float64(31)},
@@ -687,9 +661,9 @@ func TestReadTracks(t *testing.T) {
// C->S releaseStream
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"releaseStream",
float64(2),
Name: "releaseStream",
CommandID: 2,
Arguments: []interface{}{
nil,
"",
},
@@ -699,9 +673,9 @@ func TestReadTracks(t *testing.T) {
// C->S FCPublish
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"FCPublish",
float64(3),
Name: "FCPublish",
CommandID: 3,
Arguments: []interface{}{
nil,
"",
},
@@ -711,9 +685,9 @@ func TestReadTracks(t *testing.T) {
// C->S createStream
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"createStream",
float64(4),
Name: "createStream",
CommandID: 4,
Arguments: []interface{}{
nil,
},
})
@@ -724,9 +698,9 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"_result",
float64(4),
Name: "_result",
CommandID: 4,
Arguments: []interface{}{
nil,
float64(1),
},
@@ -736,9 +710,9 @@ func TestReadTracks(t *testing.T) {
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 8,
MessageStreamID: 1,
Payload: []interface{}{
"publish",
float64(5),
Name: "publish",
CommandID: 5,
Arguments: []interface{}{
nil,
"",
"live",
@@ -751,10 +725,10 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5,
MessageStreamID: 16777216,
Payload: []interface{}{
"onStatus",
float64(5),
MessageStreamID: 0x1000000,
Name: "onStatus",
CommandID: 5,
Arguments: []interface{}{
nil,
flvio.AMFMap{
{K: "level", V: "status"},
@@ -796,23 +770,16 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
// C->S H264 decoder config
codec := nh264.Codec{
SPS: map[int][]byte{
0: sps,
},
PPS: map[int][]byte{
0: pps,
},
}
b := make([]byte, 128)
var n int
codec.ToConfig(b, &n)
buf, _ := h264conf.Conf{
SPS: sps,
PPS: pps,
}.Marshal()
err = mrw.Write(&message.MsgVideo{
ChunkStreamID: 6,
MessageStreamID: 1,
IsKeyFrame: true,
H264Type: flvio.AVC_SEQHDR,
Payload: b[:n],
Payload: buf,
})
require.NoError(t, err)
@@ -861,23 +828,16 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
// C->S H264 decoder config
codec := nh264.Codec{
SPS: map[int][]byte{
0: sps,
},
PPS: map[int][]byte{
0: pps,
},
}
b := make([]byte, 128)
var n int
codec.ToConfig(b, &n)
buf, _ := h264conf.Conf{
SPS: sps,
PPS: pps,
}.Marshal()
err = mrw.Write(&message.MsgVideo{
ChunkStreamID: 6,
MessageStreamID: 1,
IsKeyFrame: true,
H264Type: flvio.AVC_SEQHDR,
Payload: b[:n],
Payload: buf,
})
require.NoError(t, err)
@@ -901,23 +861,16 @@ func TestReadTracks(t *testing.T) {
case "missing metadata":
// C->S H264 decoder config
codec := nh264.Codec{
SPS: map[int][]byte{
0: sps,
},
PPS: map[int][]byte{
0: pps,
},
}
b := make([]byte, 128)
var n int
codec.ToConfig(b, &n)
buf, _ := h264conf.Conf{
SPS: sps,
PPS: pps,
}.Marshal()
err = mrw.Write(&message.MsgVideo{
ChunkStreamID: 6,
MessageStreamID: 1,
IsKeyFrame: true,
H264Type: flvio.AVC_SEQHDR,
Payload: b[:n],
Payload: buf,
})
require.NoError(t, err)
@@ -955,8 +908,8 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
defer conn.Close()
rconn := NewServerConn(conn)
err = rconn.ServerHandshake()
rconn := NewConn(conn)
_, _, err = rconn.InitializeServer()
require.NoError(t, err)
videoTrack := &gortsplib.TrackH264{
@@ -992,21 +945,21 @@ func TestWriteTracks(t *testing.T) {
defer conn.Close()
bc := bytecounter.NewReadWriter(conn)
err = handshake.DoClient(bc)
err = handshake.DoClient(bc, true)
require.NoError(t, err)
mrw := message.NewReadWriter(bc)
mrw := message.NewReadWriter(bc, true)
// C->S connect
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"connect",
1,
Name: "connect",
CommandID: 1,
Arguments: []interface{}{
flvio.AMFMap{
{K: "app", V: "/stream"},
{K: "flashVer", V: "LNX 9,0,124,2"},
{K: "tcUrl", V: getTcURL("rtmp://127.0.0.1:9121/stream")},
{K: "tcUrl", V: "rtmp://127.0.0.1:9121/stream"},
{K: "fpad", V: false},
{K: "capabilities", V: 15},
{K: "audioCodecs", V: 4071},
@@ -1044,9 +997,9 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"_result",
float64(1),
Name: "_result",
CommandID: 1,
Arguments: []interface{}{
flvio.AMFMap{
{K: "fmsVer", V: "LNX 9,0,124,2"},
{K: "capabilities", V: float64(31)},
@@ -1075,9 +1028,9 @@ func TestWriteTracks(t *testing.T) {
// C->S createStream
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"createStream",
float64(2),
Name: "createStream",
CommandID: 2,
Arguments: []interface{}{
nil,
},
})
@@ -1088,9 +1041,9 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"_result",
float64(2),
Name: "_result",
CommandID: 2,
Arguments: []interface{}{
nil,
float64(1),
},
@@ -1099,9 +1052,9 @@ func TestWriteTracks(t *testing.T) {
// C->S getStreamLength
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 8,
Payload: []interface{}{
"getStreamLength",
float64(3),
Name: "getStreamLength",
CommandID: 3,
Arguments: []interface{}{
nil,
"",
},
@@ -1111,10 +1064,10 @@ func TestWriteTracks(t *testing.T) {
// C->S play
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 8,
MessageStreamID: 16777216,
Payload: []interface{}{
"play",
float64(4),
MessageStreamID: 0x1000000,
Name: "play",
CommandID: 4,
Arguments: []interface{}{
nil,
"",
float64(-2000),
@@ -1141,10 +1094,10 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5,
MessageStreamID: 16777216,
Payload: []interface{}{
"onStatus",
float64(4),
MessageStreamID: 0x1000000,
Name: "onStatus",
CommandID: 4,
Arguments: []interface{}{
nil,
flvio.AMFMap{
{K: "level", V: "status"},
@@ -1159,10 +1112,10 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5,
MessageStreamID: 16777216,
Payload: []interface{}{
"onStatus",
float64(4),
MessageStreamID: 0x1000000,
Name: "onStatus",
CommandID: 4,
Arguments: []interface{}{
nil,
flvio.AMFMap{
{K: "level", V: "status"},
@@ -1177,10 +1130,10 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5,
MessageStreamID: 16777216,
Payload: []interface{}{
"onStatus",
float64(4),
MessageStreamID: 0x1000000,
Name: "onStatus",
CommandID: 4,
Arguments: []interface{}{
nil,
flvio.AMFMap{
{K: "level", V: "status"},
@@ -1195,10 +1148,10 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5,
MessageStreamID: 16777216,
Payload: []interface{}{
"onStatus",
float64(4),
MessageStreamID: 0x1000000,
Name: "onStatus",
CommandID: 4,
Arguments: []interface{}{
nil,
flvio.AMFMap{
{K: "level", V: "status"},
@@ -1213,8 +1166,9 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgDataAMF0{
ChunkStreamID: 4,
MessageStreamID: 16777216,
MessageStreamID: 0x1000000,
Payload: []interface{}{
"@setDataFrame",
"onMetaData",
flvio.AMFMap{
{K: "videodatarate", V: float64(0)},
@@ -1230,7 +1184,7 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgVideo{
ChunkStreamID: 6,
MessageStreamID: 16777216,
MessageStreamID: 0x1000000,
IsKeyFrame: true,
H264Type: flvio.AVC_SEQHDR,
Payload: []byte{
@@ -1248,7 +1202,7 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgAudio{
ChunkStreamID: 4,
MessageStreamID: 16777216,
MessageStreamID: 0x1000000,
Rate: flvio.SOUND_44Khz,
Depth: flvio.SOUND_16BIT,
Channels: flvio.SOUND_STEREO,

View File

@@ -0,0 +1,89 @@
package h264conf
import (
"fmt"
)
// Conf is a RTMP H264 configuration.
type Conf struct {
SPS []byte
PPS []byte
}
// Unmarshal decodes a Conf from bytes.
func (c *Conf) Unmarshal(buf []byte) error {
if len(buf) < 8 {
return fmt.Errorf("invalid size 1")
}
pos := 5
spsCount := buf[pos] & 0x1F
pos++
if spsCount != 1 {
return fmt.Errorf("sps count != 1 is unsupported")
}
spsLen := int(uint16(buf[pos])<<8 | uint16(buf[pos+1]))
pos += 2
if (len(buf) - pos) < spsLen {
return fmt.Errorf("invalid size 2")
}
c.SPS = buf[pos : pos+spsLen]
pos += spsLen
if (len(buf) - pos) < 3 {
return fmt.Errorf("invalid size 3")
}
ppsCount := buf[pos]
pos++
if ppsCount != 1 {
return fmt.Errorf("pps count != 1 is unsupported")
}
ppsLen := int(uint16(buf[pos])<<8 | uint16(buf[pos+1]))
pos += 2
if (len(buf) - pos) < ppsLen {
return fmt.Errorf("invalid size")
}
c.PPS = buf[pos : pos+ppsLen]
return nil
}
// Marshal encodes a Conf into bytes.
func (c Conf) Marshal() ([]byte, error) {
spsLen := len(c.SPS)
ppsLen := len(c.PPS)
buf := make([]byte, 11+spsLen+ppsLen)
buf[0] = 1
buf[1] = c.SPS[1]
buf[2] = c.SPS[2]
buf[3] = c.SPS[3]
buf[4] = 3 | 0xFC
buf[5] = 1 | 0xE0
pos := 6
buf[pos] = byte(spsLen >> 8)
buf[pos+1] = byte(spsLen)
pos += 2
copy(buf[pos:], c.SPS)
pos += spsLen
buf[pos] = 1
pos++
buf[pos] = byte(ppsLen >> 8)
buf[pos+1] = byte(ppsLen)
pos += 2
copy(buf[pos:], c.PPS)
return buf, nil
}

View File

@@ -0,0 +1,29 @@
package h264conf
import (
"testing"
"github.com/stretchr/testify/require"
)
var decoded = Conf{
SPS: []byte{0x45, 0x32, 0xA3, 0x08},
PPS: []byte{0x45, 0x34},
}
var encoded = []byte{
0x1, 0x32, 0xa3, 0x8, 0xff, 0xe1, 0x0, 0x4, 0x45, 0x32, 0xa3, 0x8, 0x1, 0x0, 0x2, 0x45, 0x34,
}
func TestUnmarshal(t *testing.T) {
var dec Conf
err := dec.Unmarshal(encoded)
require.NoError(t, err)
require.Equal(t, decoded, dec)
}
func TestMarshal(t *testing.T) {
enc, err := decoded.Marshal()
require.NoError(t, err)
require.Equal(t, encoded, enc)
}

View File

@@ -5,7 +5,6 @@ import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"fmt"
"io"
)
@@ -78,14 +77,13 @@ type C1S1 struct {
}
// Read reads a C1S1.
func (c *C1S1) Read(r io.Reader, isC1 bool) error {
func (c *C1S1) Read(r io.Reader, isC1 bool, validateSignature bool) error {
buf := make([]byte, 1536)
_, err := io.ReadFull(r, buf)
if err != nil {
return err
}
// validate signature
var peerKey []byte
var key []byte
if isC1 {
@@ -97,12 +95,15 @@ func (c *C1S1) Read(r io.Reader, isC1 bool) error {
}
ok, digest := hsParse1(buf, peerKey, key)
if !ok {
if validateSignature {
return fmt.Errorf("unable to validate C1/S1 signature")
}
c.Time = binary.BigEndian.Uint32(buf)
c.Random = buf[8:]
} else {
c.Digest = digest
}
c.Time = uint32(buf[0])<<24 | uint32(buf[1])<<16 | uint32(buf[2])<<8 | uint32(buf[3])
c.Random = buf[8:]
return nil
}
@@ -111,7 +112,10 @@ func (c *C1S1) Read(r io.Reader, isC1 bool) error {
func (c *C1S1) Write(w io.Writer, isC1 bool) error {
buf := make([]byte, 1536)
binary.BigEndian.PutUint32(buf, c.Time)
buf[0] = byte(c.Time >> 24)
buf[1] = byte(c.Time >> 16)
buf[2] = byte(c.Time >> 8)
buf[3] = byte(c.Time)
copy(buf[4:], []byte{0, 0, 0, 0})
if c.Random == nil {

View File

@@ -89,7 +89,7 @@ func TestC1S1Read(t *testing.T) {
},
} {
var c1s1 C1S1
err := c1s1.Read((bytes.NewReader(ca.enc)), ca.isC1)
err := c1s1.Read((bytes.NewReader(ca.enc)), ca.isC1, true)
require.NoError(t, err)
require.Equal(t, ca.dec, c1s1)
}

View File

@@ -3,7 +3,6 @@ package handshake
import (
"bytes"
"crypto/rand"
"encoding/binary"
"fmt"
"io"
)
@@ -17,22 +16,23 @@ type C2S2 struct {
}
// Read reads a C2S2.
func (c *C2S2) Read(r io.Reader) error {
func (c *C2S2) Read(r io.Reader, validateSignature bool) error {
buf := make([]byte, 1536)
_, err := io.ReadFull(r, buf)
if err != nil {
return err
}
// validate signature
if validateSignature {
gap := len(buf) - 32
digest := hsMakeDigest(c.Digest, buf, gap)
if !bytes.Equal(buf[gap:gap+32], digest) {
return fmt.Errorf("unable to validate C2/S2 signature")
}
}
c.Time = binary.BigEndian.Uint32(buf)
c.Time2 = binary.BigEndian.Uint32(buf[4:])
c.Time = uint32(buf[0])<<24 | uint32(buf[1])<<16 | uint32(buf[2])<<8 | uint32(buf[3])
c.Time2 = uint32(buf[4])<<24 | uint32(buf[5])<<16 | uint32(buf[6])<<8 | uint32(buf[7])
c.Random = buf[8:]
return nil
@@ -41,8 +41,15 @@ func (c *C2S2) Read(r io.Reader) error {
// Write writes a C2S2.
func (c C2S2) Write(w io.Writer) error {
buf := make([]byte, 1536)
binary.BigEndian.PutUint32(buf, c.Time)
binary.BigEndian.PutUint32(buf[4:], c.Time2)
buf[0] = byte(c.Time >> 24)
buf[1] = byte(c.Time >> 16)
buf[2] = byte(c.Time >> 8)
buf[3] = byte(c.Time)
buf[4] = byte(c.Time2 >> 24)
buf[5] = byte(c.Time2 >> 16)
buf[6] = byte(c.Time2 >> 8)
buf[7] = byte(c.Time2)
if c.Random == nil {
rand.Read(buf[8:])

View File

@@ -42,7 +42,7 @@ func TestC2S2Read(t *testing.T) {
var c2s2 C2S2
c2s2.Digest = c2s2dec.Digest
err := c2s2.Read((bytes.NewReader(c2s2enc)))
err := c2s2.Read((bytes.NewReader(c2s2enc)), true)
require.NoError(t, err)
require.Equal(t, c2s2dec, c2s2)
}

View File

@@ -5,7 +5,7 @@ import (
)
// DoClient performs a client-side handshake.
func DoClient(rw io.ReadWriter) error {
func DoClient(rw io.ReadWriter, validateSignature bool) error {
err := C0S0{}.Write(rw)
if err != nil {
return err
@@ -23,12 +23,12 @@ func DoClient(rw io.ReadWriter) error {
}
s1 := C1S1{}
err = s1.Read(rw, false)
err = s1.Read(rw, false, validateSignature)
if err != nil {
return err
}
err = (&C2S2{Digest: c1.Digest}).Read(rw)
err = (&C2S2{Digest: c1.Digest}).Read(rw, validateSignature)
if err != nil {
return err
}
@@ -42,14 +42,14 @@ func DoClient(rw io.ReadWriter) error {
}
// DoServer performs a server-side handshake.
func DoServer(rw io.ReadWriter) error {
func DoServer(rw io.ReadWriter, validateSignature bool) error {
err := C0S0{}.Read(rw)
if err != nil {
return err
}
c1 := C1S1{}
err = c1.Read(rw, true)
err = c1.Read(rw, true, validateSignature)
if err != nil {
return err
}
@@ -70,7 +70,7 @@ func DoServer(rw io.ReadWriter) error {
return err
}
err = (&C2S2{Digest: s1.Digest}).Read(rw)
err = (&C2S2{Digest: s1.Digest}).Read(rw, validateSignature)
if err != nil {
return err
}

View File

@@ -19,7 +19,7 @@ func TestHandshake(t *testing.T) {
require.NoError(t, err)
defer conn.Close()
err = DoServer(conn)
err = DoServer(conn, true)
require.NoError(t, err)
close(done)
@@ -29,7 +29,7 @@ func TestHandshake(t *testing.T) {
require.NoError(t, err)
defer conn.Close()
err = DoClient(conn)
err = DoClient(conn, true)
require.NoError(t, err)
<-done

View File

@@ -1,7 +1,6 @@
package message
package message //nolint:dupl
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -23,18 +22,23 @@ func (m *MsgAcknowledge) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("unexpected body size")
}
m.Value = binary.BigEndian.Uint32(raw.Body)
m.Value = uint32(raw.Body[0])<<24 | uint32(raw.Body[1])<<16 | uint32(raw.Body[2])<<8 | uint32(raw.Body[3])
return nil
}
// Marshal implements Message.
func (m *MsgAcknowledge) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 4)
binary.BigEndian.PutUint32(body, m.Value)
buf := make([]byte, 4)
buf[0] = byte(m.Value >> 24)
buf[1] = byte(m.Value >> 16)
buf[2] = byte(m.Value >> 8)
buf[3] = byte(m.Value)
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeAcknowledge,
Body: body,
Body: buf,
}, nil
}

View File

@@ -2,6 +2,7 @@ package message
import (
"fmt"
"time"
"github.com/notedit/rtmp/format/flv/flvio"
@@ -12,7 +13,7 @@ import (
// MsgAudio is an audio message.
type MsgAudio struct {
ChunkStreamID byte
DTS uint32
DTS time.Duration
MessageStreamID uint32
Rate uint8
Depth uint8

View File

@@ -1,6 +1,8 @@
package message
import (
"fmt"
"github.com/notedit/rtmp/format/flv/flvio"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -11,7 +13,9 @@ import (
type MsgCommandAMF0 struct {
ChunkStreamID byte
MessageStreamID uint32
Payload []interface{}
Name string
CommandID int
Arguments []interface{}
}
// Unmarshal implements Message.
@@ -23,7 +27,24 @@ func (m *MsgCommandAMF0) Unmarshal(raw *rawmessage.Message) error {
if err != nil {
return err
}
m.Payload = payload
if len(payload) < 3 {
return fmt.Errorf("invalid command payload")
}
var ok bool
m.Name, ok = payload[0].(string)
if !ok {
return fmt.Errorf("invalid command payload")
}
tmp, ok := payload[1].(float64)
if !ok {
return fmt.Errorf("invalid command payload")
}
m.CommandID = int(tmp)
m.Arguments = payload[2:]
return nil
}
@@ -34,6 +55,9 @@ func (m MsgCommandAMF0) Marshal() (*rawmessage.Message, error) {
ChunkStreamID: m.ChunkStreamID,
Type: chunk.MessageTypeCommandAMF0,
MessageStreamID: m.MessageStreamID,
Body: flvio.FillAMF0ValsMalloc(m.Payload),
Body: flvio.FillAMF0ValsMalloc(append([]interface{}{
m.Name,
float64(m.CommandID),
}, m.Arguments...)),
}, nil
}

View File

@@ -1,7 +1,6 @@
package message
package message //nolint:dupl
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -23,18 +22,23 @@ func (m *MsgSetChunkSize) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("unexpected body size")
}
m.Value = binary.BigEndian.Uint32(raw.Body)
m.Value = uint32(raw.Body[0])<<24 | uint32(raw.Body[1])<<16 | uint32(raw.Body[2])<<8 | uint32(raw.Body[3])
return nil
}
// Marshal implements Message.
func (m *MsgSetChunkSize) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 4)
binary.BigEndian.PutUint32(body, m.Value)
buf := make([]byte, 4)
buf[0] = byte(m.Value >> 24)
buf[1] = byte(m.Value >> 16)
buf[2] = byte(m.Value >> 8)
buf[3] = byte(m.Value)
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeSetChunkSize,
Body: body,
Body: buf,
}, nil
}

View File

@@ -1,7 +1,6 @@
package message
package message //nolint:dupl
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -24,20 +23,25 @@ func (m *MsgSetPeerBandwidth) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("unexpected body size")
}
m.Value = binary.BigEndian.Uint32(raw.Body)
m.Value = uint32(raw.Body[0])<<24 | uint32(raw.Body[1])<<16 | uint32(raw.Body[2])<<8 | uint32(raw.Body[3])
m.Type = raw.Body[4]
return nil
}
// Marshal implements Message.
func (m *MsgSetPeerBandwidth) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 5)
binary.BigEndian.PutUint32(body, m.Value)
body[4] = m.Type
buf := make([]byte, 5)
buf[0] = byte(m.Value >> 24)
buf[1] = byte(m.Value >> 16)
buf[2] = byte(m.Value >> 8)
buf[3] = byte(m.Value)
buf[4] = m.Type
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeSetChunkSize,
Body: body,
Type: chunk.MessageTypeSetPeerBandwidth,
Body: buf,
}, nil
}

View File

@@ -1,7 +1,6 @@
package message
package message //nolint:dupl
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -23,18 +22,23 @@ func (m *MsgSetWindowAckSize) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("unexpected body size")
}
m.Value = binary.BigEndian.Uint32(raw.Body)
m.Value = uint32(raw.Body[0])<<24 | uint32(raw.Body[1])<<16 | uint32(raw.Body[2])<<8 | uint32(raw.Body[3])
return nil
}
// Marshal implements Message.
func (m *MsgSetWindowAckSize) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 4)
binary.BigEndian.PutUint32(body, m.Value)
buf := make([]byte, 4)
buf[0] = byte(m.Value >> 24)
buf[1] = byte(m.Value >> 16)
buf[2] = byte(m.Value >> 8)
buf[3] = byte(m.Value)
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeSetWindowAckSize,
Body: body,
Body: buf,
}, nil
}

View File

@@ -1,7 +1,6 @@
package message
package message //nolint:dupl
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -23,20 +22,25 @@ func (m *MsgUserControlPingRequest) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("invalid body size")
}
m.ServerTime = binary.BigEndian.Uint32(raw.Body[2:])
m.ServerTime = uint32(raw.Body[2])<<24 | uint32(raw.Body[3])<<16 | uint32(raw.Body[4])<<8 | uint32(raw.Body[5])
return nil
}
// Marshal implements Message.
func (m MsgUserControlPingRequest) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 6)
binary.BigEndian.PutUint16(body, UserControlTypePingRequest)
binary.BigEndian.PutUint32(body[2:], m.ServerTime)
buf := make([]byte, 6)
buf[0] = byte(UserControlTypePingRequest >> 8)
buf[1] = byte(UserControlTypePingRequest)
buf[2] = byte(m.ServerTime >> 24)
buf[3] = byte(m.ServerTime >> 16)
buf[4] = byte(m.ServerTime >> 8)
buf[5] = byte(m.ServerTime)
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeUserControl,
Body: body,
Body: buf,
}, nil
}

View File

@@ -1,7 +1,6 @@
package message
package message //nolint:dupl
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -23,20 +22,25 @@ func (m *MsgUserControlPingResponse) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("invalid body size")
}
m.ServerTime = binary.BigEndian.Uint32(raw.Body[2:])
m.ServerTime = uint32(raw.Body[2])<<24 | uint32(raw.Body[3])<<16 | uint32(raw.Body[4])<<8 | uint32(raw.Body[5])
return nil
}
// Marshal implements Message.
func (m MsgUserControlPingResponse) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 6)
binary.BigEndian.PutUint16(body, UserControlTypePingResponse)
binary.BigEndian.PutUint32(body[2:], m.ServerTime)
buf := make([]byte, 6)
buf[0] = byte(UserControlTypePingResponse >> 8)
buf[1] = byte(UserControlTypePingResponse)
buf[2] = byte(m.ServerTime >> 24)
buf[3] = byte(m.ServerTime >> 16)
buf[4] = byte(m.ServerTime >> 8)
buf[5] = byte(m.ServerTime)
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeUserControl,
Body: body,
Body: buf,
}, nil
}

View File

@@ -1,7 +1,6 @@
package message
package message //nolint:dupl
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -24,22 +23,30 @@ func (m *MsgUserControlSetBufferLength) Unmarshal(raw *rawmessage.Message) error
return fmt.Errorf("invalid body size")
}
m.StreamID = binary.BigEndian.Uint32(raw.Body[2:])
m.BufferLength = binary.BigEndian.Uint32(raw.Body[6:])
m.StreamID = uint32(raw.Body[2])<<24 | uint32(raw.Body[3])<<16 | uint32(raw.Body[4])<<8 | uint32(raw.Body[5])
m.BufferLength = uint32(raw.Body[6])<<24 | uint32(raw.Body[7])<<16 | uint32(raw.Body[8])<<8 | uint32(raw.Body[9])
return nil
}
// Marshal implements Message.
func (m MsgUserControlSetBufferLength) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 10)
binary.BigEndian.PutUint16(body, UserControlTypeSetBufferLength)
binary.BigEndian.PutUint32(body[2:], m.StreamID)
binary.BigEndian.PutUint32(body[6:], m.BufferLength)
buf := make([]byte, 10)
buf[0] = byte(UserControlTypeSetBufferLength >> 8)
buf[1] = byte(UserControlTypeSetBufferLength)
buf[2] = byte(m.StreamID >> 24)
buf[3] = byte(m.StreamID >> 16)
buf[4] = byte(m.StreamID >> 8)
buf[5] = byte(m.StreamID)
buf[6] = byte(m.BufferLength >> 24)
buf[7] = byte(m.BufferLength >> 16)
buf[8] = byte(m.BufferLength >> 8)
buf[9] = byte(m.BufferLength)
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeUserControl,
Body: body,
Body: buf,
}, nil
}

View File

@@ -1,7 +1,6 @@
package message
package message //nolint:dupl
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -23,20 +22,25 @@ func (m *MsgUserControlStreamBegin) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("invalid body size")
}
m.StreamID = binary.BigEndian.Uint32(raw.Body[2:])
m.StreamID = uint32(raw.Body[2])<<24 | uint32(raw.Body[3])<<16 | uint32(raw.Body[4])<<8 | uint32(raw.Body[5])
return nil
}
// Marshal implements Message.
func (m MsgUserControlStreamBegin) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 6)
binary.BigEndian.PutUint16(body, UserControlTypeStreamBegin)
binary.BigEndian.PutUint32(body[2:], m.StreamID)
buf := make([]byte, 6)
buf[0] = byte(UserControlTypeStreamBegin >> 8)
buf[1] = byte(UserControlTypeStreamBegin)
buf[2] = byte(m.StreamID >> 24)
buf[3] = byte(m.StreamID >> 16)
buf[4] = byte(m.StreamID >> 8)
buf[5] = byte(m.StreamID)
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeUserControl,
Body: body,
Body: buf,
}, nil
}

View File

@@ -1,7 +1,6 @@
package message
package message //nolint:dupl
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -23,20 +22,25 @@ func (m *MsgUserControlStreamDry) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("invalid body size")
}
m.StreamID = binary.BigEndian.Uint32(raw.Body[2:])
m.StreamID = uint32(raw.Body[2])<<24 | uint32(raw.Body[3])<<16 | uint32(raw.Body[4])<<8 | uint32(raw.Body[5])
return nil
}
// Marshal implements Message.
func (m MsgUserControlStreamDry) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 6)
binary.BigEndian.PutUint16(body, UserControlTypeStreamDry)
binary.BigEndian.PutUint32(body[2:], m.StreamID)
buf := make([]byte, 6)
buf[0] = byte(UserControlTypeStreamDry >> 8)
buf[1] = byte(UserControlTypeStreamDry)
buf[2] = byte(m.StreamID >> 24)
buf[3] = byte(m.StreamID >> 16)
buf[4] = byte(m.StreamID >> 8)
buf[5] = byte(m.StreamID)
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeUserControl,
Body: body,
Body: buf,
}, nil
}

View File

@@ -1,7 +1,6 @@
package message
package message //nolint:dupl
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -23,20 +22,25 @@ func (m *MsgUserControlStreamEOF) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("invalid body size")
}
m.StreamID = binary.BigEndian.Uint32(raw.Body[2:])
m.StreamID = uint32(raw.Body[2])<<24 | uint32(raw.Body[3])<<16 | uint32(raw.Body[4])<<8 | uint32(raw.Body[5])
return nil
}
// Marshal implements Message.
func (m MsgUserControlStreamEOF) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 6)
binary.BigEndian.PutUint16(body, UserControlTypeStreamEOF)
binary.BigEndian.PutUint32(body[2:], m.StreamID)
buf := make([]byte, 6)
buf[0] = byte(UserControlTypeStreamEOF >> 8)
buf[1] = byte(UserControlTypeStreamEOF)
buf[2] = byte(m.StreamID >> 24)
buf[3] = byte(m.StreamID >> 16)
buf[4] = byte(m.StreamID >> 8)
buf[5] = byte(m.StreamID)
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeUserControl,
Body: body,
Body: buf,
}, nil
}

View File

@@ -1,7 +1,6 @@
package message
package message //nolint:dupl
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -23,20 +22,25 @@ func (m *MsgUserControlStreamIsRecorded) Unmarshal(raw *rawmessage.Message) erro
return fmt.Errorf("invalid body size")
}
m.StreamID = binary.BigEndian.Uint32(raw.Body[2:])
m.StreamID = uint32(raw.Body[2])<<24 | uint32(raw.Body[3])<<16 | uint32(raw.Body[4])<<8 | uint32(raw.Body[5])
return nil
}
// Marshal implements Message.
func (m MsgUserControlStreamIsRecorded) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 6)
binary.BigEndian.PutUint16(body, UserControlTypeStreamIsRecorded)
binary.BigEndian.PutUint32(body[2:], m.StreamID)
buf := make([]byte, 6)
buf[0] = byte(UserControlTypeStreamIsRecorded >> 8)
buf[1] = byte(UserControlTypeStreamIsRecorded)
buf[2] = byte(m.StreamID >> 24)
buf[3] = byte(m.StreamID >> 16)
buf[4] = byte(m.StreamID >> 8)
buf[5] = byte(m.StreamID)
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeUserControl,
Body: body,
Body: buf,
}, nil
}

View File

@@ -2,6 +2,7 @@ package message
import (
"fmt"
"time"
"github.com/notedit/rtmp/format/flv/flvio"
@@ -12,11 +13,11 @@ import (
// MsgVideo is a video message.
type MsgVideo struct {
ChunkStreamID byte
DTS uint32
DTS time.Duration
MessageStreamID uint32
IsKeyFrame bool
H264Type uint8
PTSDelta uint32
PTSDelta time.Duration
Payload []byte
}
@@ -38,7 +39,10 @@ func (m *MsgVideo) Unmarshal(raw *rawmessage.Message) error {
}
m.H264Type = raw.Body[1]
m.PTSDelta = uint32(raw.Body[2])<<16 | uint32(raw.Body[3])<<8 | uint32(raw.Body[4])
tmp := uint32(raw.Body[2])<<16 | uint32(raw.Body[3])<<8 | uint32(raw.Body[4])
m.PTSDelta = time.Duration(tmp) * time.Millisecond
m.Payload = raw.Body[5:]
return nil
@@ -55,9 +59,12 @@ func (m MsgVideo) Marshal() (*rawmessage.Message, error) {
}
body[0] |= flvio.VIDEO_H264
body[1] = m.H264Type
body[2] = uint8(m.PTSDelta >> 16)
body[3] = uint8(m.PTSDelta >> 8)
body[4] = uint8(m.PTSDelta)
tmp := uint32(m.PTSDelta / time.Millisecond)
body[2] = uint8(tmp >> 16)
body[3] = uint8(tmp >> 8)
body[4] = uint8(tmp)
copy(body[5:], m.Payload)
return &rawmessage.Message{

View File

@@ -1,7 +1,6 @@
package message
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
@@ -28,7 +27,7 @@ func allocateMessage(raw *rawmessage.Message) (Message, error) {
return nil, fmt.Errorf("invalid body size")
}
subType := binary.BigEndian.Uint16(raw.Body)
subType := uint16(raw.Body[0])<<8 | uint16(raw.Body[1])
switch subType {
case UserControlTypeStreamBegin:
return &MsgUserControlStreamBegin{}, nil
@@ -68,7 +67,7 @@ func allocateMessage(raw *rawmessage.Message) (Message, error) {
return &MsgVideo{}, nil
default:
return nil, fmt.Errorf("unhandled message")
return nil, fmt.Errorf("unhandled message type (%v)", raw.Type)
}
}

View File

@@ -0,0 +1,227 @@
package message
import (
"bytes"
"testing"
"time"
"github.com/notedit/rtmp/format/flv/flvio"
"github.com/stretchr/testify/require"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
)
var readWriterCases = []struct {
name string
dec Message
enc []byte
}{
{
"acknowledge",
&MsgAcknowledge{
Value: 45953968,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x3,
0x0, 0x0, 0x0, 0x0, 0x2, 0xbd, 0x33, 0xb0,
},
},
{
"audio",
&MsgAudio{
ChunkStreamID: 7,
DTS: 6013806 * time.Millisecond,
MessageStreamID: 4534543,
Rate: flvio.SOUND_44Khz,
Depth: flvio.SOUND_16BIT,
Channels: flvio.SOUND_STEREO,
AACType: flvio.AAC_RAW,
Payload: []byte{0x5A, 0xC0, 0x77, 0x40},
},
[]byte{
0x7, 0x5b, 0xc3, 0x6e, 0x0, 0x0, 0x6, 0x8,
0x0, 0x45, 0x31, 0xf, 0xaf, 0x1, 0x5a, 0xc0,
0x77, 0x40,
},
},
{
"command amf0",
&MsgCommandAMF0{
ChunkStreamID: 3,
MessageStreamID: 345243,
Name: "i8yythrergre",
CommandID: 56456,
Arguments: []interface{}{
flvio.AMFMap{
{K: "k1", V: "v1"},
{K: "k2", V: "v2"},
},
nil,
},
},
[]byte{
0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2f, 0x14,
0x0, 0x5, 0x44, 0x9b, 0x2, 0x0, 0xc, 0x69,
0x38, 0x79, 0x79, 0x74, 0x68, 0x72, 0x65, 0x72,
0x67, 0x72, 0x65, 0x0, 0x40, 0xeb, 0x91, 0x0,
0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x2, 0x6b,
0x31, 0x2, 0x0, 0x2, 0x76, 0x31, 0x0, 0x2,
0x6b, 0x32, 0x2, 0x0, 0x2, 0x76, 0x32, 0x0,
0x0, 0x9, 0x5,
},
},
{
"data amf0",
&MsgDataAMF0{
ChunkStreamID: 3,
MessageStreamID: 345243,
Payload: []interface{}{
float64(234),
"string",
nil,
},
},
[]byte{
0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x13, 0x12,
0x0, 0x5, 0x44, 0x9b, 0x0, 0x40, 0x6d, 0x40,
0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x6,
0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x05,
},
},
{
"set chunk size",
&MsgSetChunkSize{
Value: 10000,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x1,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x27, 0x10,
},
},
{
"set peer bandwidth",
&MsgSetChunkSize{
Value: 10000,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x1,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x27, 0x10,
},
},
{
"set window ack size",
&MsgSetChunkSize{
Value: 10000,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x1,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x27, 0x10,
},
},
{
"user control ping request",
&MsgUserControlPingRequest{
ServerTime: 569834435,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x6, 0x4,
0x0, 0x0, 0x0, 0x0, 0x0, 0x6, 0x21, 0xf6,
0xfb, 0xc3,
},
},
{
"user control ping response",
&MsgUserControlPingResponse{
ServerTime: 569834435,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x6, 0x4,
0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x21, 0xf6,
0xfb, 0xc3,
},
},
{
"user control set buffer length",
&MsgUserControlSetBufferLength{
StreamID: 35534,
BufferLength: 235345,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa, 0x4,
0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0,
0x8a, 0xce, 0x0, 0x3, 0x97, 0x51,
},
},
{
"user control stream begin",
&MsgUserControlStreamBegin{
StreamID: 35534,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x6, 0x4,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
0x8a, 0xce,
},
},
{
"user control stream dry",
&MsgUserControlStreamDry{
StreamID: 35534,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x6, 0x4,
0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0,
0x8a, 0xce,
},
},
{
"user control stream eof",
&MsgUserControlStreamEOF{
StreamID: 35534,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x6, 0x4,
0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0,
0x8a, 0xce,
},
},
{
"user control stream is recorded",
&MsgUserControlStreamIsRecorded{
StreamID: 35534,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x6, 0x4,
0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x0, 0x0,
0x8a, 0xce,
},
},
{
"video",
&MsgVideo{
ChunkStreamID: 6,
DTS: 2543534 * time.Millisecond,
MessageStreamID: 0x1000000,
IsKeyFrame: true,
H264Type: flvio.AVC_SEQHDR,
PTSDelta: 10 * time.Millisecond,
Payload: []byte{0x01, 0x02, 0x03},
},
[]byte{
0x6, 0x26, 0xcf, 0xae, 0x0, 0x0, 0x8, 0x9,
0x1, 0x0, 0x0, 0x0, 0x17, 0x0, 0x0, 0x0,
0xa, 0x1, 0x2, 0x3,
},
},
}
func TestReader(t *testing.T) {
for _, ca := range readWriterCases {
t.Run(ca.name, func(t *testing.T) {
r := NewReader(bytecounter.NewReader(bytes.NewReader(ca.enc)), nil)
dec, err := r.Read()
require.NoError(t, err)
require.Equal(t, ca.dec, dec)
})
}
}

View File

@@ -11,8 +11,8 @@ type ReadWriter struct {
}
// NewReadWriter allocates a ReadWriter.
func NewReadWriter(bc *bytecounter.ReadWriter) *ReadWriter {
w := NewWriter(bc.Writer)
func NewReadWriter(bc *bytecounter.ReadWriter, checkAcknowledge bool) *ReadWriter {
w := NewWriter(bc.Writer, checkAcknowledge)
r := NewReader(bc.Reader, func(count uint32) error {
return w.Write(&MsgAcknowledge{

View File

@@ -11,9 +11,9 @@ type Writer struct {
}
// NewWriter allocates a Writer.
func NewWriter(w *bytecounter.Writer) *Writer {
func NewWriter(w *bytecounter.Writer, checkAcknowledge bool) *Writer {
return &Writer{
w: rawmessage.NewWriter(w),
w: rawmessage.NewWriter(w, checkAcknowledge),
}
}

View File

@@ -0,0 +1,22 @@
package message
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
)
func TestWriter(t *testing.T) {
for _, ca := range readWriterCases {
t.Run(ca.name, func(t *testing.T) {
var buf bytes.Buffer
r := NewWriter(bytecounter.NewWriter(&buf), true)
err := r.Write(ca.dec)
require.NoError(t, err)
require.Equal(t, ca.enc, buf.Bytes())
})
}
}

View File

@@ -1,13 +1,15 @@
package rawmessage
import (
"time"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
)
// Message is a raw message.
type Message struct {
ChunkStreamID byte
Timestamp uint32
Timestamp time.Duration
Type chunk.MessageType
MessageStreamID uint32
Body []byte

View File

@@ -3,6 +3,7 @@ package rawmessage
import (
"errors"
"fmt"
"time"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -73,7 +74,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
}
return &Message{
Timestamp: c0.Timestamp,
Timestamp: time.Duration(c0.Timestamp) * time.Millisecond,
Type: c0.Type,
MessageStreamID: c0.MessageStreamID,
Body: c0.Body,
@@ -109,7 +110,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
}
return &Message{
Timestamp: *rc.curTimestamp,
Timestamp: time.Duration(*rc.curTimestamp) * time.Millisecond,
Type: c1.Type,
MessageStreamID: *rc.curMessageStreamID,
Body: c1.Body,
@@ -124,7 +125,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
return nil, fmt.Errorf("received type 2 chunk but expected type 3 chunk")
}
chunkBodyLen := (*rc.curBodyLen)
chunkBodyLen := *rc.curBodyLen
if chunkBodyLen > rc.mr.chunkSize {
chunkBodyLen = rc.mr.chunkSize
}
@@ -140,13 +141,13 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
v2 := c2.TimestampDelta
rc.curTimestampDelta = &v2
if chunkBodyLen != uint32(len(c2.Body)) {
if *rc.curBodyLen != uint32(len(c2.Body)) {
rc.curBody = &c2.Body
return nil, errMoreChunksNeeded
}
return &Message{
Timestamp: *rc.curTimestamp,
Timestamp: time.Duration(*rc.curTimestamp) * time.Millisecond,
Type: *rc.curType,
MessageStreamID: *rc.curMessageStreamID,
Body: c2.Body,
@@ -179,7 +180,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
rc.curBody = nil
return &Message{
Timestamp: *rc.curTimestamp,
Timestamp: time.Duration(*rc.curTimestamp) * time.Millisecond,
Type: *rc.curType,
MessageStreamID: *rc.curMessageStreamID,
Body: body,
@@ -201,7 +202,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
rc.curTimestamp = &v1
return &Message{
Timestamp: *rc.curTimestamp,
Timestamp: time.Duration(*rc.curTimestamp) * time.Millisecond,
Type: *rc.curType,
MessageStreamID: *rc.curMessageStreamID,
Body: c3.Body,

View File

@@ -3,6 +3,7 @@ package rawmessage
import (
"bytes"
"testing"
"time"
"github.com/stretchr/testify/require"
@@ -10,31 +11,19 @@ import (
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
)
type sequenceEntry struct {
func TestReader(t *testing.T) {
type sequenceEntry struct {
chunk chunk.Chunk
msg *Message
}
func TestReader(t *testing.T) {
testSequence := func(t *testing.T, seq []sequenceEntry) {
var buf bytes.Buffer
bcr := bytecounter.NewReader(&buf)
r := NewReader(bcr, func(count uint32) error {
return nil
})
for _, entry := range seq {
buf2, err := entry.chunk.Marshal()
require.NoError(t, err)
buf.Write(buf2)
msg, err := r.Read()
require.NoError(t, err)
require.Equal(t, entry.msg, msg)
}
}
t.Run("chunk0 + chunk1", func(t *testing.T) {
testSequence(t, []sequenceEntry{
for _, ca := range []struct {
name string
sequence []sequenceEntry
}{
{
"chunk0 + chunk1",
[]sequenceEntry{
{
&chunk.Chunk0{
ChunkStreamID: 27,
@@ -46,7 +35,7 @@ func TestReader(t *testing.T) {
},
&Message{
ChunkStreamID: 27,
Timestamp: 18576,
Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x02}, 64),
@@ -62,17 +51,17 @@ func TestReader(t *testing.T) {
},
&Message{
ChunkStreamID: 27,
Timestamp: 18576 + 15,
Timestamp: (18576 + 15) * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 64),
},
},
})
})
t.Run("chunk0 + chunk2 + chunk3", func(t *testing.T) {
testSequence(t, []sequenceEntry{
},
},
{
"chunk0 + chunk2 + chunk3",
[]sequenceEntry{
{
&chunk.Chunk0{
ChunkStreamID: 27,
@@ -84,7 +73,7 @@ func TestReader(t *testing.T) {
},
&Message{
ChunkStreamID: 27,
Timestamp: 18576,
Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x02}, 64),
@@ -98,7 +87,7 @@ func TestReader(t *testing.T) {
},
&Message{
ChunkStreamID: 27,
Timestamp: 18576 + 15,
Timestamp: (18576 + 15) * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 64),
@@ -111,50 +100,85 @@ func TestReader(t *testing.T) {
},
&Message{
ChunkStreamID: 27,
Timestamp: 18576 + 15 + 15,
Timestamp: (18576 + 15 + 15) * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x04}, 64),
},
},
})
})
t.Run("chunk0 + chunk3", func(t *testing.T) {
var buf bytes.Buffer
bcr := bytecounter.NewReader(&buf)
r := NewReader(bcr, func(count uint32) error {
return nil
})
buf2, err := chunk.Chunk0{
},
},
{
"chunk0 + chunk3 + chunk2 + chunk3",
[]sequenceEntry{
{
&chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 192,
Body: bytes.Repeat([]byte{0x03}, 128),
}.Marshal()
require.NoError(t, err)
buf.Write(buf2)
buf2, err = chunk.Chunk3{
},
nil,
},
{
&chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x03}, 64),
}.Marshal()
require.NoError(t, err)
buf.Write(buf2)
msg, err := r.Read()
require.NoError(t, err)
require.Equal(t, &Message{
},
&Message{
ChunkStreamID: 27,
Timestamp: 18576,
Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 192),
}, msg)
},
},
{
&chunk.Chunk2{
ChunkStreamID: 27,
TimestampDelta: 15,
Body: bytes.Repeat([]byte{0x04}, 128),
},
nil,
},
{
&chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x04}, 64),
},
&Message{
ChunkStreamID: 27,
Timestamp: 18591 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x04}, 192),
},
},
},
},
} {
t.Run(ca.name, func(t *testing.T) {
var buf bytes.Buffer
bcr := bytecounter.NewReader(&buf)
r := NewReader(bcr, func(count uint32) error {
return nil
})
for _, entry := range ca.sequence {
buf2, err := entry.chunk.Marshal()
require.NoError(t, err)
buf.Write(buf2)
if entry.msg != nil {
msg, err := r.Read()
require.NoError(t, err)
require.Equal(t, entry.msg, msg)
}
}
})
}
}
func TestReaderAcknowledge(t *testing.T) {

View File

@@ -2,6 +2,7 @@ package rawmessage
import (
"fmt"
"time"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -12,14 +13,14 @@ type writerChunkStream struct {
lastMessageStreamID *uint32
lastType *chunk.MessageType
lastBodyLen *uint32
lastTimestamp *uint32
lastTimestampDelta *uint32
lastTimestamp *time.Duration
lastTimestampDelta *time.Duration
}
func (wc *writerChunkStream) writeChunk(c chunk.Chunk) error {
// check if we received an acknowledge
if wc.mw.ackWindowSize != 0 {
diff := wc.mw.w.Count() - (wc.mw.ackValue)
if wc.mw.checkAcknowledge && wc.mw.ackWindowSize != 0 {
diff := wc.mw.w.Count() - wc.mw.ackValue
if diff > (wc.mw.ackWindowSize * 3 / 2) {
return fmt.Errorf("no acknowledge received within window")
@@ -44,14 +45,13 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error {
pos := uint32(0)
firstChunk := true
var timestampDelta *uint32
var timestampDelta *time.Duration
if wc.lastTimestamp != nil {
diff := int64(msg.Timestamp) - int64(*wc.lastTimestamp)
diff := msg.Timestamp - *wc.lastTimestamp
// use delta only if it is positive
if diff >= 0 {
v := uint32(diff)
timestampDelta = &v
timestampDelta = &diff
}
}
@@ -68,7 +68,7 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error {
case wc.lastMessageStreamID == nil || timestampDelta == nil || *wc.lastMessageStreamID != msg.MessageStreamID:
err := wc.writeChunk(&chunk.Chunk0{
ChunkStreamID: msg.ChunkStreamID,
Timestamp: msg.Timestamp,
Timestamp: uint32(msg.Timestamp / time.Millisecond),
Type: msg.Type,
MessageStreamID: msg.MessageStreamID,
BodyLen: (bodyLen),
@@ -81,7 +81,7 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error {
case *wc.lastType != msg.Type || *wc.lastBodyLen != bodyLen:
err := wc.writeChunk(&chunk.Chunk1{
ChunkStreamID: msg.ChunkStreamID,
TimestampDelta: *timestampDelta,
TimestampDelta: uint32(*timestampDelta / time.Millisecond),
Type: msg.Type,
BodyLen: (bodyLen),
Body: msg.Body[pos : pos+chunkBodyLen],
@@ -93,7 +93,7 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error {
case wc.lastTimestampDelta == nil || *wc.lastTimestampDelta != *timestampDelta:
err := wc.writeChunk(&chunk.Chunk2{
ChunkStreamID: msg.ChunkStreamID,
TimestampDelta: *timestampDelta,
TimestampDelta: uint32(*timestampDelta / time.Millisecond),
Body: msg.Body[pos : pos+chunkBodyLen],
})
if err != nil {
@@ -144,6 +144,7 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error {
// Writer is a raw message writer.
type Writer struct {
w *bytecounter.Writer
checkAcknowledge bool
chunkSize uint32
ackWindowSize uint32
ackValue uint32
@@ -151,9 +152,10 @@ type Writer struct {
}
// NewWriter allocates a Writer.
func NewWriter(w *bytecounter.Writer) *Writer {
func NewWriter(w *bytecounter.Writer, checkAcknowledge bool) *Writer {
return &Writer{
w: w,
checkAcknowledge: checkAcknowledge,
chunkSize: 128,
chunkStreams: make(map[byte]*writerChunkStream),
}

View File

@@ -2,7 +2,9 @@ package rawmessage
import (
"bytes"
"reflect"
"testing"
"time"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -10,146 +12,168 @@ import (
)
func TestWriter(t *testing.T) {
t.Run("chunk0 + chunk1", func(t *testing.T) {
var buf bytes.Buffer
w := NewWriter(bytecounter.NewWriter(&buf))
err := w.Write(&Message{
for _, ca := range []struct {
name string
messages []*Message
chunks []chunk.Chunk
chunkSizes []uint32
}{
{
"chunk0 + chunk1",
[]*Message{
{
ChunkStreamID: 27,
Timestamp: 18576,
Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 64),
})
require.NoError(t, err)
var c0 chunk.Chunk0
err = c0.Read(&buf, 128)
require.NoError(t, err)
require.Equal(t, chunk.Chunk0{
},
{
ChunkStreamID: 27,
Timestamp: (18576 + 15) * time.Millisecond,
Type: chunk.MessageTypeSetWindowAckSize,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x04}, 64),
},
},
[]chunk.Chunk{
&chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x03}, 64),
}, c0)
err = w.Write(&Message{
ChunkStreamID: 27,
Timestamp: 18576 + 15,
Type: chunk.MessageTypeSetWindowAckSize,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x04}, 64),
})
require.NoError(t, err)
var c1 chunk.Chunk1
err = c1.Read(&buf, 128)
require.NoError(t, err)
require.Equal(t, chunk.Chunk1{
},
&chunk.Chunk1{
ChunkStreamID: 27,
TimestampDelta: 15,
Type: chunk.MessageTypeSetWindowAckSize,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x04}, 64),
}, c1)
})
t.Run("chunk0 + chunk2 + chunk3", func(t *testing.T) {
var buf bytes.Buffer
w := NewWriter(bytecounter.NewWriter(&buf))
err := w.Write(&Message{
},
},
[]uint32{
128,
128,
},
},
{
"chunk0 + chunk2 + chunk3",
[]*Message{
{
ChunkStreamID: 27,
Timestamp: 18576,
Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 64),
})
require.NoError(t, err)
var c0 chunk.Chunk0
err = c0.Read(&buf, 128)
require.NoError(t, err)
require.Equal(t, chunk.Chunk0{
},
{
ChunkStreamID: 27,
Timestamp: (18576 + 15) * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x04}, 64),
},
{
ChunkStreamID: 27,
Timestamp: (18576 + 15 + 15) * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x05}, 64),
},
},
[]chunk.Chunk{
&chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x03}, 64),
}, c0)
err = w.Write(&Message{
ChunkStreamID: 27,
Timestamp: 18576 + 15,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x04}, 64),
})
require.NoError(t, err)
var c2 chunk.Chunk2
err = c2.Read(&buf, 64)
require.NoError(t, err)
require.Equal(t, chunk.Chunk2{
},
&chunk.Chunk2{
ChunkStreamID: 27,
TimestampDelta: 15,
Body: bytes.Repeat([]byte{0x04}, 64),
}, c2)
err = w.Write(&Message{
ChunkStreamID: 27,
Timestamp: 18576 + 15 + 15,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x05}, 64),
})
require.NoError(t, err)
var c3 chunk.Chunk3
err = c3.Read(&buf, 64)
require.NoError(t, err)
require.Equal(t, chunk.Chunk3{
},
&chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x05}, 64),
}, c3)
})
t.Run("chunk0 + chunk3", func(t *testing.T) {
var buf bytes.Buffer
w := NewWriter(bytecounter.NewWriter(&buf))
err := w.Write(&Message{
},
},
[]uint32{
128,
64,
64,
},
},
{
"chunk0 + chunk3 + chunk2 + chunk3",
[]*Message{
{
ChunkStreamID: 27,
Timestamp: 18576,
Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 192),
})
require.NoError(t, err)
var c0 chunk.Chunk0
err = c0.Read(&buf, 128)
require.NoError(t, err)
require.Equal(t, chunk.Chunk0{
},
{
ChunkStreamID: 27,
Timestamp: 18591 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x04}, 192),
},
},
[]chunk.Chunk{
&chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 192,
Body: bytes.Repeat([]byte{0x03}, 128),
}, c0)
var c3 chunk.Chunk3
err = c3.Read(&buf, 64)
require.NoError(t, err)
require.Equal(t, chunk.Chunk3{
},
&chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x03}, 64),
}, c3)
},
&chunk.Chunk2{
ChunkStreamID: 27,
TimestampDelta: 15,
Body: bytes.Repeat([]byte{0x04}, 128),
},
&chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x04}, 64),
},
},
[]uint32{
128,
64,
128,
64,
},
},
} {
t.Run(ca.name, func(t *testing.T) {
var buf bytes.Buffer
w := NewWriter(bytecounter.NewWriter(&buf), true)
for _, msg := range ca.messages {
err := w.Write(msg)
require.NoError(t, err)
}
for i, cach := range ca.chunks {
ch := reflect.New(reflect.TypeOf(cach).Elem()).Interface().(chunk.Chunk)
err := ch.Read(&buf, ca.chunkSizes[i])
require.NoError(t, err)
require.Equal(t, cach, ch)
}
})
}
}
func TestWriterAcknowledge(t *testing.T) {
@@ -157,7 +181,7 @@ func TestWriterAcknowledge(t *testing.T) {
t.Run(ca, func(t *testing.T) {
var buf bytes.Buffer
bcw := bytecounter.NewWriter(&buf)
w := NewWriter(bcw)
w := NewWriter(bcw, true)
if ca == "overflow" {
bcw.SetCount(4294967096)
@@ -169,7 +193,7 @@ func TestWriterAcknowledge(t *testing.T) {
err := w.Write(&Message{
ChunkStreamID: 27,
Timestamp: 18576,
Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 200),
@@ -178,7 +202,7 @@ func TestWriterAcknowledge(t *testing.T) {
err = w.Write(&Message{
ChunkStreamID: 27,
Timestamp: 18576,
Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 200),