add intermediate layer between net.Conn and client / server

This commit is contained in:
aler9
2022-08-14 23:43:01 +02:00
parent a0a168d26c
commit 06bed24dd9
18 changed files with 1459 additions and 1561 deletions

View File

@@ -8,7 +8,6 @@ Examples are available at https://github.com/aler9/gortsplib/tree/master/example
package gortsplib
import (
"bufio"
"context"
"crypto/tls"
"fmt"
@@ -24,6 +23,7 @@ import (
"github.com/aler9/gortsplib/pkg/auth"
"github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/conn"
"github.com/aler9/gortsplib/pkg/headers"
"github.com/aler9/gortsplib/pkg/liberrors"
"github.com/aler9/gortsplib/pkg/ringbuffer"
@@ -256,8 +256,8 @@ type Client struct {
ctx context.Context
ctxCancel func()
state clientState
conn net.Conn
br *bufio.Reader
nconn net.Conn
conn *conn.Conn
session string
sender *auth.Sender
cseq int
@@ -581,11 +581,13 @@ func (c *Client) doClose() {
URL: c.baseURL,
}, true, false)
c.conn.Close()
c.nconn.Close()
c.nconn = nil
c.conn = nil
} else if c.conn != nil {
} else if c.nconn != nil {
c.connCloserStop()
c.conn.Close()
c.nconn.Close()
c.nconn = nil
c.conn = nil
}
@@ -756,7 +758,7 @@ func (c *Client) playRecordStart() {
// for some reason, SetReadDeadline() must always be called in the same
// goroutine, otherwise Read() freezes.
// therefore, we disable the deadline and perform a check with a ticker.
c.conn.SetReadDeadline(time.Time{})
c.nconn.SetReadDeadline(time.Time{})
// start reader
c.readerErr = make(chan error)
@@ -768,7 +770,7 @@ func (c *Client) runReader() {
if *c.effectiveTransport == TransportUDP || *c.effectiveTransport == TransportUDPMulticast {
for {
var res base.Response
err := res.Read(c.br)
err := c.conn.ReadResponse(&res)
if err != nil {
return err
}
@@ -854,7 +856,7 @@ func (c *Client) runReader() {
var res base.Response
for {
what, err := base.ReadInterleavedFrameOrResponse(&frame, tcpMaxFramePayloadSize, &res, c.br)
what, err := c.conn.ReadInterleavedFrameOrResponse(&frame, &res)
if err != nil {
return err
}
@@ -885,7 +887,7 @@ func (c *Client) runReader() {
func (c *Client) playRecordStop(isClosing bool) {
// stop reader
if c.readerErr != nil {
c.conn.SetReadDeadline(time.Now())
c.nconn.SetReadDeadline(time.Now())
<-c.readerErr
}
@@ -963,7 +965,7 @@ func (c *Client) connOpen() error {
return err
}
c.conn = func() net.Conn {
c.nconn = func() net.Conn {
if c.scheme == "rtsps" {
tlsConfig := c.TLSConfig
@@ -979,7 +981,8 @@ func (c *Client) connOpen() error {
return nconn
}()
c.br = bufio.NewReaderSize(c.conn, tcpReadBufferSize)
c.conn = conn.NewConn(c.nconn)
c.connCloserStart()
return nil
}
@@ -993,7 +996,7 @@ func (c *Client) connCloserStart() {
select {
case <-c.ctx.Done():
c.conn.Close()
c.nconn.Close()
case <-c.connCloserTerminate:
}
@@ -1007,7 +1010,7 @@ func (c *Client) connCloserStop() {
}
func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*base.Response, error) {
if c.conn == nil {
if c.nconn == nil {
err := c.connOpen()
if err != nil {
return nil, err
@@ -1042,10 +1045,8 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba
c.OnRequest(req)
}
byts, _ := req.Marshal()
c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
_, err := c.conn.Write(byts)
c.nconn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
err := c.conn.WriteRequest(req)
if err != nil {
return nil, err
}
@@ -1053,19 +1054,19 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba
var res base.Response
if !skipResponse {
c.conn.SetReadDeadline(time.Now().Add(c.ReadTimeout))
c.nconn.SetReadDeadline(time.Now().Add(c.ReadTimeout))
if allowFrames {
// read the response and ignore interleaved frames in between;
// interleaved frames are sent in two cases:
// * when the server is v4lrtspserver, before the PLAY response
// * when the stream is already playing
err = res.ReadIgnoreFrames(tcpMaxFramePayloadSize, c.br)
err = c.conn.ReadResponseIgnoreFrames(&res)
if err != nil {
return nil, err
}
} else {
err = res.Read(c.br)
err = c.conn.ReadResponse(&res)
if err != nil {
return nil, err
}
@@ -1491,13 +1492,13 @@ func (c *Client) doSetup(
if thRes.Source != nil {
return *thRes.Source
}
return c.conn.RemoteAddr().(*net.TCPAddr).IP
return c.nconn.RemoteAddr().(*net.TCPAddr).IP
}()
if thRes.ServerPorts != nil {
ct.udpRTPListener.readPort = thRes.ServerPorts[0]
ct.udpRTPListener.writeAddr = &net.UDPAddr{
IP: c.conn.RemoteAddr().(*net.TCPAddr).IP,
Zone: c.conn.RemoteAddr().(*net.TCPAddr).Zone,
IP: c.nconn.RemoteAddr().(*net.TCPAddr).IP,
Zone: c.nconn.RemoteAddr().(*net.TCPAddr).Zone,
Port: thRes.ServerPorts[0],
}
}
@@ -1506,13 +1507,13 @@ func (c *Client) doSetup(
if thRes.Source != nil {
return *thRes.Source
}
return c.conn.RemoteAddr().(*net.TCPAddr).IP
return c.nconn.RemoteAddr().(*net.TCPAddr).IP
}()
if thRes.ServerPorts != nil {
ct.udpRTCPListener.readPort = thRes.ServerPorts[1]
ct.udpRTCPListener.writeAddr = &net.UDPAddr{
IP: c.conn.RemoteAddr().(*net.TCPAddr).IP,
Zone: c.conn.RemoteAddr().(*net.TCPAddr).Zone,
IP: c.nconn.RemoteAddr().(*net.TCPAddr).IP,
Zone: c.nconn.RemoteAddr().(*net.TCPAddr).Zone,
Port: thRes.ServerPorts[1],
}
}
@@ -1551,14 +1552,14 @@ func (c *Client) doSetup(
return nil, err
}
ct.udpRTPListener.readIP = c.conn.RemoteAddr().(*net.TCPAddr).IP
ct.udpRTPListener.readIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP
ct.udpRTPListener.readPort = thRes.Ports[0]
ct.udpRTPListener.writeAddr = &net.UDPAddr{
IP: *thRes.Destination,
Port: thRes.Ports[0],
}
ct.udpRTCPListener.readIP = c.conn.RemoteAddr().(*net.TCPAddr).IP
ct.udpRTCPListener.readIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP
ct.udpRTCPListener.readPort = thRes.Ports[1]
ct.udpRTCPListener.writeAddr = &net.UDPAddr{
IP: *thRes.Destination,
@@ -1848,19 +1849,17 @@ func (c *Client) runWriter() {
writeFunc = func(trackID int, isRTP bool, payload []byte) {
if isRTP {
f := rtpFrames[trackID]
f.Payload = payload
n, _ := f.MarshalTo(buf)
fr := rtpFrames[trackID]
fr.Payload = payload
c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
c.conn.Write(buf[:n])
c.nconn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
c.conn.WriteInterleavedFrame(fr, buf)
} else {
f := rtcpFrames[trackID]
f.Payload = payload
n, _ := f.MarshalTo(buf)
fr := rtcpFrames[trackID]
fr.Payload = payload
c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
c.conn.Write(buf[:n])
c.nconn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
c.conn.WriteInterleavedFrame(fr, buf)
}
}
}

View File

@@ -1,7 +1,6 @@
package gortsplib
import (
"bufio"
"crypto/tls"
"net"
"strings"
@@ -13,6 +12,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/conn"
"github.com/aler9/gortsplib/pkg/headers"
)
@@ -78,17 +78,17 @@ func TestClientPublishSerial(t *testing.T) {
go func() {
defer close(serverDone)
conn, err := l.Accept()
nconn, err := l.Accept()
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
req, err := readRequest(br)
req, err := readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Options, req.Method)
require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL)
byts, _ := base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
@@ -97,22 +97,20 @@ func TestClientPublishSerial(t *testing.T) {
string(base.Record),
}, ", ")},
},
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Announce, req.Method)
require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Setup, req.Method)
require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream/trackID=0"), req.URL)
@@ -149,24 +147,22 @@ func TestClientPublishSerial(t *testing.T) {
th.InterleavedIDs = inTH.InterleavedIDs
}
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": th.Marshal(),
},
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Record, req.Method)
require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
// client -> server (RTP)
@@ -180,7 +176,7 @@ func TestClientPublishSerial(t *testing.T) {
require.Equal(t, testRTPPacket, pkt)
} else {
var f base.InterleavedFrame
err = f.Read(1024, br)
err = conn.ReadInterleavedFrame(&f)
require.NoError(t, err)
require.Equal(t, 0, f.Channel)
var pkt rtp.Packet
@@ -196,23 +192,21 @@ func TestClientPublishSerial(t *testing.T) {
Port: th.ClientPorts[1],
})
} else {
byts, _ := base.InterleavedFrame{
err := conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 1,
Payload: testRTCPPacketMarshaled,
}.Marshal()
_, err = conn.Write(byts)
}, make([]byte, 1024))
require.NoError(t, err)
}
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Teardown, req.Method)
require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
}()
@@ -294,16 +288,16 @@ func TestClientPublishParallel(t *testing.T) {
go func() {
defer close(serverDone)
conn, err := l.Accept()
nconn, err := l.Accept()
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
req, err := readRequest(br)
req, err := readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Options, req.Method)
byts, _ := base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
@@ -312,21 +306,19 @@ func TestClientPublishParallel(t *testing.T) {
string(base.Record),
}, ", ")},
},
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Announce, req.Method)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Setup, req.Method)
@@ -350,33 +342,30 @@ func TestClientPublishParallel(t *testing.T) {
th.InterleavedIDs = inTH.InterleavedIDs
}
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": th.Marshal(),
},
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Record, req.Method)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequestIgnoreFrames(br)
req, err = readRequestIgnoreFrames(conn)
require.NoError(t, err)
require.Equal(t, base.Teardown, req.Method)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
}()
@@ -442,16 +431,16 @@ func TestClientPublishPauseSerial(t *testing.T) {
go func() {
defer close(serverDone)
conn, err := l.Accept()
nconn, err := l.Accept()
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
req, err := readRequest(br)
req, err := readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Options, req.Method)
byts, _ := base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
@@ -461,21 +450,19 @@ func TestClientPublishPauseSerial(t *testing.T) {
string(base.Pause),
}, ", ")},
},
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Announce, req.Method)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Setup, req.Method)
@@ -499,53 +486,48 @@ func TestClientPublishPauseSerial(t *testing.T) {
th.InterleavedIDs = inTH.InterleavedIDs
}
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": th.Marshal(),
},
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Record, req.Method)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequestIgnoreFrames(br)
req, err = readRequestIgnoreFrames(conn)
require.NoError(t, err)
require.Equal(t, base.Pause, req.Method)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Record, req.Method)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequestIgnoreFrames(br)
req, err = readRequestIgnoreFrames(conn)
require.NoError(t, err)
require.Equal(t, base.Teardown, req.Method)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
}()
@@ -601,16 +583,16 @@ func TestClientPublishPauseParallel(t *testing.T) {
go func() {
defer close(serverDone)
conn, err := l.Accept()
nconn, err := l.Accept()
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
req, err := readRequest(br)
req, err := readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Options, req.Method)
byts, _ := base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
@@ -620,21 +602,19 @@ func TestClientPublishPauseParallel(t *testing.T) {
string(base.Pause),
}, ", ")},
},
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Announce, req.Method)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Setup, req.Method)
@@ -658,33 +638,30 @@ func TestClientPublishPauseParallel(t *testing.T) {
th.InterleavedIDs = inTH.InterleavedIDs
}
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": th.Marshal(),
},
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Record, req.Method)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequestIgnoreFrames(br)
req, err = readRequestIgnoreFrames(conn)
require.NoError(t, err)
require.Equal(t, base.Pause, req.Method)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
}()
@@ -745,17 +722,17 @@ func TestClientPublishAutomaticProtocol(t *testing.T) {
go func() {
defer close(serverDone)
conn, err := l.Accept()
nconn, err := l.Accept()
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
req, err := readRequest(br)
req, err := readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Options, req.Method)
require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL)
byts, _ := base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
@@ -764,32 +741,29 @@ func TestClientPublishAutomaticProtocol(t *testing.T) {
string(base.Record),
}, ", ")},
},
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Announce, req.Method)
require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Setup, req.Method)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusUnsupportedTransport,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Setup, req.Method)
@@ -807,28 +781,26 @@ func TestClientPublishAutomaticProtocol(t *testing.T) {
InterleavedIDs: &[2]int{0, 1},
}
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": th.Marshal(),
},
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Record, req.Method)
require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
var f base.InterleavedFrame
err = f.Read(2048, br)
err = conn.ReadInterleavedFrame(&f)
require.NoError(t, err)
require.Equal(t, 0, f.Channel)
var pkt rtp.Packet
@@ -836,14 +808,13 @@ func TestClientPublishAutomaticProtocol(t *testing.T) {
require.NoError(t, err)
require.Equal(t, testRTPPacket, pkt)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Teardown, req.Method)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
}()
@@ -876,16 +847,16 @@ func TestClientPublishRTCPReport(t *testing.T) {
go func() {
defer close(serverDone)
conn, err := l.Accept()
nconn, err := l.Accept()
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
req, err := readRequest(br)
req, err := readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Options, req.Method)
byts, _ := base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
@@ -894,21 +865,19 @@ func TestClientPublishRTCPReport(t *testing.T) {
string(base.Record),
}, ", ")},
},
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Announce, req.Method)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Setup, req.Method)
@@ -924,7 +893,7 @@ func TestClientPublishRTCPReport(t *testing.T) {
require.NoError(t, err)
defer l2.Close()
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": headers.Transport{
@@ -937,18 +906,16 @@ func TestClientPublishRTCPReport(t *testing.T) {
ServerPorts: &[2]int{34556, 34557},
}.Marshal(),
},
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Record, req.Method)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
buf := make([]byte, 2048)
@@ -975,14 +942,13 @@ func TestClientPublishRTCPReport(t *testing.T) {
close(reportReceived)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Teardown, req.Method)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
}()
@@ -1027,16 +993,16 @@ func TestClientPublishIgnoreTCPRTPPackets(t *testing.T) {
go func() {
defer close(serverDone)
conn, err := l.Accept()
nconn, err := l.Accept()
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
req, err := readRequest(br)
req, err := readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Options, req.Method)
byts, _ := base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
@@ -1045,21 +1011,19 @@ func TestClientPublishIgnoreTCPRTPPackets(t *testing.T) {
string(base.Record),
}, ", ")},
},
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Announce, req.Method)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Setup, req.Method)
@@ -1076,47 +1040,42 @@ func TestClientPublishIgnoreTCPRTPPackets(t *testing.T) {
InterleavedIDs: inTH.InterleavedIDs,
}
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": th.Marshal(),
},
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Record, req.Method)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
byts, _ = base.InterleavedFrame{
conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 0,
Payload: testRTPPacketMarshaled,
}.Marshal()
_, err = conn.Write(byts)
}, make([]byte, 1024))
require.NoError(t, err)
byts, _ = base.InterleavedFrame{
conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 1,
Payload: testRTCPPacketMarshaled,
}.Marshal()
_, err = conn.Write(byts)
}, make([]byte, 1024))
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Teardown, req.Method)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
}()

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,6 @@
package gortsplib
import (
"bufio"
"crypto/tls"
"net"
"strings"
@@ -11,6 +10,7 @@ import (
"github.com/aler9/gortsplib/pkg/auth"
"github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/conn"
"github.com/aler9/gortsplib/pkg/url"
)
@@ -22,15 +22,15 @@ func mustParseURL(s string) *url.URL {
return u
}
func readRequest(br *bufio.Reader) (*base.Request, error) {
func readRequest(conn *conn.Conn) (*base.Request, error) {
var req base.Request
err := req.Read(br)
err := conn.ReadRequest(&req)
return &req, err
}
func readRequestIgnoreFrames(br *bufio.Reader) (*base.Request, error) {
func readRequestIgnoreFrames(conn *conn.Conn) (*base.Request, error) {
var req base.Request
err := req.ReadIgnoreFrames(2048, br)
err := conn.ReadRequestIgnoreFrames(&req)
return &req, err
}
@@ -44,14 +44,14 @@ func TestClientTLSSetServerName(t *testing.T) {
go func() {
defer close(serverDone)
conn, err := l.Accept()
nconn, err := l.Accept()
require.NoError(t, err)
defer conn.Close()
defer nconn.Close()
cert, err := tls.X509KeyPair(serverCert, serverKey)
require.NoError(t, err)
tconn := tls.Server(conn, &tls.Config{
tnconn := tls.Server(nconn, &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: true,
VerifyConnection: func(cs tls.ConnectionState) error {
@@ -60,7 +60,7 @@ func TestClientTLSSetServerName(t *testing.T) {
},
})
err = tconn.Handshake()
err = tnconn.Handshake()
require.EqualError(t, err, "remote error: tls: bad certificate")
}()
@@ -91,16 +91,16 @@ func TestClientSession(t *testing.T) {
go func() {
defer close(serverDone)
conn, err := l.Accept()
nconn, err := l.Accept()
require.NoError(t, err)
br := bufio.NewReader(conn)
defer conn.Close()
conn := conn.NewConn(nconn)
defer nconn.Close()
req, err := readRequest(br)
req, err := readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Options, req.Method)
byts, _ := base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
@@ -108,11 +108,10 @@ func TestClientSession(t *testing.T) {
}, ", ")},
"Session": base.HeaderValue{"123456"},
},
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Describe, req.Method)
@@ -127,15 +126,14 @@ func TestClientSession(t *testing.T) {
tracks := Tracks{track}
tracks.setControls()
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Content-Type": base.HeaderValue{"application/sdp"},
"Session": base.HeaderValue{"123456"},
},
Body: tracks.Marshal(false),
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
}()
@@ -162,42 +160,40 @@ func TestClientAuth(t *testing.T) {
go func() {
defer close(serverDone)
conn, err := l.Accept()
nconn, err := l.Accept()
require.NoError(t, err)
br := bufio.NewReader(conn)
defer conn.Close()
conn := conn.NewConn(nconn)
defer nconn.Close()
req, err := readRequest(br)
req, err := readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Options, req.Method)
byts, _ := base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
string(base.Describe),
}, ", ")},
},
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Describe, req.Method)
v := auth.NewValidator("myuser", "mypass", nil)
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusUnauthorized,
Header: base.Header{
"WWW-Authenticate": v.Header(),
},
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Describe, req.Method)
@@ -213,14 +209,13 @@ func TestClientAuth(t *testing.T) {
tracks := Tracks{track}
tracks.setControls()
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Content-Type": base.HeaderValue{"application/sdp"},
},
Body: tracks.Marshal(false),
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
}()
@@ -247,27 +242,26 @@ func TestClientDescribeCharset(t *testing.T) {
go func() {
defer close(serverDone)
conn, err := l.Accept()
nconn, err := l.Accept()
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
req, err := readRequest(br)
req, err := readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Options, req.Method)
byts, _ := base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
string(base.Describe),
}, ", ")},
},
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
req, err = readRequest(br)
req, err = readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Describe, req.Method)
require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL)
@@ -278,15 +272,14 @@ func TestClientDescribeCharset(t *testing.T) {
PPS: []byte{0x01, 0x02, 0x03, 0x04},
}
byts, _ = base.Response{
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Content-Type": base.HeaderValue{"application/sdp; charset=utf-8"},
"Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream/"},
},
Body: Tracks{track1}.Marshal(false),
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
}()
@@ -349,12 +342,12 @@ func TestClientCloseDuringRequest(t *testing.T) {
go func() {
defer close(serverDone)
conn, err := l.Accept()
nconn, err := l.Accept()
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
req, err := readRequest(br)
req, err := readRequest(conn)
require.NoError(t, err)
require.Equal(t, base.Options, req.Method)

View File

@@ -1,12 +1,6 @@
package gortsplib
const (
tcpReadBufferSize = 4096
// this must fit an entire H264 NALU and a RTP header.
// with a 250 Mbps H264 video, the maximum NALU size is 2.2MB
tcpMaxFramePayloadSize = 3 * 1024 * 1024
// same size as GStreamer's rtspsrc
udpKernelReadBufferSize = 0x80000

View File

@@ -7,65 +7,10 @@ import (
)
const (
interleavedFrameMagicByte = 0x24
// InterleavedFrameMagicByte is the first byte of an interleaved frame.
InterleavedFrameMagicByte = 0x24
)
// ReadInterleavedFrameOrRequest reads an InterleavedFrame or a Response.
func ReadInterleavedFrameOrRequest(
frame *InterleavedFrame,
maxPayloadSize int,
req *Request,
br *bufio.Reader,
) (interface{}, error) {
b, err := br.ReadByte()
if err != nil {
return nil, err
}
br.UnreadByte()
if b == interleavedFrameMagicByte {
err := frame.Read(maxPayloadSize, br)
if err != nil {
return nil, err
}
return frame, err
}
err = req.Read(br)
if err != nil {
return nil, err
}
return req, nil
}
// ReadInterleavedFrameOrResponse reads an InterleavedFrame or a Response.
func ReadInterleavedFrameOrResponse(
frame *InterleavedFrame,
maxPayloadSize int,
res *Response,
br *bufio.Reader,
) (interface{}, error) {
b, err := br.ReadByte()
if err != nil {
return nil, err
}
br.UnreadByte()
if b == interleavedFrameMagicByte {
err := frame.Read(maxPayloadSize, br)
if err != nil {
return nil, err
}
return frame, err
}
err = res.Read(br)
if err != nil {
return nil, err
}
return res, nil
}
// InterleavedFrame is an interleaved frame, and allows to transfer binary data
// within RTSP/TCP connections. It is used to send and receive RTP and RTCP packets with TCP.
type InterleavedFrame struct {
@@ -77,22 +22,19 @@ type InterleavedFrame struct {
}
// Read decodes an interleaved frame.
func (f *InterleavedFrame) Read(maxPayloadSize int, br *bufio.Reader) error {
func (f *InterleavedFrame) Read(br *bufio.Reader) error {
var header [4]byte
_, err := io.ReadFull(br, header[:])
if err != nil {
return err
}
if header[0] != interleavedFrameMagicByte {
if header[0] != InterleavedFrameMagicByte {
return fmt.Errorf("invalid magic byte (0x%.2x)", header[0])
}
// it's useless to check payloadLen since it's limited to 65535
payloadLen := int(uint16(header[2])<<8 | uint16(header[3]))
if payloadLen > maxPayloadSize {
return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)",
payloadLen, maxPayloadSize)
}
f.Channel = int(header[1])
f.Payload = make([]byte, payloadLen)

View File

@@ -37,7 +37,7 @@ func TestInterleavedFrameRead(t *testing.T) {
for _, ca := range casesInterleavedFrame {
t.Run(ca.name, func(t *testing.T) {
err := f.Read(1024, bufio.NewReader(bytes.NewBuffer(ca.enc)))
err := f.Read(bufio.NewReader(bytes.NewBuffer(ca.enc)))
require.NoError(t, err)
require.Equal(t, ca.dec, f)
})
@@ -60,11 +60,6 @@ func TestInterleavedFrameReadErrors(t *testing.T) {
[]byte{0x55, 0x00, 0x00, 0x00},
"invalid magic byte (0x55)",
},
{
"payload size too big",
[]byte{0x24, 0x00, 0x00, 0x08},
"payload size (8) greater than maximum allowed (5)",
},
{
"payload invalid",
[]byte{0x24, 0x00, 0x00, 0x05, 0x01, 0x02},
@@ -73,7 +68,7 @@ func TestInterleavedFrameReadErrors(t *testing.T) {
} {
t.Run(ca.name, func(t *testing.T) {
var f InterleavedFrame
err := f.Read(5, bufio.NewReader(bytes.NewBuffer(ca.byts)))
err := f.Read(bufio.NewReader(bytes.NewBuffer(ca.byts)))
require.EqualError(t, err, ca.err)
})
}
@@ -88,106 +83,3 @@ func TestInterleavedFrameMarshal(t *testing.T) {
})
}
}
func TestReadInterleavedFrameOrRequest(t *testing.T) {
byts := []byte("DESCRIBE rtsp://example.com/media.mp4 RTSP/1.0\r\n" +
"Accept: application/sdp\r\n" +
"CSeq: 2\r\n" +
"\r\n")
byts = append(byts, []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}...)
var f InterleavedFrame
var req Request
br := bufio.NewReader(bytes.NewBuffer(byts))
out, err := ReadInterleavedFrameOrRequest(&f, 10, &req, br)
require.NoError(t, err)
require.Equal(t, &req, out)
out, err = ReadInterleavedFrameOrRequest(&f, 10, &req, br)
require.NoError(t, err)
require.Equal(t, &f, out)
}
func TestReadInterleavedFrameOrRequestErrors(t *testing.T) {
for _, ca := range []struct {
name string
byts []byte
err string
}{
{
"empty",
[]byte{},
"EOF",
},
{
"invalid frame",
[]byte{0x24, 0x00},
"unexpected EOF",
},
{
"invalid request",
[]byte("DESCRIBE"),
"EOF",
},
} {
t.Run(ca.name, func(t *testing.T) {
var f InterleavedFrame
var req Request
br := bufio.NewReader(bytes.NewBuffer(ca.byts))
_, err := ReadInterleavedFrameOrRequest(&f, 10, &req, br)
require.EqualError(t, err, ca.err)
})
}
}
func TestReadInterleavedFrameOrResponse(t *testing.T) {
byts := []byte("RTSP/1.0 200 OK\r\n" +
"CSeq: 1\r\n" +
"Public: DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE\r\n" +
"\r\n")
byts = append(byts, []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}...)
var f InterleavedFrame
var res Response
br := bufio.NewReader(bytes.NewBuffer(byts))
out, err := ReadInterleavedFrameOrResponse(&f, 10, &res, br)
require.NoError(t, err)
require.Equal(t, &res, out)
out, err = ReadInterleavedFrameOrResponse(&f, 10, &res, br)
require.NoError(t, err)
require.Equal(t, &f, out)
}
func TestReadInterleavedFrameOrResponseErrors(t *testing.T) {
for _, ca := range []struct {
name string
byts []byte
err string
}{
{
"empty",
[]byte{},
"EOF",
},
{
"invalid frame",
[]byte{0x24, 0x00},
"unexpected EOF",
},
{
"invalid response",
[]byte("RTSP/1.0"),
"EOF",
},
} {
t.Run(ca.name, func(t *testing.T) {
var f InterleavedFrame
var res Response
br := bufio.NewReader(bytes.NewBuffer(ca.byts))
_, err := ReadInterleavedFrameOrResponse(&f, 10, &res, br)
require.EqualError(t, err, ca.err)
})
}
}

View File

@@ -100,23 +100,6 @@ func (req *Request) Read(rb *bufio.Reader) error {
return nil
}
// ReadIgnoreFrames reads a request and ignores any interleaved frame sent
// before the request.
func (req *Request) ReadIgnoreFrames(maxPayloadSize int, rb *bufio.Reader) error {
var f InterleavedFrame
for {
recv, err := ReadInterleavedFrameOrRequest(&f, maxPayloadSize, req, rb)
if err != nil {
return err
}
if _, ok := recv.(*Request); ok {
return nil
}
}
}
// MarshalSize returns the size of a Request.
func (req Request) MarshalSize() int {
n := 0

View File

@@ -238,29 +238,6 @@ func TestRequestMarshal(t *testing.T) {
}
}
func TestRequestReadIgnoreFrames(t *testing.T) {
byts := []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}
byts = append(byts, []byte("OPTIONS rtsp://example.com/media.mp4 RTSP/1.0\r\n"+
"CSeq: 1\r\n"+
"Proxy-Require: gzipped-messages\r\n"+
"Require: implicit-play\r\n"+
"\r\n")...)
rb := bufio.NewReader(bytes.NewBuffer(byts))
var req Request
err := req.ReadIgnoreFrames(10, rb)
require.NoError(t, err)
}
func TestRequestReadIgnoreFramesErrors(t *testing.T) {
byts := []byte{0x25}
rb := bufio.NewReader(bytes.NewBuffer(byts))
var req Request
err := req.ReadIgnoreFrames(10, rb)
require.EqualError(t, err, "EOF")
}
func TestRequestString(t *testing.T) {
byts := []byte("OPTIONS rtsp://example.com/media.mp4 RTSP/1.0\r\n" +
"CSeq: 1\r\n" +

View File

@@ -184,23 +184,6 @@ func (res *Response) Read(rb *bufio.Reader) error {
return nil
}
// ReadIgnoreFrames reads a response and ignores any interleaved frame sent
// before the response.
func (res *Response) ReadIgnoreFrames(maxPayloadSize int, rb *bufio.Reader) error {
var f InterleavedFrame
for {
recv, err := ReadInterleavedFrameOrResponse(&f, maxPayloadSize, res, rb)
if err != nil {
return err
}
if _, ok := recv.(*Response); ok {
return nil
}
}
}
// MarshalSize returns the size of a Response.
func (res Response) MarshalSize() int {
n := 0

View File

@@ -212,28 +212,6 @@ func TestResponseMarshalAutoFillStatus(t *testing.T) {
require.Equal(t, byts, buf)
}
func TestResponseReadIgnoreFrames(t *testing.T) {
byts := []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}
byts = append(byts, []byte("RTSP/1.0 200 OK\r\n"+
"CSeq: 1\r\n"+
"Public: DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE\r\n"+
"\r\n")...)
rb := bufio.NewReader(bytes.NewBuffer(byts))
var res Response
err := res.ReadIgnoreFrames(10, rb)
require.NoError(t, err)
}
func TestResponseReadIgnoreFramesErrors(t *testing.T) {
byts := []byte{0x25}
rb := bufio.NewReader(bytes.NewBuffer(byts))
var res Response
err := res.ReadIgnoreFrames(10, rb)
require.EqualError(t, err, "EOF")
}
func TestResponseString(t *testing.T) {
byts := []byte("RTSP/1.0 200 OK\r\n" +
"CSeq: 3\r\n" +

148
pkg/conn/conn.go Normal file
View File

@@ -0,0 +1,148 @@
package conn
import (
"bufio"
"io"
"github.com/aler9/gortsplib/pkg/base"
)
const (
readBufferSize = 4096
)
// Conn is a RTSP TCP connection.
type Conn struct {
w io.Writer
br *bufio.Reader
}
// NewConn allocates a Conn.
func NewConn(rw io.ReadWriter) *Conn {
return &Conn{
w: rw,
br: bufio.NewReaderSize(rw, readBufferSize),
}
}
// ReadResponse reads a Response.
func (c *Conn) ReadResponse(res *base.Response) error {
return res.Read(c.br)
}
// ReadRequest reads a Request.
func (c *Conn) ReadRequest(req *base.Request) error {
return req.Read(c.br)
}
// ReadInterleavedFrame reads a InterleavedFrame.
func (c *Conn) ReadInterleavedFrame(fr *base.InterleavedFrame) error {
return fr.Read(c.br)
}
// ReadInterleavedFrameOrRequest reads an InterleavedFrame or a Request.
func (c *Conn) ReadInterleavedFrameOrRequest(
frame *base.InterleavedFrame,
req *base.Request,
) (interface{}, error) {
b, err := c.br.ReadByte()
if err != nil {
return nil, err
}
c.br.UnreadByte()
if b == base.InterleavedFrameMagicByte {
err := frame.Read(c.br)
if err != nil {
return nil, err
}
return frame, err
}
err = req.Read(c.br)
if err != nil {
return nil, err
}
return req, nil
}
// ReadInterleavedFrameOrResponse reads an InterleavedFrame or a Response.
func (c *Conn) ReadInterleavedFrameOrResponse(
frame *base.InterleavedFrame,
res *base.Response,
) (interface{}, error) {
b, err := c.br.ReadByte()
if err != nil {
return nil, err
}
c.br.UnreadByte()
if b == base.InterleavedFrameMagicByte {
err := frame.Read(c.br)
if err != nil {
return nil, err
}
return frame, err
}
err = res.Read(c.br)
if err != nil {
return nil, err
}
return res, nil
}
// ReadRequestIgnoreFrames reads a Request and ignore frames in between.
func (c *Conn) ReadRequestIgnoreFrames(req *base.Request) error {
var f base.InterleavedFrame
for {
recv, err := c.ReadInterleavedFrameOrRequest(&f, req)
if err != nil {
return err
}
if _, ok := recv.(*base.Request); ok {
return nil
}
}
}
// ReadResponseIgnoreFrames reads a Response and ignore frames in between.
func (c *Conn) ReadResponseIgnoreFrames(res *base.Response) error {
var f base.InterleavedFrame
for {
recv, err := c.ReadInterleavedFrameOrResponse(&f, res)
if err != nil {
return err
}
if _, ok := recv.(*base.Response); ok {
return nil
}
}
}
// WriteRequest writes a request.
func (c *Conn) WriteRequest(req *base.Request) error {
buf, _ := req.Marshal()
_, err := c.w.Write(buf)
return err
}
// WriteResponse writes a response.
func (c *Conn) WriteResponse(res *base.Response) error {
buf, _ := res.Marshal()
_, err := c.w.Write(buf)
return err
}
// WriteInterleavedFrame writes an interleaved frame.
func (c *Conn) WriteInterleavedFrame(fr *base.InterleavedFrame, buf []byte) error {
n, _ := fr.MarshalTo(buf)
_, err := c.w.Write(buf[:n])
return err
}

159
pkg/conn/conn_test.go Normal file
View File

@@ -0,0 +1,159 @@
package conn
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
"github.com/aler9/gortsplib/pkg/base"
)
func TestReadInterleavedFrameOrRequest(t *testing.T) {
byts := []byte("DESCRIBE rtsp://example.com/media.mp4 RTSP/1.0\r\n" +
"Accept: application/sdp\r\n" +
"CSeq: 2\r\n" +
"\r\n")
byts = append(byts, []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}...)
var f base.InterleavedFrame
var req base.Request
conn := NewConn(bytes.NewBuffer(byts))
out, err := conn.ReadInterleavedFrameOrRequest(&f, &req)
require.NoError(t, err)
require.Equal(t, &req, out)
out, err = conn.ReadInterleavedFrameOrRequest(&f, &req)
require.NoError(t, err)
require.Equal(t, &f, out)
}
func TestReadInterleavedFrameOrRequestErrors(t *testing.T) {
for _, ca := range []struct {
name string
byts []byte
err string
}{
{
"empty",
[]byte{},
"EOF",
},
{
"invalid frame",
[]byte{0x24, 0x00},
"unexpected EOF",
},
{
"invalid request",
[]byte("DESCRIBE"),
"EOF",
},
} {
t.Run(ca.name, func(t *testing.T) {
var f base.InterleavedFrame
var req base.Request
conn := NewConn(bytes.NewBuffer(ca.byts))
_, err := conn.ReadInterleavedFrameOrRequest(&f, &req)
require.EqualError(t, err, ca.err)
})
}
}
func TestReadInterleavedFrameOrResponse(t *testing.T) {
byts := []byte("RTSP/1.0 200 OK\r\n" +
"CSeq: 1\r\n" +
"Public: DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE\r\n" +
"\r\n")
byts = append(byts, []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}...)
var f base.InterleavedFrame
var res base.Response
conn := NewConn(bytes.NewBuffer(byts))
out, err := conn.ReadInterleavedFrameOrResponse(&f, &res)
require.NoError(t, err)
require.Equal(t, &res, out)
out, err = conn.ReadInterleavedFrameOrResponse(&f, &res)
require.NoError(t, err)
require.Equal(t, &f, out)
}
func TestReadInterleavedFrameOrResponseErrors(t *testing.T) {
for _, ca := range []struct {
name string
byts []byte
err string
}{
{
"empty",
[]byte{},
"EOF",
},
{
"invalid frame",
[]byte{0x24, 0x00},
"unexpected EOF",
},
{
"invalid response",
[]byte("RTSP/1.0"),
"EOF",
},
} {
t.Run(ca.name, func(t *testing.T) {
var f base.InterleavedFrame
var res base.Response
conn := NewConn(bytes.NewBuffer(ca.byts))
_, err := conn.ReadInterleavedFrameOrResponse(&f, &res)
require.EqualError(t, err, ca.err)
})
}
}
func TestReadRequestIgnoreFrames(t *testing.T) {
byts := []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}
byts = append(byts, []byte("OPTIONS rtsp://example.com/media.mp4 RTSP/1.0\r\n"+
"CSeq: 1\r\n"+
"Proxy-Require: gzipped-messages\r\n"+
"Require: implicit-play\r\n"+
"\r\n")...)
conn := NewConn(bytes.NewBuffer(byts))
var req base.Request
err := conn.ReadRequestIgnoreFrames(&req)
require.NoError(t, err)
}
func TestReadRequestIgnoreFramesErrors(t *testing.T) {
byts := []byte{0x25}
conn := NewConn(bytes.NewBuffer(byts))
var req base.Request
err := conn.ReadRequestIgnoreFrames(&req)
require.EqualError(t, err, "EOF")
}
func TestReadResponseIgnoreFrames(t *testing.T) {
byts := []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}
byts = append(byts, []byte("RTSP/1.0 200 OK\r\n"+
"CSeq: 1\r\n"+
"Public: DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE\r\n"+
"\r\n")...)
conn := NewConn(bytes.NewBuffer(byts))
var res base.Response
err := conn.ReadResponseIgnoreFrames(&res)
require.NoError(t, err)
}
func TestReadResponseIgnoreFramesErrors(t *testing.T) {
byts := []byte{0x25}
conn := NewConn(bytes.NewBuffer(byts))
var res base.Response
err := conn.ReadResponseIgnoreFrames(&res)
require.EqualError(t, err, "EOF")
}

View File

@@ -1,7 +1,6 @@
package gortsplib
import (
"bufio"
"crypto/tls"
"net"
"testing"
@@ -13,6 +12,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/conn"
"github.com/aler9/gortsplib/pkg/headers"
)
@@ -113,13 +113,13 @@ func TestServerPublishErrorAnnounce(t *testing.T) {
},
} {
t.Run(ca.name, func(t *testing.T) {
connClosed := make(chan struct{})
nconnClosed := make(chan struct{})
s := &Server{
Handler: &testServerHandler{
onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) {
require.EqualError(t, ctx.Error, ca.err)
close(connClosed)
close(nconnClosed)
},
onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) {
return &base.Response{
@@ -134,15 +134,15 @@ func TestServerPublishErrorAnnounce(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
_, err = writeReqReadRes(conn, br, ca.req)
_, err = writeReqReadRes(conn, ca.req)
require.NoError(t, err)
<-connClosed
<-nconnClosed
})
}
}
@@ -225,10 +225,10 @@ func TestServerPublishSetupPath(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
track := &TrackH264{
PayloadType: 96,
@@ -255,7 +255,7 @@ func TestServerPublishSetupPath(t *testing.T) {
byts, _ := sout.Marshal()
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/" + ca.path),
Header: base.Header{
@@ -280,7 +280,7 @@ func TestServerPublishSetupPath(t *testing.T) {
InterleavedIDs: &[2]int{0, 1},
}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL(ca.url),
Header: base.Header{
@@ -320,10 +320,10 @@ func TestServerPublishErrorSetupDifferentPaths(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
track := &TrackH264{
PayloadType: 96,
@@ -334,7 +334,7 @@ func TestServerPublishErrorSetupDifferentPaths(t *testing.T) {
tracks := Tracks{track}
tracks.setControls()
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -359,7 +359,7 @@ func TestServerPublishErrorSetupDifferentPaths(t *testing.T) {
InterleavedIDs: &[2]int{0, 1},
}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/test2stream/trackID=0"),
Header: base.Header{
@@ -400,10 +400,10 @@ func TestServerPublishErrorSetupTrackTwice(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
track := &TrackH264{
PayloadType: 96,
@@ -414,7 +414,7 @@ func TestServerPublishErrorSetupTrackTwice(t *testing.T) {
tracks := Tracks{track}
tracks.setControls()
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -439,7 +439,7 @@ func TestServerPublishErrorSetupTrackTwice(t *testing.T) {
InterleavedIDs: &[2]int{0, 1},
}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -454,7 +454,7 @@ func TestServerPublishErrorSetupTrackTwice(t *testing.T) {
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -501,10 +501,10 @@ func TestServerPublishErrorRecordPartialTracks(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
track1 := &TrackH264{
PayloadType: 96,
@@ -521,7 +521,7 @@ func TestServerPublishErrorRecordPartialTracks(t *testing.T) {
tracks := Tracks{track1, track2}
tracks.setControls()
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -546,7 +546,7 @@ func TestServerPublishErrorRecordPartialTracks(t *testing.T) {
InterleavedIDs: &[2]int{0, 1},
}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -561,7 +561,7 @@ func TestServerPublishErrorRecordPartialTracks(t *testing.T) {
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Record,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -583,18 +583,18 @@ func TestServerPublish(t *testing.T) {
"tls",
} {
t.Run(transport, func(t *testing.T) {
connOpened := make(chan struct{})
connClosed := make(chan struct{})
nconnOpened := make(chan struct{})
nconnClosed := make(chan struct{})
sessionOpened := make(chan struct{})
sessionClosed := make(chan struct{})
s := &Server{
Handler: &testServerHandler{
onConnOpen: func(ctx *ServerHandlerOnConnOpenCtx) {
close(connOpened)
close(nconnOpened)
},
onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) {
close(connClosed)
close(nconnClosed)
},
onSessionOpen: func(ctx *ServerHandlerOnSessionOpenCtx) {
close(sessionOpened)
@@ -649,19 +649,19 @@ func TestServerPublish(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
defer nconn.Close()
conn = func() net.Conn {
nconn = func() net.Conn {
if transport == "tls" {
return tls.Client(conn, &tls.Config{InsecureSkipVerify: true})
return tls.Client(nconn, &tls.Config{InsecureSkipVerify: true})
}
return conn
return nconn
}()
br := bufio.NewReader(conn)
conn := conn.NewConn(nconn)
<-connOpened
<-nconnOpened
track := &TrackH264{
PayloadType: 96,
@@ -672,7 +672,7 @@ func TestServerPublish(t *testing.T) {
tracks := Tracks{track}
tracks.setControls()
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -716,7 +716,7 @@ func TestServerPublish(t *testing.T) {
inTH.InterleavedIDs = &[2]int{0, 1}
}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -735,7 +735,7 @@ func TestServerPublish(t *testing.T) {
err = th.Unmarshal(res.Header["Transport"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Record,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -754,7 +754,7 @@ func TestServerPublish(t *testing.T) {
require.Equal(t, testRTCPPacketMarshaled, buf[:n])
} else {
var f base.InterleavedFrame
err := f.Read(2048, br)
err := conn.ReadInterleavedFrame(&f)
require.NoError(t, err)
require.Equal(t, 1, f.Channel)
require.Equal(t, testRTCPPacketMarshaled, f.Payload)
@@ -783,18 +783,16 @@ func TestServerPublish(t *testing.T) {
Port: th.ServerPorts[1],
})
} else {
byts, _ := base.InterleavedFrame{
err := conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 0,
Payload: testRTPPacketMarshaled,
}.Marshal()
_, err = conn.Write(byts)
}, make([]byte, 1024))
require.NoError(t, err)
byts, _ = base.InterleavedFrame{
err = conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 1,
Payload: testRTCPPacketMarshaled,
}.Marshal()
_, err = conn.Write(byts)
}, make([]byte, 1024))
require.NoError(t, err)
}
@@ -806,13 +804,13 @@ func TestServerPublish(t *testing.T) {
require.Equal(t, testRTCPPacketMarshaled, buf[:n])
} else {
var f base.InterleavedFrame
err := f.Read(2048, br)
err := conn.ReadInterleavedFrame(&f)
require.NoError(t, err)
require.Equal(t, 1, f.Channel)
require.Equal(t, testRTCPPacketMarshaled, f.Payload)
}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Teardown,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -825,8 +823,8 @@ func TestServerPublish(t *testing.T) {
<-sessionClosed
conn.Close()
<-connClosed
nconn.Close()
<-nconnClosed
})
}
}
@@ -862,10 +860,10 @@ func TestServerPublishErrorInvalidProtocol(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
track := &TrackH264{
PayloadType: 96,
@@ -876,7 +874,7 @@ func TestServerPublishErrorInvalidProtocol(t *testing.T) {
tracks := Tracks{track}
tracks.setControls()
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -901,7 +899,7 @@ func TestServerPublishErrorInvalidProtocol(t *testing.T) {
ClientPorts: &[2]int{35466, 35467},
}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -920,7 +918,7 @@ func TestServerPublishErrorInvalidProtocol(t *testing.T) {
err = th.Unmarshal(res.Header["Transport"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Record,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -931,11 +929,10 @@ func TestServerPublishErrorInvalidProtocol(t *testing.T) {
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
byts, _ := base.InterleavedFrame{
err = conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 0,
Payload: []byte{0x01, 0x02, 0x03, 0x04},
}.Marshal()
_, err = conn.Write(byts)
}, make([]byte, 1024))
require.NoError(t, err)
}
@@ -968,10 +965,10 @@ func TestServerPublishRTCPReport(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
track := &TrackH264{
PayloadType: 96,
@@ -982,7 +979,7 @@ func TestServerPublishRTCPReport(t *testing.T) {
tracks := Tracks{track}
tracks.setControls()
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1002,7 +999,7 @@ func TestServerPublishRTCPReport(t *testing.T) {
require.NoError(t, err)
defer l2.Close()
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -1032,7 +1029,7 @@ func TestServerPublishRTCPReport(t *testing.T) {
err = th.Unmarshal(res.Header["Transport"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Record,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1105,13 +1102,13 @@ func TestServerPublishTimeout(t *testing.T) {
"tcp",
} {
t.Run(transport, func(t *testing.T) {
connClosed := make(chan struct{})
nconnClosed := make(chan struct{})
sessionClosed := make(chan struct{})
s := &Server{
Handler: &testServerHandler{
onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) {
close(connClosed)
close(nconnClosed)
},
onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) {
close(sessionClosed)
@@ -1145,10 +1142,10 @@ func TestServerPublishTimeout(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
track := &TrackH264{
PayloadType: 96,
@@ -1159,7 +1156,7 @@ func TestServerPublishTimeout(t *testing.T) {
tracks := Tracks{track}
tracks.setControls()
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1190,7 +1187,7 @@ func TestServerPublishTimeout(t *testing.T) {
inTH.InterleavedIDs = &[2]int{0, 1}
}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -1209,7 +1206,7 @@ func TestServerPublishTimeout(t *testing.T) {
err = th.Unmarshal(res.Header["Transport"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Record,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1223,7 +1220,7 @@ func TestServerPublishTimeout(t *testing.T) {
<-sessionClosed
if transport == "tcp" {
<-connClosed
<-nconnClosed
}
})
}
@@ -1235,13 +1232,13 @@ func TestServerPublishWithoutTeardown(t *testing.T) {
"tcp",
} {
t.Run(transport, func(t *testing.T) {
connClosed := make(chan struct{})
nconnClosed := make(chan struct{})
sessionClosed := make(chan struct{})
s := &Server{
Handler: &testServerHandler{
onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) {
close(connClosed)
close(nconnClosed)
},
onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) {
close(sessionClosed)
@@ -1275,9 +1272,9 @@ func TestServerPublishWithoutTeardown(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
br := bufio.NewReader(conn)
conn := conn.NewConn(nconn)
track := &TrackH264{
PayloadType: 96,
@@ -1288,7 +1285,7 @@ func TestServerPublishWithoutTeardown(t *testing.T) {
tracks := Tracks{track}
tracks.setControls()
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1319,7 +1316,7 @@ func TestServerPublishWithoutTeardown(t *testing.T) {
inTH.InterleavedIDs = &[2]int{0, 1}
}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -1338,7 +1335,7 @@ func TestServerPublishWithoutTeardown(t *testing.T) {
err = th.Unmarshal(res.Header["Transport"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Record,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1349,10 +1346,10 @@ func TestServerPublishWithoutTeardown(t *testing.T) {
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
conn.Close()
nconn.Close()
<-sessionClosed
<-connClosed
<-nconnClosed
})
}
}
@@ -1395,10 +1392,10 @@ func TestServerPublishUDPChangeConn(t *testing.T) {
sxID := ""
func() {
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
track := &TrackH264{
PayloadType: 96,
@@ -1409,7 +1406,7 @@ func TestServerPublishUDPChangeConn(t *testing.T) {
tracks := Tracks{track}
tracks.setControls()
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1434,7 +1431,7 @@ func TestServerPublishUDPChangeConn(t *testing.T) {
ClientPorts: &[2]int{35466, 35467},
}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -1449,7 +1446,7 @@ func TestServerPublishUDPChangeConn(t *testing.T) {
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Record,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1464,12 +1461,12 @@ func TestServerPublishUDPChangeConn(t *testing.T) {
}()
func() {
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.GetParameter,
URL: mustParseURL("rtsp://localhost:8554/teststream/"),
Header: base.Header{

View File

@@ -1,7 +1,6 @@
package gortsplib
import (
"bufio"
"crypto/tls"
"net"
"strconv"
@@ -16,6 +15,7 @@ import (
"golang.org/x/net/ipv4"
"github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/conn"
"github.com/aler9/gortsplib/pkg/headers"
"github.com/aler9/gortsplib/pkg/url"
)
@@ -118,10 +118,10 @@ func TestServerReadSetupPath(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
th := &headers.Transport{
Protocol: headers.TransportProtocolTCP,
@@ -136,7 +136,7 @@ func TestServerReadSetupPath(t *testing.T) {
InterleavedIDs: &[2]int{ca.trackID * 2, (ca.trackID * 2) + 1},
}
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL(ca.url),
Header: base.Header{
@@ -157,7 +157,7 @@ func TestServerReadSetupErrors(t *testing.T) {
"closed stream",
} {
t.Run(ca, func(t *testing.T) {
connClosed := make(chan struct{})
nconnClosed := make(chan struct{})
track := &TrackH264{
PayloadType: 96,
@@ -185,7 +185,7 @@ func TestServerReadSetupErrors(t *testing.T) {
case "closed stream":
require.EqualError(t, ctx.Error, "stream is closed")
}
close(connClosed)
close(nconnClosed)
},
onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
@@ -200,10 +200,10 @@ func TestServerReadSetupErrors(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
th := &headers.Transport{
Protocol: headers.TransportProtocolTCP,
@@ -218,7 +218,7 @@ func TestServerReadSetupErrors(t *testing.T) {
InterleavedIDs: &[2]int{0, 1},
}
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -237,7 +237,7 @@ func TestServerReadSetupErrors(t *testing.T) {
require.NoError(t, err)
th.InterleavedIDs = &[2]int{2, 3}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/test12stream/trackID=1"),
Header: base.Header{
@@ -258,7 +258,7 @@ func TestServerReadSetupErrors(t *testing.T) {
require.NoError(t, err)
th.InterleavedIDs = &[2]int{2, 3}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -275,7 +275,7 @@ func TestServerReadSetupErrors(t *testing.T) {
require.Equal(t, base.StatusBadRequest, res.StatusCode)
}
<-connClosed
<-nconnClosed
})
}
}
@@ -288,8 +288,8 @@ func TestServerRead(t *testing.T) {
"multicast",
} {
t.Run(transport, func(t *testing.T) {
connOpened := make(chan struct{})
connClosed := make(chan struct{})
nconnOpened := make(chan struct{})
nconnClosed := make(chan struct{})
sessionOpened := make(chan struct{})
sessionClosed := make(chan struct{})
framesReceived := make(chan struct{})
@@ -310,10 +310,10 @@ func TestServerRead(t *testing.T) {
s := &Server{
Handler: &testServerHandler{
onConnOpen: func(ctx *ServerHandlerOnConnOpenCtx) {
close(connOpened)
close(nconnOpened)
},
onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) {
close(connClosed)
close(nconnClosed)
},
onSessionOpen: func(ctx *ServerHandlerOnSessionOpenCtx) {
close(sessionOpened)
@@ -385,18 +385,18 @@ func TestServerRead(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", listenIP+":8554")
nconn, err := net.Dial("tcp", listenIP+":8554")
require.NoError(t, err)
conn = func() net.Conn {
nconn = func() net.Conn {
if transport == "tls" {
return tls.Client(conn, &tls.Config{InsecureSkipVerify: true})
return tls.Client(nconn, &tls.Config{InsecureSkipVerify: true})
}
return conn
return nconn
}()
br := bufio.NewReader(conn)
conn := conn.NewConn(nconn)
<-connOpened
<-nconnOpened
inTH := &headers.Transport{
Mode: func() *headers.TransportMode {
@@ -424,7 +424,7 @@ func TestServerRead(t *testing.T) {
inTH.InterleavedIDs = &[2]int{4, 5}
}
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://" + listenIP + ":8554/teststream/trackID=0"),
Header: base.Header{
@@ -498,7 +498,7 @@ func TestServerRead(t *testing.T) {
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Play,
URL: mustParseURL("rtsp://" + listenIP + ":8554/teststream"),
Header: base.Header{
@@ -519,7 +519,7 @@ func TestServerRead(t *testing.T) {
case "tcp", "tls":
var f base.InterleavedFrame
err := f.Read(2048, br)
err := conn.ReadInterleavedFrame(&f)
require.NoError(t, err)
switch f.Channel {
@@ -549,7 +549,7 @@ func TestServerRead(t *testing.T) {
var f base.InterleavedFrame
for i := 0; i < 2; i++ {
err := f.Read(2048, br)
err := conn.ReadInterleavedFrame(&f)
require.NoError(t, err)
switch f.Channel {
@@ -582,18 +582,17 @@ func TestServerRead(t *testing.T) {
<-framesReceived
default:
byts, _ := base.InterleavedFrame{
err := conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 5,
Payload: testRTCPPacketMarshaled,
}.Marshal()
_, err = conn.Write(byts)
}, make([]byte, 1024))
require.NoError(t, err)
<-framesReceived
}
if transport == "udp" || transport == "multicast" {
// ping with OPTIONS
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Options,
URL: mustParseURL("rtsp://" + listenIP + ":8554/teststream"),
Header: base.Header{
@@ -605,7 +604,7 @@ func TestServerRead(t *testing.T) {
require.Equal(t, base.StatusOK, res.StatusCode)
// ping with GET_PARAMETER
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.GetParameter,
URL: mustParseURL("rtsp://" + listenIP + ":8554/teststream"),
Header: base.Header{
@@ -617,7 +616,7 @@ func TestServerRead(t *testing.T) {
require.Equal(t, base.StatusOK, res.StatusCode)
}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Teardown,
URL: mustParseURL("rtsp://" + listenIP + ":8554/teststream"),
Header: base.Header{
@@ -630,8 +629,8 @@ func TestServerRead(t *testing.T) {
<-sessionClosed
conn.Close()
<-connClosed
nconn.Close()
<-nconnClosed
})
}
}
@@ -669,10 +668,10 @@ func TestServerReadRTCPReport(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
inTH := &headers.Transport{
Mode: func() *headers.TransportMode {
@@ -687,7 +686,7 @@ func TestServerReadRTCPReport(t *testing.T) {
ClientPorts: &[2]int{35466, 35467},
}
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -710,7 +709,7 @@ func TestServerReadRTCPReport(t *testing.T) {
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Play,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -737,7 +736,7 @@ func TestServerReadRTCPReport(t *testing.T) {
OctetCount: 8,
}, packets[0])
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Teardown,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -779,12 +778,12 @@ func TestServerReadVLCMulticast(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", listenIP+":8554")
nconn, err := net.Dial("tcp", listenIP+":8554")
require.NoError(t, err)
br := bufio.NewReader(conn)
defer conn.Close()
conn := conn.NewConn(nconn)
defer nconn.Close()
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Describe,
URL: mustParseURL("rtsp://" + listenIP + ":8554/teststream?vlcmulticast"),
Header: base.Header{
@@ -858,12 +857,12 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -889,7 +888,7 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) {
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Play,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -900,8 +899,8 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) {
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
var fr base.InterleavedFrame
err = fr.Read(2048, br)
var f base.InterleavedFrame
err = conn.ReadInterleavedFrame(&f)
require.NoError(t, err)
}
@@ -937,12 +936,12 @@ func TestServerReadPlayPlay(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -968,7 +967,7 @@ func TestServerReadPlayPlay(t *testing.T) {
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Play,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -979,7 +978,7 @@ func TestServerReadPlayPlay(t *testing.T) {
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Play,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1053,12 +1052,12 @@ func TestServerReadPlayPausePlay(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -1084,7 +1083,7 @@ func TestServerReadPlayPausePlay(t *testing.T) {
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Play,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1095,7 +1094,7 @@ func TestServerReadPlayPausePlay(t *testing.T) {
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Pause,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1106,7 +1105,7 @@ func TestServerReadPlayPausePlay(t *testing.T) {
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Play,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1176,12 +1175,12 @@ func TestServerReadPlayPausePause(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -1207,7 +1206,7 @@ func TestServerReadPlayPausePause(t *testing.T) {
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Play,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1218,33 +1217,31 @@ func TestServerReadPlayPausePause(t *testing.T) {
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
byts, _ := base.Request{
err = conn.WriteRequest(&base.Request{
Method: base.Pause,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Session": base.HeaderValue{sx.Session},
},
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
res, err = readResIgnoreFrames(br)
res, err = readResIgnoreFrames(conn)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
byts, _ = base.Request{
err = conn.WriteRequest(&base.Request{
Method: base.Pause,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Session": base.HeaderValue{sx.Session},
},
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
res, err = readResIgnoreFrames(br)
res, err = readResIgnoreFrames(conn)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
}
@@ -1308,10 +1305,10 @@ func TestServerReadTimeout(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
inTH := &headers.Transport{
Mode: func() *headers.TransportMode {
@@ -1333,7 +1330,7 @@ func TestServerReadTimeout(t *testing.T) {
inTH.Protocol = headers.TransportProtocolUDP
}
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -1348,7 +1345,7 @@ func TestServerReadTimeout(t *testing.T) {
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Play,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1370,7 +1367,7 @@ func TestServerReadWithoutTeardown(t *testing.T) {
"tcp",
} {
t.Run(transport, func(t *testing.T) {
connClosed := make(chan struct{})
nconnClosed := make(chan struct{})
sessionClosed := make(chan struct{})
track := &TrackH264{
@@ -1385,7 +1382,7 @@ func TestServerReadWithoutTeardown(t *testing.T) {
s := &Server{
Handler: &testServerHandler{
onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) {
close(connClosed)
close(nconnClosed)
},
onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) {
close(sessionClosed)
@@ -1420,10 +1417,10 @@ func TestServerReadWithoutTeardown(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
inTH := &headers.Transport{
Delivery: func() *headers.TransportDelivery {
@@ -1444,7 +1441,7 @@ func TestServerReadWithoutTeardown(t *testing.T) {
inTH.InterleavedIDs = &[2]int{0, 1}
}
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -1459,7 +1456,7 @@ func TestServerReadWithoutTeardown(t *testing.T) {
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Play,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1470,10 +1467,10 @@ func TestServerReadWithoutTeardown(t *testing.T) {
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
conn.Close()
nconn.Close()
<-sessionClosed
<-connClosed
<-nconnClosed
})
}
}
@@ -1518,10 +1515,10 @@ func TestServerReadUDPChangeConn(t *testing.T) {
sxID := ""
func() {
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
inTH := &headers.Transport{
Delivery: func() *headers.TransportDelivery {
@@ -1536,7 +1533,7 @@ func TestServerReadUDPChangeConn(t *testing.T) {
ClientPorts: &[2]int{35466, 35467},
}
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -1551,7 +1548,7 @@ func TestServerReadUDPChangeConn(t *testing.T) {
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Play,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1566,12 +1563,12 @@ func TestServerReadUDPChangeConn(t *testing.T) {
}()
func() {
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.GetParameter,
URL: mustParseURL("rtsp://localhost:8554/teststream/"),
Header: base.Header{
@@ -1626,10 +1623,10 @@ func TestServerReadPartialTracks(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
inTH := &headers.Transport{
Delivery: func() *headers.TransportDelivery {
@@ -1644,7 +1641,7 @@ func TestServerReadPartialTracks(t *testing.T) {
InterleavedIDs: &[2]int{4, 5},
}
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=1"),
Header: base.Header{
@@ -1659,7 +1656,7 @@ func TestServerReadPartialTracks(t *testing.T) {
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Play,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1671,7 +1668,7 @@ func TestServerReadPartialTracks(t *testing.T) {
require.Equal(t, base.StatusOK, res.StatusCode)
var f base.InterleavedFrame
err = f.Read(2048, br)
err = conn.ReadInterleavedFrame(&f)
require.NoError(t, err)
require.Equal(t, 4, f.Channel)
require.Equal(t, testRTPPacketMarshaled, f.Payload)
@@ -1679,10 +1676,10 @@ func TestServerReadPartialTracks(t *testing.T) {
func TestServerReadAdditionalInfos(t *testing.T) {
getInfos := func() (*headers.RTPInfo, []*uint32) {
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
ssrcs := make([]*uint32, 2)
@@ -1699,7 +1696,7 @@ func TestServerReadAdditionalInfos(t *testing.T) {
InterleavedIDs: &[2]int{0, 1},
}
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -1732,7 +1729,7 @@ func TestServerReadAdditionalInfos(t *testing.T) {
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=1"),
Header: base.Header{
@@ -1749,7 +1746,7 @@ func TestServerReadAdditionalInfos(t *testing.T) {
require.NoError(t, err)
ssrcs[1] = th.SSRC
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Play,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1913,10 +1910,10 @@ func TestServerReadErrorUDPSamePorts(t *testing.T) {
defer s.Close()
func() {
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
inTH := &headers.Transport{
Delivery: func() *headers.TransportDelivery {
@@ -1931,7 +1928,7 @@ func TestServerReadErrorUDPSamePorts(t *testing.T) {
ClientPorts: &[2]int{35466, 35467},
}
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -1946,7 +1943,7 @@ func TestServerReadErrorUDPSamePorts(t *testing.T) {
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Play,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1959,10 +1956,10 @@ func TestServerReadErrorUDPSamePorts(t *testing.T) {
}()
func() {
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
inTH := &headers.Transport{
Delivery: func() *headers.TransportDelivery {
@@ -1977,7 +1974,7 @@ func TestServerReadErrorUDPSamePorts(t *testing.T) {
ClientPorts: &[2]int{35466, 35467},
}
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{

View File

@@ -1,7 +1,6 @@
package gortsplib
import (
"bufio"
"fmt"
"net"
"testing"
@@ -10,6 +9,7 @@ import (
"github.com/aler9/gortsplib/pkg/auth"
"github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/conn"
"github.com/aler9/gortsplib/pkg/headers"
)
@@ -67,24 +67,23 @@ NkxNic7oHgsZpIkZ8HK+QjAAWA==
-----END PRIVATE KEY-----
`)
func writeReqReadRes(conn net.Conn,
br *bufio.Reader,
func writeReqReadRes(
conn *conn.Conn,
req base.Request,
) (*base.Response, error) {
byts, _ := req.Marshal()
_, err := conn.Write(byts)
err := conn.WriteRequest(&req)
if err != nil {
return nil, err
}
var res base.Response
err = res.Read(br)
err = conn.ReadResponse(&res)
return &res, err
}
func readResIgnoreFrames(br *bufio.Reader) (*base.Response, error) {
func readResIgnoreFrames(conn *conn.Conn) (*base.Response, error) {
var res base.Response
err := res.ReadIgnoreFrames(2048, br)
err := conn.ReadResponseIgnoreFrames(&res)
return &res, err
}
@@ -232,7 +231,7 @@ func TestServerErrorInvalidUDPPorts(t *testing.T) {
}
func TestServerConnClose(t *testing.T) {
connClosed := make(chan struct{})
nconnClosed := make(chan struct{})
s := &Server{
Handler: &testServerHandler{
@@ -241,7 +240,7 @@ func TestServerConnClose(t *testing.T) {
ctx.Conn.Close()
},
onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) {
close(connClosed)
close(nconnClosed)
},
},
RTSPAddress: "localhost:8554",
@@ -251,11 +250,11 @@ func TestServerConnClose(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
defer nconn.Close()
<-connClosed
<-nconnClosed
}
func TestServerCSeq(t *testing.T) {
@@ -266,12 +265,12 @@ func TestServerCSeq(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Options,
URL: mustParseURL("rtsp://localhost:8554/"),
Header: base.Header{
@@ -285,13 +284,13 @@ func TestServerCSeq(t *testing.T) {
}
func TestServerErrorCSeqMissing(t *testing.T) {
connClosed := make(chan struct{})
nconnClosed := make(chan struct{})
s := &Server{
Handler: &testServerHandler{
onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) {
require.EqualError(t, ctx.Error, "CSeq is missing")
close(connClosed)
close(nconnClosed)
},
},
RTSPAddress: "localhost:8554",
@@ -300,12 +299,12 @@ func TestServerErrorCSeqMissing(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Options,
URL: mustParseURL("rtsp://localhost:8554/"),
Header: base.Header{},
@@ -313,7 +312,7 @@ func TestServerErrorCSeqMissing(t *testing.T) {
require.NoError(t, err)
require.Equal(t, base.StatusBadRequest, res.StatusCode)
<-connClosed
<-nconnClosed
}
type testServerErrMethodNotImplemented struct {
@@ -349,15 +348,15 @@ func TestServerErrorMethodNotImplemented(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
var sx headers.Session
if ca == "inside session" {
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -389,7 +388,7 @@ func TestServerErrorMethodNotImplemented(t *testing.T) {
headers["Session"] = base.HeaderValue{sx.Session}
}
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.SetParameter,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: headers,
@@ -404,7 +403,7 @@ func TestServerErrorMethodNotImplemented(t *testing.T) {
headers["Session"] = base.HeaderValue{sx.Session}
}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Options,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: headers,
@@ -450,12 +449,12 @@ func TestServerErrorTCPTwoConnOneSession(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn1, err := net.Dial("tcp", "localhost:8554")
nconn1, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn1.Close()
br1 := bufio.NewReader(conn1)
defer nconn1.Close()
conn1 := conn.NewConn(nconn1)
res, err := writeReqReadRes(conn1, br1, base.Request{
res, err := writeReqReadRes(conn1, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -481,7 +480,7 @@ func TestServerErrorTCPTwoConnOneSession(t *testing.T) {
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn1, br1, base.Request{
res, err = writeReqReadRes(conn1, base.Request{
Method: base.Play,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -492,12 +491,12 @@ func TestServerErrorTCPTwoConnOneSession(t *testing.T) {
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
conn2, err := net.Dial("tcp", "localhost:8554")
nconn2, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn2.Close()
br2 := bufio.NewReader(conn2)
defer nconn2.Close()
conn2 := conn.NewConn(nconn2)
res, err = writeReqReadRes(conn2, br2, base.Request{
res, err = writeReqReadRes(conn2, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -556,12 +555,12 @@ func TestServerErrorTCPOneConnTwoSessions(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -587,7 +586,7 @@ func TestServerErrorTCPOneConnTwoSessions(t *testing.T) {
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Play,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -598,7 +597,7 @@ func TestServerErrorTCPOneConnTwoSessions(t *testing.T) {
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -668,15 +667,15 @@ func TestServerGetSetParameter(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
var sx headers.Session
if ca == "inside session" {
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -709,7 +708,7 @@ func TestServerGetSetParameter(t *testing.T) {
headers["Session"] = base.HeaderValue{sx.Session}
}
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.SetParameter,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: headers,
@@ -725,7 +724,7 @@ func TestServerGetSetParameter(t *testing.T) {
headers["Session"] = base.HeaderValue{sx.Session}
}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.GetParameter,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: headers,
@@ -771,12 +770,12 @@ func TestServerErrorInvalidSession(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: method,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -815,11 +814,12 @@ func TestServerSessionClose(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
defer nconn.Close()
conn := conn.NewConn(nconn)
byts, _ := base.Request{
err = conn.WriteRequest(&base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -837,8 +837,7 @@ func TestServerSessionClose(t *testing.T) {
InterleavedIDs: &[2]int{0, 1},
}.Marshal(),
},
}.Marshal()
_, err = conn.Write(byts)
})
require.NoError(t, err)
<-sessionClosed
@@ -884,11 +883,11 @@ func TestServerSessionAutoClose(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
br := bufio.NewReader(conn)
conn := conn.NewConn(nconn)
_, err = writeReqReadRes(conn, br, base.Request{
_, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -909,7 +908,7 @@ func TestServerSessionAutoClose(t *testing.T) {
})
require.NoError(t, err)
conn.Close()
nconn.Close()
<-sessionClosed
})
@@ -919,7 +918,7 @@ func TestServerSessionAutoClose(t *testing.T) {
func TestServerErrorInvalidPath(t *testing.T) {
for _, ca := range []string{"inside session", "outside session"} {
t.Run(ca, func(t *testing.T) {
connClosed := make(chan struct{})
nconnClosed := make(chan struct{})
track := &TrackH264{
PayloadType: 96,
@@ -934,7 +933,7 @@ func TestServerErrorInvalidPath(t *testing.T) {
Handler: &testServerHandler{
onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) {
require.EqualError(t, ctx.Error, "invalid path")
close(connClosed)
close(nconnClosed)
},
onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
@@ -949,13 +948,13 @@ func TestServerErrorInvalidPath(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
if ca == "inside session" {
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -981,7 +980,7 @@ func TestServerErrorInvalidPath(t *testing.T) {
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.SetParameter,
URL: mustParseURL("rtsp://localhost:8554"),
Header: base.Header{
@@ -992,7 +991,7 @@ func TestServerErrorInvalidPath(t *testing.T) {
require.NoError(t, err)
require.Equal(t, base.StatusBadRequest, res.StatusCode)
} else {
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.SetParameter,
URL: mustParseURL("rtsp://localhost:8554"),
Header: base.Header{
@@ -1003,7 +1002,7 @@ func TestServerErrorInvalidPath(t *testing.T) {
require.Equal(t, base.StatusBadRequest, res.StatusCode)
}
<-connClosed
<-nconnClosed
})
}
}
@@ -1036,10 +1035,10 @@ func TestServerAuth(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
track := &TrackH264{
PayloadType: 96,
@@ -1057,7 +1056,7 @@ func TestServerAuth(t *testing.T) {
Body: Tracks{track}.Marshal(false),
}
res, err := writeReqReadRes(conn, br, req)
res, err := writeReqReadRes(conn, req)
require.NoError(t, err)
require.Equal(t, base.StatusUnauthorized, res.StatusCode)
@@ -1065,7 +1064,7 @@ func TestServerAuth(t *testing.T) {
require.NoError(t, err)
sender.AddAuthorization(&req)
res, err = writeReqReadRes(conn, br, req)
res, err = writeReqReadRes(conn, req)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
}

View File

@@ -1,7 +1,6 @@
package gortsplib
import (
"bufio"
"context"
"crypto/tls"
"errors"
@@ -14,6 +13,7 @@ import (
"github.com/pion/rtcp"
"github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/conn"
"github.com/aler9/gortsplib/pkg/liberrors"
"github.com/aler9/gortsplib/pkg/url"
)
@@ -32,13 +32,13 @@ type readReq struct {
// ServerConn is a server-side RTSP connection.
type ServerConn struct {
s *Server
conn net.Conn
s *Server
nconn net.Conn
ctx context.Context
ctxCancel func()
remoteAddr *net.TCPAddr
br *bufio.Reader
conn *conn.Conn
session *ServerSession
readFunc func(readRequest chan readReq) error
@@ -55,7 +55,7 @@ func newServerConn(
) *ServerConn {
ctx, ctxCancel := context.WithCancel(s.ctx)
conn := func() net.Conn {
nconn = func() net.Conn {
if s.TLSConfig != nil {
return tls.Server(nconn, s.TLSConfig)
}
@@ -64,10 +64,10 @@ func newServerConn(
sc := &ServerConn{
s: s,
conn: conn,
nconn: nconn,
ctx: ctx,
ctxCancel: ctxCancel,
remoteAddr: conn.RemoteAddr().(*net.TCPAddr),
remoteAddr: nconn.RemoteAddr().(*net.TCPAddr),
sessionRemove: make(chan *ServerSession),
done: make(chan struct{}),
}
@@ -88,7 +88,7 @@ func (sc *ServerConn) Close() error {
// NetConn returns the underlying net.Conn.
func (sc *ServerConn) NetConn() net.Conn {
return sc.conn
return sc.nconn
}
func (sc *ServerConn) ip() net.IP {
@@ -109,7 +109,7 @@ func (sc *ServerConn) run() {
})
}
sc.br = bufio.NewReaderSize(sc.conn, tcpReadBufferSize)
sc.conn = conn.NewConn(sc.nconn)
readRequest := make(chan readReq)
readErr := make(chan error)
@@ -120,7 +120,7 @@ func (sc *ServerConn) run() {
sc.ctxCancel()
sc.conn.Close()
sc.nconn.Close()
<-readDone
if sc.session != nil {
@@ -185,12 +185,12 @@ func (sc *ServerConn) runReader(readRequest chan readReq, readErr chan error, re
func (sc *ServerConn) readFuncStandard(readRequest chan readReq) error {
// reset deadline
sc.conn.SetReadDeadline(time.Time{})
sc.nconn.SetReadDeadline(time.Time{})
var req base.Request
for {
err := req.Read(sc.br)
err := sc.conn.ReadRequest(&req)
if err != nil {
return err
}
@@ -211,7 +211,7 @@ func (sc *ServerConn) readFuncStandard(readRequest chan readReq) error {
func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
// reset deadline
sc.conn.SetReadDeadline(time.Time{})
sc.nconn.SetReadDeadline(time.Time{})
select {
case sc.session.startWriter <- struct{}{}:
@@ -299,10 +299,10 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
for {
if sc.session.state == ServerSessionStateRecord {
sc.conn.SetReadDeadline(time.Now().Add(sc.s.ReadTimeout))
sc.nconn.SetReadDeadline(time.Now().Add(sc.s.ReadTimeout))
}
what, err := base.ReadInterleavedFrameOrRequest(&frame, tcpMaxFramePayloadSize, &req, sc.br)
what, err := sc.conn.ReadInterleavedFrameOrRequest(&frame, &req)
if err != nil {
return err
}
@@ -532,10 +532,8 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error {
h.OnResponse(sc, res)
}
byts, _ := res.Marshal()
sc.conn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout))
sc.conn.Write(byts)
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout))
sc.conn.WriteResponse(res)
return err
}

View File

@@ -1163,19 +1163,17 @@ func (ss *ServerSession) runWriter() {
writeFunc = func(trackID int, isRTP bool, payload []byte) {
if isRTP {
f := rtpFrames[trackID]
f.Payload = payload
n, _ := f.MarshalTo(buf)
fr := rtpFrames[trackID]
fr.Payload = payload
ss.tcpConn.conn.SetWriteDeadline(time.Now().Add(ss.s.WriteTimeout))
ss.tcpConn.conn.Write(buf[:n])
ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(ss.s.WriteTimeout))
ss.tcpConn.conn.WriteInterleavedFrame(fr, buf)
} else {
f := rtcpFrames[trackID]
f.Payload = payload
n, _ := f.MarshalTo(buf)
fr := rtcpFrames[trackID]
fr.Payload = payload
ss.tcpConn.conn.SetWriteDeadline(time.Now().Add(ss.s.WriteTimeout))
ss.tcpConn.conn.Write(buf[:n])
ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(ss.s.WriteTimeout))
ss.tcpConn.conn.WriteInterleavedFrame(fr, buf)
}
}
}