rtmp: implement acknowledge mechanism

This commit is contained in:
aler9
2022-06-08 20:47:36 +02:00
parent ee2908081e
commit 2601ca5661
23 changed files with 473 additions and 177 deletions

View File

@@ -0,0 +1,37 @@
package bytecounter
import (
"bufio"
"io"
)
type readerInner struct {
r io.Reader
count uint32
}
func (r *readerInner) Read(p []byte) (int, error) {
n, err := r.r.Read(p)
r.count += uint32(n)
return n, err
}
// Reader allows to count read bytes.
type Reader struct {
ri *readerInner
*bufio.Reader
}
// NewReader allocates a Reader.
func NewReader(r io.Reader) *Reader {
ri := &readerInner{r: r}
return &Reader{
ri: ri,
Reader: bufio.NewReader(ri),
}
}
// Count returns read bytes.
func (r Reader) Count() uint32 {
return r.ri.count
}

View File

@@ -0,0 +1,19 @@
package bytecounter
import (
"io"
)
// ReadWriter allows to count read and written bytes.
type ReadWriter struct {
*Reader
*Writer
}
// NewReadWriter allocates a ReadWriter.
func NewReadWriter(rw io.ReadWriter) *ReadWriter {
return &ReadWriter{
Reader: NewReader(rw),
Writer: NewWriter(rw),
}
}

View File

@@ -0,0 +1,30 @@
package bytecounter
import (
"io"
)
// Writer allows to count written bytes.
type Writer struct {
w io.Writer
count uint32
}
// NewWriter allocates a Writer.
func NewWriter(w io.Writer) *Writer {
return &Writer{
w: w,
}
}
// Write implements io.Writer.
func (w *Writer) Write(p []byte) (int, error) {
n, err := w.w.Write(p)
w.count += uint32(n)
return n, err
}
// Count returns written bytes.
func (w Writer) Count() uint32 {
return w.count
}

View File

@@ -0,0 +1,11 @@
package chunk
import (
"io"
)
// Chunk is a chunk.
type Chunk interface {
Read(io.Reader, uint32) error
Write() ([]byte, error)
}

View File

@@ -18,7 +18,7 @@ type Chunk0 struct {
}
// Read reads the chunk.
func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen int) error {
func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen uint32) error {
header := make([]byte, 12)
_, err := r.Read(header)
if err != nil {
@@ -31,7 +31,7 @@ func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen int) error {
c.Type = MessageType(header[7])
c.MessageStreamID = uint32(header[8])<<24 | uint32(header[9])<<16 | uint32(header[10])<<8 | uint32(header[11])
chunkBodyLen := int(c.BodyLen)
chunkBodyLen := c.BodyLen
if chunkBodyLen > chunkMaxBodyLen {
chunkBodyLen = chunkMaxBodyLen
}

View File

@@ -19,7 +19,7 @@ type Chunk1 struct {
}
// Read reads the chunk.
func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen int) error {
func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen uint32) error {
header := make([]byte, 8)
_, err := r.Read(header)
if err != nil {
@@ -31,7 +31,7 @@ func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen int) error {
c.BodyLen = uint32(header[4])<<16 | uint32(header[5])<<8 | uint32(header[6])
c.Type = MessageType(header[7])
chunkBodyLen := int(c.BodyLen)
chunkBodyLen := (c.BodyLen)
if chunkBodyLen > chunkMaxBodyLen {
chunkBodyLen = chunkMaxBodyLen
}

View File

@@ -15,7 +15,7 @@ type Chunk2 struct {
}
// Read reads the chunk.
func (c *Chunk2) Read(r io.Reader, chunkBodyLen int) error {
func (c *Chunk2) Read(r io.Reader, chunkBodyLen uint32) error {
header := make([]byte, 4)
_, err := r.Read(header)
if err != nil {

View File

@@ -16,7 +16,7 @@ type Chunk3 struct {
}
// Read reads the chunk.
func (c *Chunk3) Read(r io.Reader, chunkBodyLen int) error {
func (c *Chunk3) Read(r io.Reader, chunkBodyLen uint32) error {
header := make([]byte, 1)
_, err := r.Read(header)
if err != nil {

View File

@@ -1,7 +1,6 @@
package rtmp
import (
"bufio"
"net"
"net/url"
"strings"
@@ -13,6 +12,7 @@ import (
"github.com/notedit/rtmp/format/flv/flvio"
"github.com/stretchr/testify/require"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/handshake"
"github.com/aler9/rtsp-simple-server/internal/rtmp/message"
)
@@ -114,7 +114,7 @@ func TestReadTracks(t *testing.T) {
conn, err := net.Dial("tcp", "127.0.0.1:9121")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
bc := bytecounter.NewReadWriter(conn)
// C->S handshake C0
err = handshake.C0S0{}.Write(conn)
@@ -126,27 +126,26 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
// S->C handshake S0
err = handshake.C0S0{}.Read(br)
err = handshake.C0S0{}.Read(bc)
require.NoError(t, err)
// S->C handshake S1
s1 := handshake.C1S1{}
err = s1.Read(br, false)
err = s1.Read(bc, false)
require.NoError(t, err)
// S->C handshake S2
err = (&handshake.C2S2{Digest: c1.Digest}).Read(br)
err = (&handshake.C2S2{Digest: c1.Digest}).Read(bc)
require.NoError(t, err)
// C->S handshake C2
err = handshake.C2S2{Digest: s1.Digest}.Write(conn)
require.NoError(t, err)
mw := message.NewWriter(conn)
mr := message.NewReader(br)
mrw := message.NewReadWriter(bc)
// C->S connect
err = mw.Write(&message.MsgCommandAMF0{
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"connect",
@@ -166,14 +165,14 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
// S->C window acknowledgement size
msg, err := mr.Read()
msg, err := mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgSetWindowAckSize{
Value: 2500000,
}, msg)
// S->C set peer bandwidth
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgSetPeerBandwidth{
Value: 2500000,
@@ -181,16 +180,14 @@ func TestReadTracks(t *testing.T) {
}, msg)
// S->C set chunk size
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgSetChunkSize{
Value: 65536,
}, msg)
mr.SetChunkSize(65536)
// S->C result
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
@@ -211,15 +208,13 @@ func TestReadTracks(t *testing.T) {
}, msg)
// C->S set chunk size
err = mw.Write(&message.MsgSetChunkSize{
err = mrw.Write(&message.MsgSetChunkSize{
Value: 65536,
})
require.NoError(t, err)
mw.SetChunkSize(65536)
// C->S releaseStream
err = mw.Write(&message.MsgCommandAMF0{
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"releaseStream",
@@ -231,7 +226,7 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
// C->S FCPublish
err = mw.Write(&message.MsgCommandAMF0{
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"FCPublish",
@@ -243,7 +238,7 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
// C->S createStream
err = mw.Write(&message.MsgCommandAMF0{
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"createStream",
@@ -254,7 +249,7 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
// S->C result
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
@@ -267,7 +262,7 @@ func TestReadTracks(t *testing.T) {
}, msg)
// C->S publish
err = mw.Write(&message.MsgCommandAMF0{
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 8,
MessageStreamID: 1,
Payload: []interface{}{
@@ -281,7 +276,7 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
// S->C onStatus
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5,
@@ -301,7 +296,7 @@ func TestReadTracks(t *testing.T) {
switch ca {
case "standard":
// C->S metadata
err = mw.Write(&message.MsgDataAMF0{
err = mrw.Write(&message.MsgDataAMF0{
ChunkStreamID: 4,
MessageStreamID: 1,
Payload: []interface{}{
@@ -341,7 +336,7 @@ func TestReadTracks(t *testing.T) {
b := make([]byte, 128)
var n int
codec.ToConfig(b, &n)
err = mw.Write(&message.MsgVideo{
err = mrw.Write(&message.MsgVideo{
ChunkStreamID: 6,
MessageStreamID: 1,
IsKeyFrame: true,
@@ -357,7 +352,7 @@ func TestReadTracks(t *testing.T) {
ChannelCount: 2,
}.Encode()
require.NoError(t, err)
err = mw.Write(&message.MsgAudio{
err = mrw.Write(&message.MsgAudio{
ChunkStreamID: 4,
MessageStreamID: 1,
Rate: flvio.SOUND_44Khz,
@@ -370,7 +365,7 @@ func TestReadTracks(t *testing.T) {
case "metadata without codec id":
// C->S metadata
err = mw.Write(&message.MsgDataAMF0{
err = mrw.Write(&message.MsgDataAMF0{
ChunkStreamID: 4,
MessageStreamID: 1,
Payload: []interface{}{
@@ -406,7 +401,7 @@ func TestReadTracks(t *testing.T) {
b := make([]byte, 128)
var n int
codec.ToConfig(b, &n)
err = mw.Write(&message.MsgVideo{
err = mrw.Write(&message.MsgVideo{
ChunkStreamID: 6,
MessageStreamID: 1,
IsKeyFrame: true,
@@ -428,7 +423,7 @@ func TestReadTracks(t *testing.T) {
b := make([]byte, 128)
var n int
codec.ToConfig(b, &n)
err = mw.Write(&message.MsgVideo{
err = mrw.Write(&message.MsgVideo{
ChunkStreamID: 6,
MessageStreamID: 1,
IsKeyFrame: true,
@@ -479,7 +474,7 @@ func TestWriteTracks(t *testing.T) {
conn, err := net.Dial("tcp", "127.0.0.1:9121")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
bc := bytecounter.NewReadWriter(conn)
// C->S handshake C0
err = handshake.C0S0{}.Write(conn)
@@ -491,27 +486,26 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
// S->C handshake S0
err = handshake.C0S0{}.Read(br)
err = handshake.C0S0{}.Read(bc)
require.NoError(t, err)
// S->C handshake S1
s1 := handshake.C1S1{}
err = s1.Read(br, false)
err = s1.Read(bc, false)
require.NoError(t, err)
// S->C handshake S2
err = (&handshake.C2S2{Digest: c1.Digest}).Read(br)
err = (&handshake.C2S2{Digest: c1.Digest}).Read(bc)
require.NoError(t, err)
// C->S handshake C2
err = handshake.C2S2{Digest: s1.Digest}.Write(conn)
require.NoError(t, err)
mw := message.NewWriter(conn)
mr := message.NewReader(br)
mrw := message.NewReadWriter(bc)
// C->S connect
err = mw.Write(&message.MsgCommandAMF0{
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"connect",
@@ -531,14 +525,14 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
// S->C window acknowledgement size
msg, err := mr.Read()
msg, err := mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgSetWindowAckSize{
Value: 2500000,
}, msg)
// S->C set peer bandwidth
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgSetPeerBandwidth{
Value: 2500000,
@@ -546,16 +540,14 @@ func TestWriteTracks(t *testing.T) {
}, msg)
// S->C set chunk size
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgSetChunkSize{
Value: 65536,
}, msg)
mr.SetChunkSize(65536)
// S->C result
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
@@ -576,21 +568,19 @@ func TestWriteTracks(t *testing.T) {
}, msg)
// C->S window acknowledgement size
err = mw.Write(&message.MsgSetWindowAckSize{
err = mrw.Write(&message.MsgSetWindowAckSize{
Value: 2500000,
})
require.NoError(t, err)
// C->S set chunk size
err = mw.Write(&message.MsgSetChunkSize{
err = mrw.Write(&message.MsgSetChunkSize{
Value: 65536,
})
require.NoError(t, err)
mw.SetChunkSize(65536)
// C->S createStream
err = mw.Write(&message.MsgCommandAMF0{
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"createStream",
@@ -601,7 +591,7 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
// S->C result
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
@@ -614,7 +604,7 @@ func TestWriteTracks(t *testing.T) {
}, msg)
// C->S getStreamLength
err = mw.Write(&message.MsgCommandAMF0{
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 8,
Payload: []interface{}{
"getStreamLength",
@@ -626,7 +616,7 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
// C->S play
err = mw.Write(&message.MsgCommandAMF0{
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 8,
Payload: []interface{}{
"play",
@@ -639,21 +629,21 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
// S->C event "stream is recorded"
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgUserControlStreamIsRecorded{
StreamID: 1,
}, msg)
// S->C event "stream begin 1"
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgUserControlStreamBegin{
StreamID: 1,
}, msg)
// S->C onStatus
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5,
@@ -671,7 +661,7 @@ func TestWriteTracks(t *testing.T) {
}, msg)
// S->C onStatus
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5,
@@ -689,7 +679,7 @@ func TestWriteTracks(t *testing.T) {
}, msg)
// S->C onStatus
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5,
@@ -707,7 +697,7 @@ func TestWriteTracks(t *testing.T) {
}, msg)
// S->C onStatus
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5,
@@ -725,7 +715,7 @@ func TestWriteTracks(t *testing.T) {
}, msg)
// S->C onMetadata
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgDataAMF0{
ChunkStreamID: 4,
@@ -742,7 +732,7 @@ func TestWriteTracks(t *testing.T) {
}, msg)
// S->C H264 decoder config
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgVideo{
ChunkStreamID: 6,
@@ -760,7 +750,7 @@ func TestWriteTracks(t *testing.T) {
}, msg)
// S->C AAC decoder config
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgAudio{
ChunkStreamID: 4,

View File

@@ -1,7 +1,6 @@
package handshake
import (
"bufio"
"fmt"
"io"
)
@@ -14,7 +13,7 @@ const (
type C0S0 struct{}
// Read reads a C0S0.
func (C0S0) Read(r *bufio.Reader) error {
func (C0S0) Read(r io.Reader) error {
buf := make([]byte, 1)
_, err := io.ReadFull(r, buf)
if err != nil {

View File

@@ -1,7 +1,6 @@
package handshake
import (
"bufio"
"bytes"
"testing"
@@ -14,7 +13,7 @@ var c0s0dec = C0S0{}
func TestC0S0Read(t *testing.T) {
var c0s0 C0S0
err := c0s0.Read(bufio.NewReader(bytes.NewReader(c0s0enc)))
err := c0s0.Read((bytes.NewReader(c0s0enc)))
require.NoError(t, err)
require.Equal(t, c0s0dec, c0s0)
}

View File

@@ -1,7 +1,6 @@
package handshake
import (
"bufio"
"bytes"
"crypto/hmac"
"crypto/rand"
@@ -79,7 +78,7 @@ type C1S1 struct {
}
// Read reads a C1S1.
func (c *C1S1) Read(r *bufio.Reader, isC1 bool) error {
func (c *C1S1) Read(r io.Reader, isC1 bool) error {
buf := make([]byte, 1536)
_, err := io.ReadFull(r, buf)
if err != nil {

View File

@@ -1,7 +1,6 @@
package handshake
import (
"bufio"
"bytes"
"testing"
@@ -44,7 +43,7 @@ func TestC1S1Read(t *testing.T) {
)
var c1s1 C1S1
err := c1s1.Read(bufio.NewReader(bytes.NewReader(c1s1enc)), true)
err := c1s1.Read((bytes.NewReader(c1s1enc)), true)
require.NoError(t, err)
require.Equal(t, c1s1dec, c1s1)
}

View File

@@ -1,7 +1,6 @@
package handshake
import (
"bufio"
"bytes"
"crypto/rand"
"encoding/binary"
@@ -18,7 +17,7 @@ type C2S2 struct {
}
// Read reads a C2S2.
func (c *C2S2) Read(r *bufio.Reader) error {
func (c *C2S2) Read(r io.Reader) error {
buf := make([]byte, 1536)
_, err := io.ReadFull(r, buf)
if err != nil {

View File

@@ -1,7 +1,6 @@
package handshake
import (
"bufio"
"bytes"
"testing"
@@ -43,7 +42,7 @@ func TestC2S2Read(t *testing.T) {
var c2s2 C2S2
c2s2.Digest = c2s2dec.Digest
err := c2s2.Read(bufio.NewReader(bytes.NewReader(c2s2enc)))
err := c2s2.Read((bytes.NewReader(c2s2enc)))
require.NoError(t, err)
require.Equal(t, c2s2dec, c2s2)
}

View File

@@ -0,0 +1,40 @@
package message
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
"github.com/aler9/rtsp-simple-server/internal/rtmp/rawmessage"
)
// MsgAcknowledge is an acknowledgement message.
type MsgAcknowledge struct {
Value uint32
}
// Unmarshal implements Message.
func (m *MsgAcknowledge) Unmarshal(raw *rawmessage.Message) error {
if raw.ChunkStreamID != ControlChunkStreamID {
return fmt.Errorf("unexpected chunk stream ID")
}
if len(raw.Body) != 4 {
return fmt.Errorf("unexpected body size")
}
m.Value = binary.BigEndian.Uint32(raw.Body)
return nil
}
// Marshal implements Message.
func (m *MsgAcknowledge) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 4)
binary.BigEndian.PutUint32(body, m.Value)
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeAcknowledge,
Body: body,
}, nil
}

View File

@@ -1,10 +1,10 @@
package message
import (
"bufio"
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
"github.com/aler9/rtsp-simple-server/internal/rtmp/rawmessage"
)
@@ -14,6 +14,9 @@ func allocateMessage(raw *rawmessage.Message) (Message, error) {
case chunk.MessageTypeSetChunkSize:
return &MsgSetChunkSize{}, nil
case chunk.MessageTypeAcknowledge:
return &MsgAcknowledge{}, nil
case chunk.MessageTypeSetWindowAckSize:
return &MsgSetWindowAckSize{}, nil
@@ -75,18 +78,13 @@ type Reader struct {
}
// NewReader allocates a Reader.
func NewReader(r *bufio.Reader) *Reader {
func NewReader(r *bytecounter.Reader, onAckNeeded func(uint32) error) *Reader {
return &Reader{
r: rawmessage.NewReader(r),
r: rawmessage.NewReader(r, onAckNeeded),
}
}
// SetChunkSize sets the maximum chunk size.
func (r *Reader) SetChunkSize(v int) {
r.r.SetChunkSize(v)
}
// Read reads a essage.
// Read reads a Message.
func (r *Reader) Read() (Message, error) {
raw, err := r.r.Read()
if err != nil {
@@ -103,5 +101,13 @@ func (r *Reader) Read() (Message, error) {
return nil, err
}
switch tmsg := msg.(type) {
case *MsgSetChunkSize:
r.r.SetChunkSize(tmsg.Value)
case *MsgSetWindowAckSize:
r.r.SetWindowAckSize(tmsg.Value)
}
return msg, nil
}

View File

@@ -0,0 +1,46 @@
package message
import (
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
)
// ReadWriter is a message reader/writer.
type ReadWriter struct {
r *Reader
w *Writer
}
// NewReadWriter allocates a ReadWriter.
func NewReadWriter(bc *bytecounter.ReadWriter) *ReadWriter {
w := NewWriter(bc.Writer)
r := NewReader(bc.Reader, func(count uint32) error {
return w.Write(&MsgAcknowledge{
Value: (count),
})
})
return &ReadWriter{
r: r,
w: w,
}
}
// Read reads a message.
func (rw *ReadWriter) Read() (Message, error) {
msg, err := rw.r.Read()
if err != nil {
return nil, err
}
if tmsg, ok := msg.(*MsgAcknowledge); ok {
rw.w.SetAcknowledgeValue(tmsg.Value)
}
return msg, nil
}
// Write writes a message.
func (rw *ReadWriter) Write(msg Message) error {
return rw.w.Write(msg)
}

View File

@@ -1,8 +1,7 @@
package message
import (
"io"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/rawmessage"
)
@@ -12,23 +11,36 @@ type Writer struct {
}
// NewWriter allocates a Writer.
func NewWriter(w io.Writer) *Writer {
func NewWriter(w *bytecounter.Writer) *Writer {
return &Writer{
w: rawmessage.NewWriter(w),
}
}
// SetChunkSize sets the maximum chunk size.
func (mw *Writer) SetChunkSize(v int) {
mw.w.SetChunkSize(v)
// SetAcknowledgeValue sets the value of the last received acknowledge.
func (w *Writer) SetAcknowledgeValue(v uint32) {
w.w.SetAcknowledgeValue(v)
}
// Write writes a message.
func (mw *Writer) Write(msg Message) error {
func (w *Writer) Write(msg Message) error {
raw, err := msg.Marshal()
if err != nil {
return err
}
return mw.w.Write(raw)
err = w.w.Write(raw)
if err != nil {
return err
}
switch tmsg := msg.(type) {
case *MsgSetChunkSize:
w.w.SetChunkSize(tmsg.Value)
case *MsgSetWindowAckSize:
w.w.SetWindowAckSize(tmsg.Value)
}
return nil
}

View File

@@ -1,10 +1,10 @@
package rawmessage
import (
"bufio"
"errors"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
)
@@ -20,7 +20,32 @@ type readerChunkStream struct {
curTimestampDelta *uint32
}
func (rc *readerChunkStream) read(typ byte) (*Message, error) {
func (rc *readerChunkStream) readChunk(c chunk.Chunk, chunkBodySize uint32) error {
err := c.Read(rc.mr.r, chunkBodySize)
if err != nil {
return err
}
// check if an ack is needed
if rc.mr.ackWindowSize != 0 {
count := rc.mr.r.Count()
diff := count - rc.mr.lastAckCount
// TODO: handle overflow
if diff > (rc.mr.ackWindowSize) {
err := rc.mr.onAckNeeded(count)
if err != nil {
return err
}
rc.mr.lastAckCount += (rc.mr.ackWindowSize)
}
}
return nil
}
func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
switch typ {
case 0:
if rc.curBody != nil {
@@ -28,7 +53,7 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) {
}
var c0 chunk.Chunk0
err := c0.Read(rc.mr.r, rc.mr.chunkSize)
err := rc.readChunk(&c0, rc.mr.chunkSize)
if err != nil {
return nil, err
}
@@ -65,7 +90,7 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) {
}
var c1 chunk.Chunk1
err := c1.Read(rc.mr.r, rc.mr.chunkSize)
err := rc.readChunk(&c1, rc.mr.chunkSize)
if err != nil {
return nil, err
}
@@ -100,13 +125,13 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) {
return nil, fmt.Errorf("received type 2 chunk but expected type 3 chunk")
}
chunkBodyLen := int(*rc.curBodyLen)
chunkBodyLen := (*rc.curBodyLen)
if chunkBodyLen > rc.mr.chunkSize {
chunkBodyLen = rc.mr.chunkSize
}
var c2 chunk.Chunk2
err := c2.Read(rc.mr.r, chunkBodyLen)
err := rc.readChunk(&c2, chunkBodyLen)
if err != nil {
return nil, err
}
@@ -116,7 +141,7 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) {
v2 := c2.TimestampDelta
rc.curTimestampDelta = &v2
if chunkBodyLen != len(c2.Body) {
if chunkBodyLen != uint32(len(c2.Body)) {
rc.curBody = &c2.Body
return nil, errMoreChunksNeeded
}
@@ -134,13 +159,13 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) {
}
if rc.curBody != nil {
chunkBodyLen := int(*rc.curBodyLen) - len(*rc.curBody)
chunkBodyLen := (*rc.curBodyLen) - uint32(len(*rc.curBody))
if chunkBodyLen > rc.mr.chunkSize {
chunkBodyLen = rc.mr.chunkSize
}
var c3 chunk.Chunk3
err := c3.Read(rc.mr.r, chunkBodyLen)
err := rc.readChunk(&c3, chunkBodyLen)
if err != nil {
return nil, err
}
@@ -162,13 +187,13 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) {
}, nil
}
chunkBodyLen := int(*rc.curBodyLen)
chunkBodyLen := (*rc.curBodyLen)
if chunkBodyLen > rc.mr.chunkSize {
chunkBodyLen = rc.mr.chunkSize
}
var c3 chunk.Chunk3
err := c3.Read(rc.mr.r, chunkBodyLen)
err := rc.readChunk(&c3, chunkBodyLen)
if err != nil {
return nil, err
}
@@ -187,25 +212,35 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) {
// Reader is a raw message reader.
type Reader struct {
r *bufio.Reader
chunkSize int
r *bytecounter.Reader
onAckNeeded func(uint32) error
chunkSize uint32
ackWindowSize uint32
lastAckCount uint32
chunkStreams map[byte]*readerChunkStream
}
// NewReader allocates a Reader.
func NewReader(r *bufio.Reader) *Reader {
func NewReader(r *bytecounter.Reader, onAckNeeded func(uint32) error) *Reader {
return &Reader{
r: r,
onAckNeeded: onAckNeeded,
chunkSize: 128,
chunkStreams: make(map[byte]*readerChunkStream),
}
}
// SetChunkSize sets the maximum chunk size.
func (r *Reader) SetChunkSize(v int) {
func (r *Reader) SetChunkSize(v uint32) {
r.chunkSize = v
}
// SetWindowAckSize sets the window acknowledgement size.
func (r *Reader) SetWindowAckSize(v uint32) {
r.ackWindowSize = v
}
// Read reads a Message.
func (r *Reader) Read() (*Message, error) {
for {
@@ -225,7 +260,7 @@ func (r *Reader) Read() (*Message, error) {
r.r.UnreadByte()
msg, err := rc.read(typ)
msg, err := rc.readMessage(typ)
if err != nil {
if err == errMoreChunksNeeded {
continue

View File

@@ -1,12 +1,13 @@
package rawmessage
import (
"bufio"
"bytes"
"testing"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
"github.com/stretchr/testify/require"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
)
type writableChunk interface {
@@ -21,7 +22,10 @@ type sequenceEntry struct {
func TestReader(t *testing.T) {
testSequence := func(t *testing.T, seq []sequenceEntry) {
var buf bytes.Buffer
r := NewReader(bufio.NewReader(&buf))
bcr := bytecounter.NewReader(&buf)
r := NewReader(bcr, func(count uint32) error {
return nil
})
for _, entry := range seq {
buf2, err := entry.chunk.Write()
@@ -122,7 +126,10 @@ func TestReader(t *testing.T) {
t.Run("chunk0 + chunk3", func(t *testing.T) {
var buf bytes.Buffer
r := NewReader(bufio.NewReader(&buf))
bcr := bytecounter.NewReader(&buf)
r := NewReader(bcr, func(count uint32) error {
return nil
})
buf2, err := chunk.Chunk0{
ChunkStreamID: 27,
@@ -153,3 +160,36 @@ func TestReader(t *testing.T) {
}, msg)
})
}
func TestReaderAcknowledge(t *testing.T) {
onAckCalled := make(chan struct{})
var buf bytes.Buffer
bcr := bytecounter.NewReader(&buf)
r := NewReader(bcr, func(count uint32) error {
close(onAckCalled)
return nil
})
r.SetWindowAckSize(100)
for i := 0; i < 2; i++ {
buf2, err := chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x03}, 64),
}.Write()
require.NoError(t, err)
buf.Write(buf2)
}
for i := 0; i < 2; i++ {
_, err := r.Read()
require.NoError(t, err)
}
<-onAckCalled
}

View File

@@ -1,8 +1,9 @@
package rawmessage
import (
"io"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
)
@@ -10,14 +11,38 @@ type writerChunkStream struct {
mw *Writer
lastMessageStreamID *uint32
lastType *chunk.MessageType
lastBodyLen *int
lastBodyLen *uint32
lastTimestamp *uint32
lastTimestampDelta *uint32
}
func (wc *writerChunkStream) write(msg *Message) error {
bodyLen := len(msg.Body)
pos := 0
func (wc *writerChunkStream) writeChunk(c chunk.Chunk) error {
buf, err := c.Write()
if err != nil {
return err
}
_, err = wc.mw.w.Write(buf)
if err != nil {
return err
}
// check if we received an acknowledge
if wc.mw.ackWindowSize != 0 {
diff := wc.mw.w.Count() - (wc.mw.ackValue)
// TODO: handle overflow
if diff > (wc.mw.ackWindowSize * 3 / 2) {
return fmt.Errorf("no acknowledge received within window")
}
}
return nil
}
func (wc *writerChunkStream) writeMessage(msg *Message) error {
bodyLen := uint32(len(msg.Body))
pos := uint32(0)
firstChunk := true
var timestampDelta *uint32
@@ -42,65 +67,45 @@ func (wc *writerChunkStream) write(msg *Message) error {
switch {
case wc.lastMessageStreamID == nil || timestampDelta == nil || *wc.lastMessageStreamID != msg.MessageStreamID:
buf, err := chunk.Chunk0{
err := wc.writeChunk(&chunk.Chunk0{
ChunkStreamID: msg.ChunkStreamID,
Timestamp: msg.Timestamp,
Type: msg.Type,
MessageStreamID: msg.MessageStreamID,
BodyLen: uint32(bodyLen),
BodyLen: (bodyLen),
Body: msg.Body[pos : pos+chunkBodyLen],
}.Write()
if err != nil {
return err
}
_, err = wc.mw.w.Write(buf)
})
if err != nil {
return err
}
case *wc.lastType != msg.Type || *wc.lastBodyLen != bodyLen:
buf, err := chunk.Chunk1{
err := wc.writeChunk(&chunk.Chunk1{
ChunkStreamID: msg.ChunkStreamID,
TimestampDelta: *timestampDelta,
Type: msg.Type,
BodyLen: uint32(bodyLen),
BodyLen: (bodyLen),
Body: msg.Body[pos : pos+chunkBodyLen],
}.Write()
if err != nil {
return err
}
_, err = wc.mw.w.Write(buf)
})
if err != nil {
return err
}
case wc.lastTimestampDelta == nil || *wc.lastTimestampDelta != *timestampDelta:
buf, err := chunk.Chunk2{
err := wc.writeChunk(&chunk.Chunk2{
ChunkStreamID: msg.ChunkStreamID,
TimestampDelta: *timestampDelta,
Body: msg.Body[pos : pos+chunkBodyLen],
}.Write()
if err != nil {
return err
}
_, err = wc.mw.w.Write(buf)
})
if err != nil {
return err
}
default:
buf, err := chunk.Chunk3{
err := wc.writeChunk(&chunk.Chunk3{
ChunkStreamID: msg.ChunkStreamID,
Body: msg.Body[pos : pos+chunkBodyLen],
}.Write()
if err != nil {
return err
}
_, err = wc.mw.w.Write(buf)
})
if err != nil {
return err
}
@@ -120,15 +125,10 @@ func (wc *writerChunkStream) write(msg *Message) error {
wc.lastTimestampDelta = &v5
}
} else {
buf, err := chunk.Chunk3{
err := wc.writeChunk(&chunk.Chunk3{
ChunkStreamID: msg.ChunkStreamID,
Body: msg.Body[pos : pos+chunkBodyLen],
}.Write()
if err != nil {
return err
}
_, err = wc.mw.w.Write(buf)
})
if err != nil {
return err
}
@@ -144,13 +144,15 @@ func (wc *writerChunkStream) write(msg *Message) error {
// Writer is a raw message writer.
type Writer struct {
w io.Writer
chunkSize int
w *bytecounter.Writer
chunkSize uint32
ackWindowSize uint32
ackValue uint32
chunkStreams map[byte]*writerChunkStream
}
// NewWriter allocates a Writer.
func NewWriter(w io.Writer) *Writer {
func NewWriter(w *bytecounter.Writer) *Writer {
return &Writer{
w: w,
chunkSize: 128,
@@ -159,10 +161,20 @@ func NewWriter(w io.Writer) *Writer {
}
// SetChunkSize sets the maximum chunk size.
func (w *Writer) SetChunkSize(v int) {
func (w *Writer) SetChunkSize(v uint32) {
w.chunkSize = v
}
// SetWindowAckSize sets the window acknowledgement size.
func (w *Writer) SetWindowAckSize(v uint32) {
w.ackWindowSize = v
}
// SetAcknowledgeValue sets the acknowledge sequence number.
func (w *Writer) SetAcknowledgeValue(v uint32) {
w.ackValue = v
}
// Write writes a Message.
func (w *Writer) Write(msg *Message) error {
wc, ok := w.chunkStreams[msg.ChunkStreamID]
@@ -171,5 +183,5 @@ func (w *Writer) Write(msg *Message) error {
w.chunkStreams[msg.ChunkStreamID] = wc
}
return wc.write(msg)
return wc.writeMessage(msg)
}

View File

@@ -1,10 +1,10 @@
package rawmessage
import (
"bufio"
"bytes"
"testing"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
"github.com/stretchr/testify/require"
)
@@ -12,8 +12,7 @@ import (
func TestWriter(t *testing.T) {
t.Run("chunk0 + chunk1", func(t *testing.T) {
var buf bytes.Buffer
br := bufio.NewReader(&buf)
w := NewWriter(&buf)
w := NewWriter(bytecounter.NewWriter(&buf))
err := w.Write(&Message{
ChunkStreamID: 27,
@@ -25,7 +24,7 @@ func TestWriter(t *testing.T) {
require.NoError(t, err)
var c0 chunk.Chunk0
err = c0.Read(br, 128)
err = c0.Read(&buf, 128)
require.NoError(t, err)
require.Equal(t, chunk.Chunk0{
ChunkStreamID: 27,
@@ -46,7 +45,7 @@ func TestWriter(t *testing.T) {
require.NoError(t, err)
var c1 chunk.Chunk1
err = c1.Read(br, 128)
err = c1.Read(&buf, 128)
require.NoError(t, err)
require.Equal(t, chunk.Chunk1{
ChunkStreamID: 27,
@@ -59,8 +58,7 @@ func TestWriter(t *testing.T) {
t.Run("chunk0 + chunk2 + chunk3", func(t *testing.T) {
var buf bytes.Buffer
br := bufio.NewReader(&buf)
w := NewWriter(&buf)
w := NewWriter(bytecounter.NewWriter(&buf))
err := w.Write(&Message{
ChunkStreamID: 27,
@@ -72,7 +70,7 @@ func TestWriter(t *testing.T) {
require.NoError(t, err)
var c0 chunk.Chunk0
err = c0.Read(br, 128)
err = c0.Read(&buf, 128)
require.NoError(t, err)
require.Equal(t, chunk.Chunk0{
ChunkStreamID: 27,
@@ -93,7 +91,7 @@ func TestWriter(t *testing.T) {
require.NoError(t, err)
var c2 chunk.Chunk2
err = c2.Read(br, 64)
err = c2.Read(&buf, 64)
require.NoError(t, err)
require.Equal(t, chunk.Chunk2{
ChunkStreamID: 27,
@@ -111,7 +109,7 @@ func TestWriter(t *testing.T) {
require.NoError(t, err)
var c3 chunk.Chunk3
err = c3.Read(br, 64)
err = c3.Read(&buf, 64)
require.NoError(t, err)
require.Equal(t, chunk.Chunk3{
ChunkStreamID: 27,
@@ -121,8 +119,7 @@ func TestWriter(t *testing.T) {
t.Run("chunk0 + chunk3", func(t *testing.T) {
var buf bytes.Buffer
br := bufio.NewReader(&buf)
w := NewWriter(&buf)
w := NewWriter(bytecounter.NewWriter(&buf))
err := w.Write(&Message{
ChunkStreamID: 27,
@@ -134,7 +131,7 @@ func TestWriter(t *testing.T) {
require.NoError(t, err)
var c0 chunk.Chunk0
err = c0.Read(br, 128)
err = c0.Read(&buf, 128)
require.NoError(t, err)
require.Equal(t, chunk.Chunk0{
ChunkStreamID: 27,
@@ -146,7 +143,7 @@ func TestWriter(t *testing.T) {
}, c0)
var c3 chunk.Chunk3
err = c3.Read(br, 64)
err = c3.Read(&buf, 64)
require.NoError(t, err)
require.Equal(t, chunk.Chunk3{
ChunkStreamID: 27,
@@ -154,3 +151,30 @@ func TestWriter(t *testing.T) {
}, c3)
})
}
func TestWriterAcknowledge(t *testing.T) {
var buf bytes.Buffer
w := NewWriter(bytecounter.NewWriter(&buf))
w.SetWindowAckSize(100)
for i := 0; i < 2; i++ {
err := w.Write(&Message{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 64),
})
require.NoError(t, err)
}
err := w.Write(&Message{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 64),
})
require.EqualError(t, err, "no acknowledge received within window")
}