rtmp: rewrite implementation of rtmp connection (#1047)

* rtmp: improve MsgCommandAMF0

* rtmp: fix MsgSetPeerBandwidth

* rtmp: add message tests

* rtmp: replace implementation with new one

* rtmp: rename handshake functions

* rtmp: avoid calling useless function

* rtmp: use time.Duration for PTSDelta

* rtmp: fix decoding chunks with relevant size

* rtmp: rewrite implementation of rtmp connection

* rtmp: fix tests

* rtmp: improve error message

* rtmp: replace h264 config implementation

* link against github.com/notedit/rtmp

* normalize MessageStreamID

* rtmp: make acknowledge optional

* rtmp: fix decoding of chunk2 + chunk3

* avoid using encoding/binary
This commit is contained in:
Alessandro Ros
2022-07-17 15:17:18 +02:00
committed by GitHub
parent 50d205274f
commit 9e6abc6e9f
45 changed files with 2045 additions and 1064 deletions

6
go.mod
View File

@@ -5,14 +5,14 @@ go 1.17
require ( require (
code.cloudfoundry.org/bytefmt v0.0.0-20211005130812-5bb3c17173e5 code.cloudfoundry.org/bytefmt v0.0.0-20211005130812-5bb3c17173e5
github.com/abema/go-mp4 v0.7.2 github.com/abema/go-mp4 v0.7.2
github.com/aler9/gortsplib v0.0.0-20220709151311-234e4f4f8d6f github.com/aler9/gortsplib v0.0.0-20220717125404-c6972424d6b8
github.com/asticode/go-astits v1.10.1-0.20220319093903-4abe66a9b757 github.com/asticode/go-astits v1.10.1-0.20220319093903-4abe66a9b757
github.com/fsnotify/fsnotify v1.4.9 github.com/fsnotify/fsnotify v1.4.9
github.com/gin-gonic/gin v1.8.1 github.com/gin-gonic/gin v1.8.1
github.com/gookit/color v1.4.2 github.com/gookit/color v1.4.2
github.com/grafov/m3u8 v0.11.1 github.com/grafov/m3u8 v0.11.1
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51
github.com/notedit/rtmp v0.0.0 github.com/notedit/rtmp v0.0.2
github.com/orcaman/writerseeker v0.0.0 github.com/orcaman/writerseeker v0.0.0
github.com/pion/rtp v1.7.13 github.com/pion/rtp v1.7.13
github.com/stretchr/testify v1.7.1 github.com/stretchr/testify v1.7.1
@@ -51,6 +51,4 @@ require (
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
) )
replace github.com/notedit/rtmp => github.com/aler9/rtmp v0.0.0-20210403095203-3be4a5535927
replace github.com/orcaman/writerseeker => github.com/aler9/writerseeker v0.0.0-20220601075008-6f0e685b9c82 replace github.com/orcaman/writerseeker => github.com/aler9/writerseeker v0.0.0-20220601075008-6f0e685b9c82

8
go.sum
View File

@@ -6,10 +6,8 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafo
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d h1:UQZhZ2O0vMHr2cI+DC1Mbh0TJxzA3RcLoMsFw+aXw7E= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d h1:UQZhZ2O0vMHr2cI+DC1Mbh0TJxzA3RcLoMsFw+aXw7E=
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho=
github.com/aler9/gortsplib v0.0.0-20220709151311-234e4f4f8d6f h1:EC+MOSv3e8ZEvtdHoL1++HahNoiVIkvu2Ygjrx6LyOg= github.com/aler9/gortsplib v0.0.0-20220717125404-c6972424d6b8 h1:GdQOJFYbcrw8bXGClhroHTBIEJAb/jPCIV33Q966rms=
github.com/aler9/gortsplib v0.0.0-20220709151311-234e4f4f8d6f/go.mod h1:WI3nMhY2mM6nfoeW9uyk7TyG5Qr6YnYxmFoCply0sbo= github.com/aler9/gortsplib v0.0.0-20220717125404-c6972424d6b8/go.mod h1:WI3nMhY2mM6nfoeW9uyk7TyG5Qr6YnYxmFoCply0sbo=
github.com/aler9/rtmp v0.0.0-20210403095203-3be4a5535927 h1:95mXJ5fUCYpBRdSOnLAQAdJHHKxxxJrVCiaqDi965YQ=
github.com/aler9/rtmp v0.0.0-20210403095203-3be4a5535927/go.mod h1:vzuE21rowz+lT1NGsWbreIvYulgBpCGnQyeTyFblUHc=
github.com/aler9/writerseeker v0.0.0-20220601075008-6f0e685b9c82 h1:9WgSzBLo3a9ToSVV7sRTBYZ1GGOZUpq4+5H3SN0UZq4= github.com/aler9/writerseeker v0.0.0-20220601075008-6f0e685b9c82 h1:9WgSzBLo3a9ToSVV7sRTBYZ1GGOZUpq4+5H3SN0UZq4=
github.com/aler9/writerseeker v0.0.0-20220601075008-6f0e685b9c82/go.mod h1:qsMrZCbeBf/mCLOeF16KDkPu4gktn/pOWyaq1aYQE7U= github.com/aler9/writerseeker v0.0.0-20220601075008-6f0e685b9c82/go.mod h1:qsMrZCbeBf/mCLOeF16KDkPu4gktn/pOWyaq1aYQE7U=
github.com/asticode/go-astikit v0.20.0 h1:+7N+J4E4lWx2QOkRdOf6DafWJMv6O4RRfgClwQokrH8= github.com/asticode/go-astikit v0.20.0 h1:+7N+J4E4lWx2QOkRdOf6DafWJMv6O4RRfgClwQokrH8=
@@ -83,6 +81,8 @@ github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OH
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/notedit/rtmp v0.0.2 h1:5+to4yezKATiJgnrcETu9LbV5G/QsWkOV9Ts2M/p33w=
github.com/notedit/rtmp v0.0.2/go.mod h1:vzuE21rowz+lT1NGsWbreIvYulgBpCGnQyeTyFblUHc=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=

View File

@@ -16,13 +16,14 @@ import (
"github.com/aler9/gortsplib/pkg/ringbuffer" "github.com/aler9/gortsplib/pkg/ringbuffer"
"github.com/aler9/gortsplib/pkg/rtpaac" "github.com/aler9/gortsplib/pkg/rtpaac"
"github.com/aler9/gortsplib/pkg/rtph264" "github.com/aler9/gortsplib/pkg/rtph264"
"github.com/notedit/rtmp/av" "github.com/notedit/rtmp/format/flv/flvio"
nh264 "github.com/notedit/rtmp/codec/h264"
"github.com/aler9/rtsp-simple-server/internal/conf" "github.com/aler9/rtsp-simple-server/internal/conf"
"github.com/aler9/rtsp-simple-server/internal/externalcmd" "github.com/aler9/rtsp-simple-server/internal/externalcmd"
"github.com/aler9/rtsp-simple-server/internal/logger" "github.com/aler9/rtsp-simple-server/internal/logger"
"github.com/aler9/rtsp-simple-server/internal/rtmp" "github.com/aler9/rtsp-simple-server/internal/rtmp"
"github.com/aler9/rtsp-simple-server/internal/rtmp/h264conf"
"github.com/aler9/rtsp-simple-server/internal/rtmp/message"
) )
const ( const (
@@ -107,7 +108,7 @@ func newRTMPConn(
runOnConnect: runOnConnect, runOnConnect: runOnConnect,
runOnConnectRestart: runOnConnectRestart, runOnConnectRestart: runOnConnectRestart,
wg: wg, wg: wg,
conn: rtmp.NewServerConn(nconn), conn: rtmp.NewConn(nconn),
nconn: nconn, nconn: nconn,
externalCmdPool: externalCmdPool, externalCmdPool: externalCmdPool,
pathManager: pathManager, pathManager: pathManager,
@@ -211,19 +212,19 @@ func (c *rtmpConn) runInner(ctx context.Context) error {
c.nconn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout))) c.nconn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout)))
c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout))) c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout)))
err := c.conn.ServerHandshake() u, isReading, err := c.conn.InitializeServer()
if err != nil { if err != nil {
return err return err
} }
if c.conn.IsPublishing() { if isReading {
return c.runPublish(ctx) return c.runRead(ctx, u)
} }
return c.runRead(ctx) return c.runPublish(ctx, u)
} }
func (c *rtmpConn) runRead(ctx context.Context) error { func (c *rtmpConn) runRead(ctx context.Context, u *url.URL) error {
pathName, query, rawQuery := pathNameAndQuery(c.conn.URL()) pathName, query, rawQuery := pathNameAndQuery(u)
res := c.pathManager.onReaderSetupPlay(pathReaderSetupPlayReq{ res := c.pathManager.onReaderSetupPlay(pathReaderSetupPlayReq{
author: c, author: c,
@@ -410,22 +411,17 @@ func (c *rtmpConn) runRead(ctx context.Context) error {
sps := videoTrack.SafeSPS() sps := videoTrack.SafeSPS()
pps := videoTrack.SafePPS() pps := videoTrack.SafePPS()
codec := nh264.Codec{ buf, _ := h264conf.Conf{
SPS: map[int][]byte{ SPS: sps,
0: sps, PPS: pps,
}, }.Marshal()
PPS: map[int][]byte{
0: pps,
},
}
b := make([]byte, 128)
var n int
codec.ToConfig(b, &n)
b = b[:n]
err = c.conn.WritePacket(av.Packet{ err = c.conn.WriteMessage(&message.MsgVideo{
Type: av.H264DecoderConfig, ChunkStreamID: 6,
Data: b, MessageStreamID: 1,
IsKeyFrame: true,
H264Type: flvio.AVC_SEQHDR,
Payload: buf,
}) })
if err != nil { if err != nil {
return err return err
@@ -438,11 +434,14 @@ func (c *rtmpConn) runRead(ctx context.Context) error {
} }
c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout))) c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout)))
err = c.conn.WritePacket(av.Packet{ err = c.conn.WriteMessage(&message.MsgVideo{
Type: av.H264, ChunkStreamID: 6,
Data: avcc, MessageStreamID: 1,
Time: dts, IsKeyFrame: idrPresent,
CTime: pts - dts, H264Type: flvio.AVC_NALU,
Payload: avcc,
DTS: dts,
PTSDelta: pts - dts,
}) })
if err != nil { if err != nil {
return err return err
@@ -467,10 +466,15 @@ func (c *rtmpConn) runRead(ctx context.Context) error {
for i, au := range aus { for i, au := range aus {
c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout))) c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout)))
err := c.conn.WritePacket(av.Packet{ err := c.conn.WriteMessage(&message.MsgAudio{
Type: av.AAC, ChunkStreamID: 4,
Data: au, MessageStreamID: 1,
Time: pts + time.Duration(i)*aac.SamplesPerAccessUnit*time.Second/time.Duration(audioTrack.ClockRate()), Rate: flvio.SOUND_44Khz,
Depth: flvio.SOUND_16BIT,
Channels: flvio.SOUND_STEREO,
AACType: flvio.AAC_RAW,
Payload: au,
DTS: pts + time.Duration(i)*aac.SamplesPerAccessUnit*time.Second/time.Duration(audioTrack.ClockRate()),
}) })
if err != nil { if err != nil {
return err return err
@@ -480,7 +484,7 @@ func (c *rtmpConn) runRead(ctx context.Context) error {
} }
} }
func (c *rtmpConn) runPublish(ctx context.Context) error { func (c *rtmpConn) runPublish(ctx context.Context, u *url.URL) error {
c.nconn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout))) c.nconn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout)))
videoTrack, audioTrack, err := c.conn.ReadTracks() videoTrack, audioTrack, err := c.conn.ReadTracks()
if err != nil { if err != nil {
@@ -513,7 +517,7 @@ func (c *rtmpConn) runPublish(ctx context.Context) error {
tracks = append(tracks, audioTrack) tracks = append(tracks, audioTrack)
} }
pathName, query, rawQuery := pathNameAndQuery(c.conn.URL()) pathName, query, rawQuery := pathNameAndQuery(u)
res := c.pathManager.onPublisherAnnounce(pathPublisherAnnounceReq{ res := c.pathManager.onPublisherAnnounce(pathPublisherAnnounceReq{
author: c, author: c,
@@ -559,121 +563,125 @@ func (c *rtmpConn) runPublish(ctx context.Context) error {
for { for {
c.nconn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout))) c.nconn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout)))
pkt, err := c.conn.ReadPacket() msg, err := c.conn.ReadMessage()
if err != nil { if err != nil {
return err return err
} }
switch pkt.Type { switch tmsg := msg.(type) {
case av.H264DecoderConfig: case *message.MsgVideo:
codec, err := nh264.FromDecoderConfig(pkt.Data) if tmsg.H264Type == flvio.AVC_SEQHDR {
if err != nil { var conf h264conf.Conf
return err err = conf.Unmarshal(tmsg.Payload)
} if err != nil {
return fmt.Errorf("unable to parse H264 config: %v", err)
}
pts := pkt.Time + pkt.CTime pts := tmsg.DTS + tmsg.PTSDelta
nalus := [][]byte{ nalus := [][]byte{
codec.SPS[0], conf.SPS,
codec.PPS[0], conf.PPS,
} }
pkts, err := h264Encoder.Encode(nalus, pts) pkts, err := h264Encoder.Encode(nalus, pts)
if err != nil { if err != nil {
return fmt.Errorf("error while encoding H264: %v", err) return fmt.Errorf("error while encoding H264: %v", err)
} }
lastPkt := len(pkts) - 1 lastPkt := len(pkts) - 1
for i, pkt := range pkts { for i, pkt := range pkts {
if i != lastPkt { if i != lastPkt {
rres.stream.writeData(&data{ rres.stream.writeData(&data{
trackID: videoTrackID, trackID: videoTrackID,
rtp: pkt, rtp: pkt,
ptsEqualsDTS: false, ptsEqualsDTS: false,
}) })
} else { } else {
rres.stream.writeData(&data{ rres.stream.writeData(&data{
trackID: videoTrackID, trackID: videoTrackID,
rtp: pkt, rtp: pkt,
ptsEqualsDTS: false, ptsEqualsDTS: false,
h264NALUs: nalus, h264NALUs: nalus,
h264PTS: pts, h264PTS: pts,
}) })
}
}
} else if tmsg.H264Type == flvio.AVC_NALU {
if videoTrack == nil {
return fmt.Errorf("received an H264 packet, but track is not set up")
}
nalus, err := h264.AVCCUnmarshal(tmsg.Payload)
if err != nil {
return fmt.Errorf("unable to decode AVCC: %v", err)
}
// skip invalid NALUs sent by DJI
n := 0
for _, nalu := range nalus {
if len(nalu) != 0 {
n++
}
}
if n == 0 {
continue
}
validNALUs := make([][]byte, n)
pos := 0
for _, nalu := range nalus {
if len(nalu) != 0 {
validNALUs[pos] = nalu
pos++
}
}
pts := tmsg.DTS + tmsg.PTSDelta
pkts, err := h264Encoder.Encode(validNALUs, pts)
if err != nil {
return fmt.Errorf("error while encoding H264: %v", err)
}
lastPkt := len(pkts) - 1
for i, pkt := range pkts {
if i != lastPkt {
rres.stream.writeData(&data{
trackID: videoTrackID,
rtp: pkt,
ptsEqualsDTS: false,
})
} else {
rres.stream.writeData(&data{
trackID: videoTrackID,
rtp: pkt,
ptsEqualsDTS: h264.IDRPresent(validNALUs),
h264NALUs: validNALUs,
h264PTS: pts,
})
}
} }
} }
case av.H264: case *message.MsgAudio:
if videoTrack == nil { if tmsg.AACType == flvio.AAC_RAW {
return fmt.Errorf("received an H264 packet, but track is not set up") if audioTrack == nil {
} return fmt.Errorf("received an AAC packet, but track is not set up")
nalus, err := h264.AVCCUnmarshal(pkt.Data)
if err != nil {
return err
}
// skip invalid NALUs sent by DJI
n := 0
for _, nalu := range nalus {
if len(nalu) != 0 {
n++
} }
}
if n == 0 {
continue
}
validNALUs := make([][]byte, n) pkts, err := aacEncoder.Encode([][]byte{tmsg.Payload}, tmsg.DTS)
pos := 0 if err != nil {
for _, nalu := range nalus { return fmt.Errorf("error while encoding AAC: %v", err)
if len(nalu) != 0 {
validNALUs[pos] = nalu
pos++
} }
}
pts := pkt.Time + pkt.CTime for _, pkt := range pkts {
pkts, err := h264Encoder.Encode(validNALUs, pts)
if err != nil {
return fmt.Errorf("error while encoding H264: %v", err)
}
lastPkt := len(pkts) - 1
for i, pkt := range pkts {
if i != lastPkt {
rres.stream.writeData(&data{ rres.stream.writeData(&data{
trackID: videoTrackID, trackID: audioTrackID,
rtp: pkt, rtp: pkt,
ptsEqualsDTS: false, ptsEqualsDTS: true,
})
} else {
rres.stream.writeData(&data{
trackID: videoTrackID,
rtp: pkt,
ptsEqualsDTS: h264.IDRPresent(validNALUs),
h264NALUs: validNALUs,
h264PTS: pts,
}) })
} }
} }
case av.AAC:
if audioTrack == nil {
return fmt.Errorf("received an AAC packet, but track is not set up")
}
pkts, err := aacEncoder.Encode([][]byte{pkt.Data}, pkt.Time+pkt.CTime)
if err != nil {
return fmt.Errorf("error while encoding AAC: %v", err)
}
for _, pkt := range pkts {
rres.stream.writeData(&data{
trackID: audioTrackID,
rtp: pkt,
ptsEqualsDTS: true,
})
}
} }
} }
} }

View File

@@ -3,7 +3,6 @@ package core
import ( import (
"context" "context"
"crypto/rand" "crypto/rand"
"encoding/binary"
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
@@ -259,7 +258,7 @@ func (s *rtmpServer) newConnID() (string, error) {
return "", err return "", err
} }
u := binary.LittleEndian.Uint32(b) u := uint32(b[3])<<24 | uint32(b[2])<<16 | uint32(b[1])<<8 | uint32(b[0])
u %= 899999999 u %= 899999999
u += 100000000 u += 100000000

View File

@@ -141,9 +141,9 @@ func TestRTMPServerAuth(t *testing.T) {
nconn, err := net.Dial("tcp", u.Host) nconn, err := net.Dial("tcp", u.Host)
require.NoError(t, err) require.NoError(t, err)
defer nconn.Close() defer nconn.Close()
conn := rtmp.NewClientConn(nconn, u) conn := rtmp.NewConn(nconn)
err = conn.ClientHandshake(true) err = conn.InitializeClient(u, true)
require.NoError(t, err) require.NoError(t, err)
_, _, err = conn.ReadTracks() _, _, err = conn.ReadTracks()
@@ -229,9 +229,17 @@ func TestRTMPServerAuthFail(t *testing.T) {
nconn, err := net.Dial("tcp", u.Host) nconn, err := net.Dial("tcp", u.Host)
require.NoError(t, err) require.NoError(t, err)
defer nconn.Close() defer nconn.Close()
conn := rtmp.NewClientConn(nconn, u) conn := rtmp.NewConn(nconn)
err = conn.ClientHandshake(true) err = conn.InitializeClient(u, true)
require.NoError(t, err)
for i := 0; i < 3; i++ {
_, err := conn.ReadMessage()
require.NoError(t, err)
}
_, err = conn.ReadMessage()
require.Equal(t, err, io.EOF) require.Equal(t, err, io.EOF)
}) })
} }

View File

@@ -12,11 +12,12 @@ import (
"github.com/aler9/gortsplib/pkg/h264" "github.com/aler9/gortsplib/pkg/h264"
"github.com/aler9/gortsplib/pkg/rtpaac" "github.com/aler9/gortsplib/pkg/rtpaac"
"github.com/aler9/gortsplib/pkg/rtph264" "github.com/aler9/gortsplib/pkg/rtph264"
"github.com/notedit/rtmp/av" "github.com/notedit/rtmp/format/flv/flvio"
"github.com/aler9/rtsp-simple-server/internal/conf" "github.com/aler9/rtsp-simple-server/internal/conf"
"github.com/aler9/rtsp-simple-server/internal/logger" "github.com/aler9/rtsp-simple-server/internal/logger"
"github.com/aler9/rtsp-simple-server/internal/rtmp" "github.com/aler9/rtsp-simple-server/internal/rtmp"
"github.com/aler9/rtsp-simple-server/internal/rtmp/message"
) )
const ( const (
@@ -126,14 +127,14 @@ func (s *rtmpSource) runInner() bool {
return err return err
} }
conn := rtmp.NewClientConn(nconn, u) conn := rtmp.NewConn(nconn)
readDone := make(chan error) readDone := make(chan error)
go func() { go func() {
readDone <- func() error { readDone <- func() error {
nconn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout))) nconn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout)))
nconn.SetWriteDeadline(time.Now().Add(time.Duration(s.writeTimeout))) nconn.SetWriteDeadline(time.Now().Add(time.Duration(s.writeTimeout)))
err = conn.ClientHandshake(true) err = conn.InitializeClient(u, true)
if err != nil { if err != nil {
return err return err
} }
@@ -187,64 +188,68 @@ func (s *rtmpSource) runInner() bool {
for { for {
nconn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout))) nconn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout)))
pkt, err := conn.ReadPacket() msg, err := conn.ReadMessage()
if err != nil { if err != nil {
return err return err
} }
switch pkt.Type { switch tmsg := msg.(type) {
case av.H264: case *message.MsgVideo:
if videoTrack == nil { if tmsg.H264Type == flvio.AVC_NALU {
return fmt.Errorf("received an H264 packet, but track is not set up") if videoTrack == nil {
} return fmt.Errorf("received an H264 packet, but track is not set up")
}
nalus, err := h264.AVCCUnmarshal(pkt.Data) nalus, err := h264.AVCCUnmarshal(tmsg.Payload)
if err != nil { if err != nil {
return err return fmt.Errorf("unable to decode AVCC: %v", err)
} }
pts := pkt.Time + pkt.CTime pts := tmsg.DTS + tmsg.PTSDelta
pkts, err := h264Encoder.Encode(nalus, pts) pkts, err := h264Encoder.Encode(nalus, pts)
if err != nil { if err != nil {
return fmt.Errorf("error while encoding H264: %v", err) return fmt.Errorf("error while encoding H264: %v", err)
} }
lastPkt := len(pkts) - 1 lastPkt := len(pkts) - 1
for i, pkt := range pkts { for i, pkt := range pkts {
if i != lastPkt { if i != lastPkt {
res.stream.writeData(&data{ res.stream.writeData(&data{
trackID: videoTrackID, trackID: videoTrackID,
rtp: pkt, rtp: pkt,
ptsEqualsDTS: false, ptsEqualsDTS: false,
}) })
} else { } else {
res.stream.writeData(&data{ res.stream.writeData(&data{
trackID: videoTrackID, trackID: videoTrackID,
rtp: pkt, rtp: pkt,
ptsEqualsDTS: h264.IDRPresent(nalus), ptsEqualsDTS: h264.IDRPresent(nalus),
h264NALUs: nalus, h264NALUs: nalus,
h264PTS: pts, h264PTS: pts,
}) })
}
} }
} }
case av.AAC: case *message.MsgAudio:
if audioTrack == nil { if tmsg.AACType == flvio.AAC_RAW {
return fmt.Errorf("received an AAC packet, but track is not set up") if audioTrack == nil {
} return fmt.Errorf("received an AAC packet, but track is not set up")
}
pkts, err := aacEncoder.Encode([][]byte{pkt.Data}, pkt.Time+pkt.CTime) pkts, err := aacEncoder.Encode([][]byte{tmsg.Payload}, tmsg.DTS)
if err != nil { if err != nil {
return fmt.Errorf("error while encoding AAC: %v", err) return fmt.Errorf("error while encoding AAC: %v", err)
} }
for _, pkt := range pkts { for _, pkt := range pkts {
res.stream.writeData(&data{ res.stream.writeData(&data{
trackID: audioTrackID, trackID: audioTrackID,
rtp: pkt, rtp: pkt,
ptsEqualsDTS: true, ptsEqualsDTS: true,
}) })
}
} }
} }
} }

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/tls" "crypto/tls"
"encoding/binary"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
@@ -235,7 +234,7 @@ func (s *rtspServer) newSessionID() (string, error) {
return "", err return "", err
} }
u := binary.LittleEndian.Uint32(b) u := uint32(b[3])<<24 | uint32(b[2])<<16 | uint32(b[1])<<8 | uint32(b[0])
u %= 899999999 u %= 899999999
u += 100000000 u += 100000000

View File

@@ -20,7 +20,7 @@ type Chunk0 struct {
// Read reads the chunk. // Read reads the chunk.
func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen uint32) error { func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen uint32) error {
header := make([]byte, 12) header := make([]byte, 12)
_, err := r.Read(header) _, err := io.ReadFull(r, header)
if err != nil { if err != nil {
return err return err
} }
@@ -37,7 +37,7 @@ func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen uint32) error {
} }
c.Body = make([]byte, chunkBodyLen) c.Body = make([]byte, chunkBodyLen)
_, err = r.Read(c.Body) _, err = io.ReadFull(r, c.Body)
return err return err
} }

View File

@@ -21,7 +21,7 @@ type Chunk1 struct {
// Read reads the chunk. // Read reads the chunk.
func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen uint32) error { func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen uint32) error {
header := make([]byte, 8) header := make([]byte, 8)
_, err := r.Read(header) _, err := io.ReadFull(r, header)
if err != nil { if err != nil {
return err return err
} }
@@ -37,7 +37,7 @@ func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen uint32) error {
} }
c.Body = make([]byte, chunkBodyLen) c.Body = make([]byte, chunkBodyLen)
_, err = r.Read(c.Body) _, err = io.ReadFull(r, c.Body)
return err return err
} }

View File

@@ -17,7 +17,7 @@ type Chunk2 struct {
// Read reads the chunk. // Read reads the chunk.
func (c *Chunk2) Read(r io.Reader, chunkBodyLen uint32) error { func (c *Chunk2) Read(r io.Reader, chunkBodyLen uint32) error {
header := make([]byte, 4) header := make([]byte, 4)
_, err := r.Read(header) _, err := io.ReadFull(r, header)
if err != nil { if err != nil {
return err return err
} }
@@ -26,7 +26,7 @@ func (c *Chunk2) Read(r io.Reader, chunkBodyLen uint32) error {
c.TimestampDelta = uint32(header[1])<<16 | uint32(header[2])<<8 | uint32(header[3]) c.TimestampDelta = uint32(header[1])<<16 | uint32(header[2])<<8 | uint32(header[3])
c.Body = make([]byte, chunkBodyLen) c.Body = make([]byte, chunkBodyLen)
_, err = r.Read(c.Body) _, err = io.ReadFull(r, c.Body)
return err return err
} }

View File

@@ -18,7 +18,7 @@ type Chunk3 struct {
// Read reads the chunk. // Read reads the chunk.
func (c *Chunk3) Read(r io.Reader, chunkBodyLen uint32) error { func (c *Chunk3) Read(r io.Reader, chunkBodyLen uint32) error {
header := make([]byte, 1) header := make([]byte, 1)
_, err := r.Read(header) _, err := io.ReadFull(r, header)
if err != nil { if err != nil {
return err return err
} }
@@ -26,7 +26,7 @@ func (c *Chunk3) Read(r io.Reader, chunkBodyLen uint32) error {
c.ChunkStreamID = header[0] & 0x3F c.ChunkStreamID = header[0] & 0x3F
c.Body = make([]byte, chunkBodyLen) c.Body = make([]byte, chunkBodyLen)
_, err = r.Read(c.Body) _, err = io.ReadFull(r, c.Body)
return err return err
} }

File diff suppressed because it is too large Load Diff

View File

@@ -3,52 +3,20 @@ package rtmp
import ( import (
"net" "net"
"net/url" "net/url"
"strings"
"testing" "testing"
"github.com/aler9/gortsplib" "github.com/aler9/gortsplib"
"github.com/aler9/gortsplib/pkg/aac" "github.com/aler9/gortsplib/pkg/aac"
nh264 "github.com/notedit/rtmp/codec/h264"
"github.com/notedit/rtmp/format/flv/flvio" "github.com/notedit/rtmp/format/flv/flvio"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter" "github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/h264conf"
"github.com/aler9/rtsp-simple-server/internal/rtmp/handshake" "github.com/aler9/rtsp-simple-server/internal/rtmp/handshake"
"github.com/aler9/rtsp-simple-server/internal/rtmp/message" "github.com/aler9/rtsp-simple-server/internal/rtmp/message"
) )
func splitPath(u *url.URL) (app, stream string) { func TestInitializeClient(t *testing.T) {
nu := *u
nu.ForceQuery = false
pathsegs := strings.Split(nu.RequestURI(), "/")
if len(pathsegs) == 2 {
app = pathsegs[1]
}
if len(pathsegs) == 3 {
app = pathsegs[1]
stream = pathsegs[2]
}
if len(pathsegs) > 3 {
app = strings.Join(pathsegs[1:3], "/")
stream = strings.Join(pathsegs[3:], "/")
}
return
}
func getTcURL(u string) string {
ur, err := url.Parse(u)
if err != nil {
panic(err)
}
app, _ := splitPath(ur)
nu := *ur
nu.RawQuery = ""
nu.Path = "/"
return nu.String() + app
}
func TestClientHandshake(t *testing.T) {
for _, ca := range []string{"read", "publish"} { for _, ca := range []string{"read", "publish"} {
t.Run(ca, func(t *testing.T) { t.Run(ca, func(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:9121") ln, err := net.Listen("tcp", "127.0.0.1:9121")
@@ -63,10 +31,10 @@ func TestClientHandshake(t *testing.T) {
defer conn.Close() defer conn.Close()
bc := bytecounter.NewReadWriter(conn) bc := bytecounter.NewReadWriter(conn)
err = handshake.DoServer(bc) err = handshake.DoServer(bc, true)
require.NoError(t, err) require.NoError(t, err)
mrw := message.NewReadWriter(bc) mrw := message.NewReadWriter(bc, true)
// C->S set window ack size // C->S set window ack size
msg, err := mrw.Read() msg, err := mrw.Read()
@@ -79,7 +47,7 @@ func TestClientHandshake(t *testing.T) {
msg, err = mrw.Read() msg, err = mrw.Read()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &message.MsgSetPeerBandwidth{ require.Equal(t, &message.MsgSetPeerBandwidth{
Value: 0x2625a0, Value: 2500000,
Type: 2, Type: 2,
}, msg) }, msg)
@@ -95,13 +63,13 @@ func TestClientHandshake(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{ require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "connect",
"connect", CommandID: 1,
float64(1), Arguments: []interface{}{
flvio.AMFMap{ flvio.AMFMap{
{K: "app", V: "stream"}, {K: "app", V: "stream"},
{K: "flashVer", V: "LNX 9,0,124,2"}, {K: "flashVer", V: "LNX 9,0,124,2"},
{K: "tcUrl", V: getTcURL("rtmp://127.0.0.1:9121/stream")}, {K: "tcUrl", V: "rtmp://127.0.0.1:9121/stream"},
{K: "fpad", V: false}, {K: "fpad", V: false},
{K: "capabilities", V: float64(15)}, {K: "capabilities", V: float64(15)},
{K: "audioCodecs", V: float64(4071)}, {K: "audioCodecs", V: float64(4071)},
@@ -114,9 +82,9 @@ func TestClientHandshake(t *testing.T) {
// S->C result // S->C result
err = mrw.Write(&message.MsgCommandAMF0{ err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "_result",
"_result", CommandID: 1,
float64(1), Arguments: []interface{}{
flvio.AMFMap{ flvio.AMFMap{
{K: "fmsVer", V: "LNX 9,0,124,2"}, {K: "fmsVer", V: "LNX 9,0,124,2"},
{K: "capabilities", V: float64(31)}, {K: "capabilities", V: float64(31)},
@@ -137,9 +105,9 @@ func TestClientHandshake(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{ require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "createStream",
"createStream", CommandID: 2,
float64(2), Arguments: []interface{}{
nil, nil,
}, },
}, msg) }, msg)
@@ -147,9 +115,9 @@ func TestClientHandshake(t *testing.T) {
// S->C result // S->C result
err = mrw.Write(&message.MsgCommandAMF0{ err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "_result",
"_result", CommandID: 2,
float64(2), Arguments: []interface{}{
nil, nil,
float64(1), float64(1),
}, },
@@ -168,10 +136,10 @@ func TestClientHandshake(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{ require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 4, ChunkStreamID: 4,
MessageStreamID: 16777216, MessageStreamID: 0x1000000,
Payload: []interface{}{ Name: "play",
"play", CommandID: 0,
float64(0), Arguments: []interface{}{
nil, nil,
"", "",
}, },
@@ -180,10 +148,10 @@ func TestClientHandshake(t *testing.T) {
// S->C onStatus // S->C onStatus
err = mrw.Write(&message.MsgCommandAMF0{ err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 5, ChunkStreamID: 5,
MessageStreamID: 16777216, MessageStreamID: 0x1000000,
Payload: []interface{}{ Name: "onStatus",
"onStatus", CommandID: 4,
float64(4), Arguments: []interface{}{
nil, nil,
flvio.AMFMap{ flvio.AMFMap{
{K: "level", V: "status"}, {K: "level", V: "status"},
@@ -199,9 +167,9 @@ func TestClientHandshake(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{ require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "releaseStream",
"releaseStream", CommandID: 2,
float64(2), Arguments: []interface{}{
nil, nil,
"", "",
}, },
@@ -212,9 +180,9 @@ func TestClientHandshake(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{ require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "FCPublish",
"FCPublish", CommandID: 3,
float64(3), Arguments: []interface{}{
nil, nil,
"", "",
}, },
@@ -225,9 +193,9 @@ func TestClientHandshake(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{ require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "createStream",
"createStream", CommandID: 4,
float64(4), Arguments: []interface{}{
nil, nil,
}, },
}, msg) }, msg)
@@ -235,9 +203,9 @@ func TestClientHandshake(t *testing.T) {
// S->C result // S->C result
err = mrw.Write(&message.MsgCommandAMF0{ err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "_result",
"_result", CommandID: 4,
float64(4), Arguments: []interface{}{
nil, nil,
float64(1), float64(1),
}, },
@@ -249,10 +217,10 @@ func TestClientHandshake(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{ require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 4, ChunkStreamID: 4,
MessageStreamID: 16777216, MessageStreamID: 0x1000000,
Payload: []interface{}{ Name: "publish",
"publish", CommandID: 5,
float64(5), Arguments: []interface{}{
nil, nil,
"", "",
"stream", "stream",
@@ -262,10 +230,10 @@ func TestClientHandshake(t *testing.T) {
// S->C onStatus // S->C onStatus
err = mrw.Write(&message.MsgCommandAMF0{ err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 5, ChunkStreamID: 5,
MessageStreamID: 16777216, MessageStreamID: 0x1000000,
Payload: []interface{}{ Name: "onStatus",
"onStatus", CommandID: 5,
float64(5), Arguments: []interface{}{
nil, nil,
flvio.AMFMap{ flvio.AMFMap{
{K: "level", V: "status"}, {K: "level", V: "status"},
@@ -286,9 +254,9 @@ func TestClientHandshake(t *testing.T) {
nconn, err := net.Dial("tcp", u.Host) nconn, err := net.Dial("tcp", u.Host)
require.NoError(t, err) require.NoError(t, err)
defer nconn.Close() defer nconn.Close()
conn := NewClientConn(nconn, u) conn := NewConn(nconn)
err = conn.ClientHandshake(ca == "read") err = conn.InitializeClient(u, ca == "read")
require.NoError(t, err) require.NoError(t, err)
<-done <-done
@@ -296,7 +264,7 @@ func TestClientHandshake(t *testing.T) {
} }
} }
func TestServerHandshake(t *testing.T) { func TestInitializeServer(t *testing.T) {
for _, ca := range []string{"read", "publish"} { for _, ca := range []string{"read", "publish"} {
t.Run(ca, func(t *testing.T) { t.Run(ca, func(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:9121") ln, err := net.Listen("tcp", "127.0.0.1:9121")
@@ -310,9 +278,15 @@ func TestServerHandshake(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer nconn.Close() defer nconn.Close()
conn := NewServerConn(nconn) conn := NewConn(nconn)
err = conn.ServerHandshake() u, isReading, err := conn.InitializeServer()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &url.URL{
Scheme: "rtmp",
Host: "127.0.0.1:9121",
Path: "//stream/",
}, u)
require.Equal(t, ca == "read", isReading)
close(done) close(done)
}() }()
@@ -322,21 +296,21 @@ func TestServerHandshake(t *testing.T) {
defer conn.Close() defer conn.Close()
bc := bytecounter.NewReadWriter(conn) bc := bytecounter.NewReadWriter(conn)
err = handshake.DoClient(bc) err = handshake.DoClient(bc, true)
require.NoError(t, err) require.NoError(t, err)
mrw := message.NewReadWriter(bc) mrw := message.NewReadWriter(bc, true)
// C->S connect // C->S connect
err = mrw.Write(&message.MsgCommandAMF0{ err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "connect",
"connect", CommandID: 1,
1, Arguments: []interface{}{
flvio.AMFMap{ flvio.AMFMap{
{K: "app", V: "/stream"}, {K: "app", V: "/stream"},
{K: "flashVer", V: "LNX 9,0,124,2"}, {K: "flashVer", V: "LNX 9,0,124,2"},
{K: "tcUrl", V: getTcURL("rtmp://127.0.0.1:9121/stream")}, {K: "tcUrl", V: "rtmp://127.0.0.1:9121/stream"},
{K: "fpad", V: false}, {K: "fpad", V: false},
{K: "capabilities", V: 15}, {K: "capabilities", V: 15},
{K: "audioCodecs", V: 4071}, {K: "audioCodecs", V: 4071},
@@ -374,9 +348,9 @@ func TestServerHandshake(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{ require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "_result",
"_result", CommandID: 1,
float64(1), Arguments: []interface{}{
flvio.AMFMap{ flvio.AMFMap{
{K: "fmsVer", V: "LNX 9,0,124,2"}, {K: "fmsVer", V: "LNX 9,0,124,2"},
{K: "capabilities", V: float64(31)}, {K: "capabilities", V: float64(31)},
@@ -400,9 +374,9 @@ func TestServerHandshake(t *testing.T) {
// C->S createStream // C->S createStream
err = mrw.Write(&message.MsgCommandAMF0{ err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "createStream",
"createStream", CommandID: 2,
float64(2), Arguments: []interface{}{
nil, nil,
}, },
}) })
@@ -413,9 +387,9 @@ func TestServerHandshake(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{ require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "_result",
"_result", CommandID: 2,
float64(2), Arguments: []interface{}{
nil, nil,
float64(1), float64(1),
}, },
@@ -430,10 +404,10 @@ func TestServerHandshake(t *testing.T) {
// C->S play // C->S play
err = mrw.Write(&message.MsgCommandAMF0{ err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 4, ChunkStreamID: 4,
MessageStreamID: 16777216, MessageStreamID: 0x1000000,
Payload: []interface{}{ Name: "play",
"play", CommandID: 0,
float64(0), Arguments: []interface{}{
nil, nil,
"", "",
}, },
@@ -443,9 +417,9 @@ func TestServerHandshake(t *testing.T) {
// C->S releaseStream // C->S releaseStream
err = mrw.Write(&message.MsgCommandAMF0{ err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "releaseStream",
"releaseStream", CommandID: 2,
float64(2), Arguments: []interface{}{
nil, nil,
"", "",
}, },
@@ -455,9 +429,9 @@ func TestServerHandshake(t *testing.T) {
// C->S FCPublish // C->S FCPublish
err = mrw.Write(&message.MsgCommandAMF0{ err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "FCPublish",
"FCPublish", CommandID: 3,
float64(3), Arguments: []interface{}{
nil, nil,
"", "",
}, },
@@ -467,9 +441,9 @@ func TestServerHandshake(t *testing.T) {
// C->S createStream // C->S createStream
err = mrw.Write(&message.MsgCommandAMF0{ err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "createStream",
"createStream", CommandID: 4,
float64(4), Arguments: []interface{}{
nil, nil,
}, },
}) })
@@ -480,9 +454,9 @@ func TestServerHandshake(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{ require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "_result",
"_result", CommandID: 4,
float64(4), Arguments: []interface{}{
nil, nil,
float64(1), float64(1),
}, },
@@ -491,10 +465,10 @@ func TestServerHandshake(t *testing.T) {
// C->S publish // C->S publish
err = mrw.Write(&message.MsgCommandAMF0{ err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 4, ChunkStreamID: 4,
MessageStreamID: 16777216, MessageStreamID: 0x1000000,
Payload: []interface{}{ Name: "publish",
"publish", CommandID: 5,
float64(5), Arguments: []interface{}{
nil, nil,
"", "",
"stream", "stream",
@@ -536,8 +510,8 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
rconn := NewServerConn(conn) rconn := NewConn(conn)
err = rconn.ServerHandshake() _, _, err = rconn.InitializeServer()
require.NoError(t, err) require.NoError(t, err)
videoTrack, audioTrack, err := rconn.ReadTracks() videoTrack, audioTrack, err := rconn.ReadTracks()
@@ -610,21 +584,21 @@ func TestReadTracks(t *testing.T) {
defer conn.Close() defer conn.Close()
bc := bytecounter.NewReadWriter(conn) bc := bytecounter.NewReadWriter(conn)
err = handshake.DoClient(bc) err = handshake.DoClient(bc, true)
require.NoError(t, err) require.NoError(t, err)
mrw := message.NewReadWriter(bc) mrw := message.NewReadWriter(bc, true)
// C->S connect // C->S connect
err = mrw.Write(&message.MsgCommandAMF0{ err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "connect",
"connect", CommandID: 1,
1, Arguments: []interface{}{
flvio.AMFMap{ flvio.AMFMap{
{K: "app", V: "/stream"}, {K: "app", V: "/stream"},
{K: "flashVer", V: "LNX 9,0,124,2"}, {K: "flashVer", V: "LNX 9,0,124,2"},
{K: "tcUrl", V: getTcURL("rtmp://127.0.0.1:9121/stream")}, {K: "tcUrl", V: "rtmp://127.0.0.1:9121/stream"},
{K: "fpad", V: false}, {K: "fpad", V: false},
{K: "capabilities", V: 15}, {K: "capabilities", V: 15},
{K: "audioCodecs", V: 4071}, {K: "audioCodecs", V: 4071},
@@ -662,9 +636,9 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{ require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "_result",
"_result", CommandID: 1,
float64(1), Arguments: []interface{}{
flvio.AMFMap{ flvio.AMFMap{
{K: "fmsVer", V: "LNX 9,0,124,2"}, {K: "fmsVer", V: "LNX 9,0,124,2"},
{K: "capabilities", V: float64(31)}, {K: "capabilities", V: float64(31)},
@@ -687,9 +661,9 @@ func TestReadTracks(t *testing.T) {
// C->S releaseStream // C->S releaseStream
err = mrw.Write(&message.MsgCommandAMF0{ err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "releaseStream",
"releaseStream", CommandID: 2,
float64(2), Arguments: []interface{}{
nil, nil,
"", "",
}, },
@@ -699,9 +673,9 @@ func TestReadTracks(t *testing.T) {
// C->S FCPublish // C->S FCPublish
err = mrw.Write(&message.MsgCommandAMF0{ err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "FCPublish",
"FCPublish", CommandID: 3,
float64(3), Arguments: []interface{}{
nil, nil,
"", "",
}, },
@@ -711,9 +685,9 @@ func TestReadTracks(t *testing.T) {
// C->S createStream // C->S createStream
err = mrw.Write(&message.MsgCommandAMF0{ err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "createStream",
"createStream", CommandID: 4,
float64(4), Arguments: []interface{}{
nil, nil,
}, },
}) })
@@ -724,9 +698,9 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{ require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "_result",
"_result", CommandID: 4,
float64(4), Arguments: []interface{}{
nil, nil,
float64(1), float64(1),
}, },
@@ -736,9 +710,9 @@ func TestReadTracks(t *testing.T) {
err = mrw.Write(&message.MsgCommandAMF0{ err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 8, ChunkStreamID: 8,
MessageStreamID: 1, MessageStreamID: 1,
Payload: []interface{}{ Name: "publish",
"publish", CommandID: 5,
float64(5), Arguments: []interface{}{
nil, nil,
"", "",
"live", "live",
@@ -751,10 +725,10 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{ require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5, ChunkStreamID: 5,
MessageStreamID: 16777216, MessageStreamID: 0x1000000,
Payload: []interface{}{ Name: "onStatus",
"onStatus", CommandID: 5,
float64(5), Arguments: []interface{}{
nil, nil,
flvio.AMFMap{ flvio.AMFMap{
{K: "level", V: "status"}, {K: "level", V: "status"},
@@ -796,23 +770,16 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// C->S H264 decoder config // C->S H264 decoder config
codec := nh264.Codec{ buf, _ := h264conf.Conf{
SPS: map[int][]byte{ SPS: sps,
0: sps, PPS: pps,
}, }.Marshal()
PPS: map[int][]byte{
0: pps,
},
}
b := make([]byte, 128)
var n int
codec.ToConfig(b, &n)
err = mrw.Write(&message.MsgVideo{ err = mrw.Write(&message.MsgVideo{
ChunkStreamID: 6, ChunkStreamID: 6,
MessageStreamID: 1, MessageStreamID: 1,
IsKeyFrame: true, IsKeyFrame: true,
H264Type: flvio.AVC_SEQHDR, H264Type: flvio.AVC_SEQHDR,
Payload: b[:n], Payload: buf,
}) })
require.NoError(t, err) require.NoError(t, err)
@@ -861,23 +828,16 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// C->S H264 decoder config // C->S H264 decoder config
codec := nh264.Codec{ buf, _ := h264conf.Conf{
SPS: map[int][]byte{ SPS: sps,
0: sps, PPS: pps,
}, }.Marshal()
PPS: map[int][]byte{
0: pps,
},
}
b := make([]byte, 128)
var n int
codec.ToConfig(b, &n)
err = mrw.Write(&message.MsgVideo{ err = mrw.Write(&message.MsgVideo{
ChunkStreamID: 6, ChunkStreamID: 6,
MessageStreamID: 1, MessageStreamID: 1,
IsKeyFrame: true, IsKeyFrame: true,
H264Type: flvio.AVC_SEQHDR, H264Type: flvio.AVC_SEQHDR,
Payload: b[:n], Payload: buf,
}) })
require.NoError(t, err) require.NoError(t, err)
@@ -901,23 +861,16 @@ func TestReadTracks(t *testing.T) {
case "missing metadata": case "missing metadata":
// C->S H264 decoder config // C->S H264 decoder config
codec := nh264.Codec{ buf, _ := h264conf.Conf{
SPS: map[int][]byte{ SPS: sps,
0: sps, PPS: pps,
}, }.Marshal()
PPS: map[int][]byte{
0: pps,
},
}
b := make([]byte, 128)
var n int
codec.ToConfig(b, &n)
err = mrw.Write(&message.MsgVideo{ err = mrw.Write(&message.MsgVideo{
ChunkStreamID: 6, ChunkStreamID: 6,
MessageStreamID: 1, MessageStreamID: 1,
IsKeyFrame: true, IsKeyFrame: true,
H264Type: flvio.AVC_SEQHDR, H264Type: flvio.AVC_SEQHDR,
Payload: b[:n], Payload: buf,
}) })
require.NoError(t, err) require.NoError(t, err)
@@ -955,8 +908,8 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
rconn := NewServerConn(conn) rconn := NewConn(conn)
err = rconn.ServerHandshake() _, _, err = rconn.InitializeServer()
require.NoError(t, err) require.NoError(t, err)
videoTrack := &gortsplib.TrackH264{ videoTrack := &gortsplib.TrackH264{
@@ -992,21 +945,21 @@ func TestWriteTracks(t *testing.T) {
defer conn.Close() defer conn.Close()
bc := bytecounter.NewReadWriter(conn) bc := bytecounter.NewReadWriter(conn)
err = handshake.DoClient(bc) err = handshake.DoClient(bc, true)
require.NoError(t, err) require.NoError(t, err)
mrw := message.NewReadWriter(bc) mrw := message.NewReadWriter(bc, true)
// C->S connect // C->S connect
err = mrw.Write(&message.MsgCommandAMF0{ err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "connect",
"connect", CommandID: 1,
1, Arguments: []interface{}{
flvio.AMFMap{ flvio.AMFMap{
{K: "app", V: "/stream"}, {K: "app", V: "/stream"},
{K: "flashVer", V: "LNX 9,0,124,2"}, {K: "flashVer", V: "LNX 9,0,124,2"},
{K: "tcUrl", V: getTcURL("rtmp://127.0.0.1:9121/stream")}, {K: "tcUrl", V: "rtmp://127.0.0.1:9121/stream"},
{K: "fpad", V: false}, {K: "fpad", V: false},
{K: "capabilities", V: 15}, {K: "capabilities", V: 15},
{K: "audioCodecs", V: 4071}, {K: "audioCodecs", V: 4071},
@@ -1044,9 +997,9 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{ require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "_result",
"_result", CommandID: 1,
float64(1), Arguments: []interface{}{
flvio.AMFMap{ flvio.AMFMap{
{K: "fmsVer", V: "LNX 9,0,124,2"}, {K: "fmsVer", V: "LNX 9,0,124,2"},
{K: "capabilities", V: float64(31)}, {K: "capabilities", V: float64(31)},
@@ -1075,9 +1028,9 @@ func TestWriteTracks(t *testing.T) {
// C->S createStream // C->S createStream
err = mrw.Write(&message.MsgCommandAMF0{ err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "createStream",
"createStream", CommandID: 2,
float64(2), Arguments: []interface{}{
nil, nil,
}, },
}) })
@@ -1088,9 +1041,9 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{ require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3, ChunkStreamID: 3,
Payload: []interface{}{ Name: "_result",
"_result", CommandID: 2,
float64(2), Arguments: []interface{}{
nil, nil,
float64(1), float64(1),
}, },
@@ -1099,9 +1052,9 @@ func TestWriteTracks(t *testing.T) {
// C->S getStreamLength // C->S getStreamLength
err = mrw.Write(&message.MsgCommandAMF0{ err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 8, ChunkStreamID: 8,
Payload: []interface{}{ Name: "getStreamLength",
"getStreamLength", CommandID: 3,
float64(3), Arguments: []interface{}{
nil, nil,
"", "",
}, },
@@ -1111,10 +1064,10 @@ func TestWriteTracks(t *testing.T) {
// C->S play // C->S play
err = mrw.Write(&message.MsgCommandAMF0{ err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 8, ChunkStreamID: 8,
MessageStreamID: 16777216, MessageStreamID: 0x1000000,
Payload: []interface{}{ Name: "play",
"play", CommandID: 4,
float64(4), Arguments: []interface{}{
nil, nil,
"", "",
float64(-2000), float64(-2000),
@@ -1141,10 +1094,10 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{ require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5, ChunkStreamID: 5,
MessageStreamID: 16777216, MessageStreamID: 0x1000000,
Payload: []interface{}{ Name: "onStatus",
"onStatus", CommandID: 4,
float64(4), Arguments: []interface{}{
nil, nil,
flvio.AMFMap{ flvio.AMFMap{
{K: "level", V: "status"}, {K: "level", V: "status"},
@@ -1159,10 +1112,10 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{ require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5, ChunkStreamID: 5,
MessageStreamID: 16777216, MessageStreamID: 0x1000000,
Payload: []interface{}{ Name: "onStatus",
"onStatus", CommandID: 4,
float64(4), Arguments: []interface{}{
nil, nil,
flvio.AMFMap{ flvio.AMFMap{
{K: "level", V: "status"}, {K: "level", V: "status"},
@@ -1177,10 +1130,10 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{ require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5, ChunkStreamID: 5,
MessageStreamID: 16777216, MessageStreamID: 0x1000000,
Payload: []interface{}{ Name: "onStatus",
"onStatus", CommandID: 4,
float64(4), Arguments: []interface{}{
nil, nil,
flvio.AMFMap{ flvio.AMFMap{
{K: "level", V: "status"}, {K: "level", V: "status"},
@@ -1195,10 +1148,10 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{ require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5, ChunkStreamID: 5,
MessageStreamID: 16777216, MessageStreamID: 0x1000000,
Payload: []interface{}{ Name: "onStatus",
"onStatus", CommandID: 4,
float64(4), Arguments: []interface{}{
nil, nil,
flvio.AMFMap{ flvio.AMFMap{
{K: "level", V: "status"}, {K: "level", V: "status"},
@@ -1213,8 +1166,9 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &message.MsgDataAMF0{ require.Equal(t, &message.MsgDataAMF0{
ChunkStreamID: 4, ChunkStreamID: 4,
MessageStreamID: 16777216, MessageStreamID: 0x1000000,
Payload: []interface{}{ Payload: []interface{}{
"@setDataFrame",
"onMetaData", "onMetaData",
flvio.AMFMap{ flvio.AMFMap{
{K: "videodatarate", V: float64(0)}, {K: "videodatarate", V: float64(0)},
@@ -1230,7 +1184,7 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &message.MsgVideo{ require.Equal(t, &message.MsgVideo{
ChunkStreamID: 6, ChunkStreamID: 6,
MessageStreamID: 16777216, MessageStreamID: 0x1000000,
IsKeyFrame: true, IsKeyFrame: true,
H264Type: flvio.AVC_SEQHDR, H264Type: flvio.AVC_SEQHDR,
Payload: []byte{ Payload: []byte{
@@ -1248,7 +1202,7 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &message.MsgAudio{ require.Equal(t, &message.MsgAudio{
ChunkStreamID: 4, ChunkStreamID: 4,
MessageStreamID: 16777216, MessageStreamID: 0x1000000,
Rate: flvio.SOUND_44Khz, Rate: flvio.SOUND_44Khz,
Depth: flvio.SOUND_16BIT, Depth: flvio.SOUND_16BIT,
Channels: flvio.SOUND_STEREO, Channels: flvio.SOUND_STEREO,

View File

@@ -0,0 +1,89 @@
package h264conf
import (
"fmt"
)
// Conf is a RTMP H264 configuration.
type Conf struct {
SPS []byte
PPS []byte
}
// Unmarshal decodes a Conf from bytes.
func (c *Conf) Unmarshal(buf []byte) error {
if len(buf) < 8 {
return fmt.Errorf("invalid size 1")
}
pos := 5
spsCount := buf[pos] & 0x1F
pos++
if spsCount != 1 {
return fmt.Errorf("sps count != 1 is unsupported")
}
spsLen := int(uint16(buf[pos])<<8 | uint16(buf[pos+1]))
pos += 2
if (len(buf) - pos) < spsLen {
return fmt.Errorf("invalid size 2")
}
c.SPS = buf[pos : pos+spsLen]
pos += spsLen
if (len(buf) - pos) < 3 {
return fmt.Errorf("invalid size 3")
}
ppsCount := buf[pos]
pos++
if ppsCount != 1 {
return fmt.Errorf("pps count != 1 is unsupported")
}
ppsLen := int(uint16(buf[pos])<<8 | uint16(buf[pos+1]))
pos += 2
if (len(buf) - pos) < ppsLen {
return fmt.Errorf("invalid size")
}
c.PPS = buf[pos : pos+ppsLen]
return nil
}
// Marshal encodes a Conf into bytes.
func (c Conf) Marshal() ([]byte, error) {
spsLen := len(c.SPS)
ppsLen := len(c.PPS)
buf := make([]byte, 11+spsLen+ppsLen)
buf[0] = 1
buf[1] = c.SPS[1]
buf[2] = c.SPS[2]
buf[3] = c.SPS[3]
buf[4] = 3 | 0xFC
buf[5] = 1 | 0xE0
pos := 6
buf[pos] = byte(spsLen >> 8)
buf[pos+1] = byte(spsLen)
pos += 2
copy(buf[pos:], c.SPS)
pos += spsLen
buf[pos] = 1
pos++
buf[pos] = byte(ppsLen >> 8)
buf[pos+1] = byte(ppsLen)
pos += 2
copy(buf[pos:], c.PPS)
return buf, nil
}

View File

@@ -0,0 +1,29 @@
package h264conf
import (
"testing"
"github.com/stretchr/testify/require"
)
var decoded = Conf{
SPS: []byte{0x45, 0x32, 0xA3, 0x08},
PPS: []byte{0x45, 0x34},
}
var encoded = []byte{
0x1, 0x32, 0xa3, 0x8, 0xff, 0xe1, 0x0, 0x4, 0x45, 0x32, 0xa3, 0x8, 0x1, 0x0, 0x2, 0x45, 0x34,
}
func TestUnmarshal(t *testing.T) {
var dec Conf
err := dec.Unmarshal(encoded)
require.NoError(t, err)
require.Equal(t, decoded, dec)
}
func TestMarshal(t *testing.T) {
enc, err := decoded.Marshal()
require.NoError(t, err)
require.Equal(t, encoded, enc)
}

View File

@@ -5,7 +5,6 @@ import (
"crypto/hmac" "crypto/hmac"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/binary"
"fmt" "fmt"
"io" "io"
) )
@@ -78,14 +77,13 @@ type C1S1 struct {
} }
// Read reads a C1S1. // Read reads a C1S1.
func (c *C1S1) Read(r io.Reader, isC1 bool) error { func (c *C1S1) Read(r io.Reader, isC1 bool, validateSignature bool) error {
buf := make([]byte, 1536) buf := make([]byte, 1536)
_, err := io.ReadFull(r, buf) _, err := io.ReadFull(r, buf)
if err != nil { if err != nil {
return err return err
} }
// validate signature
var peerKey []byte var peerKey []byte
var key []byte var key []byte
if isC1 { if isC1 {
@@ -97,12 +95,15 @@ func (c *C1S1) Read(r io.Reader, isC1 bool) error {
} }
ok, digest := hsParse1(buf, peerKey, key) ok, digest := hsParse1(buf, peerKey, key)
if !ok { if !ok {
return fmt.Errorf("unable to validate C1/S1 signature") if validateSignature {
return fmt.Errorf("unable to validate C1/S1 signature")
}
} else {
c.Digest = digest
} }
c.Time = binary.BigEndian.Uint32(buf) c.Time = uint32(buf[0])<<24 | uint32(buf[1])<<16 | uint32(buf[2])<<8 | uint32(buf[3])
c.Random = buf[8:] c.Random = buf[8:]
c.Digest = digest
return nil return nil
} }
@@ -111,7 +112,10 @@ func (c *C1S1) Read(r io.Reader, isC1 bool) error {
func (c *C1S1) Write(w io.Writer, isC1 bool) error { func (c *C1S1) Write(w io.Writer, isC1 bool) error {
buf := make([]byte, 1536) buf := make([]byte, 1536)
binary.BigEndian.PutUint32(buf, c.Time) buf[0] = byte(c.Time >> 24)
buf[1] = byte(c.Time >> 16)
buf[2] = byte(c.Time >> 8)
buf[3] = byte(c.Time)
copy(buf[4:], []byte{0, 0, 0, 0}) copy(buf[4:], []byte{0, 0, 0, 0})
if c.Random == nil { if c.Random == nil {

View File

@@ -89,7 +89,7 @@ func TestC1S1Read(t *testing.T) {
}, },
} { } {
var c1s1 C1S1 var c1s1 C1S1
err := c1s1.Read((bytes.NewReader(ca.enc)), ca.isC1) err := c1s1.Read((bytes.NewReader(ca.enc)), ca.isC1, true)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, ca.dec, c1s1) require.Equal(t, ca.dec, c1s1)
} }

View File

@@ -3,7 +3,6 @@ package handshake
import ( import (
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"encoding/binary"
"fmt" "fmt"
"io" "io"
) )
@@ -17,22 +16,23 @@ type C2S2 struct {
} }
// Read reads a C2S2. // Read reads a C2S2.
func (c *C2S2) Read(r io.Reader) error { func (c *C2S2) Read(r io.Reader, validateSignature bool) error {
buf := make([]byte, 1536) buf := make([]byte, 1536)
_, err := io.ReadFull(r, buf) _, err := io.ReadFull(r, buf)
if err != nil { if err != nil {
return err return err
} }
// validate signature if validateSignature {
gap := len(buf) - 32 gap := len(buf) - 32
digest := hsMakeDigest(c.Digest, buf, gap) digest := hsMakeDigest(c.Digest, buf, gap)
if !bytes.Equal(buf[gap:gap+32], digest) { if !bytes.Equal(buf[gap:gap+32], digest) {
return fmt.Errorf("unable to validate C2/S2 signature") return fmt.Errorf("unable to validate C2/S2 signature")
}
} }
c.Time = binary.BigEndian.Uint32(buf) c.Time = uint32(buf[0])<<24 | uint32(buf[1])<<16 | uint32(buf[2])<<8 | uint32(buf[3])
c.Time2 = binary.BigEndian.Uint32(buf[4:]) c.Time2 = uint32(buf[4])<<24 | uint32(buf[5])<<16 | uint32(buf[6])<<8 | uint32(buf[7])
c.Random = buf[8:] c.Random = buf[8:]
return nil return nil
@@ -41,8 +41,15 @@ func (c *C2S2) Read(r io.Reader) error {
// Write writes a C2S2. // Write writes a C2S2.
func (c C2S2) Write(w io.Writer) error { func (c C2S2) Write(w io.Writer) error {
buf := make([]byte, 1536) buf := make([]byte, 1536)
binary.BigEndian.PutUint32(buf, c.Time)
binary.BigEndian.PutUint32(buf[4:], c.Time2) buf[0] = byte(c.Time >> 24)
buf[1] = byte(c.Time >> 16)
buf[2] = byte(c.Time >> 8)
buf[3] = byte(c.Time)
buf[4] = byte(c.Time2 >> 24)
buf[5] = byte(c.Time2 >> 16)
buf[6] = byte(c.Time2 >> 8)
buf[7] = byte(c.Time2)
if c.Random == nil { if c.Random == nil {
rand.Read(buf[8:]) rand.Read(buf[8:])

View File

@@ -42,7 +42,7 @@ func TestC2S2Read(t *testing.T) {
var c2s2 C2S2 var c2s2 C2S2
c2s2.Digest = c2s2dec.Digest c2s2.Digest = c2s2dec.Digest
err := c2s2.Read((bytes.NewReader(c2s2enc))) err := c2s2.Read((bytes.NewReader(c2s2enc)), true)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, c2s2dec, c2s2) require.Equal(t, c2s2dec, c2s2)
} }

View File

@@ -5,7 +5,7 @@ import (
) )
// DoClient performs a client-side handshake. // DoClient performs a client-side handshake.
func DoClient(rw io.ReadWriter) error { func DoClient(rw io.ReadWriter, validateSignature bool) error {
err := C0S0{}.Write(rw) err := C0S0{}.Write(rw)
if err != nil { if err != nil {
return err return err
@@ -23,12 +23,12 @@ func DoClient(rw io.ReadWriter) error {
} }
s1 := C1S1{} s1 := C1S1{}
err = s1.Read(rw, false) err = s1.Read(rw, false, validateSignature)
if err != nil { if err != nil {
return err return err
} }
err = (&C2S2{Digest: c1.Digest}).Read(rw) err = (&C2S2{Digest: c1.Digest}).Read(rw, validateSignature)
if err != nil { if err != nil {
return err return err
} }
@@ -42,14 +42,14 @@ func DoClient(rw io.ReadWriter) error {
} }
// DoServer performs a server-side handshake. // DoServer performs a server-side handshake.
func DoServer(rw io.ReadWriter) error { func DoServer(rw io.ReadWriter, validateSignature bool) error {
err := C0S0{}.Read(rw) err := C0S0{}.Read(rw)
if err != nil { if err != nil {
return err return err
} }
c1 := C1S1{} c1 := C1S1{}
err = c1.Read(rw, true) err = c1.Read(rw, true, validateSignature)
if err != nil { if err != nil {
return err return err
} }
@@ -70,7 +70,7 @@ func DoServer(rw io.ReadWriter) error {
return err return err
} }
err = (&C2S2{Digest: s1.Digest}).Read(rw) err = (&C2S2{Digest: s1.Digest}).Read(rw, validateSignature)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -19,7 +19,7 @@ func TestHandshake(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
err = DoServer(conn) err = DoServer(conn, true)
require.NoError(t, err) require.NoError(t, err)
close(done) close(done)
@@ -29,7 +29,7 @@ func TestHandshake(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
err = DoClient(conn) err = DoClient(conn, true)
require.NoError(t, err) require.NoError(t, err)
<-done <-done

View File

@@ -1,7 +1,6 @@
package message package message //nolint:dupl
import ( import (
"encoding/binary"
"fmt" "fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -23,18 +22,23 @@ func (m *MsgAcknowledge) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("unexpected body size") return fmt.Errorf("unexpected body size")
} }
m.Value = binary.BigEndian.Uint32(raw.Body) m.Value = uint32(raw.Body[0])<<24 | uint32(raw.Body[1])<<16 | uint32(raw.Body[2])<<8 | uint32(raw.Body[3])
return nil return nil
} }
// Marshal implements Message. // Marshal implements Message.
func (m *MsgAcknowledge) Marshal() (*rawmessage.Message, error) { func (m *MsgAcknowledge) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 4) buf := make([]byte, 4)
binary.BigEndian.PutUint32(body, m.Value)
buf[0] = byte(m.Value >> 24)
buf[1] = byte(m.Value >> 16)
buf[2] = byte(m.Value >> 8)
buf[3] = byte(m.Value)
return &rawmessage.Message{ return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID, ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeAcknowledge, Type: chunk.MessageTypeAcknowledge,
Body: body, Body: buf,
}, nil }, nil
} }

View File

@@ -2,6 +2,7 @@ package message
import ( import (
"fmt" "fmt"
"time"
"github.com/notedit/rtmp/format/flv/flvio" "github.com/notedit/rtmp/format/flv/flvio"
@@ -12,7 +13,7 @@ import (
// MsgAudio is an audio message. // MsgAudio is an audio message.
type MsgAudio struct { type MsgAudio struct {
ChunkStreamID byte ChunkStreamID byte
DTS uint32 DTS time.Duration
MessageStreamID uint32 MessageStreamID uint32
Rate uint8 Rate uint8
Depth uint8 Depth uint8

View File

@@ -1,6 +1,8 @@
package message package message
import ( import (
"fmt"
"github.com/notedit/rtmp/format/flv/flvio" "github.com/notedit/rtmp/format/flv/flvio"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -11,7 +13,9 @@ import (
type MsgCommandAMF0 struct { type MsgCommandAMF0 struct {
ChunkStreamID byte ChunkStreamID byte
MessageStreamID uint32 MessageStreamID uint32
Payload []interface{} Name string
CommandID int
Arguments []interface{}
} }
// Unmarshal implements Message. // Unmarshal implements Message.
@@ -23,7 +27,24 @@ func (m *MsgCommandAMF0) Unmarshal(raw *rawmessage.Message) error {
if err != nil { if err != nil {
return err return err
} }
m.Payload = payload
if len(payload) < 3 {
return fmt.Errorf("invalid command payload")
}
var ok bool
m.Name, ok = payload[0].(string)
if !ok {
return fmt.Errorf("invalid command payload")
}
tmp, ok := payload[1].(float64)
if !ok {
return fmt.Errorf("invalid command payload")
}
m.CommandID = int(tmp)
m.Arguments = payload[2:]
return nil return nil
} }
@@ -34,6 +55,9 @@ func (m MsgCommandAMF0) Marshal() (*rawmessage.Message, error) {
ChunkStreamID: m.ChunkStreamID, ChunkStreamID: m.ChunkStreamID,
Type: chunk.MessageTypeCommandAMF0, Type: chunk.MessageTypeCommandAMF0,
MessageStreamID: m.MessageStreamID, MessageStreamID: m.MessageStreamID,
Body: flvio.FillAMF0ValsMalloc(m.Payload), Body: flvio.FillAMF0ValsMalloc(append([]interface{}{
m.Name,
float64(m.CommandID),
}, m.Arguments...)),
}, nil }, nil
} }

View File

@@ -1,7 +1,6 @@
package message package message //nolint:dupl
import ( import (
"encoding/binary"
"fmt" "fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -23,18 +22,23 @@ func (m *MsgSetChunkSize) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("unexpected body size") return fmt.Errorf("unexpected body size")
} }
m.Value = binary.BigEndian.Uint32(raw.Body) m.Value = uint32(raw.Body[0])<<24 | uint32(raw.Body[1])<<16 | uint32(raw.Body[2])<<8 | uint32(raw.Body[3])
return nil return nil
} }
// Marshal implements Message. // Marshal implements Message.
func (m *MsgSetChunkSize) Marshal() (*rawmessage.Message, error) { func (m *MsgSetChunkSize) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 4) buf := make([]byte, 4)
binary.BigEndian.PutUint32(body, m.Value)
buf[0] = byte(m.Value >> 24)
buf[1] = byte(m.Value >> 16)
buf[2] = byte(m.Value >> 8)
buf[3] = byte(m.Value)
return &rawmessage.Message{ return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID, ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeSetChunkSize, Type: chunk.MessageTypeSetChunkSize,
Body: body, Body: buf,
}, nil }, nil
} }

View File

@@ -1,7 +1,6 @@
package message package message //nolint:dupl
import ( import (
"encoding/binary"
"fmt" "fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -24,20 +23,25 @@ func (m *MsgSetPeerBandwidth) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("unexpected body size") return fmt.Errorf("unexpected body size")
} }
m.Value = binary.BigEndian.Uint32(raw.Body) m.Value = uint32(raw.Body[0])<<24 | uint32(raw.Body[1])<<16 | uint32(raw.Body[2])<<8 | uint32(raw.Body[3])
m.Type = raw.Body[4] m.Type = raw.Body[4]
return nil return nil
} }
// Marshal implements Message. // Marshal implements Message.
func (m *MsgSetPeerBandwidth) Marshal() (*rawmessage.Message, error) { func (m *MsgSetPeerBandwidth) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 5) buf := make([]byte, 5)
binary.BigEndian.PutUint32(body, m.Value)
body[4] = m.Type buf[0] = byte(m.Value >> 24)
buf[1] = byte(m.Value >> 16)
buf[2] = byte(m.Value >> 8)
buf[3] = byte(m.Value)
buf[4] = m.Type
return &rawmessage.Message{ return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID, ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeSetChunkSize, Type: chunk.MessageTypeSetPeerBandwidth,
Body: body, Body: buf,
}, nil }, nil
} }

View File

@@ -1,7 +1,6 @@
package message package message //nolint:dupl
import ( import (
"encoding/binary"
"fmt" "fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -23,18 +22,23 @@ func (m *MsgSetWindowAckSize) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("unexpected body size") return fmt.Errorf("unexpected body size")
} }
m.Value = binary.BigEndian.Uint32(raw.Body) m.Value = uint32(raw.Body[0])<<24 | uint32(raw.Body[1])<<16 | uint32(raw.Body[2])<<8 | uint32(raw.Body[3])
return nil return nil
} }
// Marshal implements Message. // Marshal implements Message.
func (m *MsgSetWindowAckSize) Marshal() (*rawmessage.Message, error) { func (m *MsgSetWindowAckSize) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 4) buf := make([]byte, 4)
binary.BigEndian.PutUint32(body, m.Value)
buf[0] = byte(m.Value >> 24)
buf[1] = byte(m.Value >> 16)
buf[2] = byte(m.Value >> 8)
buf[3] = byte(m.Value)
return &rawmessage.Message{ return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID, ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeSetWindowAckSize, Type: chunk.MessageTypeSetWindowAckSize,
Body: body, Body: buf,
}, nil }, nil
} }

View File

@@ -1,7 +1,6 @@
package message package message //nolint:dupl
import ( import (
"encoding/binary"
"fmt" "fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -23,20 +22,25 @@ func (m *MsgUserControlPingRequest) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("invalid body size") return fmt.Errorf("invalid body size")
} }
m.ServerTime = binary.BigEndian.Uint32(raw.Body[2:]) m.ServerTime = uint32(raw.Body[2])<<24 | uint32(raw.Body[3])<<16 | uint32(raw.Body[4])<<8 | uint32(raw.Body[5])
return nil return nil
} }
// Marshal implements Message. // Marshal implements Message.
func (m MsgUserControlPingRequest) Marshal() (*rawmessage.Message, error) { func (m MsgUserControlPingRequest) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 6) buf := make([]byte, 6)
binary.BigEndian.PutUint16(body, UserControlTypePingRequest)
binary.BigEndian.PutUint32(body[2:], m.ServerTime) buf[0] = byte(UserControlTypePingRequest >> 8)
buf[1] = byte(UserControlTypePingRequest)
buf[2] = byte(m.ServerTime >> 24)
buf[3] = byte(m.ServerTime >> 16)
buf[4] = byte(m.ServerTime >> 8)
buf[5] = byte(m.ServerTime)
return &rawmessage.Message{ return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID, ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeUserControl, Type: chunk.MessageTypeUserControl,
Body: body, Body: buf,
}, nil }, nil
} }

View File

@@ -1,7 +1,6 @@
package message package message //nolint:dupl
import ( import (
"encoding/binary"
"fmt" "fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -23,20 +22,25 @@ func (m *MsgUserControlPingResponse) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("invalid body size") return fmt.Errorf("invalid body size")
} }
m.ServerTime = binary.BigEndian.Uint32(raw.Body[2:]) m.ServerTime = uint32(raw.Body[2])<<24 | uint32(raw.Body[3])<<16 | uint32(raw.Body[4])<<8 | uint32(raw.Body[5])
return nil return nil
} }
// Marshal implements Message. // Marshal implements Message.
func (m MsgUserControlPingResponse) Marshal() (*rawmessage.Message, error) { func (m MsgUserControlPingResponse) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 6) buf := make([]byte, 6)
binary.BigEndian.PutUint16(body, UserControlTypePingResponse)
binary.BigEndian.PutUint32(body[2:], m.ServerTime) buf[0] = byte(UserControlTypePingResponse >> 8)
buf[1] = byte(UserControlTypePingResponse)
buf[2] = byte(m.ServerTime >> 24)
buf[3] = byte(m.ServerTime >> 16)
buf[4] = byte(m.ServerTime >> 8)
buf[5] = byte(m.ServerTime)
return &rawmessage.Message{ return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID, ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeUserControl, Type: chunk.MessageTypeUserControl,
Body: body, Body: buf,
}, nil }, nil
} }

View File

@@ -1,7 +1,6 @@
package message package message //nolint:dupl
import ( import (
"encoding/binary"
"fmt" "fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -24,22 +23,30 @@ func (m *MsgUserControlSetBufferLength) Unmarshal(raw *rawmessage.Message) error
return fmt.Errorf("invalid body size") return fmt.Errorf("invalid body size")
} }
m.StreamID = binary.BigEndian.Uint32(raw.Body[2:]) m.StreamID = uint32(raw.Body[2])<<24 | uint32(raw.Body[3])<<16 | uint32(raw.Body[4])<<8 | uint32(raw.Body[5])
m.BufferLength = binary.BigEndian.Uint32(raw.Body[6:]) m.BufferLength = uint32(raw.Body[6])<<24 | uint32(raw.Body[7])<<16 | uint32(raw.Body[8])<<8 | uint32(raw.Body[9])
return nil return nil
} }
// Marshal implements Message. // Marshal implements Message.
func (m MsgUserControlSetBufferLength) Marshal() (*rawmessage.Message, error) { func (m MsgUserControlSetBufferLength) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 10) buf := make([]byte, 10)
binary.BigEndian.PutUint16(body, UserControlTypeSetBufferLength)
binary.BigEndian.PutUint32(body[2:], m.StreamID) buf[0] = byte(UserControlTypeSetBufferLength >> 8)
binary.BigEndian.PutUint32(body[6:], m.BufferLength) buf[1] = byte(UserControlTypeSetBufferLength)
buf[2] = byte(m.StreamID >> 24)
buf[3] = byte(m.StreamID >> 16)
buf[4] = byte(m.StreamID >> 8)
buf[5] = byte(m.StreamID)
buf[6] = byte(m.BufferLength >> 24)
buf[7] = byte(m.BufferLength >> 16)
buf[8] = byte(m.BufferLength >> 8)
buf[9] = byte(m.BufferLength)
return &rawmessage.Message{ return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID, ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeUserControl, Type: chunk.MessageTypeUserControl,
Body: body, Body: buf,
}, nil }, nil
} }

View File

@@ -1,7 +1,6 @@
package message package message //nolint:dupl
import ( import (
"encoding/binary"
"fmt" "fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -23,20 +22,25 @@ func (m *MsgUserControlStreamBegin) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("invalid body size") return fmt.Errorf("invalid body size")
} }
m.StreamID = binary.BigEndian.Uint32(raw.Body[2:]) m.StreamID = uint32(raw.Body[2])<<24 | uint32(raw.Body[3])<<16 | uint32(raw.Body[4])<<8 | uint32(raw.Body[5])
return nil return nil
} }
// Marshal implements Message. // Marshal implements Message.
func (m MsgUserControlStreamBegin) Marshal() (*rawmessage.Message, error) { func (m MsgUserControlStreamBegin) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 6) buf := make([]byte, 6)
binary.BigEndian.PutUint16(body, UserControlTypeStreamBegin)
binary.BigEndian.PutUint32(body[2:], m.StreamID) buf[0] = byte(UserControlTypeStreamBegin >> 8)
buf[1] = byte(UserControlTypeStreamBegin)
buf[2] = byte(m.StreamID >> 24)
buf[3] = byte(m.StreamID >> 16)
buf[4] = byte(m.StreamID >> 8)
buf[5] = byte(m.StreamID)
return &rawmessage.Message{ return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID, ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeUserControl, Type: chunk.MessageTypeUserControl,
Body: body, Body: buf,
}, nil }, nil
} }

View File

@@ -1,7 +1,6 @@
package message package message //nolint:dupl
import ( import (
"encoding/binary"
"fmt" "fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -23,20 +22,25 @@ func (m *MsgUserControlStreamDry) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("invalid body size") return fmt.Errorf("invalid body size")
} }
m.StreamID = binary.BigEndian.Uint32(raw.Body[2:]) m.StreamID = uint32(raw.Body[2])<<24 | uint32(raw.Body[3])<<16 | uint32(raw.Body[4])<<8 | uint32(raw.Body[5])
return nil return nil
} }
// Marshal implements Message. // Marshal implements Message.
func (m MsgUserControlStreamDry) Marshal() (*rawmessage.Message, error) { func (m MsgUserControlStreamDry) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 6) buf := make([]byte, 6)
binary.BigEndian.PutUint16(body, UserControlTypeStreamDry)
binary.BigEndian.PutUint32(body[2:], m.StreamID) buf[0] = byte(UserControlTypeStreamDry >> 8)
buf[1] = byte(UserControlTypeStreamDry)
buf[2] = byte(m.StreamID >> 24)
buf[3] = byte(m.StreamID >> 16)
buf[4] = byte(m.StreamID >> 8)
buf[5] = byte(m.StreamID)
return &rawmessage.Message{ return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID, ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeUserControl, Type: chunk.MessageTypeUserControl,
Body: body, Body: buf,
}, nil }, nil
} }

View File

@@ -1,7 +1,6 @@
package message package message //nolint:dupl
import ( import (
"encoding/binary"
"fmt" "fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -23,20 +22,25 @@ func (m *MsgUserControlStreamEOF) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("invalid body size") return fmt.Errorf("invalid body size")
} }
m.StreamID = binary.BigEndian.Uint32(raw.Body[2:]) m.StreamID = uint32(raw.Body[2])<<24 | uint32(raw.Body[3])<<16 | uint32(raw.Body[4])<<8 | uint32(raw.Body[5])
return nil return nil
} }
// Marshal implements Message. // Marshal implements Message.
func (m MsgUserControlStreamEOF) Marshal() (*rawmessage.Message, error) { func (m MsgUserControlStreamEOF) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 6) buf := make([]byte, 6)
binary.BigEndian.PutUint16(body, UserControlTypeStreamEOF)
binary.BigEndian.PutUint32(body[2:], m.StreamID) buf[0] = byte(UserControlTypeStreamEOF >> 8)
buf[1] = byte(UserControlTypeStreamEOF)
buf[2] = byte(m.StreamID >> 24)
buf[3] = byte(m.StreamID >> 16)
buf[4] = byte(m.StreamID >> 8)
buf[5] = byte(m.StreamID)
return &rawmessage.Message{ return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID, ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeUserControl, Type: chunk.MessageTypeUserControl,
Body: body, Body: buf,
}, nil }, nil
} }

View File

@@ -1,7 +1,6 @@
package message package message //nolint:dupl
import ( import (
"encoding/binary"
"fmt" "fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -23,20 +22,25 @@ func (m *MsgUserControlStreamIsRecorded) Unmarshal(raw *rawmessage.Message) erro
return fmt.Errorf("invalid body size") return fmt.Errorf("invalid body size")
} }
m.StreamID = binary.BigEndian.Uint32(raw.Body[2:]) m.StreamID = uint32(raw.Body[2])<<24 | uint32(raw.Body[3])<<16 | uint32(raw.Body[4])<<8 | uint32(raw.Body[5])
return nil return nil
} }
// Marshal implements Message. // Marshal implements Message.
func (m MsgUserControlStreamIsRecorded) Marshal() (*rawmessage.Message, error) { func (m MsgUserControlStreamIsRecorded) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 6) buf := make([]byte, 6)
binary.BigEndian.PutUint16(body, UserControlTypeStreamIsRecorded)
binary.BigEndian.PutUint32(body[2:], m.StreamID) buf[0] = byte(UserControlTypeStreamIsRecorded >> 8)
buf[1] = byte(UserControlTypeStreamIsRecorded)
buf[2] = byte(m.StreamID >> 24)
buf[3] = byte(m.StreamID >> 16)
buf[4] = byte(m.StreamID >> 8)
buf[5] = byte(m.StreamID)
return &rawmessage.Message{ return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID, ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeUserControl, Type: chunk.MessageTypeUserControl,
Body: body, Body: buf,
}, nil }, nil
} }

View File

@@ -2,6 +2,7 @@ package message
import ( import (
"fmt" "fmt"
"time"
"github.com/notedit/rtmp/format/flv/flvio" "github.com/notedit/rtmp/format/flv/flvio"
@@ -12,11 +13,11 @@ import (
// MsgVideo is a video message. // MsgVideo is a video message.
type MsgVideo struct { type MsgVideo struct {
ChunkStreamID byte ChunkStreamID byte
DTS uint32 DTS time.Duration
MessageStreamID uint32 MessageStreamID uint32
IsKeyFrame bool IsKeyFrame bool
H264Type uint8 H264Type uint8
PTSDelta uint32 PTSDelta time.Duration
Payload []byte Payload []byte
} }
@@ -38,7 +39,10 @@ func (m *MsgVideo) Unmarshal(raw *rawmessage.Message) error {
} }
m.H264Type = raw.Body[1] m.H264Type = raw.Body[1]
m.PTSDelta = uint32(raw.Body[2])<<16 | uint32(raw.Body[3])<<8 | uint32(raw.Body[4])
tmp := uint32(raw.Body[2])<<16 | uint32(raw.Body[3])<<8 | uint32(raw.Body[4])
m.PTSDelta = time.Duration(tmp) * time.Millisecond
m.Payload = raw.Body[5:] m.Payload = raw.Body[5:]
return nil return nil
@@ -55,9 +59,12 @@ func (m MsgVideo) Marshal() (*rawmessage.Message, error) {
} }
body[0] |= flvio.VIDEO_H264 body[0] |= flvio.VIDEO_H264
body[1] = m.H264Type body[1] = m.H264Type
body[2] = uint8(m.PTSDelta >> 16)
body[3] = uint8(m.PTSDelta >> 8) tmp := uint32(m.PTSDelta / time.Millisecond)
body[4] = uint8(m.PTSDelta) body[2] = uint8(tmp >> 16)
body[3] = uint8(tmp >> 8)
body[4] = uint8(tmp)
copy(body[5:], m.Payload) copy(body[5:], m.Payload)
return &rawmessage.Message{ return &rawmessage.Message{

View File

@@ -1,7 +1,6 @@
package message package message
import ( import (
"encoding/binary"
"fmt" "fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter" "github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
@@ -28,7 +27,7 @@ func allocateMessage(raw *rawmessage.Message) (Message, error) {
return nil, fmt.Errorf("invalid body size") return nil, fmt.Errorf("invalid body size")
} }
subType := binary.BigEndian.Uint16(raw.Body) subType := uint16(raw.Body[0])<<8 | uint16(raw.Body[1])
switch subType { switch subType {
case UserControlTypeStreamBegin: case UserControlTypeStreamBegin:
return &MsgUserControlStreamBegin{}, nil return &MsgUserControlStreamBegin{}, nil
@@ -68,7 +67,7 @@ func allocateMessage(raw *rawmessage.Message) (Message, error) {
return &MsgVideo{}, nil return &MsgVideo{}, nil
default: default:
return nil, fmt.Errorf("unhandled message") return nil, fmt.Errorf("unhandled message type (%v)", raw.Type)
} }
} }

View File

@@ -0,0 +1,227 @@
package message
import (
"bytes"
"testing"
"time"
"github.com/notedit/rtmp/format/flv/flvio"
"github.com/stretchr/testify/require"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
)
var readWriterCases = []struct {
name string
dec Message
enc []byte
}{
{
"acknowledge",
&MsgAcknowledge{
Value: 45953968,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x3,
0x0, 0x0, 0x0, 0x0, 0x2, 0xbd, 0x33, 0xb0,
},
},
{
"audio",
&MsgAudio{
ChunkStreamID: 7,
DTS: 6013806 * time.Millisecond,
MessageStreamID: 4534543,
Rate: flvio.SOUND_44Khz,
Depth: flvio.SOUND_16BIT,
Channels: flvio.SOUND_STEREO,
AACType: flvio.AAC_RAW,
Payload: []byte{0x5A, 0xC0, 0x77, 0x40},
},
[]byte{
0x7, 0x5b, 0xc3, 0x6e, 0x0, 0x0, 0x6, 0x8,
0x0, 0x45, 0x31, 0xf, 0xaf, 0x1, 0x5a, 0xc0,
0x77, 0x40,
},
},
{
"command amf0",
&MsgCommandAMF0{
ChunkStreamID: 3,
MessageStreamID: 345243,
Name: "i8yythrergre",
CommandID: 56456,
Arguments: []interface{}{
flvio.AMFMap{
{K: "k1", V: "v1"},
{K: "k2", V: "v2"},
},
nil,
},
},
[]byte{
0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2f, 0x14,
0x0, 0x5, 0x44, 0x9b, 0x2, 0x0, 0xc, 0x69,
0x38, 0x79, 0x79, 0x74, 0x68, 0x72, 0x65, 0x72,
0x67, 0x72, 0x65, 0x0, 0x40, 0xeb, 0x91, 0x0,
0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x2, 0x6b,
0x31, 0x2, 0x0, 0x2, 0x76, 0x31, 0x0, 0x2,
0x6b, 0x32, 0x2, 0x0, 0x2, 0x76, 0x32, 0x0,
0x0, 0x9, 0x5,
},
},
{
"data amf0",
&MsgDataAMF0{
ChunkStreamID: 3,
MessageStreamID: 345243,
Payload: []interface{}{
float64(234),
"string",
nil,
},
},
[]byte{
0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x13, 0x12,
0x0, 0x5, 0x44, 0x9b, 0x0, 0x40, 0x6d, 0x40,
0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x6,
0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x05,
},
},
{
"set chunk size",
&MsgSetChunkSize{
Value: 10000,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x1,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x27, 0x10,
},
},
{
"set peer bandwidth",
&MsgSetChunkSize{
Value: 10000,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x1,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x27, 0x10,
},
},
{
"set window ack size",
&MsgSetChunkSize{
Value: 10000,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x1,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x27, 0x10,
},
},
{
"user control ping request",
&MsgUserControlPingRequest{
ServerTime: 569834435,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x6, 0x4,
0x0, 0x0, 0x0, 0x0, 0x0, 0x6, 0x21, 0xf6,
0xfb, 0xc3,
},
},
{
"user control ping response",
&MsgUserControlPingResponse{
ServerTime: 569834435,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x6, 0x4,
0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x21, 0xf6,
0xfb, 0xc3,
},
},
{
"user control set buffer length",
&MsgUserControlSetBufferLength{
StreamID: 35534,
BufferLength: 235345,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa, 0x4,
0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0,
0x8a, 0xce, 0x0, 0x3, 0x97, 0x51,
},
},
{
"user control stream begin",
&MsgUserControlStreamBegin{
StreamID: 35534,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x6, 0x4,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
0x8a, 0xce,
},
},
{
"user control stream dry",
&MsgUserControlStreamDry{
StreamID: 35534,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x6, 0x4,
0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0,
0x8a, 0xce,
},
},
{
"user control stream eof",
&MsgUserControlStreamEOF{
StreamID: 35534,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x6, 0x4,
0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0,
0x8a, 0xce,
},
},
{
"user control stream is recorded",
&MsgUserControlStreamIsRecorded{
StreamID: 35534,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x6, 0x4,
0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x0, 0x0,
0x8a, 0xce,
},
},
{
"video",
&MsgVideo{
ChunkStreamID: 6,
DTS: 2543534 * time.Millisecond,
MessageStreamID: 0x1000000,
IsKeyFrame: true,
H264Type: flvio.AVC_SEQHDR,
PTSDelta: 10 * time.Millisecond,
Payload: []byte{0x01, 0x02, 0x03},
},
[]byte{
0x6, 0x26, 0xcf, 0xae, 0x0, 0x0, 0x8, 0x9,
0x1, 0x0, 0x0, 0x0, 0x17, 0x0, 0x0, 0x0,
0xa, 0x1, 0x2, 0x3,
},
},
}
func TestReader(t *testing.T) {
for _, ca := range readWriterCases {
t.Run(ca.name, func(t *testing.T) {
r := NewReader(bytecounter.NewReader(bytes.NewReader(ca.enc)), nil)
dec, err := r.Read()
require.NoError(t, err)
require.Equal(t, ca.dec, dec)
})
}
}

View File

@@ -11,8 +11,8 @@ type ReadWriter struct {
} }
// NewReadWriter allocates a ReadWriter. // NewReadWriter allocates a ReadWriter.
func NewReadWriter(bc *bytecounter.ReadWriter) *ReadWriter { func NewReadWriter(bc *bytecounter.ReadWriter, checkAcknowledge bool) *ReadWriter {
w := NewWriter(bc.Writer) w := NewWriter(bc.Writer, checkAcknowledge)
r := NewReader(bc.Reader, func(count uint32) error { r := NewReader(bc.Reader, func(count uint32) error {
return w.Write(&MsgAcknowledge{ return w.Write(&MsgAcknowledge{

View File

@@ -11,9 +11,9 @@ type Writer struct {
} }
// NewWriter allocates a Writer. // NewWriter allocates a Writer.
func NewWriter(w *bytecounter.Writer) *Writer { func NewWriter(w *bytecounter.Writer, checkAcknowledge bool) *Writer {
return &Writer{ return &Writer{
w: rawmessage.NewWriter(w), w: rawmessage.NewWriter(w, checkAcknowledge),
} }
} }

View File

@@ -0,0 +1,22 @@
package message
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
)
func TestWriter(t *testing.T) {
for _, ca := range readWriterCases {
t.Run(ca.name, func(t *testing.T) {
var buf bytes.Buffer
r := NewWriter(bytecounter.NewWriter(&buf), true)
err := r.Write(ca.dec)
require.NoError(t, err)
require.Equal(t, ca.enc, buf.Bytes())
})
}
}

View File

@@ -1,13 +1,15 @@
package rawmessage package rawmessage
import ( import (
"time"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
) )
// Message is a raw message. // Message is a raw message.
type Message struct { type Message struct {
ChunkStreamID byte ChunkStreamID byte
Timestamp uint32 Timestamp time.Duration
Type chunk.MessageType Type chunk.MessageType
MessageStreamID uint32 MessageStreamID uint32
Body []byte Body []byte

View File

@@ -3,6 +3,7 @@ package rawmessage
import ( import (
"errors" "errors"
"fmt" "fmt"
"time"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter" "github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -73,7 +74,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
} }
return &Message{ return &Message{
Timestamp: c0.Timestamp, Timestamp: time.Duration(c0.Timestamp) * time.Millisecond,
Type: c0.Type, Type: c0.Type,
MessageStreamID: c0.MessageStreamID, MessageStreamID: c0.MessageStreamID,
Body: c0.Body, Body: c0.Body,
@@ -109,7 +110,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
} }
return &Message{ return &Message{
Timestamp: *rc.curTimestamp, Timestamp: time.Duration(*rc.curTimestamp) * time.Millisecond,
Type: c1.Type, Type: c1.Type,
MessageStreamID: *rc.curMessageStreamID, MessageStreamID: *rc.curMessageStreamID,
Body: c1.Body, Body: c1.Body,
@@ -124,7 +125,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
return nil, fmt.Errorf("received type 2 chunk but expected type 3 chunk") return nil, fmt.Errorf("received type 2 chunk but expected type 3 chunk")
} }
chunkBodyLen := (*rc.curBodyLen) chunkBodyLen := *rc.curBodyLen
if chunkBodyLen > rc.mr.chunkSize { if chunkBodyLen > rc.mr.chunkSize {
chunkBodyLen = rc.mr.chunkSize chunkBodyLen = rc.mr.chunkSize
} }
@@ -140,13 +141,13 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
v2 := c2.TimestampDelta v2 := c2.TimestampDelta
rc.curTimestampDelta = &v2 rc.curTimestampDelta = &v2
if chunkBodyLen != uint32(len(c2.Body)) { if *rc.curBodyLen != uint32(len(c2.Body)) {
rc.curBody = &c2.Body rc.curBody = &c2.Body
return nil, errMoreChunksNeeded return nil, errMoreChunksNeeded
} }
return &Message{ return &Message{
Timestamp: *rc.curTimestamp, Timestamp: time.Duration(*rc.curTimestamp) * time.Millisecond,
Type: *rc.curType, Type: *rc.curType,
MessageStreamID: *rc.curMessageStreamID, MessageStreamID: *rc.curMessageStreamID,
Body: c2.Body, Body: c2.Body,
@@ -179,7 +180,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
rc.curBody = nil rc.curBody = nil
return &Message{ return &Message{
Timestamp: *rc.curTimestamp, Timestamp: time.Duration(*rc.curTimestamp) * time.Millisecond,
Type: *rc.curType, Type: *rc.curType,
MessageStreamID: *rc.curMessageStreamID, MessageStreamID: *rc.curMessageStreamID,
Body: body, Body: body,
@@ -201,7 +202,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
rc.curTimestamp = &v1 rc.curTimestamp = &v1
return &Message{ return &Message{
Timestamp: *rc.curTimestamp, Timestamp: time.Duration(*rc.curTimestamp) * time.Millisecond,
Type: *rc.curType, Type: *rc.curType,
MessageStreamID: *rc.curMessageStreamID, MessageStreamID: *rc.curMessageStreamID,
Body: c3.Body, Body: c3.Body,

View File

@@ -3,6 +3,7 @@ package rawmessage
import ( import (
"bytes" "bytes"
"testing" "testing"
"time"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -10,151 +11,174 @@ import (
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
) )
type sequenceEntry struct {
chunk chunk.Chunk
msg *Message
}
func TestReader(t *testing.T) { func TestReader(t *testing.T) {
testSequence := func(t *testing.T, seq []sequenceEntry) { type sequenceEntry struct {
var buf bytes.Buffer chunk chunk.Chunk
bcr := bytecounter.NewReader(&buf) msg *Message
r := NewReader(bcr, func(count uint32) error {
return nil
})
for _, entry := range seq {
buf2, err := entry.chunk.Marshal()
require.NoError(t, err)
buf.Write(buf2)
msg, err := r.Read()
require.NoError(t, err)
require.Equal(t, entry.msg, msg)
}
} }
t.Run("chunk0 + chunk1", func(t *testing.T) { for _, ca := range []struct {
testSequence(t, []sequenceEntry{ name string
{ sequence []sequenceEntry
&chunk.Chunk0{ }{
ChunkStreamID: 27, {
Timestamp: 18576, "chunk0 + chunk1",
Type: chunk.MessageTypeSetPeerBandwidth, []sequenceEntry{
MessageStreamID: 3123, {
BodyLen: 64, &chunk.Chunk0{
Body: bytes.Repeat([]byte{0x02}, 64), ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x02}, 64),
},
&Message{
ChunkStreamID: 27,
Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x02}, 64),
},
}, },
&Message{ {
ChunkStreamID: 27, &chunk.Chunk1{
Timestamp: 18576, ChunkStreamID: 27,
Type: chunk.MessageTypeSetPeerBandwidth, TimestampDelta: 15,
MessageStreamID: 3123, Type: chunk.MessageTypeSetPeerBandwidth,
Body: bytes.Repeat([]byte{0x02}, 64), BodyLen: 64,
Body: bytes.Repeat([]byte{0x03}, 64),
},
&Message{
ChunkStreamID: 27,
Timestamp: (18576 + 15) * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 64),
},
}, },
}, },
{ },
&chunk.Chunk1{ {
ChunkStreamID: 27, "chunk0 + chunk2 + chunk3",
TimestampDelta: 15, []sequenceEntry{
Type: chunk.MessageTypeSetPeerBandwidth, {
BodyLen: 64, &chunk.Chunk0{
Body: bytes.Repeat([]byte{0x03}, 64), ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x02}, 64),
},
&Message{
ChunkStreamID: 27,
Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x02}, 64),
},
}, },
&Message{ {
ChunkStreamID: 27, &chunk.Chunk2{
Timestamp: 18576 + 15, ChunkStreamID: 27,
Type: chunk.MessageTypeSetPeerBandwidth, TimestampDelta: 15,
MessageStreamID: 3123, Body: bytes.Repeat([]byte{0x03}, 64),
Body: bytes.Repeat([]byte{0x03}, 64), },
&Message{
ChunkStreamID: 27,
Timestamp: (18576 + 15) * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 64),
},
},
{
&chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x04}, 64),
},
&Message{
ChunkStreamID: 27,
Timestamp: (18576 + 15 + 15) * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x04}, 64),
},
}, },
}, },
},
{
"chunk0 + chunk3 + chunk2 + chunk3",
[]sequenceEntry{
{
&chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 192,
Body: bytes.Repeat([]byte{0x03}, 128),
},
nil,
},
{
&chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x03}, 64),
},
&Message{
ChunkStreamID: 27,
Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 192),
},
},
{
&chunk.Chunk2{
ChunkStreamID: 27,
TimestampDelta: 15,
Body: bytes.Repeat([]byte{0x04}, 128),
},
nil,
},
{
&chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x04}, 64),
},
&Message{
ChunkStreamID: 27,
Timestamp: 18591 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x04}, 192),
},
},
},
},
} {
t.Run(ca.name, func(t *testing.T) {
var buf bytes.Buffer
bcr := bytecounter.NewReader(&buf)
r := NewReader(bcr, func(count uint32) error {
return nil
})
for _, entry := range ca.sequence {
buf2, err := entry.chunk.Marshal()
require.NoError(t, err)
buf.Write(buf2)
if entry.msg != nil {
msg, err := r.Read()
require.NoError(t, err)
require.Equal(t, entry.msg, msg)
}
}
}) })
}) }
t.Run("chunk0 + chunk2 + chunk3", func(t *testing.T) {
testSequence(t, []sequenceEntry{
{
&chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x02}, 64),
},
&Message{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x02}, 64),
},
},
{
&chunk.Chunk2{
ChunkStreamID: 27,
TimestampDelta: 15,
Body: bytes.Repeat([]byte{0x03}, 64),
},
&Message{
ChunkStreamID: 27,
Timestamp: 18576 + 15,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 64),
},
},
{
&chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x04}, 64),
},
&Message{
ChunkStreamID: 27,
Timestamp: 18576 + 15 + 15,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x04}, 64),
},
},
})
})
t.Run("chunk0 + chunk3", func(t *testing.T) {
var buf bytes.Buffer
bcr := bytecounter.NewReader(&buf)
r := NewReader(bcr, func(count uint32) error {
return nil
})
buf2, err := chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 192,
Body: bytes.Repeat([]byte{0x03}, 128),
}.Marshal()
require.NoError(t, err)
buf.Write(buf2)
buf2, err = chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x03}, 64),
}.Marshal()
require.NoError(t, err)
buf.Write(buf2)
msg, err := r.Read()
require.NoError(t, err)
require.Equal(t, &Message{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 192),
}, msg)
})
} }
func TestReaderAcknowledge(t *testing.T) { func TestReaderAcknowledge(t *testing.T) {

View File

@@ -2,6 +2,7 @@ package rawmessage
import ( import (
"fmt" "fmt"
"time"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter" "github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -12,14 +13,14 @@ type writerChunkStream struct {
lastMessageStreamID *uint32 lastMessageStreamID *uint32
lastType *chunk.MessageType lastType *chunk.MessageType
lastBodyLen *uint32 lastBodyLen *uint32
lastTimestamp *uint32 lastTimestamp *time.Duration
lastTimestampDelta *uint32 lastTimestampDelta *time.Duration
} }
func (wc *writerChunkStream) writeChunk(c chunk.Chunk) error { func (wc *writerChunkStream) writeChunk(c chunk.Chunk) error {
// check if we received an acknowledge // check if we received an acknowledge
if wc.mw.ackWindowSize != 0 { if wc.mw.checkAcknowledge && wc.mw.ackWindowSize != 0 {
diff := wc.mw.w.Count() - (wc.mw.ackValue) diff := wc.mw.w.Count() - wc.mw.ackValue
if diff > (wc.mw.ackWindowSize * 3 / 2) { if diff > (wc.mw.ackWindowSize * 3 / 2) {
return fmt.Errorf("no acknowledge received within window") return fmt.Errorf("no acknowledge received within window")
@@ -44,14 +45,13 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error {
pos := uint32(0) pos := uint32(0)
firstChunk := true firstChunk := true
var timestampDelta *uint32 var timestampDelta *time.Duration
if wc.lastTimestamp != nil { if wc.lastTimestamp != nil {
diff := int64(msg.Timestamp) - int64(*wc.lastTimestamp) diff := msg.Timestamp - *wc.lastTimestamp
// use delta only if it is positive // use delta only if it is positive
if diff >= 0 { if diff >= 0 {
v := uint32(diff) timestampDelta = &diff
timestampDelta = &v
} }
} }
@@ -68,7 +68,7 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error {
case wc.lastMessageStreamID == nil || timestampDelta == nil || *wc.lastMessageStreamID != msg.MessageStreamID: case wc.lastMessageStreamID == nil || timestampDelta == nil || *wc.lastMessageStreamID != msg.MessageStreamID:
err := wc.writeChunk(&chunk.Chunk0{ err := wc.writeChunk(&chunk.Chunk0{
ChunkStreamID: msg.ChunkStreamID, ChunkStreamID: msg.ChunkStreamID,
Timestamp: msg.Timestamp, Timestamp: uint32(msg.Timestamp / time.Millisecond),
Type: msg.Type, Type: msg.Type,
MessageStreamID: msg.MessageStreamID, MessageStreamID: msg.MessageStreamID,
BodyLen: (bodyLen), BodyLen: (bodyLen),
@@ -81,7 +81,7 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error {
case *wc.lastType != msg.Type || *wc.lastBodyLen != bodyLen: case *wc.lastType != msg.Type || *wc.lastBodyLen != bodyLen:
err := wc.writeChunk(&chunk.Chunk1{ err := wc.writeChunk(&chunk.Chunk1{
ChunkStreamID: msg.ChunkStreamID, ChunkStreamID: msg.ChunkStreamID,
TimestampDelta: *timestampDelta, TimestampDelta: uint32(*timestampDelta / time.Millisecond),
Type: msg.Type, Type: msg.Type,
BodyLen: (bodyLen), BodyLen: (bodyLen),
Body: msg.Body[pos : pos+chunkBodyLen], Body: msg.Body[pos : pos+chunkBodyLen],
@@ -93,7 +93,7 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error {
case wc.lastTimestampDelta == nil || *wc.lastTimestampDelta != *timestampDelta: case wc.lastTimestampDelta == nil || *wc.lastTimestampDelta != *timestampDelta:
err := wc.writeChunk(&chunk.Chunk2{ err := wc.writeChunk(&chunk.Chunk2{
ChunkStreamID: msg.ChunkStreamID, ChunkStreamID: msg.ChunkStreamID,
TimestampDelta: *timestampDelta, TimestampDelta: uint32(*timestampDelta / time.Millisecond),
Body: msg.Body[pos : pos+chunkBodyLen], Body: msg.Body[pos : pos+chunkBodyLen],
}) })
if err != nil { if err != nil {
@@ -143,19 +143,21 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error {
// Writer is a raw message writer. // Writer is a raw message writer.
type Writer struct { type Writer struct {
w *bytecounter.Writer w *bytecounter.Writer
chunkSize uint32 checkAcknowledge bool
ackWindowSize uint32 chunkSize uint32
ackValue uint32 ackWindowSize uint32
chunkStreams map[byte]*writerChunkStream ackValue uint32
chunkStreams map[byte]*writerChunkStream
} }
// NewWriter allocates a Writer. // NewWriter allocates a Writer.
func NewWriter(w *bytecounter.Writer) *Writer { func NewWriter(w *bytecounter.Writer, checkAcknowledge bool) *Writer {
return &Writer{ return &Writer{
w: w, w: w,
chunkSize: 128, checkAcknowledge: checkAcknowledge,
chunkStreams: make(map[byte]*writerChunkStream), chunkSize: 128,
chunkStreams: make(map[byte]*writerChunkStream),
} }
} }

View File

@@ -2,7 +2,9 @@ package rawmessage
import ( import (
"bytes" "bytes"
"reflect"
"testing" "testing"
"time"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter" "github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@@ -10,146 +12,168 @@ import (
) )
func TestWriter(t *testing.T) { func TestWriter(t *testing.T) {
t.Run("chunk0 + chunk1", func(t *testing.T) { for _, ca := range []struct {
var buf bytes.Buffer name string
w := NewWriter(bytecounter.NewWriter(&buf)) messages []*Message
chunks []chunk.Chunk
chunkSizes []uint32
}{
{
"chunk0 + chunk1",
[]*Message{
{
ChunkStreamID: 27,
Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 64),
},
{
ChunkStreamID: 27,
Timestamp: (18576 + 15) * time.Millisecond,
Type: chunk.MessageTypeSetWindowAckSize,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x04}, 64),
},
},
[]chunk.Chunk{
&chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x03}, 64),
},
&chunk.Chunk1{
ChunkStreamID: 27,
TimestampDelta: 15,
Type: chunk.MessageTypeSetWindowAckSize,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x04}, 64),
},
},
[]uint32{
128,
128,
},
},
{
"chunk0 + chunk2 + chunk3",
[]*Message{
{
ChunkStreamID: 27,
Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 64),
},
{
ChunkStreamID: 27,
Timestamp: (18576 + 15) * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x04}, 64),
},
{
ChunkStreamID: 27,
Timestamp: (18576 + 15 + 15) * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x05}, 64),
},
},
[]chunk.Chunk{
&chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x03}, 64),
},
&chunk.Chunk2{
ChunkStreamID: 27,
TimestampDelta: 15,
Body: bytes.Repeat([]byte{0x04}, 64),
},
&chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x05}, 64),
},
},
[]uint32{
128,
64,
64,
},
},
{
"chunk0 + chunk3 + chunk2 + chunk3",
[]*Message{
{
ChunkStreamID: 27,
Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 192),
},
{
ChunkStreamID: 27,
Timestamp: 18591 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x04}, 192),
},
},
[]chunk.Chunk{
&chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 192,
Body: bytes.Repeat([]byte{0x03}, 128),
},
&chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x03}, 64),
},
&chunk.Chunk2{
ChunkStreamID: 27,
TimestampDelta: 15,
Body: bytes.Repeat([]byte{0x04}, 128),
},
&chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x04}, 64),
},
},
[]uint32{
128,
64,
128,
64,
},
},
} {
t.Run(ca.name, func(t *testing.T) {
var buf bytes.Buffer
w := NewWriter(bytecounter.NewWriter(&buf), true)
err := w.Write(&Message{ for _, msg := range ca.messages {
ChunkStreamID: 27, err := w.Write(msg)
Timestamp: 18576, require.NoError(t, err)
Type: chunk.MessageTypeSetPeerBandwidth, }
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 64), for i, cach := range ca.chunks {
ch := reflect.New(reflect.TypeOf(cach).Elem()).Interface().(chunk.Chunk)
err := ch.Read(&buf, ca.chunkSizes[i])
require.NoError(t, err)
require.Equal(t, cach, ch)
}
}) })
require.NoError(t, err) }
var c0 chunk.Chunk0
err = c0.Read(&buf, 128)
require.NoError(t, err)
require.Equal(t, chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x03}, 64),
}, c0)
err = w.Write(&Message{
ChunkStreamID: 27,
Timestamp: 18576 + 15,
Type: chunk.MessageTypeSetWindowAckSize,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x04}, 64),
})
require.NoError(t, err)
var c1 chunk.Chunk1
err = c1.Read(&buf, 128)
require.NoError(t, err)
require.Equal(t, chunk.Chunk1{
ChunkStreamID: 27,
TimestampDelta: 15,
Type: chunk.MessageTypeSetWindowAckSize,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x04}, 64),
}, c1)
})
t.Run("chunk0 + chunk2 + chunk3", func(t *testing.T) {
var buf bytes.Buffer
w := NewWriter(bytecounter.NewWriter(&buf))
err := w.Write(&Message{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 64),
})
require.NoError(t, err)
var c0 chunk.Chunk0
err = c0.Read(&buf, 128)
require.NoError(t, err)
require.Equal(t, chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x03}, 64),
}, c0)
err = w.Write(&Message{
ChunkStreamID: 27,
Timestamp: 18576 + 15,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x04}, 64),
})
require.NoError(t, err)
var c2 chunk.Chunk2
err = c2.Read(&buf, 64)
require.NoError(t, err)
require.Equal(t, chunk.Chunk2{
ChunkStreamID: 27,
TimestampDelta: 15,
Body: bytes.Repeat([]byte{0x04}, 64),
}, c2)
err = w.Write(&Message{
ChunkStreamID: 27,
Timestamp: 18576 + 15 + 15,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x05}, 64),
})
require.NoError(t, err)
var c3 chunk.Chunk3
err = c3.Read(&buf, 64)
require.NoError(t, err)
require.Equal(t, chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x05}, 64),
}, c3)
})
t.Run("chunk0 + chunk3", func(t *testing.T) {
var buf bytes.Buffer
w := NewWriter(bytecounter.NewWriter(&buf))
err := w.Write(&Message{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 192),
})
require.NoError(t, err)
var c0 chunk.Chunk0
err = c0.Read(&buf, 128)
require.NoError(t, err)
require.Equal(t, chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 192,
Body: bytes.Repeat([]byte{0x03}, 128),
}, c0)
var c3 chunk.Chunk3
err = c3.Read(&buf, 64)
require.NoError(t, err)
require.Equal(t, chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x03}, 64),
}, c3)
})
} }
func TestWriterAcknowledge(t *testing.T) { func TestWriterAcknowledge(t *testing.T) {
@@ -157,7 +181,7 @@ func TestWriterAcknowledge(t *testing.T) {
t.Run(ca, func(t *testing.T) { t.Run(ca, func(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
bcw := bytecounter.NewWriter(&buf) bcw := bytecounter.NewWriter(&buf)
w := NewWriter(bcw) w := NewWriter(bcw, true)
if ca == "overflow" { if ca == "overflow" {
bcw.SetCount(4294967096) bcw.SetCount(4294967096)
@@ -169,7 +193,7 @@ func TestWriterAcknowledge(t *testing.T) {
err := w.Write(&Message{ err := w.Write(&Message{
ChunkStreamID: 27, ChunkStreamID: 27,
Timestamp: 18576, Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth, Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123, MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 200), Body: bytes.Repeat([]byte{0x03}, 200),
@@ -178,7 +202,7 @@ func TestWriterAcknowledge(t *testing.T) {
err = w.Write(&Message{ err = w.Write(&Message{
ChunkStreamID: 27, ChunkStreamID: 27,
Timestamp: 18576, Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth, Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123, MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 200), Body: bytes.Repeat([]byte{0x03}, 200),