support servers which don't provide the server-port field (#21)

This commit is contained in:
aler9
2021-01-21 09:02:15 +01:00
parent b9dfe1b310
commit 30b43a9ead
2 changed files with 124 additions and 104 deletions

View File

@@ -149,114 +149,126 @@ func TestClientDialRead(t *testing.T) {
} }
} }
func TestClientDialReadZeroServerPorts(t *testing.T) { func TestClientDialReadNoServerPorts(t *testing.T) {
l, err := net.Listen("tcp", "localhost:8554") for _, ca := range []string{
require.NoError(t, err) "zero",
defer l.Close() "no",
} {
t.Run(ca, func(t *testing.T) {
l, err := net.Listen("tcp", "localhost:8554")
require.NoError(t, err)
defer l.Close()
serverDone := make(chan struct{}) serverDone := make(chan struct{})
defer func() { <-serverDone }() defer func() { <-serverDone }()
go func() { go func() {
defer close(serverDone) defer close(serverDone)
conn, err := l.Accept() conn, err := l.Accept()
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
var req base.Request var req base.Request
err = req.Read(bconn.Reader) err = req.Read(bconn.Reader)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Options, req.Method) require.Equal(t, base.Options, req.Method)
err = base.Response{ err = base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{ "Public": base.HeaderValue{strings.Join([]string{
string(base.Describe), string(base.Describe),
string(base.Setup), string(base.Setup),
string(base.Play), string(base.Play),
}, ", ")}, }, ", ")},
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
err = req.Read(bconn.Reader) err = req.Read(bconn.Reader)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Describe, req.Method) require.Equal(t, base.Describe, req.Method)
track, err := NewTrackH264(96, []byte("123456"), []byte("123456")) track, err := NewTrackH264(96, []byte("123456"), []byte("123456"))
require.NoError(t, err) require.NoError(t, err)
sdp := Tracks{track}.Write() sdp := Tracks{track}.Write()
err = base.Response{ err = base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Content-Type": base.HeaderValue{"application/sdp"}, "Content-Type": base.HeaderValue{"application/sdp"},
}, },
Body: sdp, Body: sdp,
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
err = req.Read(bconn.Reader) err = req.Read(bconn.Reader)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Setup, req.Method) require.Equal(t, base.Setup, req.Method)
th, err := headers.ReadTransport(req.Header["Transport"]) th, err := headers.ReadTransport(req.Header["Transport"])
require.NoError(t, err) require.NoError(t, err)
err = base.Response{ err = base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Transport": headers.Transport{ "Transport": headers.Transport{
Protocol: StreamProtocolUDP, Protocol: StreamProtocolUDP,
Delivery: func() *base.StreamDelivery { Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast v := base.StreamDeliveryUnicast
return &v return &v
}(), }(),
ClientPorts: th.ClientPorts, ClientPorts: th.ClientPorts,
ServerPorts: &[2]int{0, 0}, ServerPorts: func() *[2]int {
}.Write(), if ca == "zero" {
}, return &[2]int{0, 0}
}.Write(bconn.Writer) }
require.NoError(t, err) return nil
}(),
}.Write(),
},
}.Write(bconn.Writer)
require.NoError(t, err)
err = req.Read(bconn.Reader) err = req.Read(bconn.Reader)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Play, req.Method) require.Equal(t, base.Play, req.Method)
err = base.Response{ err = base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
l1, err := net.ListenPacket("udp", "localhost:0") l1, err := net.ListenPacket("udp", "localhost:0")
require.NoError(t, err) require.NoError(t, err)
defer l1.Close() defer l1.Close()
l1.WriteTo([]byte("\x00\x00\x00\x00"), &net.UDPAddr{ l1.WriteTo([]byte("\x00\x00\x00\x00"), &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: th.ClientPorts[0], Port: th.ClientPorts[0],
})
}()
conf := ClientConf{
AnyPortEnable: true,
}
conn, err := conf.DialRead("rtsp://localhost:8554/teststream")
require.NoError(t, err)
frameRecv := make(chan struct{})
done := conn.ReadFrames(func(id int, typ StreamType, payload []byte) {
close(frameRecv)
})
<-frameRecv
conn.Close()
<-done
}) })
}()
conf := ClientConf{
AnyPortEnable: true,
} }
conn, err := conf.DialRead("rtsp://localhost:8554/teststream")
require.NoError(t, err)
frameRecv := make(chan struct{})
done := conn.ReadFrames(func(id int, typ StreamType, payload []byte) {
close(frameRecv)
})
<-frameRecv
conn.Close()
<-done
} }
func TestClientDialReadAutomaticProtocol(t *testing.T) { func TestClientDialReadAutomaticProtocol(t *testing.T) {

View File

@@ -575,21 +575,25 @@ func (c *ClientConn) Setup(mode headers.TransportMode, track *Track,
} }
if proto == StreamProtocolUDP { if proto == StreamProtocolUDP {
if thRes.ServerPorts == nil { if thRes.ServerPorts != nil {
rtpListener.close() if (thRes.ServerPorts[0] == 0 && thRes.ServerPorts[1] != 0) ||
rtcpListener.close() (thRes.ServerPorts[0] != 0 && thRes.ServerPorts[1] == 0) {
return nil, fmt.Errorf("server ports have not been provided. Use AnyPortEnable to communicate with this server") rtpListener.close()
} rtcpListener.close()
return nil, fmt.Errorf("server ports must be both zero or both not zero")
if (thRes.ServerPorts[0] == 0 && thRes.ServerPorts[1] != 0) || }
(thRes.ServerPorts[0] != 0 && thRes.ServerPorts[1] == 0) {
rtpListener.close()
rtcpListener.close()
return nil, fmt.Errorf("server ports must be both zero or both not zero")
} }
if !c.conf.AnyPortEnable { if !c.conf.AnyPortEnable {
if thRes.ServerPorts == nil {
rtpListener.close()
rtcpListener.close()
return nil, fmt.Errorf("server ports have not been provided. Use AnyPortEnable to communicate with this server")
}
if thRes.ServerPorts[0] == 0 && thRes.ServerPorts[1] == 0 { if thRes.ServerPorts[0] == 0 && thRes.ServerPorts[1] == 0 {
rtpListener.close()
rtcpListener.close()
return nil, fmt.Errorf("server ports have not been provided. Use AnyPortEnable to communicate with this server") return nil, fmt.Errorf("server ports have not been provided. Use AnyPortEnable to communicate with this server")
} }
} }
@@ -623,14 +627,18 @@ func (c *ClientConn) Setup(mode headers.TransportMode, track *Track,
if proto == StreamProtocolUDP { if proto == StreamProtocolUDP {
rtpListener.remoteIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP rtpListener.remoteIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP
rtpListener.remoteZone = c.nconn.RemoteAddr().(*net.TCPAddr).Zone rtpListener.remoteZone = c.nconn.RemoteAddr().(*net.TCPAddr).Zone
rtpListener.remotePort = (*thRes.ServerPorts)[0] if thRes.ServerPorts != nil {
rtpListener.remotePort = (*thRes.ServerPorts)[0]
}
rtpListener.trackID = track.ID rtpListener.trackID = track.ID
rtpListener.streamType = StreamTypeRTP rtpListener.streamType = StreamTypeRTP
c.udpRTPListeners[track.ID] = rtpListener c.udpRTPListeners[track.ID] = rtpListener
rtcpListener.remoteIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP rtcpListener.remoteIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP
rtcpListener.remoteZone = c.nconn.RemoteAddr().(*net.TCPAddr).Zone rtcpListener.remoteZone = c.nconn.RemoteAddr().(*net.TCPAddr).Zone
rtcpListener.remotePort = (*thRes.ServerPorts)[1] if thRes.ServerPorts != nil {
rtcpListener.remotePort = (*thRes.ServerPorts)[1]
}
rtcpListener.trackID = track.ID rtcpListener.trackID = track.ID
rtcpListener.streamType = StreamTypeRTCP rtcpListener.streamType = StreamTypeRTCP
c.udpRTCPListeners[track.ID] = rtcpListener c.udpRTCPListeners[track.ID] = rtcpListener