allow writing primitives to static buffers

This commit is contained in:
aler9
2022-05-11 14:52:20 +02:00
parent ee6d7a87a3
commit c1b10a80be
19 changed files with 662 additions and 786 deletions

View File

@@ -9,7 +9,6 @@ package gortsplib
import ( import (
"bufio" "bufio"
"bytes"
"context" "context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
@@ -1059,11 +1058,10 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba
c.OnRequest(req) c.OnRequest(req)
} }
var buf bytes.Buffer byts, _ := req.Write()
req.Write(&buf)
c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
_, err := c.conn.Write(buf.Bytes()) _, err := c.conn.Write(byts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1877,25 +1875,23 @@ func (c *Client) runWriter() {
rtcpFrames[trackID] = &base.InterleavedFrame{Channel: cct.tcpChannel + 1} rtcpFrames[trackID] = &base.InterleavedFrame{Channel: cct.tcpChannel + 1}
} }
var buf bytes.Buffer buf := make([]byte, maxPacketSize+4)
writeFunc = func(trackID int, isRTP bool, payload []byte) { writeFunc = func(trackID int, isRTP bool, payload []byte) {
if isRTP { if isRTP {
f := rtpFrames[trackID] f := rtpFrames[trackID]
f.Payload = payload f.Payload = payload
buf.Reset() n, _ := f.WriteTo(buf)
f.Write(&buf)
c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
c.conn.Write(buf.Bytes()) c.conn.Write(buf[:n])
} else { } else {
f := rtcpFrames[trackID] f := rtcpFrames[trackID]
f.Payload = payload f.Payload = payload
buf.Reset() n, _ := f.WriteTo(buf)
f.Write(&buf)
c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
c.conn.Write(buf.Bytes()) c.conn.Write(buf[:n])
} }
} }
} }

View File

@@ -2,7 +2,6 @@ package gortsplib
import ( import (
"bufio" "bufio"
"bytes"
"crypto/tls" "crypto/tls"
"net" "net"
"strings" "strings"
@@ -83,15 +82,13 @@ func TestClientPublishSerial(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
br := bufio.NewReader(conn) br := bufio.NewReader(conn)
var bb bytes.Buffer
req, err := readRequest(br) req, err := readRequest(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Options, req.Method) require.Equal(t, base.Options, req.Method)
require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL) require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL)
bb.Reset() byts, _ := base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{ "Public": base.HeaderValue{strings.Join([]string{
@@ -100,8 +97,8 @@ func TestClientPublishSerial(t *testing.T) {
string(base.Record), string(base.Record),
}, ", ")}, }, ", ")},
}, },
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
@@ -109,11 +106,10 @@ func TestClientPublishSerial(t *testing.T) {
require.Equal(t, base.Announce, req.Method) require.Equal(t, base.Announce, req.Method)
require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL) require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
@@ -153,14 +149,13 @@ func TestClientPublishSerial(t *testing.T) {
th.InterleavedIDs = inTH.InterleavedIDs th.InterleavedIDs = inTH.InterleavedIDs
} }
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Transport": th.Write(), "Transport": th.Write(),
}, },
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
@@ -168,11 +163,10 @@ func TestClientPublishSerial(t *testing.T) {
require.Equal(t, base.Record, req.Method) require.Equal(t, base.Record, req.Method)
require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL) require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
// client -> server (RTP) // client -> server (RTP)
@@ -202,12 +196,11 @@ func TestClientPublishSerial(t *testing.T) {
Port: th.ClientPorts[1], Port: th.ClientPorts[1],
}) })
} else { } else {
bb.Reset() byts, _ := base.InterleavedFrame{
base.InterleavedFrame{
Channel: 1, Channel: 1,
Payload: testRTCPPacketMarshaled, Payload: testRTCPPacketMarshaled,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
} }
@@ -216,11 +209,10 @@ func TestClientPublishSerial(t *testing.T) {
require.Equal(t, base.Teardown, req.Method) require.Equal(t, base.Teardown, req.Method)
require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL) require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
}() }()
@@ -303,14 +295,12 @@ func TestClientPublishParallel(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
br := bufio.NewReader(conn) br := bufio.NewReader(conn)
var bb bytes.Buffer
req, err := readRequest(br) req, err := readRequest(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Options, req.Method) require.Equal(t, base.Options, req.Method)
bb.Reset() byts, _ := base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{ "Public": base.HeaderValue{strings.Join([]string{
@@ -319,19 +309,18 @@ func TestClientPublishParallel(t *testing.T) {
string(base.Record), string(base.Record),
}, ", ")}, }, ", ")},
}, },
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Announce, req.Method) require.Equal(t, base.Announce, req.Method)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
@@ -358,36 +347,33 @@ func TestClientPublishParallel(t *testing.T) {
th.InterleavedIDs = inTH.InterleavedIDs th.InterleavedIDs = inTH.InterleavedIDs
} }
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Transport": th.Write(), "Transport": th.Write(),
}, },
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Record, req.Method) require.Equal(t, base.Record, req.Method)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequestIgnoreFrames(br) req, err = readRequestIgnoreFrames(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Teardown, req.Method) require.Equal(t, base.Teardown, req.Method)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
}() }()
@@ -454,14 +440,12 @@ func TestClientPublishPauseSerial(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
br := bufio.NewReader(conn) br := bufio.NewReader(conn)
var bb bytes.Buffer
req, err := readRequest(br) req, err := readRequest(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Options, req.Method) require.Equal(t, base.Options, req.Method)
bb.Reset() byts, _ := base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{ "Public": base.HeaderValue{strings.Join([]string{
@@ -471,19 +455,18 @@ func TestClientPublishPauseSerial(t *testing.T) {
string(base.Pause), string(base.Pause),
}, ", ")}, }, ", ")},
}, },
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Announce, req.Method) require.Equal(t, base.Announce, req.Method)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
@@ -510,58 +493,53 @@ func TestClientPublishPauseSerial(t *testing.T) {
th.InterleavedIDs = inTH.InterleavedIDs th.InterleavedIDs = inTH.InterleavedIDs
} }
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Transport": th.Write(), "Transport": th.Write(),
}, },
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Record, req.Method) require.Equal(t, base.Record, req.Method)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequestIgnoreFrames(br) req, err = readRequestIgnoreFrames(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Pause, req.Method) require.Equal(t, base.Pause, req.Method)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Record, req.Method) require.Equal(t, base.Record, req.Method)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequestIgnoreFrames(br) req, err = readRequestIgnoreFrames(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Teardown, req.Method) require.Equal(t, base.Teardown, req.Method)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
}() }()
@@ -618,14 +596,12 @@ func TestClientPublishPauseParallel(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
br := bufio.NewReader(conn) br := bufio.NewReader(conn)
var bb bytes.Buffer
req, err := readRequest(br) req, err := readRequest(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Options, req.Method) require.Equal(t, base.Options, req.Method)
bb.Reset() byts, _ := base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{ "Public": base.HeaderValue{strings.Join([]string{
@@ -635,19 +611,18 @@ func TestClientPublishPauseParallel(t *testing.T) {
string(base.Pause), string(base.Pause),
}, ", ")}, }, ", ")},
}, },
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Announce, req.Method) require.Equal(t, base.Announce, req.Method)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
@@ -674,36 +649,33 @@ func TestClientPublishPauseParallel(t *testing.T) {
th.InterleavedIDs = inTH.InterleavedIDs th.InterleavedIDs = inTH.InterleavedIDs
} }
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Transport": th.Write(), "Transport": th.Write(),
}, },
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Record, req.Method) require.Equal(t, base.Record, req.Method)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequestIgnoreFrames(br) req, err = readRequestIgnoreFrames(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Pause, req.Method) require.Equal(t, base.Pause, req.Method)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
}() }()
@@ -765,15 +737,13 @@ func TestClientPublishAutomaticProtocol(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
br := bufio.NewReader(conn) br := bufio.NewReader(conn)
var bb bytes.Buffer
req, err := readRequest(br) req, err := readRequest(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Options, req.Method) require.Equal(t, base.Options, req.Method)
require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL)
bb.Reset() byts, _ := base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{ "Public": base.HeaderValue{strings.Join([]string{
@@ -782,8 +752,8 @@ func TestClientPublishAutomaticProtocol(t *testing.T) {
string(base.Record), string(base.Record),
}, ", ")}, }, ", ")},
}, },
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
@@ -791,22 +761,20 @@ func TestClientPublishAutomaticProtocol(t *testing.T) {
require.Equal(t, base.Announce, req.Method) require.Equal(t, base.Announce, req.Method)
require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Setup, req.Method) require.Equal(t, base.Setup, req.Method)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusUnsupportedTransport, StatusCode: base.StatusUnsupportedTransport,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
@@ -827,14 +795,13 @@ func TestClientPublishAutomaticProtocol(t *testing.T) {
InterleavedIDs: &[2]int{0, 1}, InterleavedIDs: &[2]int{0, 1},
} }
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Transport": th.Write(), "Transport": th.Write(),
}, },
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
@@ -842,11 +809,10 @@ func TestClientPublishAutomaticProtocol(t *testing.T) {
require.Equal(t, base.Record, req.Method) require.Equal(t, base.Record, req.Method)
require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
var f base.InterleavedFrame var f base.InterleavedFrame
@@ -862,11 +828,10 @@ func TestClientPublishAutomaticProtocol(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Teardown, req.Method) require.Equal(t, base.Teardown, req.Method)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
}() }()
@@ -900,14 +865,12 @@ func TestClientPublishRTCPReport(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
br := bufio.NewReader(conn) br := bufio.NewReader(conn)
var bb bytes.Buffer
req, err := readRequest(br) req, err := readRequest(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Options, req.Method) require.Equal(t, base.Options, req.Method)
bb.Reset() byts, _ := base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{ "Public": base.HeaderValue{strings.Join([]string{
@@ -916,19 +879,18 @@ func TestClientPublishRTCPReport(t *testing.T) {
string(base.Record), string(base.Record),
}, ", ")}, }, ", ")},
}, },
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Announce, req.Method) require.Equal(t, base.Announce, req.Method)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
@@ -947,8 +909,7 @@ func TestClientPublishRTCPReport(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer l2.Close() defer l2.Close()
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Transport": headers.Transport{ "Transport": headers.Transport{
@@ -961,19 +922,18 @@ func TestClientPublishRTCPReport(t *testing.T) {
ServerPorts: &[2]int{34556, 34557}, ServerPorts: &[2]int{34556, 34557},
}.Write(), }.Write(),
}, },
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Record, req.Method) require.Equal(t, base.Record, req.Method)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
buf := make([]byte, 2048) buf := make([]byte, 2048)
@@ -1004,11 +964,10 @@ func TestClientPublishRTCPReport(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Teardown, req.Method) require.Equal(t, base.Teardown, req.Method)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
}() }()
@@ -1054,14 +1013,12 @@ func TestClientPublishIgnoreTCPRTPPackets(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
br := bufio.NewReader(conn) br := bufio.NewReader(conn)
var bb bytes.Buffer
req, err := readRequest(br) req, err := readRequest(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Options, req.Method) require.Equal(t, base.Options, req.Method)
bb.Reset() byts, _ := base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{ "Public": base.HeaderValue{strings.Join([]string{
@@ -1070,19 +1027,18 @@ func TestClientPublishIgnoreTCPRTPPackets(t *testing.T) {
string(base.Record), string(base.Record),
}, ", ")}, }, ", ")},
}, },
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Announce, req.Method) require.Equal(t, base.Announce, req.Method)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
@@ -1102,52 +1058,47 @@ func TestClientPublishIgnoreTCPRTPPackets(t *testing.T) {
InterleavedIDs: inTH.InterleavedIDs, InterleavedIDs: inTH.InterleavedIDs,
} }
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Transport": th.Write(), "Transport": th.Write(),
}, },
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Record, req.Method) require.Equal(t, base.Record, req.Method)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
bb.Reset() byts, _ = base.InterleavedFrame{
base.InterleavedFrame{
Channel: 0, Channel: 0,
Payload: testRTPPacketMarshaled, Payload: testRTPPacketMarshaled,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
bb.Reset() byts, _ = base.InterleavedFrame{
base.InterleavedFrame{
Channel: 1, Channel: 1,
Payload: testRTCPPacketMarshaled, Payload: testRTCPPacketMarshaled,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Teardown, req.Method) require.Equal(t, base.Teardown, req.Method)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
}() }()

File diff suppressed because it is too large Load Diff

View File

@@ -2,7 +2,6 @@ package gortsplib
import ( import (
"bufio" "bufio"
"bytes"
"crypto/tls" "crypto/tls"
"net" "net"
"strings" "strings"
@@ -94,15 +93,13 @@ func TestClientSession(t *testing.T) {
conn, err := l.Accept() conn, err := l.Accept()
require.NoError(t, err) require.NoError(t, err)
br := bufio.NewReader(conn) br := bufio.NewReader(conn)
var bb bytes.Buffer
defer conn.Close() defer conn.Close()
req, err := readRequest(br) req, err := readRequest(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Options, req.Method) require.Equal(t, base.Options, req.Method)
bb.Reset() byts, _ := base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{ "Public": base.HeaderValue{strings.Join([]string{
@@ -110,8 +107,8 @@ func TestClientSession(t *testing.T) {
}, ", ")}, }, ", ")},
"Session": base.HeaderValue{"123456"}, "Session": base.HeaderValue{"123456"},
}, },
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
@@ -126,16 +123,15 @@ func TestClientSession(t *testing.T) {
tracks := Tracks{track} tracks := Tracks{track}
tracks.setControls() tracks.setControls()
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Content-Type": base.HeaderValue{"application/sdp"}, "Content-Type": base.HeaderValue{"application/sdp"},
"Session": base.HeaderValue{"123456"}, "Session": base.HeaderValue{"123456"},
}, },
Body: tracks.Write(false), Body: tracks.Write(false),
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
}() }()
@@ -165,23 +161,21 @@ func TestClientAuth(t *testing.T) {
conn, err := l.Accept() conn, err := l.Accept()
require.NoError(t, err) require.NoError(t, err)
br := bufio.NewReader(conn) br := bufio.NewReader(conn)
var bb bytes.Buffer
defer conn.Close() defer conn.Close()
req, err := readRequest(br) req, err := readRequest(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Options, req.Method) require.Equal(t, base.Options, req.Method)
bb.Reset() byts, _ := base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{ "Public": base.HeaderValue{strings.Join([]string{
string(base.Describe), string(base.Describe),
}, ", ")}, }, ", ")},
}, },
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
@@ -190,14 +184,13 @@ func TestClientAuth(t *testing.T) {
v := auth.NewValidator("myuser", "mypass", nil) v := auth.NewValidator("myuser", "mypass", nil)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusUnauthorized, StatusCode: base.StatusUnauthorized,
Header: base.Header{ Header: base.Header{
"WWW-Authenticate": v.Header(), "WWW-Authenticate": v.Header(),
}, },
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
@@ -213,15 +206,14 @@ func TestClientAuth(t *testing.T) {
tracks := Tracks{track} tracks := Tracks{track}
tracks.setControls() tracks.setControls()
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Content-Type": base.HeaderValue{"application/sdp"}, "Content-Type": base.HeaderValue{"application/sdp"},
}, },
Body: tracks.Write(false), Body: tracks.Write(false),
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
}() }()
@@ -252,22 +244,20 @@ func TestClientDescribeCharset(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
br := bufio.NewReader(conn) br := bufio.NewReader(conn)
var bb bytes.Buffer
req, err := readRequest(br) req, err := readRequest(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Options, req.Method) require.Equal(t, base.Options, req.Method)
bb.Reset() byts, _ := base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{ "Public": base.HeaderValue{strings.Join([]string{
string(base.Describe), string(base.Describe),
}, ", ")}, }, ", ")},
}, },
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
req, err = readRequest(br) req, err = readRequest(br)
@@ -278,16 +268,15 @@ func TestClientDescribeCharset(t *testing.T) {
track1, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil) track1, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil)
require.NoError(t, err) require.NoError(t, err)
bb.Reset() byts, _ = base.Response{
base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Content-Type": base.HeaderValue{"application/sdp; charset=utf-8"}, "Content-Type": base.HeaderValue{"application/sdp; charset=utf-8"},
"Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream/"}, "Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream/"},
}, },
Body: Tracks{track1}.Write(false), Body: Tracks{track1}.Write(false),
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
}() }()

View File

@@ -35,11 +35,16 @@ func (b *body) read(header Header, rb *bufio.Reader) error {
return nil return nil
} }
func (b body) write(w io.Writer) error { func (b body) writeSize() int {
if len(b) == 0 { return len(b)
return nil }
}
func (b body) writeTo(buf []byte) int {
_, err := w.Write(b) return copy(buf, b)
return err }
func (b body) write() []byte {
buf := make([]byte, b.writeSize())
b.writeTo(buf)
return buf
} }

View File

@@ -20,11 +20,6 @@ var casesBody = []struct {
}, },
[]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04},
}, },
{
"nil",
Header{},
[]byte(nil),
},
} }
func TestBodyRead(t *testing.T) { func TestBodyRead(t *testing.T) {
@@ -81,9 +76,8 @@ func TestBodyReadErrors(t *testing.T) {
func TestBodyWrite(t *testing.T) { func TestBodyWrite(t *testing.T) {
for _, ca := range casesBody { for _, ca := range casesBody {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
var buf bytes.Buffer buf := body(ca.byts).write()
body(ca.byts).write(&buf) require.Equal(t, ca.byts, buf)
require.Equal(t, ca.byts, buf.Bytes())
}) })
} }
} }

View File

@@ -3,7 +3,6 @@ package base
import ( import (
"bufio" "bufio"
"fmt" "fmt"
"io"
"net/http" "net/http"
"sort" "sort"
"strings" "strings"
@@ -98,7 +97,7 @@ func (h *Header) read(rb *bufio.Reader) error {
return nil return nil
} }
func (h Header) write(w io.Writer) error { func (h Header) writeSize() int {
// sort headers by key // sort headers by key
// in order to obtain deterministic results // in order to obtain deterministic results
keys := make([]string, len(h)) keys := make([]string, len(h))
@@ -107,15 +106,43 @@ func (h Header) write(w io.Writer) error {
} }
sort.Strings(keys) sort.Strings(keys)
n := 0
for _, key := range keys { for _, key := range keys {
for _, val := range h[key] { for _, val := range h[key] {
_, err := w.Write([]byte(key + ": " + val + "\r\n")) n += len([]byte(key + ": " + val + "\r\n"))
if err != nil {
return err
}
} }
} }
_, err := w.Write([]byte("\r\n")) n += 2
return err
return n
}
func (h Header) writeTo(buf []byte) int {
// sort headers by key
// in order to obtain deterministic results
keys := make([]string, len(h))
for key := range h {
keys = append(keys, key)
}
sort.Strings(keys)
pos := 0
for _, key := range keys {
for _, val := range h[key] {
pos += copy(buf[pos:], []byte(key+": "+val+"\r\n"))
}
}
pos += copy(buf[pos:], []byte("\r\n"))
return pos
}
func (h Header) write() []byte {
buf := make([]byte, h.writeSize())
h.writeTo(buf)
return buf
} }

View File

@@ -176,9 +176,8 @@ func TestHeaderReadErrors(t *testing.T) {
func TestHeaderWrite(t *testing.T) { func TestHeaderWrite(t *testing.T) {
for _, ca := range casesHeader { for _, ca := range casesHeader {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
var buf bytes.Buffer buf := ca.header.write()
ca.header.write(&buf) require.Equal(t, ca.enc, buf)
require.Equal(t, ca.enc, buf.Bytes())
}) })
} }
} }

View File

@@ -105,16 +105,28 @@ func (f *InterleavedFrame) Read(maxPayloadSize int, br *bufio.Reader) error {
return nil return nil
} }
// Write writes an InterleavedFrame into a buffered writer. // WriteSize returns the size of an InterleavedFrame.
func (f InterleavedFrame) Write(w io.Writer) error { func (f InterleavedFrame) WriteSize() int {
buf := []byte{0x24, byte(f.Channel), 0x00, 0x00} return 4 + len(f.Payload)
binary.BigEndian.PutUint16(buf[2:], uint16(len(f.Payload))) }
_, err := w.Write(buf) // WriteTo writes an InterleavedFrame.
if err != nil { func (f InterleavedFrame) WriteTo(buf []byte) (int, error) {
return err pos := 0
}
pos += copy(buf[pos:], []byte{0x24, byte(f.Channel)})
_, err = w.Write(f.Payload)
return err binary.BigEndian.PutUint16(buf[pos:], uint16(len(f.Payload)))
pos += 2
pos += copy(buf[pos:], f.Payload)
return pos, nil
}
// Write writes an InterleavedFrame.
func (f InterleavedFrame) Write() ([]byte, error) {
buf := make([]byte, f.WriteSize())
_, err := f.WriteTo(buf)
return buf, err
} }

View File

@@ -82,9 +82,9 @@ func TestInterleavedFrameReadErrors(t *testing.T) {
func TestInterleavedFrameWrite(t *testing.T) { func TestInterleavedFrameWrite(t *testing.T) {
for _, ca := range casesInterleavedFrame { for _, ca := range casesInterleavedFrame {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
var buf bytes.Buffer buf, err := ca.dec.Write()
ca.dec.Write(&buf) require.NoError(t, err)
require.Equal(t, ca.enc, buf.Bytes()) require.Equal(t, ca.enc, buf)
}) })
} }
} }

View File

@@ -3,9 +3,7 @@ package base
import ( import (
"bufio" "bufio"
"bytes"
"fmt" "fmt"
"io"
"strconv" "strconv"
) )
@@ -117,29 +115,51 @@ func (req *Request) ReadIgnoreFrames(maxPayloadSize int, rb *bufio.Reader) error
} }
} }
// Write writes a request. // WriteSize returns the size of a Request.
func (req Request) Write(w io.Writer) error { func (req Request) WriteSize() int {
n := 0
urStr := req.URL.CloneWithoutCredentials().String() urStr := req.URL.CloneWithoutCredentials().String()
_, err := w.Write([]byte(string(req.Method) + " " + urStr + " " + rtspProtocol10 + "\r\n")) n += len([]byte(string(req.Method) + " " + urStr + " " + rtspProtocol10 + "\r\n"))
if err != nil {
return err
}
if len(req.Body) != 0 { if len(req.Body) != 0 {
req.Header["Content-Length"] = HeaderValue{strconv.FormatInt(int64(len(req.Body)), 10)} req.Header["Content-Length"] = HeaderValue{strconv.FormatInt(int64(len(req.Body)), 10)}
} }
err = req.Header.write(w) n += req.Header.writeSize()
if err != nil {
return err n += body(req.Body).writeSize()
return n
}
// WriteTo writes a Request.
func (req Request) WriteTo(buf []byte) (int, error) {
pos := 0
urStr := req.URL.CloneWithoutCredentials().String()
pos += copy(buf[pos:], []byte(string(req.Method)+" "+urStr+" "+rtspProtocol10+"\r\n"))
if len(req.Body) != 0 {
req.Header["Content-Length"] = HeaderValue{strconv.FormatInt(int64(len(req.Body)), 10)}
} }
return body(req.Body).write(w) pos += req.Header.writeTo(buf[pos:])
pos += body(req.Body).writeTo(buf[pos:])
return pos, nil
}
// Write writes a Request.
func (req Request) Write() ([]byte, error) {
buf := make([]byte, req.WriteSize())
_, err := req.WriteTo(buf)
return buf, err
} }
// String implements fmt.Stringer. // String implements fmt.Stringer.
func (req Request) String() string { func (req Request) String() string {
var buf bytes.Buffer buf, _ := req.Write()
req.Write(&buf) return string(buf)
return buf.String()
} }

View File

@@ -221,9 +221,9 @@ func TestRequestReadErrors(t *testing.T) {
func TestRequestWrite(t *testing.T) { func TestRequestWrite(t *testing.T) {
for _, ca := range casesRequest { for _, ca := range casesRequest {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
var buf bytes.Buffer buf, err := ca.req.Write()
ca.req.Write(&buf) require.NoError(t, err)
require.Equal(t, ca.byts, buf.Bytes()) require.Equal(t, ca.byts, buf)
}) })
} }
} }

View File

@@ -2,9 +2,7 @@ package base
import ( import (
"bufio" "bufio"
"bytes"
"fmt" "fmt"
"io"
"strconv" "strconv"
) )
@@ -203,36 +201,65 @@ func (res *Response) ReadIgnoreFrames(maxPayloadSize int, rb *bufio.Reader) erro
} }
} }
// Write writes a Response. // WriteSize returns the size of a Response.
func (res Response) Write(w io.Writer) error { func (res Response) WriteSize() int {
n := 0
if res.StatusMessage == "" { if res.StatusMessage == "" {
if status, ok := statusMessages[res.StatusCode]; ok { if status, ok := statusMessages[res.StatusCode]; ok {
res.StatusMessage = status res.StatusMessage = status
} }
} }
_, err := w.Write([]byte(rtspProtocol10 + " " + n += len([]byte(rtspProtocol10 + " " +
strconv.FormatInt(int64(res.StatusCode), 10) + " " + strconv.FormatInt(int64(res.StatusCode), 10) + " " +
res.StatusMessage + "\r\n")) res.StatusMessage + "\r\n"))
if err != nil {
return err
}
if len(res.Body) != 0 { if len(res.Body) != 0 {
res.Header["Content-Length"] = HeaderValue{strconv.FormatInt(int64(len(res.Body)), 10)} res.Header["Content-Length"] = HeaderValue{strconv.FormatInt(int64(len(res.Body)), 10)}
} }
err = res.Header.write(w) n += res.Header.writeSize()
if err != nil {
return err n += body(res.Body).writeSize()
return n
}
// WriteTo writes a Response.
func (res Response) WriteTo(buf []byte) (int, error) {
if res.StatusMessage == "" {
if status, ok := statusMessages[res.StatusCode]; ok {
res.StatusMessage = status
}
} }
return body(res.Body).write(w) pos := 0
pos += copy(buf[pos:], []byte(rtspProtocol10+" "+
strconv.FormatInt(int64(res.StatusCode), 10)+" "+
res.StatusMessage+"\r\n"))
if len(res.Body) != 0 {
res.Header["Content-Length"] = HeaderValue{strconv.FormatInt(int64(len(res.Body)), 10)}
}
pos += res.Header.writeTo(buf[pos:])
pos += body(res.Body).writeTo(buf[pos:])
return pos, nil
}
// Write writes a Response.
func (res Response) Write() ([]byte, error) {
buf := make([]byte, res.WriteSize())
_, err := res.WriteTo(buf)
return buf, err
} }
// String implements fmt.Stringer. // String implements fmt.Stringer.
func (res Response) String() string { func (res Response) String() string {
var buf bytes.Buffer buf, _ := res.Write()
res.Write(&buf) return string(buf)
return buf.String()
} }

View File

@@ -178,9 +178,9 @@ func TestResponseReadErrors(t *testing.T) {
func TestResponseWrite(t *testing.T) { func TestResponseWrite(t *testing.T) {
for _, c := range casesResponse { for _, c := range casesResponse {
t.Run(c.name, func(t *testing.T) { t.Run(c.name, func(t *testing.T) {
var buf bytes.Buffer buf, err := c.res.Write()
c.res.Write(&buf) require.NoError(t, err)
require.Equal(t, c.byts, buf.Bytes()) require.Equal(t, c.byts, buf)
}) })
} }
} }
@@ -207,9 +207,9 @@ func TestResponseWriteAutoFillStatus(t *testing.T) {
"\r\n", "\r\n",
) )
var buf bytes.Buffer buf, err := res.Write()
res.Write(&buf) require.NoError(t, err)
require.Equal(t, byts, buf.Bytes()) require.Equal(t, byts, buf)
} }
func TestResponseReadIgnoreFrames(t *testing.T) { func TestResponseReadIgnoreFrames(t *testing.T) {

View File

@@ -2,7 +2,6 @@ package gortsplib
import ( import (
"bufio" "bufio"
"bytes"
"crypto/tls" "crypto/tls"
"net" "net"
"testing" "testing"
@@ -644,7 +643,6 @@ func TestServerPublish(t *testing.T) {
return conn return conn
}() }()
br := bufio.NewReader(conn) br := bufio.NewReader(conn)
var bb bytes.Buffer
<-connOpened <-connOpened
@@ -765,20 +763,18 @@ func TestServerPublish(t *testing.T) {
Port: th.ServerPorts[1], Port: th.ServerPorts[1],
}) })
} else { } else {
bb.Reset() byts, _ := base.InterleavedFrame{
base.InterleavedFrame{
Channel: 0, Channel: 0,
Payload: testRTPPacketMarshaled, Payload: testRTPPacketMarshaled,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
bb.Reset() byts, _ = base.InterleavedFrame{
base.InterleavedFrame{
Channel: 1, Channel: 1,
Payload: testRTCPPacketMarshaled, Payload: testRTCPPacketMarshaled,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
} }
@@ -857,7 +853,6 @@ func TestServerPublishOversizedPacket(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
br := bufio.NewReader(conn) br := bufio.NewReader(conn)
var bb bytes.Buffer
track, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil) track, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil)
require.NoError(t, err) require.NoError(t, err)
@@ -917,12 +912,11 @@ func TestServerPublishOversizedPacket(t *testing.T) {
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
byts, _ := oversizedPacketRTPIn.Marshal() byts, _ := oversizedPacketRTPIn.Marshal()
bb.Reset() byts, _ = base.InterleavedFrame{
base.InterleavedFrame{
Channel: 0, Channel: 0,
Payload: byts, Payload: byts,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
<-packetRecv <-packetRecv
@@ -963,7 +957,6 @@ func TestServerPublishErrorInvalidProtocol(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
br := bufio.NewReader(conn) br := bufio.NewReader(conn)
var bb bytes.Buffer
track, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil) track, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil)
require.NoError(t, err) require.NoError(t, err)
@@ -1026,12 +1019,11 @@ func TestServerPublishErrorInvalidProtocol(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
bb.Reset() byts, _ := base.InterleavedFrame{
base.InterleavedFrame{
Channel: 0, Channel: 0,
Payload: []byte{0x01, 0x02, 0x03, 0x04}, Payload: []byte{0x01, 0x02, 0x03, 0x04},
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
} }

View File

@@ -2,7 +2,6 @@ package gortsplib
import ( import (
"bufio" "bufio"
"bytes"
"crypto/tls" "crypto/tls"
"net" "net"
"strconv" "strconv"
@@ -359,7 +358,6 @@ func TestServerRead(t *testing.T) {
return conn return conn
}() }()
br := bufio.NewReader(conn) br := bufio.NewReader(conn)
var bb bytes.Buffer
<-connOpened <-connOpened
@@ -547,12 +545,11 @@ func TestServerRead(t *testing.T) {
<-framesReceived <-framesReceived
default: default:
bb.Reset() byts, _ := base.InterleavedFrame{
base.InterleavedFrame{
Channel: 5, Channel: 5,
Payload: testRTCPPacketMarshaled, Payload: testRTCPPacketMarshaled,
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
<-framesReceived <-framesReceived
} }
@@ -1128,7 +1125,6 @@ func TestServerReadPlayPausePause(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
br := bufio.NewReader(conn) br := bufio.NewReader(conn)
var bb bytes.Buffer
res, err := writeReqReadRes(conn, br, base.Request{ res, err := writeReqReadRes(conn, br, base.Request{
Method: base.Setup, Method: base.Setup,
@@ -1167,32 +1163,30 @@ func TestServerReadPlayPausePause(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
bb.Reset() byts, _ := base.Request{
base.Request{
Method: base.Pause, Method: base.Pause,
URL: mustParseURL("rtsp://localhost:8554/teststream"), URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"2"},
"Session": base.HeaderValue{sx.Session}, "Session": base.HeaderValue{sx.Session},
}, },
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
res, err = readResIgnoreFrames(br) res, err = readResIgnoreFrames(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
bb.Reset() byts, _ = base.Request{
base.Request{
Method: base.Pause, Method: base.Pause,
URL: mustParseURL("rtsp://localhost:8554/teststream"), URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"2"},
"Session": base.HeaderValue{sx.Session}, "Session": base.HeaderValue{sx.Session},
}, },
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
res, err = readResIgnoreFrames(br) res, err = readResIgnoreFrames(br)

View File

@@ -2,7 +2,6 @@ package gortsplib
import ( import (
"bufio" "bufio"
"bytes"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"net" "net"
@@ -24,9 +23,8 @@ func writeReqReadRes(conn net.Conn,
br *bufio.Reader, br *bufio.Reader,
req base.Request, req base.Request,
) (*base.Response, error) { ) (*base.Response, error) {
var bb bytes.Buffer byts, _ := req.Write()
req.Write(&bb) _, err := conn.Write(byts)
_, err := conn.Write(bb.Bytes())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1026,8 +1024,7 @@ func TestServerSessionClose(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
var bb bytes.Buffer byts, _ := base.Request{
base.Request{
Method: base.Setup, Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{ Header: base.Header{
@@ -1045,8 +1042,8 @@ func TestServerSessionClose(t *testing.T) {
InterleavedIDs: &[2]int{0, 1}, InterleavedIDs: &[2]int{0, 1},
}.Write(), }.Write(),
}, },
}.Write(&bb) }.Write()
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(byts)
require.NoError(t, err) require.NoError(t, err)
<-sessionClosed <-sessionClosed

View File

@@ -2,7 +2,6 @@ package gortsplib
import ( import (
"bufio" "bufio"
"bytes"
"context" "context"
"crypto/tls" "crypto/tls"
"errors" "errors"
@@ -589,11 +588,10 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error {
h.OnResponse(sc, res) h.OnResponse(sc, res)
} }
var buf bytes.Buffer byts, _ := res.Write()
res.Write(&buf)
sc.conn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout)) sc.conn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout))
sc.conn.Write(buf.Bytes()) sc.conn.Write(byts)
return err return err
} }

View File

@@ -1,7 +1,6 @@
package gortsplib package gortsplib
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"net" "net"
@@ -1182,25 +1181,23 @@ func (ss *ServerSession) runWriter() {
rtcpFrames[trackID] = &base.InterleavedFrame{Channel: sst.tcpChannel + 1} rtcpFrames[trackID] = &base.InterleavedFrame{Channel: sst.tcpChannel + 1}
} }
var buf bytes.Buffer buf := make([]byte, maxPacketSize+4)
writeFunc = func(trackID int, isRTP bool, payload []byte) { writeFunc = func(trackID int, isRTP bool, payload []byte) {
if isRTP { if isRTP {
f := rtpFrames[trackID] f := rtpFrames[trackID]
f.Payload = payload f.Payload = payload
buf.Reset() n, _ := f.WriteTo(buf)
f.Write(&buf)
ss.tcpConn.conn.SetWriteDeadline(time.Now().Add(ss.s.WriteTimeout)) ss.tcpConn.conn.SetWriteDeadline(time.Now().Add(ss.s.WriteTimeout))
ss.tcpConn.conn.Write(buf.Bytes()) ss.tcpConn.conn.Write(buf[:n])
} else { } else {
f := rtcpFrames[trackID] f := rtcpFrames[trackID]
f.Payload = payload f.Payload = payload
buf.Reset() n, _ := f.WriteTo(buf)
f.Write(&buf)
ss.tcpConn.conn.SetWriteDeadline(time.Now().Add(ss.s.WriteTimeout)) ss.tcpConn.conn.SetWriteDeadline(time.Now().Add(ss.s.WriteTimeout))
ss.tcpConn.conn.Write(buf.Bytes()) ss.tcpConn.conn.Write(buf[:n])
} }
} }
} }