mirror of
https://github.com/aler9/rtsp-simple-server
synced 2025-10-24 16:20:23 +08:00
rtmp: implement acknowledge mechanism
This commit is contained in:
37
internal/rtmp/bytecounter/reader.go
Normal file
37
internal/rtmp/bytecounter/reader.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package bytecounter
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
)
|
||||
|
||||
type readerInner struct {
|
||||
r io.Reader
|
||||
count uint32
|
||||
}
|
||||
|
||||
func (r *readerInner) Read(p []byte) (int, error) {
|
||||
n, err := r.r.Read(p)
|
||||
r.count += uint32(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Reader allows to count read bytes.
|
||||
type Reader struct {
|
||||
ri *readerInner
|
||||
*bufio.Reader
|
||||
}
|
||||
|
||||
// NewReader allocates a Reader.
|
||||
func NewReader(r io.Reader) *Reader {
|
||||
ri := &readerInner{r: r}
|
||||
return &Reader{
|
||||
ri: ri,
|
||||
Reader: bufio.NewReader(ri),
|
||||
}
|
||||
}
|
||||
|
||||
// Count returns read bytes.
|
||||
func (r Reader) Count() uint32 {
|
||||
return r.ri.count
|
||||
}
|
||||
19
internal/rtmp/bytecounter/readwriter.go
Normal file
19
internal/rtmp/bytecounter/readwriter.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package bytecounter
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
// ReadWriter allows to count read and written bytes.
|
||||
type ReadWriter struct {
|
||||
*Reader
|
||||
*Writer
|
||||
}
|
||||
|
||||
// NewReadWriter allocates a ReadWriter.
|
||||
func NewReadWriter(rw io.ReadWriter) *ReadWriter {
|
||||
return &ReadWriter{
|
||||
Reader: NewReader(rw),
|
||||
Writer: NewWriter(rw),
|
||||
}
|
||||
}
|
||||
30
internal/rtmp/bytecounter/writer.go
Normal file
30
internal/rtmp/bytecounter/writer.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package bytecounter
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
// Writer allows to count written bytes.
|
||||
type Writer struct {
|
||||
w io.Writer
|
||||
count uint32
|
||||
}
|
||||
|
||||
// NewWriter allocates a Writer.
|
||||
func NewWriter(w io.Writer) *Writer {
|
||||
return &Writer{
|
||||
w: w,
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
func (w *Writer) Write(p []byte) (int, error) {
|
||||
n, err := w.w.Write(p)
|
||||
w.count += uint32(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Count returns written bytes.
|
||||
func (w Writer) Count() uint32 {
|
||||
return w.count
|
||||
}
|
||||
11
internal/rtmp/chunk/chunk.go
Normal file
11
internal/rtmp/chunk/chunk.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package chunk
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
// Chunk is a chunk.
|
||||
type Chunk interface {
|
||||
Read(io.Reader, uint32) error
|
||||
Write() ([]byte, error)
|
||||
}
|
||||
@@ -18,7 +18,7 @@ type Chunk0 struct {
|
||||
}
|
||||
|
||||
// Read reads the chunk.
|
||||
func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen int) error {
|
||||
func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen uint32) error {
|
||||
header := make([]byte, 12)
|
||||
_, err := r.Read(header)
|
||||
if err != nil {
|
||||
@@ -31,7 +31,7 @@ func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen int) error {
|
||||
c.Type = MessageType(header[7])
|
||||
c.MessageStreamID = uint32(header[8])<<24 | uint32(header[9])<<16 | uint32(header[10])<<8 | uint32(header[11])
|
||||
|
||||
chunkBodyLen := int(c.BodyLen)
|
||||
chunkBodyLen := c.BodyLen
|
||||
if chunkBodyLen > chunkMaxBodyLen {
|
||||
chunkBodyLen = chunkMaxBodyLen
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ type Chunk1 struct {
|
||||
}
|
||||
|
||||
// Read reads the chunk.
|
||||
func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen int) error {
|
||||
func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen uint32) error {
|
||||
header := make([]byte, 8)
|
||||
_, err := r.Read(header)
|
||||
if err != nil {
|
||||
@@ -31,7 +31,7 @@ func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen int) error {
|
||||
c.BodyLen = uint32(header[4])<<16 | uint32(header[5])<<8 | uint32(header[6])
|
||||
c.Type = MessageType(header[7])
|
||||
|
||||
chunkBodyLen := int(c.BodyLen)
|
||||
chunkBodyLen := (c.BodyLen)
|
||||
if chunkBodyLen > chunkMaxBodyLen {
|
||||
chunkBodyLen = chunkMaxBodyLen
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ type Chunk2 struct {
|
||||
}
|
||||
|
||||
// Read reads the chunk.
|
||||
func (c *Chunk2) Read(r io.Reader, chunkBodyLen int) error {
|
||||
func (c *Chunk2) Read(r io.Reader, chunkBodyLen uint32) error {
|
||||
header := make([]byte, 4)
|
||||
_, err := r.Read(header)
|
||||
if err != nil {
|
||||
|
||||
@@ -16,7 +16,7 @@ type Chunk3 struct {
|
||||
}
|
||||
|
||||
// Read reads the chunk.
|
||||
func (c *Chunk3) Read(r io.Reader, chunkBodyLen int) error {
|
||||
func (c *Chunk3) Read(r io.Reader, chunkBodyLen uint32) error {
|
||||
header := make([]byte, 1)
|
||||
_, err := r.Read(header)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package rtmp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -13,6 +12,7 @@ import (
|
||||
"github.com/notedit/rtmp/format/flv/flvio"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
|
||||
"github.com/aler9/rtsp-simple-server/internal/rtmp/handshake"
|
||||
"github.com/aler9/rtsp-simple-server/internal/rtmp/message"
|
||||
)
|
||||
@@ -114,7 +114,7 @@ func TestReadTracks(t *testing.T) {
|
||||
conn, err := net.Dial("tcp", "127.0.0.1:9121")
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
br := bufio.NewReader(conn)
|
||||
bc := bytecounter.NewReadWriter(conn)
|
||||
|
||||
// C->S handshake C0
|
||||
err = handshake.C0S0{}.Write(conn)
|
||||
@@ -126,27 +126,26 @@ func TestReadTracks(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// S->C handshake S0
|
||||
err = handshake.C0S0{}.Read(br)
|
||||
err = handshake.C0S0{}.Read(bc)
|
||||
require.NoError(t, err)
|
||||
|
||||
// S->C handshake S1
|
||||
s1 := handshake.C1S1{}
|
||||
err = s1.Read(br, false)
|
||||
err = s1.Read(bc, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// S->C handshake S2
|
||||
err = (&handshake.C2S2{Digest: c1.Digest}).Read(br)
|
||||
err = (&handshake.C2S2{Digest: c1.Digest}).Read(bc)
|
||||
require.NoError(t, err)
|
||||
|
||||
// C->S handshake C2
|
||||
err = handshake.C2S2{Digest: s1.Digest}.Write(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
mw := message.NewWriter(conn)
|
||||
mr := message.NewReader(br)
|
||||
mrw := message.NewReadWriter(bc)
|
||||
|
||||
// C->S connect
|
||||
err = mw.Write(&message.MsgCommandAMF0{
|
||||
err = mrw.Write(&message.MsgCommandAMF0{
|
||||
ChunkStreamID: 3,
|
||||
Payload: []interface{}{
|
||||
"connect",
|
||||
@@ -166,14 +165,14 @@ func TestReadTracks(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// S->C window acknowledgement size
|
||||
msg, err := mr.Read()
|
||||
msg, err := mrw.Read()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &message.MsgSetWindowAckSize{
|
||||
Value: 2500000,
|
||||
}, msg)
|
||||
|
||||
// S->C set peer bandwidth
|
||||
msg, err = mr.Read()
|
||||
msg, err = mrw.Read()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &message.MsgSetPeerBandwidth{
|
||||
Value: 2500000,
|
||||
@@ -181,16 +180,14 @@ func TestReadTracks(t *testing.T) {
|
||||
}, msg)
|
||||
|
||||
// S->C set chunk size
|
||||
msg, err = mr.Read()
|
||||
msg, err = mrw.Read()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &message.MsgSetChunkSize{
|
||||
Value: 65536,
|
||||
}, msg)
|
||||
|
||||
mr.SetChunkSize(65536)
|
||||
|
||||
// S->C result
|
||||
msg, err = mr.Read()
|
||||
msg, err = mrw.Read()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &message.MsgCommandAMF0{
|
||||
ChunkStreamID: 3,
|
||||
@@ -211,15 +208,13 @@ func TestReadTracks(t *testing.T) {
|
||||
}, msg)
|
||||
|
||||
// C->S set chunk size
|
||||
err = mw.Write(&message.MsgSetChunkSize{
|
||||
err = mrw.Write(&message.MsgSetChunkSize{
|
||||
Value: 65536,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
mw.SetChunkSize(65536)
|
||||
|
||||
// C->S releaseStream
|
||||
err = mw.Write(&message.MsgCommandAMF0{
|
||||
err = mrw.Write(&message.MsgCommandAMF0{
|
||||
ChunkStreamID: 3,
|
||||
Payload: []interface{}{
|
||||
"releaseStream",
|
||||
@@ -231,7 +226,7 @@ func TestReadTracks(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// C->S FCPublish
|
||||
err = mw.Write(&message.MsgCommandAMF0{
|
||||
err = mrw.Write(&message.MsgCommandAMF0{
|
||||
ChunkStreamID: 3,
|
||||
Payload: []interface{}{
|
||||
"FCPublish",
|
||||
@@ -243,7 +238,7 @@ func TestReadTracks(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// C->S createStream
|
||||
err = mw.Write(&message.MsgCommandAMF0{
|
||||
err = mrw.Write(&message.MsgCommandAMF0{
|
||||
ChunkStreamID: 3,
|
||||
Payload: []interface{}{
|
||||
"createStream",
|
||||
@@ -254,7 +249,7 @@ func TestReadTracks(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// S->C result
|
||||
msg, err = mr.Read()
|
||||
msg, err = mrw.Read()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &message.MsgCommandAMF0{
|
||||
ChunkStreamID: 3,
|
||||
@@ -267,7 +262,7 @@ func TestReadTracks(t *testing.T) {
|
||||
}, msg)
|
||||
|
||||
// C->S publish
|
||||
err = mw.Write(&message.MsgCommandAMF0{
|
||||
err = mrw.Write(&message.MsgCommandAMF0{
|
||||
ChunkStreamID: 8,
|
||||
MessageStreamID: 1,
|
||||
Payload: []interface{}{
|
||||
@@ -281,7 +276,7 @@ func TestReadTracks(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// S->C onStatus
|
||||
msg, err = mr.Read()
|
||||
msg, err = mrw.Read()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &message.MsgCommandAMF0{
|
||||
ChunkStreamID: 5,
|
||||
@@ -301,7 +296,7 @@ func TestReadTracks(t *testing.T) {
|
||||
switch ca {
|
||||
case "standard":
|
||||
// C->S metadata
|
||||
err = mw.Write(&message.MsgDataAMF0{
|
||||
err = mrw.Write(&message.MsgDataAMF0{
|
||||
ChunkStreamID: 4,
|
||||
MessageStreamID: 1,
|
||||
Payload: []interface{}{
|
||||
@@ -341,7 +336,7 @@ func TestReadTracks(t *testing.T) {
|
||||
b := make([]byte, 128)
|
||||
var n int
|
||||
codec.ToConfig(b, &n)
|
||||
err = mw.Write(&message.MsgVideo{
|
||||
err = mrw.Write(&message.MsgVideo{
|
||||
ChunkStreamID: 6,
|
||||
MessageStreamID: 1,
|
||||
IsKeyFrame: true,
|
||||
@@ -357,7 +352,7 @@ func TestReadTracks(t *testing.T) {
|
||||
ChannelCount: 2,
|
||||
}.Encode()
|
||||
require.NoError(t, err)
|
||||
err = mw.Write(&message.MsgAudio{
|
||||
err = mrw.Write(&message.MsgAudio{
|
||||
ChunkStreamID: 4,
|
||||
MessageStreamID: 1,
|
||||
Rate: flvio.SOUND_44Khz,
|
||||
@@ -370,7 +365,7 @@ func TestReadTracks(t *testing.T) {
|
||||
|
||||
case "metadata without codec id":
|
||||
// C->S metadata
|
||||
err = mw.Write(&message.MsgDataAMF0{
|
||||
err = mrw.Write(&message.MsgDataAMF0{
|
||||
ChunkStreamID: 4,
|
||||
MessageStreamID: 1,
|
||||
Payload: []interface{}{
|
||||
@@ -406,7 +401,7 @@ func TestReadTracks(t *testing.T) {
|
||||
b := make([]byte, 128)
|
||||
var n int
|
||||
codec.ToConfig(b, &n)
|
||||
err = mw.Write(&message.MsgVideo{
|
||||
err = mrw.Write(&message.MsgVideo{
|
||||
ChunkStreamID: 6,
|
||||
MessageStreamID: 1,
|
||||
IsKeyFrame: true,
|
||||
@@ -428,7 +423,7 @@ func TestReadTracks(t *testing.T) {
|
||||
b := make([]byte, 128)
|
||||
var n int
|
||||
codec.ToConfig(b, &n)
|
||||
err = mw.Write(&message.MsgVideo{
|
||||
err = mrw.Write(&message.MsgVideo{
|
||||
ChunkStreamID: 6,
|
||||
MessageStreamID: 1,
|
||||
IsKeyFrame: true,
|
||||
@@ -479,7 +474,7 @@ func TestWriteTracks(t *testing.T) {
|
||||
conn, err := net.Dial("tcp", "127.0.0.1:9121")
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
br := bufio.NewReader(conn)
|
||||
bc := bytecounter.NewReadWriter(conn)
|
||||
|
||||
// C->S handshake C0
|
||||
err = handshake.C0S0{}.Write(conn)
|
||||
@@ -491,27 +486,26 @@ func TestWriteTracks(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// S->C handshake S0
|
||||
err = handshake.C0S0{}.Read(br)
|
||||
err = handshake.C0S0{}.Read(bc)
|
||||
require.NoError(t, err)
|
||||
|
||||
// S->C handshake S1
|
||||
s1 := handshake.C1S1{}
|
||||
err = s1.Read(br, false)
|
||||
err = s1.Read(bc, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// S->C handshake S2
|
||||
err = (&handshake.C2S2{Digest: c1.Digest}).Read(br)
|
||||
err = (&handshake.C2S2{Digest: c1.Digest}).Read(bc)
|
||||
require.NoError(t, err)
|
||||
|
||||
// C->S handshake C2
|
||||
err = handshake.C2S2{Digest: s1.Digest}.Write(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
mw := message.NewWriter(conn)
|
||||
mr := message.NewReader(br)
|
||||
mrw := message.NewReadWriter(bc)
|
||||
|
||||
// C->S connect
|
||||
err = mw.Write(&message.MsgCommandAMF0{
|
||||
err = mrw.Write(&message.MsgCommandAMF0{
|
||||
ChunkStreamID: 3,
|
||||
Payload: []interface{}{
|
||||
"connect",
|
||||
@@ -531,14 +525,14 @@ func TestWriteTracks(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// S->C window acknowledgement size
|
||||
msg, err := mr.Read()
|
||||
msg, err := mrw.Read()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &message.MsgSetWindowAckSize{
|
||||
Value: 2500000,
|
||||
}, msg)
|
||||
|
||||
// S->C set peer bandwidth
|
||||
msg, err = mr.Read()
|
||||
msg, err = mrw.Read()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &message.MsgSetPeerBandwidth{
|
||||
Value: 2500000,
|
||||
@@ -546,16 +540,14 @@ func TestWriteTracks(t *testing.T) {
|
||||
}, msg)
|
||||
|
||||
// S->C set chunk size
|
||||
msg, err = mr.Read()
|
||||
msg, err = mrw.Read()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &message.MsgSetChunkSize{
|
||||
Value: 65536,
|
||||
}, msg)
|
||||
|
||||
mr.SetChunkSize(65536)
|
||||
|
||||
// S->C result
|
||||
msg, err = mr.Read()
|
||||
msg, err = mrw.Read()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &message.MsgCommandAMF0{
|
||||
ChunkStreamID: 3,
|
||||
@@ -576,21 +568,19 @@ func TestWriteTracks(t *testing.T) {
|
||||
}, msg)
|
||||
|
||||
// C->S window acknowledgement size
|
||||
err = mw.Write(&message.MsgSetWindowAckSize{
|
||||
err = mrw.Write(&message.MsgSetWindowAckSize{
|
||||
Value: 2500000,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// C->S set chunk size
|
||||
err = mw.Write(&message.MsgSetChunkSize{
|
||||
err = mrw.Write(&message.MsgSetChunkSize{
|
||||
Value: 65536,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
mw.SetChunkSize(65536)
|
||||
|
||||
// C->S createStream
|
||||
err = mw.Write(&message.MsgCommandAMF0{
|
||||
err = mrw.Write(&message.MsgCommandAMF0{
|
||||
ChunkStreamID: 3,
|
||||
Payload: []interface{}{
|
||||
"createStream",
|
||||
@@ -601,7 +591,7 @@ func TestWriteTracks(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// S->C result
|
||||
msg, err = mr.Read()
|
||||
msg, err = mrw.Read()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &message.MsgCommandAMF0{
|
||||
ChunkStreamID: 3,
|
||||
@@ -614,7 +604,7 @@ func TestWriteTracks(t *testing.T) {
|
||||
}, msg)
|
||||
|
||||
// C->S getStreamLength
|
||||
err = mw.Write(&message.MsgCommandAMF0{
|
||||
err = mrw.Write(&message.MsgCommandAMF0{
|
||||
ChunkStreamID: 8,
|
||||
Payload: []interface{}{
|
||||
"getStreamLength",
|
||||
@@ -626,7 +616,7 @@ func TestWriteTracks(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// C->S play
|
||||
err = mw.Write(&message.MsgCommandAMF0{
|
||||
err = mrw.Write(&message.MsgCommandAMF0{
|
||||
ChunkStreamID: 8,
|
||||
Payload: []interface{}{
|
||||
"play",
|
||||
@@ -639,21 +629,21 @@ func TestWriteTracks(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// S->C event "stream is recorded"
|
||||
msg, err = mr.Read()
|
||||
msg, err = mrw.Read()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &message.MsgUserControlStreamIsRecorded{
|
||||
StreamID: 1,
|
||||
}, msg)
|
||||
|
||||
// S->C event "stream begin 1"
|
||||
msg, err = mr.Read()
|
||||
msg, err = mrw.Read()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &message.MsgUserControlStreamBegin{
|
||||
StreamID: 1,
|
||||
}, msg)
|
||||
|
||||
// S->C onStatus
|
||||
msg, err = mr.Read()
|
||||
msg, err = mrw.Read()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &message.MsgCommandAMF0{
|
||||
ChunkStreamID: 5,
|
||||
@@ -671,7 +661,7 @@ func TestWriteTracks(t *testing.T) {
|
||||
}, msg)
|
||||
|
||||
// S->C onStatus
|
||||
msg, err = mr.Read()
|
||||
msg, err = mrw.Read()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &message.MsgCommandAMF0{
|
||||
ChunkStreamID: 5,
|
||||
@@ -689,7 +679,7 @@ func TestWriteTracks(t *testing.T) {
|
||||
}, msg)
|
||||
|
||||
// S->C onStatus
|
||||
msg, err = mr.Read()
|
||||
msg, err = mrw.Read()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &message.MsgCommandAMF0{
|
||||
ChunkStreamID: 5,
|
||||
@@ -707,7 +697,7 @@ func TestWriteTracks(t *testing.T) {
|
||||
}, msg)
|
||||
|
||||
// S->C onStatus
|
||||
msg, err = mr.Read()
|
||||
msg, err = mrw.Read()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &message.MsgCommandAMF0{
|
||||
ChunkStreamID: 5,
|
||||
@@ -725,7 +715,7 @@ func TestWriteTracks(t *testing.T) {
|
||||
}, msg)
|
||||
|
||||
// S->C onMetadata
|
||||
msg, err = mr.Read()
|
||||
msg, err = mrw.Read()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &message.MsgDataAMF0{
|
||||
ChunkStreamID: 4,
|
||||
@@ -742,7 +732,7 @@ func TestWriteTracks(t *testing.T) {
|
||||
}, msg)
|
||||
|
||||
// S->C H264 decoder config
|
||||
msg, err = mr.Read()
|
||||
msg, err = mrw.Read()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &message.MsgVideo{
|
||||
ChunkStreamID: 6,
|
||||
@@ -760,7 +750,7 @@ func TestWriteTracks(t *testing.T) {
|
||||
}, msg)
|
||||
|
||||
// S->C AAC decoder config
|
||||
msg, err = mr.Read()
|
||||
msg, err = mrw.Read()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &message.MsgAudio{
|
||||
ChunkStreamID: 4,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
@@ -14,7 +13,7 @@ const (
|
||||
type C0S0 struct{}
|
||||
|
||||
// Read reads a C0S0.
|
||||
func (C0S0) Read(r *bufio.Reader) error {
|
||||
func (C0S0) Read(r io.Reader) error {
|
||||
buf := make([]byte, 1)
|
||||
_, err := io.ReadFull(r, buf)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
@@ -14,7 +13,7 @@ var c0s0dec = C0S0{}
|
||||
|
||||
func TestC0S0Read(t *testing.T) {
|
||||
var c0s0 C0S0
|
||||
err := c0s0.Read(bufio.NewReader(bytes.NewReader(c0s0enc)))
|
||||
err := c0s0.Read((bytes.NewReader(c0s0enc)))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, c0s0dec, c0s0)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
@@ -79,7 +78,7 @@ type C1S1 struct {
|
||||
}
|
||||
|
||||
// Read reads a C1S1.
|
||||
func (c *C1S1) Read(r *bufio.Reader, isC1 bool) error {
|
||||
func (c *C1S1) Read(r io.Reader, isC1 bool) error {
|
||||
buf := make([]byte, 1536)
|
||||
_, err := io.ReadFull(r, buf)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
@@ -44,7 +43,7 @@ func TestC1S1Read(t *testing.T) {
|
||||
)
|
||||
|
||||
var c1s1 C1S1
|
||||
err := c1s1.Read(bufio.NewReader(bytes.NewReader(c1s1enc)), true)
|
||||
err := c1s1.Read((bytes.NewReader(c1s1enc)), true)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, c1s1dec, c1s1)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
@@ -18,7 +17,7 @@ type C2S2 struct {
|
||||
}
|
||||
|
||||
// Read reads a C2S2.
|
||||
func (c *C2S2) Read(r *bufio.Reader) error {
|
||||
func (c *C2S2) Read(r io.Reader) error {
|
||||
buf := make([]byte, 1536)
|
||||
_, err := io.ReadFull(r, buf)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
@@ -43,7 +42,7 @@ func TestC2S2Read(t *testing.T) {
|
||||
|
||||
var c2s2 C2S2
|
||||
c2s2.Digest = c2s2dec.Digest
|
||||
err := c2s2.Read(bufio.NewReader(bytes.NewReader(c2s2enc)))
|
||||
err := c2s2.Read((bytes.NewReader(c2s2enc)))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, c2s2dec, c2s2)
|
||||
}
|
||||
|
||||
40
internal/rtmp/message/msg_acknowledge.go
Normal file
40
internal/rtmp/message/msg_acknowledge.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
|
||||
"github.com/aler9/rtsp-simple-server/internal/rtmp/rawmessage"
|
||||
)
|
||||
|
||||
// MsgAcknowledge is an acknowledgement message.
|
||||
type MsgAcknowledge struct {
|
||||
Value uint32
|
||||
}
|
||||
|
||||
// Unmarshal implements Message.
|
||||
func (m *MsgAcknowledge) Unmarshal(raw *rawmessage.Message) error {
|
||||
if raw.ChunkStreamID != ControlChunkStreamID {
|
||||
return fmt.Errorf("unexpected chunk stream ID")
|
||||
}
|
||||
|
||||
if len(raw.Body) != 4 {
|
||||
return fmt.Errorf("unexpected body size")
|
||||
}
|
||||
|
||||
m.Value = binary.BigEndian.Uint32(raw.Body)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Marshal implements Message.
|
||||
func (m *MsgAcknowledge) Marshal() (*rawmessage.Message, error) {
|
||||
body := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(body, m.Value)
|
||||
|
||||
return &rawmessage.Message{
|
||||
ChunkStreamID: ControlChunkStreamID,
|
||||
Type: chunk.MessageTypeAcknowledge,
|
||||
Body: body,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
|
||||
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
|
||||
"github.com/aler9/rtsp-simple-server/internal/rtmp/rawmessage"
|
||||
)
|
||||
@@ -14,6 +14,9 @@ func allocateMessage(raw *rawmessage.Message) (Message, error) {
|
||||
case chunk.MessageTypeSetChunkSize:
|
||||
return &MsgSetChunkSize{}, nil
|
||||
|
||||
case chunk.MessageTypeAcknowledge:
|
||||
return &MsgAcknowledge{}, nil
|
||||
|
||||
case chunk.MessageTypeSetWindowAckSize:
|
||||
return &MsgSetWindowAckSize{}, nil
|
||||
|
||||
@@ -75,18 +78,13 @@ type Reader struct {
|
||||
}
|
||||
|
||||
// NewReader allocates a Reader.
|
||||
func NewReader(r *bufio.Reader) *Reader {
|
||||
func NewReader(r *bytecounter.Reader, onAckNeeded func(uint32) error) *Reader {
|
||||
return &Reader{
|
||||
r: rawmessage.NewReader(r),
|
||||
r: rawmessage.NewReader(r, onAckNeeded),
|
||||
}
|
||||
}
|
||||
|
||||
// SetChunkSize sets the maximum chunk size.
|
||||
func (r *Reader) SetChunkSize(v int) {
|
||||
r.r.SetChunkSize(v)
|
||||
}
|
||||
|
||||
// Read reads a essage.
|
||||
// Read reads a Message.
|
||||
func (r *Reader) Read() (Message, error) {
|
||||
raw, err := r.r.Read()
|
||||
if err != nil {
|
||||
@@ -103,5 +101,13 @@ func (r *Reader) Read() (Message, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch tmsg := msg.(type) {
|
||||
case *MsgSetChunkSize:
|
||||
r.r.SetChunkSize(tmsg.Value)
|
||||
|
||||
case *MsgSetWindowAckSize:
|
||||
r.r.SetWindowAckSize(tmsg.Value)
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
46
internal/rtmp/message/readwriter.go
Normal file
46
internal/rtmp/message/readwriter.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
|
||||
)
|
||||
|
||||
// ReadWriter is a message reader/writer.
|
||||
type ReadWriter struct {
|
||||
r *Reader
|
||||
w *Writer
|
||||
}
|
||||
|
||||
// NewReadWriter allocates a ReadWriter.
|
||||
func NewReadWriter(bc *bytecounter.ReadWriter) *ReadWriter {
|
||||
w := NewWriter(bc.Writer)
|
||||
|
||||
r := NewReader(bc.Reader, func(count uint32) error {
|
||||
return w.Write(&MsgAcknowledge{
|
||||
Value: (count),
|
||||
})
|
||||
})
|
||||
|
||||
return &ReadWriter{
|
||||
r: r,
|
||||
w: w,
|
||||
}
|
||||
}
|
||||
|
||||
// Read reads a message.
|
||||
func (rw *ReadWriter) Read() (Message, error) {
|
||||
msg, err := rw.r.Read()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tmsg, ok := msg.(*MsgAcknowledge); ok {
|
||||
rw.w.SetAcknowledgeValue(tmsg.Value)
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// Write writes a message.
|
||||
func (rw *ReadWriter) Write(msg Message) error {
|
||||
return rw.w.Write(msg)
|
||||
}
|
||||
@@ -1,8 +1,7 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
|
||||
"github.com/aler9/rtsp-simple-server/internal/rtmp/rawmessage"
|
||||
)
|
||||
|
||||
@@ -12,23 +11,36 @@ type Writer struct {
|
||||
}
|
||||
|
||||
// NewWriter allocates a Writer.
|
||||
func NewWriter(w io.Writer) *Writer {
|
||||
func NewWriter(w *bytecounter.Writer) *Writer {
|
||||
return &Writer{
|
||||
w: rawmessage.NewWriter(w),
|
||||
}
|
||||
}
|
||||
|
||||
// SetChunkSize sets the maximum chunk size.
|
||||
func (mw *Writer) SetChunkSize(v int) {
|
||||
mw.w.SetChunkSize(v)
|
||||
// SetAcknowledgeValue sets the value of the last received acknowledge.
|
||||
func (w *Writer) SetAcknowledgeValue(v uint32) {
|
||||
w.w.SetAcknowledgeValue(v)
|
||||
}
|
||||
|
||||
// Write writes a message.
|
||||
func (mw *Writer) Write(msg Message) error {
|
||||
func (w *Writer) Write(msg Message) error {
|
||||
raw, err := msg.Marshal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return mw.w.Write(raw)
|
||||
err = w.w.Write(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch tmsg := msg.(type) {
|
||||
case *MsgSetChunkSize:
|
||||
w.w.SetChunkSize(tmsg.Value)
|
||||
|
||||
case *MsgSetWindowAckSize:
|
||||
w.w.SetWindowAckSize(tmsg.Value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package rawmessage
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
|
||||
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
|
||||
)
|
||||
|
||||
@@ -20,7 +20,32 @@ type readerChunkStream struct {
|
||||
curTimestampDelta *uint32
|
||||
}
|
||||
|
||||
func (rc *readerChunkStream) read(typ byte) (*Message, error) {
|
||||
func (rc *readerChunkStream) readChunk(c chunk.Chunk, chunkBodySize uint32) error {
|
||||
err := c.Read(rc.mr.r, chunkBodySize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check if an ack is needed
|
||||
if rc.mr.ackWindowSize != 0 {
|
||||
count := rc.mr.r.Count()
|
||||
diff := count - rc.mr.lastAckCount
|
||||
// TODO: handle overflow
|
||||
|
||||
if diff > (rc.mr.ackWindowSize) {
|
||||
err := rc.mr.onAckNeeded(count)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rc.mr.lastAckCount += (rc.mr.ackWindowSize)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
|
||||
switch typ {
|
||||
case 0:
|
||||
if rc.curBody != nil {
|
||||
@@ -28,7 +53,7 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) {
|
||||
}
|
||||
|
||||
var c0 chunk.Chunk0
|
||||
err := c0.Read(rc.mr.r, rc.mr.chunkSize)
|
||||
err := rc.readChunk(&c0, rc.mr.chunkSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -65,7 +90,7 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) {
|
||||
}
|
||||
|
||||
var c1 chunk.Chunk1
|
||||
err := c1.Read(rc.mr.r, rc.mr.chunkSize)
|
||||
err := rc.readChunk(&c1, rc.mr.chunkSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -100,13 +125,13 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) {
|
||||
return nil, fmt.Errorf("received type 2 chunk but expected type 3 chunk")
|
||||
}
|
||||
|
||||
chunkBodyLen := int(*rc.curBodyLen)
|
||||
chunkBodyLen := (*rc.curBodyLen)
|
||||
if chunkBodyLen > rc.mr.chunkSize {
|
||||
chunkBodyLen = rc.mr.chunkSize
|
||||
}
|
||||
|
||||
var c2 chunk.Chunk2
|
||||
err := c2.Read(rc.mr.r, chunkBodyLen)
|
||||
err := rc.readChunk(&c2, chunkBodyLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -116,7 +141,7 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) {
|
||||
v2 := c2.TimestampDelta
|
||||
rc.curTimestampDelta = &v2
|
||||
|
||||
if chunkBodyLen != len(c2.Body) {
|
||||
if chunkBodyLen != uint32(len(c2.Body)) {
|
||||
rc.curBody = &c2.Body
|
||||
return nil, errMoreChunksNeeded
|
||||
}
|
||||
@@ -134,13 +159,13 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) {
|
||||
}
|
||||
|
||||
if rc.curBody != nil {
|
||||
chunkBodyLen := int(*rc.curBodyLen) - len(*rc.curBody)
|
||||
chunkBodyLen := (*rc.curBodyLen) - uint32(len(*rc.curBody))
|
||||
if chunkBodyLen > rc.mr.chunkSize {
|
||||
chunkBodyLen = rc.mr.chunkSize
|
||||
}
|
||||
|
||||
var c3 chunk.Chunk3
|
||||
err := c3.Read(rc.mr.r, chunkBodyLen)
|
||||
err := rc.readChunk(&c3, chunkBodyLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -162,13 +187,13 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
chunkBodyLen := int(*rc.curBodyLen)
|
||||
chunkBodyLen := (*rc.curBodyLen)
|
||||
if chunkBodyLen > rc.mr.chunkSize {
|
||||
chunkBodyLen = rc.mr.chunkSize
|
||||
}
|
||||
|
||||
var c3 chunk.Chunk3
|
||||
err := c3.Read(rc.mr.r, chunkBodyLen)
|
||||
err := rc.readChunk(&c3, chunkBodyLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -187,25 +212,35 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) {
|
||||
|
||||
// Reader is a raw message reader.
|
||||
type Reader struct {
|
||||
r *bufio.Reader
|
||||
chunkSize int
|
||||
r *bytecounter.Reader
|
||||
onAckNeeded func(uint32) error
|
||||
|
||||
chunkSize uint32
|
||||
ackWindowSize uint32
|
||||
lastAckCount uint32
|
||||
chunkStreams map[byte]*readerChunkStream
|
||||
}
|
||||
|
||||
// NewReader allocates a Reader.
|
||||
func NewReader(r *bufio.Reader) *Reader {
|
||||
func NewReader(r *bytecounter.Reader, onAckNeeded func(uint32) error) *Reader {
|
||||
return &Reader{
|
||||
r: r,
|
||||
onAckNeeded: onAckNeeded,
|
||||
chunkSize: 128,
|
||||
chunkStreams: make(map[byte]*readerChunkStream),
|
||||
}
|
||||
}
|
||||
|
||||
// SetChunkSize sets the maximum chunk size.
|
||||
func (r *Reader) SetChunkSize(v int) {
|
||||
func (r *Reader) SetChunkSize(v uint32) {
|
||||
r.chunkSize = v
|
||||
}
|
||||
|
||||
// SetWindowAckSize sets the window acknowledgement size.
|
||||
func (r *Reader) SetWindowAckSize(v uint32) {
|
||||
r.ackWindowSize = v
|
||||
}
|
||||
|
||||
// Read reads a Message.
|
||||
func (r *Reader) Read() (*Message, error) {
|
||||
for {
|
||||
@@ -225,7 +260,7 @@ func (r *Reader) Read() (*Message, error) {
|
||||
|
||||
r.r.UnreadByte()
|
||||
|
||||
msg, err := rc.read(typ)
|
||||
msg, err := rc.readMessage(typ)
|
||||
if err != nil {
|
||||
if err == errMoreChunksNeeded {
|
||||
continue
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package rawmessage
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
|
||||
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
|
||||
)
|
||||
|
||||
type writableChunk interface {
|
||||
@@ -21,7 +22,10 @@ type sequenceEntry struct {
|
||||
func TestReader(t *testing.T) {
|
||||
testSequence := func(t *testing.T, seq []sequenceEntry) {
|
||||
var buf bytes.Buffer
|
||||
r := NewReader(bufio.NewReader(&buf))
|
||||
bcr := bytecounter.NewReader(&buf)
|
||||
r := NewReader(bcr, func(count uint32) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
for _, entry := range seq {
|
||||
buf2, err := entry.chunk.Write()
|
||||
@@ -122,7 +126,10 @@ func TestReader(t *testing.T) {
|
||||
|
||||
t.Run("chunk0 + chunk3", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
r := NewReader(bufio.NewReader(&buf))
|
||||
bcr := bytecounter.NewReader(&buf)
|
||||
r := NewReader(bcr, func(count uint32) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
buf2, err := chunk.Chunk0{
|
||||
ChunkStreamID: 27,
|
||||
@@ -153,3 +160,36 @@ func TestReader(t *testing.T) {
|
||||
}, msg)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReaderAcknowledge(t *testing.T) {
|
||||
onAckCalled := make(chan struct{})
|
||||
|
||||
var buf bytes.Buffer
|
||||
bcr := bytecounter.NewReader(&buf)
|
||||
r := NewReader(bcr, func(count uint32) error {
|
||||
close(onAckCalled)
|
||||
return nil
|
||||
})
|
||||
|
||||
r.SetWindowAckSize(100)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
buf2, err := chunk.Chunk0{
|
||||
ChunkStreamID: 27,
|
||||
Timestamp: 18576,
|
||||
Type: chunk.MessageTypeSetPeerBandwidth,
|
||||
MessageStreamID: 3123,
|
||||
BodyLen: 64,
|
||||
Body: bytes.Repeat([]byte{0x03}, 64),
|
||||
}.Write()
|
||||
require.NoError(t, err)
|
||||
buf.Write(buf2)
|
||||
}
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
_, err := r.Read()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
<-onAckCalled
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
package rawmessage
|
||||
|
||||
import (
|
||||
"io"
|
||||
"fmt"
|
||||
|
||||
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
|
||||
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
|
||||
)
|
||||
|
||||
@@ -10,14 +11,38 @@ type writerChunkStream struct {
|
||||
mw *Writer
|
||||
lastMessageStreamID *uint32
|
||||
lastType *chunk.MessageType
|
||||
lastBodyLen *int
|
||||
lastBodyLen *uint32
|
||||
lastTimestamp *uint32
|
||||
lastTimestampDelta *uint32
|
||||
}
|
||||
|
||||
func (wc *writerChunkStream) write(msg *Message) error {
|
||||
bodyLen := len(msg.Body)
|
||||
pos := 0
|
||||
func (wc *writerChunkStream) writeChunk(c chunk.Chunk) error {
|
||||
buf, err := c.Write()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = wc.mw.w.Write(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check if we received an acknowledge
|
||||
if wc.mw.ackWindowSize != 0 {
|
||||
diff := wc.mw.w.Count() - (wc.mw.ackValue)
|
||||
// TODO: handle overflow
|
||||
|
||||
if diff > (wc.mw.ackWindowSize * 3 / 2) {
|
||||
return fmt.Errorf("no acknowledge received within window")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wc *writerChunkStream) writeMessage(msg *Message) error {
|
||||
bodyLen := uint32(len(msg.Body))
|
||||
pos := uint32(0)
|
||||
firstChunk := true
|
||||
|
||||
var timestampDelta *uint32
|
||||
@@ -42,65 +67,45 @@ func (wc *writerChunkStream) write(msg *Message) error {
|
||||
|
||||
switch {
|
||||
case wc.lastMessageStreamID == nil || timestampDelta == nil || *wc.lastMessageStreamID != msg.MessageStreamID:
|
||||
buf, err := chunk.Chunk0{
|
||||
err := wc.writeChunk(&chunk.Chunk0{
|
||||
ChunkStreamID: msg.ChunkStreamID,
|
||||
Timestamp: msg.Timestamp,
|
||||
Type: msg.Type,
|
||||
MessageStreamID: msg.MessageStreamID,
|
||||
BodyLen: uint32(bodyLen),
|
||||
BodyLen: (bodyLen),
|
||||
Body: msg.Body[pos : pos+chunkBodyLen],
|
||||
}.Write()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = wc.mw.w.Write(buf)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case *wc.lastType != msg.Type || *wc.lastBodyLen != bodyLen:
|
||||
buf, err := chunk.Chunk1{
|
||||
err := wc.writeChunk(&chunk.Chunk1{
|
||||
ChunkStreamID: msg.ChunkStreamID,
|
||||
TimestampDelta: *timestampDelta,
|
||||
Type: msg.Type,
|
||||
BodyLen: uint32(bodyLen),
|
||||
BodyLen: (bodyLen),
|
||||
Body: msg.Body[pos : pos+chunkBodyLen],
|
||||
}.Write()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = wc.mw.w.Write(buf)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case wc.lastTimestampDelta == nil || *wc.lastTimestampDelta != *timestampDelta:
|
||||
buf, err := chunk.Chunk2{
|
||||
err := wc.writeChunk(&chunk.Chunk2{
|
||||
ChunkStreamID: msg.ChunkStreamID,
|
||||
TimestampDelta: *timestampDelta,
|
||||
Body: msg.Body[pos : pos+chunkBodyLen],
|
||||
}.Write()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = wc.mw.w.Write(buf)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
default:
|
||||
buf, err := chunk.Chunk3{
|
||||
err := wc.writeChunk(&chunk.Chunk3{
|
||||
ChunkStreamID: msg.ChunkStreamID,
|
||||
Body: msg.Body[pos : pos+chunkBodyLen],
|
||||
}.Write()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = wc.mw.w.Write(buf)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -120,15 +125,10 @@ func (wc *writerChunkStream) write(msg *Message) error {
|
||||
wc.lastTimestampDelta = &v5
|
||||
}
|
||||
} else {
|
||||
buf, err := chunk.Chunk3{
|
||||
err := wc.writeChunk(&chunk.Chunk3{
|
||||
ChunkStreamID: msg.ChunkStreamID,
|
||||
Body: msg.Body[pos : pos+chunkBodyLen],
|
||||
}.Write()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = wc.mw.w.Write(buf)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -144,13 +144,15 @@ func (wc *writerChunkStream) write(msg *Message) error {
|
||||
|
||||
// Writer is a raw message writer.
|
||||
type Writer struct {
|
||||
w io.Writer
|
||||
chunkSize int
|
||||
w *bytecounter.Writer
|
||||
chunkSize uint32
|
||||
ackWindowSize uint32
|
||||
ackValue uint32
|
||||
chunkStreams map[byte]*writerChunkStream
|
||||
}
|
||||
|
||||
// NewWriter allocates a Writer.
|
||||
func NewWriter(w io.Writer) *Writer {
|
||||
func NewWriter(w *bytecounter.Writer) *Writer {
|
||||
return &Writer{
|
||||
w: w,
|
||||
chunkSize: 128,
|
||||
@@ -159,10 +161,20 @@ func NewWriter(w io.Writer) *Writer {
|
||||
}
|
||||
|
||||
// SetChunkSize sets the maximum chunk size.
|
||||
func (w *Writer) SetChunkSize(v int) {
|
||||
func (w *Writer) SetChunkSize(v uint32) {
|
||||
w.chunkSize = v
|
||||
}
|
||||
|
||||
// SetWindowAckSize sets the window acknowledgement size.
|
||||
func (w *Writer) SetWindowAckSize(v uint32) {
|
||||
w.ackWindowSize = v
|
||||
}
|
||||
|
||||
// SetAcknowledgeValue sets the acknowledge sequence number.
|
||||
func (w *Writer) SetAcknowledgeValue(v uint32) {
|
||||
w.ackValue = v
|
||||
}
|
||||
|
||||
// Write writes a Message.
|
||||
func (w *Writer) Write(msg *Message) error {
|
||||
wc, ok := w.chunkStreams[msg.ChunkStreamID]
|
||||
@@ -171,5 +183,5 @@ func (w *Writer) Write(msg *Message) error {
|
||||
w.chunkStreams[msg.ChunkStreamID] = wc
|
||||
}
|
||||
|
||||
return wc.write(msg)
|
||||
return wc.writeMessage(msg)
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package rawmessage
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
|
||||
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -12,8 +12,7 @@ import (
|
||||
func TestWriter(t *testing.T) {
|
||||
t.Run("chunk0 + chunk1", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
br := bufio.NewReader(&buf)
|
||||
w := NewWriter(&buf)
|
||||
w := NewWriter(bytecounter.NewWriter(&buf))
|
||||
|
||||
err := w.Write(&Message{
|
||||
ChunkStreamID: 27,
|
||||
@@ -25,7 +24,7 @@ func TestWriter(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
var c0 chunk.Chunk0
|
||||
err = c0.Read(br, 128)
|
||||
err = c0.Read(&buf, 128)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, chunk.Chunk0{
|
||||
ChunkStreamID: 27,
|
||||
@@ -46,7 +45,7 @@ func TestWriter(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
var c1 chunk.Chunk1
|
||||
err = c1.Read(br, 128)
|
||||
err = c1.Read(&buf, 128)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, chunk.Chunk1{
|
||||
ChunkStreamID: 27,
|
||||
@@ -59,8 +58,7 @@ func TestWriter(t *testing.T) {
|
||||
|
||||
t.Run("chunk0 + chunk2 + chunk3", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
br := bufio.NewReader(&buf)
|
||||
w := NewWriter(&buf)
|
||||
w := NewWriter(bytecounter.NewWriter(&buf))
|
||||
|
||||
err := w.Write(&Message{
|
||||
ChunkStreamID: 27,
|
||||
@@ -72,7 +70,7 @@ func TestWriter(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
var c0 chunk.Chunk0
|
||||
err = c0.Read(br, 128)
|
||||
err = c0.Read(&buf, 128)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, chunk.Chunk0{
|
||||
ChunkStreamID: 27,
|
||||
@@ -93,7 +91,7 @@ func TestWriter(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
var c2 chunk.Chunk2
|
||||
err = c2.Read(br, 64)
|
||||
err = c2.Read(&buf, 64)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, chunk.Chunk2{
|
||||
ChunkStreamID: 27,
|
||||
@@ -111,7 +109,7 @@ func TestWriter(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
var c3 chunk.Chunk3
|
||||
err = c3.Read(br, 64)
|
||||
err = c3.Read(&buf, 64)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, chunk.Chunk3{
|
||||
ChunkStreamID: 27,
|
||||
@@ -121,8 +119,7 @@ func TestWriter(t *testing.T) {
|
||||
|
||||
t.Run("chunk0 + chunk3", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
br := bufio.NewReader(&buf)
|
||||
w := NewWriter(&buf)
|
||||
w := NewWriter(bytecounter.NewWriter(&buf))
|
||||
|
||||
err := w.Write(&Message{
|
||||
ChunkStreamID: 27,
|
||||
@@ -134,7 +131,7 @@ func TestWriter(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
var c0 chunk.Chunk0
|
||||
err = c0.Read(br, 128)
|
||||
err = c0.Read(&buf, 128)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, chunk.Chunk0{
|
||||
ChunkStreamID: 27,
|
||||
@@ -146,7 +143,7 @@ func TestWriter(t *testing.T) {
|
||||
}, c0)
|
||||
|
||||
var c3 chunk.Chunk3
|
||||
err = c3.Read(br, 64)
|
||||
err = c3.Read(&buf, 64)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, chunk.Chunk3{
|
||||
ChunkStreamID: 27,
|
||||
@@ -154,3 +151,30 @@ func TestWriter(t *testing.T) {
|
||||
}, c3)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWriterAcknowledge(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
w := NewWriter(bytecounter.NewWriter(&buf))
|
||||
|
||||
w.SetWindowAckSize(100)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
err := w.Write(&Message{
|
||||
ChunkStreamID: 27,
|
||||
Timestamp: 18576,
|
||||
Type: chunk.MessageTypeSetPeerBandwidth,
|
||||
MessageStreamID: 3123,
|
||||
Body: bytes.Repeat([]byte{0x03}, 64),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err := w.Write(&Message{
|
||||
ChunkStreamID: 27,
|
||||
Timestamp: 18576,
|
||||
Type: chunk.MessageTypeSetPeerBandwidth,
|
||||
MessageStreamID: 3123,
|
||||
Body: bytes.Repeat([]byte{0x03}, 64),
|
||||
})
|
||||
require.EqualError(t, err, "no acknowledge received within window")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user