add intermediate layer between net.Conn and client / server

This commit is contained in:
aler9
2022-08-14 23:43:01 +02:00
parent a0a168d26c
commit 06bed24dd9
18 changed files with 1459 additions and 1561 deletions

View File

@@ -8,7 +8,6 @@ Examples are available at https://github.com/aler9/gortsplib/tree/master/example
package gortsplib
import (
"bufio"
"context"
"crypto/tls"
"fmt"
@@ -24,6 +23,7 @@ import (
"github.com/aler9/gortsplib/pkg/auth"
"github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/conn"
"github.com/aler9/gortsplib/pkg/headers"
"github.com/aler9/gortsplib/pkg/liberrors"
"github.com/aler9/gortsplib/pkg/ringbuffer"
@@ -256,8 +256,8 @@ type Client struct {
ctx context.Context
ctxCancel func()
state clientState
conn net.Conn
br *bufio.Reader
nconn net.Conn
conn *conn.Conn
session string
sender *auth.Sender
cseq int
@@ -581,11 +581,13 @@ func (c *Client) doClose() {
URL: c.baseURL,
}, true, false)
c.conn.Close()
c.nconn.Close()
c.nconn = nil
c.conn = nil
} else if c.conn != nil {
} else if c.nconn != nil {
c.connCloserStop()
c.conn.Close()
c.nconn.Close()
c.nconn = nil
c.conn = nil
}
@@ -756,7 +758,7 @@ func (c *Client) playRecordStart() {
// for some reason, SetReadDeadline() must always be called in the same
// goroutine, otherwise Read() freezes.
// therefore, we disable the deadline and perform a check with a ticker.
c.conn.SetReadDeadline(time.Time{})
c.nconn.SetReadDeadline(time.Time{})
// start reader
c.readerErr = make(chan error)
@@ -768,7 +770,7 @@ func (c *Client) runReader() {
if *c.effectiveTransport == TransportUDP || *c.effectiveTransport == TransportUDPMulticast {
for {
var res base.Response
err := res.Read(c.br)
err := c.conn.ReadResponse(&res)
if err != nil {
return err
}
@@ -854,7 +856,7 @@ func (c *Client) runReader() {
var res base.Response
for {
what, err := base.ReadInterleavedFrameOrResponse(&frame, tcpMaxFramePayloadSize, &res, c.br)
what, err := c.conn.ReadInterleavedFrameOrResponse(&frame, &res)
if err != nil {
return err
}
@@ -885,7 +887,7 @@ func (c *Client) runReader() {
func (c *Client) playRecordStop(isClosing bool) {
// stop reader
if c.readerErr != nil {
c.conn.SetReadDeadline(time.Now())
c.nconn.SetReadDeadline(time.Now())
<-c.readerErr
}
@@ -963,7 +965,7 @@ func (c *Client) connOpen() error {
return err
}
c.conn = func() net.Conn {
c.nconn = func() net.Conn {
if c.scheme == "rtsps" {
tlsConfig := c.TLSConfig
@@ -979,7 +981,8 @@ func (c *Client) connOpen() error {
return nconn
}()
c.br = bufio.NewReaderSize(c.conn, tcpReadBufferSize)
c.conn = conn.NewConn(c.nconn)
c.connCloserStart()
return nil
}
@@ -993,7 +996,7 @@ func (c *Client) connCloserStart() {
select {
case <-c.ctx.Done():
c.conn.Close()
c.nconn.Close()
case <-c.connCloserTerminate:
}
@@ -1007,7 +1010,7 @@ func (c *Client) connCloserStop() {
}
func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*base.Response, error) {
if c.conn == nil {
if c.nconn == nil {
err := c.connOpen()
if err != nil {
return nil, err
@@ -1042,10 +1045,8 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba
c.OnRequest(req)
}
byts, _ := req.Marshal()
c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
_, err := c.conn.Write(byts)
c.nconn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
err := c.conn.WriteRequest(req)
if err != nil {
return nil, err
}
@@ -1053,19 +1054,19 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba
var res base.Response
if !skipResponse {
c.conn.SetReadDeadline(time.Now().Add(c.ReadTimeout))
c.nconn.SetReadDeadline(time.Now().Add(c.ReadTimeout))
if allowFrames {
// read the response and ignore interleaved frames in between;
// interleaved frames are sent in two cases:
// * when the server is v4lrtspserver, before the PLAY response
// * when the stream is already playing
err = res.ReadIgnoreFrames(tcpMaxFramePayloadSize, c.br)
err = c.conn.ReadResponseIgnoreFrames(&res)
if err != nil {
return nil, err
}
} else {
err = res.Read(c.br)
err = c.conn.ReadResponse(&res)
if err != nil {
return nil, err
}
@@ -1491,13 +1492,13 @@ func (c *Client) doSetup(
if thRes.Source != nil {
return *thRes.Source
}
return c.conn.RemoteAddr().(*net.TCPAddr).IP
return c.nconn.RemoteAddr().(*net.TCPAddr).IP
}()
if thRes.ServerPorts != nil {
ct.udpRTPListener.readPort = thRes.ServerPorts[0]
ct.udpRTPListener.writeAddr = &net.UDPAddr{
IP: c.conn.RemoteAddr().(*net.TCPAddr).IP,
Zone: c.conn.RemoteAddr().(*net.TCPAddr).Zone,
IP: c.nconn.RemoteAddr().(*net.TCPAddr).IP,
Zone: c.nconn.RemoteAddr().(*net.TCPAddr).Zone,
Port: thRes.ServerPorts[0],
}
}
@@ -1506,13 +1507,13 @@ func (c *Client) doSetup(
if thRes.Source != nil {
return *thRes.Source
}
return c.conn.RemoteAddr().(*net.TCPAddr).IP
return c.nconn.RemoteAddr().(*net.TCPAddr).IP
}()
if thRes.ServerPorts != nil {
ct.udpRTCPListener.readPort = thRes.ServerPorts[1]
ct.udpRTCPListener.writeAddr = &net.UDPAddr{
IP: c.conn.RemoteAddr().(*net.TCPAddr).IP,
Zone: c.conn.RemoteAddr().(*net.TCPAddr).Zone,
IP: c.nconn.RemoteAddr().(*net.TCPAddr).IP,
Zone: c.nconn.RemoteAddr().(*net.TCPAddr).Zone,
Port: thRes.ServerPorts[1],
}
}
@@ -1551,14 +1552,14 @@ func (c *Client) doSetup(
return nil, err
}
ct.udpRTPListener.readIP = c.conn.RemoteAddr().(*net.TCPAddr).IP
ct.udpRTPListener.readIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP
ct.udpRTPListener.readPort = thRes.Ports[0]
ct.udpRTPListener.writeAddr = &net.UDPAddr{
IP: *thRes.Destination,
Port: thRes.Ports[0],
}
ct.udpRTCPListener.readIP = c.conn.RemoteAddr().(*net.TCPAddr).IP
ct.udpRTCPListener.readIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP
ct.udpRTCPListener.readPort = thRes.Ports[1]
ct.udpRTCPListener.writeAddr = &net.UDPAddr{
IP: *thRes.Destination,
@@ -1848,19 +1849,17 @@ func (c *Client) runWriter() {
writeFunc = func(trackID int, isRTP bool, payload []byte) {
if isRTP {
f := rtpFrames[trackID]
f.Payload = payload
n, _ := f.MarshalTo(buf)
fr := rtpFrames[trackID]
fr.Payload = payload
c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
c.conn.Write(buf[:n])
c.nconn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
c.conn.WriteInterleavedFrame(fr, buf)
} else {
f := rtcpFrames[trackID]
f.Payload = payload
n, _ := f.MarshalTo(buf)
fr := rtcpFrames[trackID]
fr.Payload = payload
c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
c.conn.Write(buf[:n])
c.nconn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
c.conn.WriteInterleavedFrame(fr, buf)
}
}
}