mirror of
https://github.com/aler9/rtsp-simple-server
synced 2025-10-18 21:44:42 +08:00
rtmp: make chunk writes atomic
This commit is contained in:
@@ -42,25 +42,20 @@ func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Write writes the chunk.
|
// Write writes the chunk.
|
||||||
func (c Chunk0) Write(w io.Writer) error {
|
func (c Chunk0) Write() ([]byte, error) {
|
||||||
header := make([]byte, 12)
|
buf := make([]byte, 12+len(c.Body))
|
||||||
header[0] = c.ChunkStreamID
|
buf[0] = c.ChunkStreamID
|
||||||
header[1] = byte(c.Timestamp >> 16)
|
buf[1] = byte(c.Timestamp >> 16)
|
||||||
header[2] = byte(c.Timestamp >> 8)
|
buf[2] = byte(c.Timestamp >> 8)
|
||||||
header[3] = byte(c.Timestamp)
|
buf[3] = byte(c.Timestamp)
|
||||||
header[4] = byte(c.BodyLen >> 16)
|
buf[4] = byte(c.BodyLen >> 16)
|
||||||
header[5] = byte(c.BodyLen >> 8)
|
buf[5] = byte(c.BodyLen >> 8)
|
||||||
header[6] = byte(c.BodyLen)
|
buf[6] = byte(c.BodyLen)
|
||||||
header[7] = byte(c.Type)
|
buf[7] = byte(c.Type)
|
||||||
header[8] = byte(c.MessageStreamID >> 24)
|
buf[8] = byte(c.MessageStreamID >> 24)
|
||||||
header[9] = byte(c.MessageStreamID >> 16)
|
buf[9] = byte(c.MessageStreamID >> 16)
|
||||||
header[10] = byte(c.MessageStreamID >> 8)
|
buf[10] = byte(c.MessageStreamID >> 8)
|
||||||
header[11] = byte(c.MessageStreamID)
|
buf[11] = byte(c.MessageStreamID)
|
||||||
_, err := w.Write(header)
|
copy(buf[12:], c.Body)
|
||||||
if err != nil {
|
return buf, nil
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = w.Write(c.Body)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
@@ -29,8 +29,7 @@ func TestChunk0Read(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestChunk0Write(t *testing.T) {
|
func TestChunk0Write(t *testing.T) {
|
||||||
var buf bytes.Buffer
|
buf, err := chunk0dec.Write()
|
||||||
err := chunk0dec.Write(&buf)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, chunk0enc, buf.Bytes())
|
require.Equal(t, chunk0enc, buf)
|
||||||
}
|
}
|
||||||
|
@@ -42,21 +42,16 @@ func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Write writes the chunk.
|
// Write writes the chunk.
|
||||||
func (c Chunk1) Write(w io.Writer) error {
|
func (c Chunk1) Write() ([]byte, error) {
|
||||||
header := make([]byte, 8)
|
buf := make([]byte, 8+len(c.Body))
|
||||||
header[0] = 1<<6 | c.ChunkStreamID
|
buf[0] = 1<<6 | c.ChunkStreamID
|
||||||
header[1] = byte(c.TimestampDelta >> 16)
|
buf[1] = byte(c.TimestampDelta >> 16)
|
||||||
header[2] = byte(c.TimestampDelta >> 8)
|
buf[2] = byte(c.TimestampDelta >> 8)
|
||||||
header[3] = byte(c.TimestampDelta)
|
buf[3] = byte(c.TimestampDelta)
|
||||||
header[4] = byte(c.BodyLen >> 16)
|
buf[4] = byte(c.BodyLen >> 16)
|
||||||
header[5] = byte(c.BodyLen >> 8)
|
buf[5] = byte(c.BodyLen >> 8)
|
||||||
header[6] = byte(c.BodyLen)
|
buf[6] = byte(c.BodyLen)
|
||||||
header[7] = byte(c.Type)
|
buf[7] = byte(c.Type)
|
||||||
_, err := w.Write(header)
|
copy(buf[8:], c.Body)
|
||||||
if err != nil {
|
return buf, nil
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = w.Write(c.Body)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
@@ -28,8 +28,7 @@ func TestChunk1Read(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestChunk1Write(t *testing.T) {
|
func TestChunk1Write(t *testing.T) {
|
||||||
var buf bytes.Buffer
|
buf, err := chunk1dec.Write()
|
||||||
err := chunk1dec.Write(&buf)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, chunk1enc, buf.Bytes())
|
require.Equal(t, chunk1enc, buf)
|
||||||
}
|
}
|
||||||
|
@@ -31,17 +31,12 @@ func (c *Chunk2) Read(r io.Reader, chunkBodyLen int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Write writes the chunk.
|
// Write writes the chunk.
|
||||||
func (c Chunk2) Write(w io.Writer) error {
|
func (c Chunk2) Write() ([]byte, error) {
|
||||||
header := make([]byte, 4)
|
buf := make([]byte, 4+len(c.Body))
|
||||||
header[0] = 2<<6 | c.ChunkStreamID
|
buf[0] = 2<<6 | c.ChunkStreamID
|
||||||
header[1] = byte(c.TimestampDelta >> 16)
|
buf[1] = byte(c.TimestampDelta >> 16)
|
||||||
header[2] = byte(c.TimestampDelta >> 8)
|
buf[2] = byte(c.TimestampDelta >> 8)
|
||||||
header[3] = byte(c.TimestampDelta)
|
buf[3] = byte(c.TimestampDelta)
|
||||||
_, err := w.Write(header)
|
copy(buf[4:], c.Body)
|
||||||
if err != nil {
|
return buf, nil
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = w.Write(c.Body)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
@@ -25,8 +25,7 @@ func TestChunk2Read(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestChunk2Write(t *testing.T) {
|
func TestChunk2Write(t *testing.T) {
|
||||||
var buf bytes.Buffer
|
buf, err := chunk2dec.Write()
|
||||||
err := chunk2dec.Write(&buf)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, chunk2enc, buf.Bytes())
|
require.Equal(t, chunk2enc, buf)
|
||||||
}
|
}
|
||||||
|
@@ -31,14 +31,9 @@ func (c *Chunk3) Read(r io.Reader, chunkBodyLen int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Write writes the chunk.
|
// Write writes the chunk.
|
||||||
func (c Chunk3) Write(w io.Writer) error {
|
func (c Chunk3) Write() ([]byte, error) {
|
||||||
header := make([]byte, 1)
|
buf := make([]byte, 1+len(c.Body))
|
||||||
header[0] = 3<<6 | c.ChunkStreamID
|
buf[0] = 3<<6 | c.ChunkStreamID
|
||||||
_, err := w.Write(header)
|
copy(buf[1:], c.Body)
|
||||||
if err != nil {
|
return buf, nil
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = w.Write(c.Body)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
@@ -24,8 +24,7 @@ func TestChunk3Read(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestChunk3Write(t *testing.T) {
|
func TestChunk3Write(t *testing.T) {
|
||||||
var buf bytes.Buffer
|
buf, err := chunk3dec.Write()
|
||||||
err := chunk3dec.Write(&buf)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, chunk3enc, buf.Bytes())
|
require.Equal(t, chunk3enc, buf)
|
||||||
}
|
}
|
||||||
|
@@ -3,7 +3,6 @@ package rawmessage
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"io"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
|
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
|
||||||
@@ -11,7 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type writableChunk interface {
|
type writableChunk interface {
|
||||||
Write(w io.Writer) error
|
Write() ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type sequenceEntry struct {
|
type sequenceEntry struct {
|
||||||
@@ -25,8 +24,9 @@ func TestReader(t *testing.T) {
|
|||||||
r := NewReader(bufio.NewReader(&buf))
|
r := NewReader(bufio.NewReader(&buf))
|
||||||
|
|
||||||
for _, entry := range seq {
|
for _, entry := range seq {
|
||||||
err := entry.chunk.Write(&buf)
|
buf2, err := entry.chunk.Write()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
buf.Write(buf2)
|
||||||
msg, err := r.Read()
|
msg, err := r.Read()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, entry.msg, msg)
|
require.Equal(t, entry.msg, msg)
|
||||||
@@ -124,21 +124,23 @@ func TestReader(t *testing.T) {
|
|||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
r := NewReader(bufio.NewReader(&buf))
|
r := NewReader(bufio.NewReader(&buf))
|
||||||
|
|
||||||
err := chunk.Chunk0{
|
buf2, err := chunk.Chunk0{
|
||||||
ChunkStreamID: 27,
|
ChunkStreamID: 27,
|
||||||
Timestamp: 18576,
|
Timestamp: 18576,
|
||||||
Type: chunk.MessageTypeSetPeerBandwidth,
|
Type: chunk.MessageTypeSetPeerBandwidth,
|
||||||
MessageStreamID: 3123,
|
MessageStreamID: 3123,
|
||||||
BodyLen: 192,
|
BodyLen: 192,
|
||||||
Body: bytes.Repeat([]byte{0x03}, 128),
|
Body: bytes.Repeat([]byte{0x03}, 128),
|
||||||
}.Write(&buf)
|
}.Write()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
buf.Write(buf2)
|
||||||
|
|
||||||
err = chunk.Chunk3{
|
buf2, err = chunk.Chunk3{
|
||||||
ChunkStreamID: 27,
|
ChunkStreamID: 27,
|
||||||
Body: bytes.Repeat([]byte{0x03}, 64),
|
Body: bytes.Repeat([]byte{0x03}, 64),
|
||||||
}.Write(&buf)
|
}.Write()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
buf.Write(buf2)
|
||||||
|
|
||||||
msg, err := r.Read()
|
msg, err := r.Read()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@@ -42,45 +42,65 @@ func (wc *writerChunkStream) write(msg *Message) error {
|
|||||||
|
|
||||||
switch {
|
switch {
|
||||||
case wc.lastMessageStreamID == nil || timestampDelta == nil || *wc.lastMessageStreamID != msg.MessageStreamID:
|
case wc.lastMessageStreamID == nil || timestampDelta == nil || *wc.lastMessageStreamID != msg.MessageStreamID:
|
||||||
err := chunk.Chunk0{
|
buf, err := chunk.Chunk0{
|
||||||
ChunkStreamID: msg.ChunkStreamID,
|
ChunkStreamID: msg.ChunkStreamID,
|
||||||
Timestamp: msg.Timestamp,
|
Timestamp: msg.Timestamp,
|
||||||
Type: msg.Type,
|
Type: msg.Type,
|
||||||
MessageStreamID: msg.MessageStreamID,
|
MessageStreamID: msg.MessageStreamID,
|
||||||
BodyLen: uint32(bodyLen),
|
BodyLen: uint32(bodyLen),
|
||||||
Body: msg.Body[pos : pos+chunkBodyLen],
|
Body: msg.Body[pos : pos+chunkBodyLen],
|
||||||
}.Write(wc.mw.w)
|
}.Write()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = wc.mw.w.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
case *wc.lastType != msg.Type || *wc.lastBodyLen != bodyLen:
|
case *wc.lastType != msg.Type || *wc.lastBodyLen != bodyLen:
|
||||||
err := chunk.Chunk1{
|
buf, err := chunk.Chunk1{
|
||||||
ChunkStreamID: msg.ChunkStreamID,
|
ChunkStreamID: msg.ChunkStreamID,
|
||||||
TimestampDelta: *timestampDelta,
|
TimestampDelta: *timestampDelta,
|
||||||
Type: msg.Type,
|
Type: msg.Type,
|
||||||
BodyLen: uint32(bodyLen),
|
BodyLen: uint32(bodyLen),
|
||||||
Body: msg.Body[pos : pos+chunkBodyLen],
|
Body: msg.Body[pos : pos+chunkBodyLen],
|
||||||
}.Write(wc.mw.w)
|
}.Write()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = wc.mw.w.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
case wc.lastTimestampDelta == nil || *wc.lastTimestampDelta != *timestampDelta:
|
case wc.lastTimestampDelta == nil || *wc.lastTimestampDelta != *timestampDelta:
|
||||||
err := chunk.Chunk2{
|
buf, err := chunk.Chunk2{
|
||||||
ChunkStreamID: msg.ChunkStreamID,
|
ChunkStreamID: msg.ChunkStreamID,
|
||||||
TimestampDelta: *timestampDelta,
|
TimestampDelta: *timestampDelta,
|
||||||
Body: msg.Body[pos : pos+chunkBodyLen],
|
Body: msg.Body[pos : pos+chunkBodyLen],
|
||||||
}.Write(wc.mw.w)
|
}.Write()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = wc.mw.w.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
err := chunk.Chunk3{
|
buf, err := chunk.Chunk3{
|
||||||
ChunkStreamID: msg.ChunkStreamID,
|
ChunkStreamID: msg.ChunkStreamID,
|
||||||
Body: msg.Body[pos : pos+chunkBodyLen],
|
Body: msg.Body[pos : pos+chunkBodyLen],
|
||||||
}.Write(wc.mw.w)
|
}.Write()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = wc.mw.w.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -100,10 +120,15 @@ func (wc *writerChunkStream) write(msg *Message) error {
|
|||||||
wc.lastTimestampDelta = &v5
|
wc.lastTimestampDelta = &v5
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err := chunk.Chunk3{
|
buf, err := chunk.Chunk3{
|
||||||
ChunkStreamID: msg.ChunkStreamID,
|
ChunkStreamID: msg.ChunkStreamID,
|
||||||
Body: msg.Body[pos : pos+chunkBodyLen],
|
Body: msg.Body[pos : pos+chunkBodyLen],
|
||||||
}.Write(wc.mw.w)
|
}.Write()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = wc.mw.w.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user