add intermediate layer between net.Conn and client / server

This commit is contained in:
aler9
2022-08-14 23:43:01 +02:00
parent a0a168d26c
commit 06bed24dd9
18 changed files with 1459 additions and 1561 deletions

View File

@@ -1,7 +1,6 @@
package gortsplib
import (
"bufio"
"crypto/tls"
"net"
"testing"
@@ -13,6 +12,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/conn"
"github.com/aler9/gortsplib/pkg/headers"
)
@@ -113,13 +113,13 @@ func TestServerPublishErrorAnnounce(t *testing.T) {
},
} {
t.Run(ca.name, func(t *testing.T) {
connClosed := make(chan struct{})
nconnClosed := make(chan struct{})
s := &Server{
Handler: &testServerHandler{
onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) {
require.EqualError(t, ctx.Error, ca.err)
close(connClosed)
close(nconnClosed)
},
onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) {
return &base.Response{
@@ -134,15 +134,15 @@ func TestServerPublishErrorAnnounce(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
_, err = writeReqReadRes(conn, br, ca.req)
_, err = writeReqReadRes(conn, ca.req)
require.NoError(t, err)
<-connClosed
<-nconnClosed
})
}
}
@@ -225,10 +225,10 @@ func TestServerPublishSetupPath(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
track := &TrackH264{
PayloadType: 96,
@@ -255,7 +255,7 @@ func TestServerPublishSetupPath(t *testing.T) {
byts, _ := sout.Marshal()
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/" + ca.path),
Header: base.Header{
@@ -280,7 +280,7 @@ func TestServerPublishSetupPath(t *testing.T) {
InterleavedIDs: &[2]int{0, 1},
}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL(ca.url),
Header: base.Header{
@@ -320,10 +320,10 @@ func TestServerPublishErrorSetupDifferentPaths(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
track := &TrackH264{
PayloadType: 96,
@@ -334,7 +334,7 @@ func TestServerPublishErrorSetupDifferentPaths(t *testing.T) {
tracks := Tracks{track}
tracks.setControls()
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -359,7 +359,7 @@ func TestServerPublishErrorSetupDifferentPaths(t *testing.T) {
InterleavedIDs: &[2]int{0, 1},
}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/test2stream/trackID=0"),
Header: base.Header{
@@ -400,10 +400,10 @@ func TestServerPublishErrorSetupTrackTwice(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
track := &TrackH264{
PayloadType: 96,
@@ -414,7 +414,7 @@ func TestServerPublishErrorSetupTrackTwice(t *testing.T) {
tracks := Tracks{track}
tracks.setControls()
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -439,7 +439,7 @@ func TestServerPublishErrorSetupTrackTwice(t *testing.T) {
InterleavedIDs: &[2]int{0, 1},
}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -454,7 +454,7 @@ func TestServerPublishErrorSetupTrackTwice(t *testing.T) {
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -501,10 +501,10 @@ func TestServerPublishErrorRecordPartialTracks(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
track1 := &TrackH264{
PayloadType: 96,
@@ -521,7 +521,7 @@ func TestServerPublishErrorRecordPartialTracks(t *testing.T) {
tracks := Tracks{track1, track2}
tracks.setControls()
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -546,7 +546,7 @@ func TestServerPublishErrorRecordPartialTracks(t *testing.T) {
InterleavedIDs: &[2]int{0, 1},
}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -561,7 +561,7 @@ func TestServerPublishErrorRecordPartialTracks(t *testing.T) {
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Record,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -583,18 +583,18 @@ func TestServerPublish(t *testing.T) {
"tls",
} {
t.Run(transport, func(t *testing.T) {
connOpened := make(chan struct{})
connClosed := make(chan struct{})
nconnOpened := make(chan struct{})
nconnClosed := make(chan struct{})
sessionOpened := make(chan struct{})
sessionClosed := make(chan struct{})
s := &Server{
Handler: &testServerHandler{
onConnOpen: func(ctx *ServerHandlerOnConnOpenCtx) {
close(connOpened)
close(nconnOpened)
},
onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) {
close(connClosed)
close(nconnClosed)
},
onSessionOpen: func(ctx *ServerHandlerOnSessionOpenCtx) {
close(sessionOpened)
@@ -649,19 +649,19 @@ func TestServerPublish(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
defer nconn.Close()
conn = func() net.Conn {
nconn = func() net.Conn {
if transport == "tls" {
return tls.Client(conn, &tls.Config{InsecureSkipVerify: true})
return tls.Client(nconn, &tls.Config{InsecureSkipVerify: true})
}
return conn
return nconn
}()
br := bufio.NewReader(conn)
conn := conn.NewConn(nconn)
<-connOpened
<-nconnOpened
track := &TrackH264{
PayloadType: 96,
@@ -672,7 +672,7 @@ func TestServerPublish(t *testing.T) {
tracks := Tracks{track}
tracks.setControls()
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -716,7 +716,7 @@ func TestServerPublish(t *testing.T) {
inTH.InterleavedIDs = &[2]int{0, 1}
}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -735,7 +735,7 @@ func TestServerPublish(t *testing.T) {
err = th.Unmarshal(res.Header["Transport"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Record,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -754,7 +754,7 @@ func TestServerPublish(t *testing.T) {
require.Equal(t, testRTCPPacketMarshaled, buf[:n])
} else {
var f base.InterleavedFrame
err := f.Read(2048, br)
err := conn.ReadInterleavedFrame(&f)
require.NoError(t, err)
require.Equal(t, 1, f.Channel)
require.Equal(t, testRTCPPacketMarshaled, f.Payload)
@@ -783,18 +783,16 @@ func TestServerPublish(t *testing.T) {
Port: th.ServerPorts[1],
})
} else {
byts, _ := base.InterleavedFrame{
err := conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 0,
Payload: testRTPPacketMarshaled,
}.Marshal()
_, err = conn.Write(byts)
}, make([]byte, 1024))
require.NoError(t, err)
byts, _ = base.InterleavedFrame{
err = conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 1,
Payload: testRTCPPacketMarshaled,
}.Marshal()
_, err = conn.Write(byts)
}, make([]byte, 1024))
require.NoError(t, err)
}
@@ -806,13 +804,13 @@ func TestServerPublish(t *testing.T) {
require.Equal(t, testRTCPPacketMarshaled, buf[:n])
} else {
var f base.InterleavedFrame
err := f.Read(2048, br)
err := conn.ReadInterleavedFrame(&f)
require.NoError(t, err)
require.Equal(t, 1, f.Channel)
require.Equal(t, testRTCPPacketMarshaled, f.Payload)
}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Teardown,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -825,8 +823,8 @@ func TestServerPublish(t *testing.T) {
<-sessionClosed
conn.Close()
<-connClosed
nconn.Close()
<-nconnClosed
})
}
}
@@ -862,10 +860,10 @@ func TestServerPublishErrorInvalidProtocol(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
track := &TrackH264{
PayloadType: 96,
@@ -876,7 +874,7 @@ func TestServerPublishErrorInvalidProtocol(t *testing.T) {
tracks := Tracks{track}
tracks.setControls()
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -901,7 +899,7 @@ func TestServerPublishErrorInvalidProtocol(t *testing.T) {
ClientPorts: &[2]int{35466, 35467},
}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -920,7 +918,7 @@ func TestServerPublishErrorInvalidProtocol(t *testing.T) {
err = th.Unmarshal(res.Header["Transport"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Record,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -931,11 +929,10 @@ func TestServerPublishErrorInvalidProtocol(t *testing.T) {
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
byts, _ := base.InterleavedFrame{
err = conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 0,
Payload: []byte{0x01, 0x02, 0x03, 0x04},
}.Marshal()
_, err = conn.Write(byts)
}, make([]byte, 1024))
require.NoError(t, err)
}
@@ -968,10 +965,10 @@ func TestServerPublishRTCPReport(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
track := &TrackH264{
PayloadType: 96,
@@ -982,7 +979,7 @@ func TestServerPublishRTCPReport(t *testing.T) {
tracks := Tracks{track}
tracks.setControls()
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1002,7 +999,7 @@ func TestServerPublishRTCPReport(t *testing.T) {
require.NoError(t, err)
defer l2.Close()
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -1032,7 +1029,7 @@ func TestServerPublishRTCPReport(t *testing.T) {
err = th.Unmarshal(res.Header["Transport"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Record,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1105,13 +1102,13 @@ func TestServerPublishTimeout(t *testing.T) {
"tcp",
} {
t.Run(transport, func(t *testing.T) {
connClosed := make(chan struct{})
nconnClosed := make(chan struct{})
sessionClosed := make(chan struct{})
s := &Server{
Handler: &testServerHandler{
onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) {
close(connClosed)
close(nconnClosed)
},
onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) {
close(sessionClosed)
@@ -1145,10 +1142,10 @@ func TestServerPublishTimeout(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
track := &TrackH264{
PayloadType: 96,
@@ -1159,7 +1156,7 @@ func TestServerPublishTimeout(t *testing.T) {
tracks := Tracks{track}
tracks.setControls()
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1190,7 +1187,7 @@ func TestServerPublishTimeout(t *testing.T) {
inTH.InterleavedIDs = &[2]int{0, 1}
}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -1209,7 +1206,7 @@ func TestServerPublishTimeout(t *testing.T) {
err = th.Unmarshal(res.Header["Transport"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Record,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1223,7 +1220,7 @@ func TestServerPublishTimeout(t *testing.T) {
<-sessionClosed
if transport == "tcp" {
<-connClosed
<-nconnClosed
}
})
}
@@ -1235,13 +1232,13 @@ func TestServerPublishWithoutTeardown(t *testing.T) {
"tcp",
} {
t.Run(transport, func(t *testing.T) {
connClosed := make(chan struct{})
nconnClosed := make(chan struct{})
sessionClosed := make(chan struct{})
s := &Server{
Handler: &testServerHandler{
onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) {
close(connClosed)
close(nconnClosed)
},
onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) {
close(sessionClosed)
@@ -1275,9 +1272,9 @@ func TestServerPublishWithoutTeardown(t *testing.T) {
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
br := bufio.NewReader(conn)
conn := conn.NewConn(nconn)
track := &TrackH264{
PayloadType: 96,
@@ -1288,7 +1285,7 @@ func TestServerPublishWithoutTeardown(t *testing.T) {
tracks := Tracks{track}
tracks.setControls()
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1319,7 +1316,7 @@ func TestServerPublishWithoutTeardown(t *testing.T) {
inTH.InterleavedIDs = &[2]int{0, 1}
}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -1338,7 +1335,7 @@ func TestServerPublishWithoutTeardown(t *testing.T) {
err = th.Unmarshal(res.Header["Transport"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Record,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1349,10 +1346,10 @@ func TestServerPublishWithoutTeardown(t *testing.T) {
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
conn.Close()
nconn.Close()
<-sessionClosed
<-connClosed
<-nconnClosed
})
}
}
@@ -1395,10 +1392,10 @@ func TestServerPublishUDPChangeConn(t *testing.T) {
sxID := ""
func() {
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
track := &TrackH264{
PayloadType: 96,
@@ -1409,7 +1406,7 @@ func TestServerPublishUDPChangeConn(t *testing.T) {
tracks := Tracks{track}
tracks.setControls()
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1434,7 +1431,7 @@ func TestServerPublishUDPChangeConn(t *testing.T) {
ClientPorts: &[2]int{35466, 35467},
}
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
@@ -1449,7 +1446,7 @@ func TestServerPublishUDPChangeConn(t *testing.T) {
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
res, err = writeReqReadRes(conn, base.Request{
Method: base.Record,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
@@ -1464,12 +1461,12 @@ func TestServerPublishUDPChangeConn(t *testing.T) {
}()
func() {
conn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
defer nconn.Close()
conn := conn.NewConn(nconn)
res, err := writeReqReadRes(conn, br, base.Request{
res, err := writeReqReadRes(conn, base.Request{
Method: base.GetParameter,
URL: mustParseURL("rtsp://localhost:8554/teststream/"),
Header: base.Header{