diff --git a/client.go b/client.go index 0879322d..5ccf445b 100644 --- a/client.go +++ b/client.go @@ -889,7 +889,8 @@ func (c *Client) startTransportRoutines() { c.tcpBuffer = make([]byte, c.MaxPacketSize+4) } - if c.state == clientStatePlay { + // always enable keepalives unless we are recording with TCP + if c.state == clientStatePlay || *c.effectiveTransport != TransportTCP { c.keepAliveTimer = time.NewTimer(c.keepAlivePeriod) } diff --git a/client_play_test.go b/client_play_test.go index 28864b68..97a1ff41 100644 --- a/client_play_test.go +++ b/client_play_test.go @@ -432,7 +432,8 @@ func TestClientPlay(t *testing.T) { require.NoError(t, err2) for i := 0; i < 2; i++ { - // server -> client (RTP) + // server -> client RTP packet + switch transport { case "udp": _, err2 = l1s[i].WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ @@ -456,7 +457,8 @@ func TestClientPlay(t *testing.T) { require.NoError(t, err2) } - // client -> server (RTCP) + // client -> server RTCP packet + switch transport { case "udp", "multicast": // skip firewall opening @@ -2874,7 +2876,6 @@ func TestClientPlayDifferentSource(t *testing.T) { }) require.NoError(t, err2) - // server -> client (RTP) _, err2 = l1.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ClientPorts[0], diff --git a/client_record_test.go b/client_record_test.go index e5408ac7..24e9023d 100644 --- a/client_record_test.go +++ b/client_record_test.go @@ -235,6 +235,10 @@ func TestClientRecord(t *testing.T) { StatusCode: base.StatusOK, Header: base.Header{ "Transport": th.Marshal(), + "Session": headers.Session{ + Session: "ABCDE", + Timeout: uintPtr(1), + }.Marshal(), }, }) require.NoError(t, err2) @@ -249,30 +253,49 @@ func TestClientRecord(t *testing.T) { }) require.NoError(t, err2) - // client -> server (RTP) + var pl []byte + + // client -> server RTP packet + if transport == "udp" { buf := make([]byte, 2048) var n int n, _, err2 = l1.ReadFrom(buf) require.NoError(t, err2) - - var pkt rtp.Packet - err2 = pkt.Unmarshal(buf[:n]) - require.NoError(t, err2) - require.Equal(t, testRTPPacket, pkt) + pl = buf[:n] } else { var f *base.InterleavedFrame f, err2 = conn.ReadInterleavedFrame() require.NoError(t, err2) require.Equal(t, 0, f.Channel) - - var pkt rtp.Packet - err2 = pkt.Unmarshal(f.Payload) - require.NoError(t, err2) - require.Equal(t, testRTPPacket, pkt) + pl = f.Payload } - // server -> client (RTCP) + var pkt rtp.Packet + err2 = pkt.Unmarshal(pl) + require.NoError(t, err2) + require.Equal(t, testRTPPacket, pkt) + + // client -> server keepalive (UDP only) + + if transport == "udp" { + recv := make(chan struct{}) + go func() { + defer close(recv) + req, err2 = conn.ReadRequest() + require.NoError(t, err2) + require.Equal(t, base.Options, req.Method) + }() + + select { + case <-recv: + case <-time.After(2 * time.Second): + t.Errorf("should not happen") + } + } + + // server -> client RTCP packet + if transport == "udp" { _, err2 = l2.WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), diff --git a/server_play_test.go b/server_play_test.go index e87ba4d0..7a3cf7bb 100644 --- a/server_play_test.go +++ b/server_play_test.go @@ -848,6 +848,7 @@ func TestServerPlay(t *testing.T) { doPlay(t, conn, "rtsp://"+listenIP+":8554/teststream", session) // server -> client (direct) + switch transport { case "udp": buf := make([]byte, 2048) @@ -874,6 +875,7 @@ func TestServerPlay(t *testing.T) { } // server -> client (through stream) + if transport == "udp" || transport == "multicast" { buf := make([]byte, 2048) var n int @@ -904,7 +906,8 @@ func TestServerPlay(t *testing.T) { } } - // client -> server (RTCP) + // client -> server RTCP packet + switch transport { case "udp": _, err = l2.WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{ diff --git a/server_record_test.go b/server_record_test.go index 98885171..2ab9b89f 100644 --- a/server_record_test.go +++ b/server_record_test.go @@ -728,6 +728,7 @@ func TestServerRecord(t *testing.T) { } // server -> client (direct) + if transport == "udp" { buf := make([]byte, 2048) var n int @@ -742,7 +743,8 @@ func TestServerRecord(t *testing.T) { require.Equal(t, testRTCPPacketMarshaled, f.Payload) } - // client -> server + // client -> server RTP+RTCP packets + if transport == "udp" { _, err = l1s[i].WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), @@ -771,7 +773,8 @@ func TestServerRecord(t *testing.T) { } for i := 0; i < 2; i++ { - // server -> client (RTCP) + // server -> client RTCP packet + if transport == "udp" { buf := make([]byte, 2048) n, _, err := l2s[i].ReadFrom(buf)