diff --git a/active_tcp.go b/active_tcp.go index 2fd466b..66c330f 100644 --- a/active_tcp.go +++ b/active_tcp.go @@ -18,6 +18,7 @@ import ( type activeTCPConn struct { readBuffer, writeBuffer *packetio.Buffer localAddr, remoteAddr atomic.Value + conn atomic.Value // stores net.Conn closed atomic.Bool } @@ -55,6 +56,7 @@ func newActiveTCPConn( return } + a.conn.Store(conn) a.remoteAddr.Store(conn.RemoteAddr()) go func() { @@ -125,6 +127,9 @@ func (a *activeTCPConn) Close() error { a.closed.Store(true) _ = a.readBuffer.Close() _ = a.writeBuffer.Close() + if c, ok := a.conn.Load().(net.Conn); ok { + _ = c.Close() + } return nil } @@ -150,9 +155,38 @@ func (a *activeTCPConn) RemoteAddr() net.Addr { return &net.TCPAddr{} } -func (a *activeTCPConn) SetDeadline(time.Time) error { return io.EOF } -func (a *activeTCPConn) SetReadDeadline(time.Time) error { return io.EOF } -func (a *activeTCPConn) SetWriteDeadline(time.Time) error { return io.EOF } +func (a *activeTCPConn) SetDeadline(t time.Time) error { + if a.closed.Load() { + return io.EOF + } + if c, ok := a.conn.Load().(net.Conn); ok { + return c.SetDeadline(t) + } + + return io.EOF +} + +func (a *activeTCPConn) SetReadDeadline(t time.Time) error { + if a.closed.Load() { + return io.EOF + } + if c, ok := a.conn.Load().(net.Conn); ok { + return c.SetReadDeadline(t) + } + + return io.EOF +} + +func (a *activeTCPConn) SetWriteDeadline(t time.Time) error { + if a.closed.Load() { + return io.EOF + } + if c, ok := a.conn.Load().(net.Conn); ok { + return c.SetWriteDeadline(t) + } + + return io.EOF +} func getTCPAddrOnInterface(address string) (*net.TCPAddr, error) { addr, err := net.ResolveTCPAddr("tcp", address) diff --git a/active_tcp_test.go b/active_tcp_test.go index 63665d5..cb74f10 100644 --- a/active_tcp_test.go +++ b/active_tcp_test.go @@ -441,3 +441,56 @@ func TestActiveTCPConn_SetDeadlines_ReturnEOF(t *testing.T) { err = a.SetWriteDeadline(time.Now()) require.ErrorIs(t, err, io.EOF) } + +func TestActiveTCPConn_SetDeadlines_WhenConnected(t *testing.T) { + defer test.CheckRoutines(t)() + + ln, err := net.Listen("tcp", "127.0.0.1:0") // nolint: noctx + if err != nil { + t.Skipf("tcp listen not permitted in this environment: %v", err) + } + defer func() { _ = ln.Close() }() + + remote := netip.MustParseAddrPort(ln.Addr().String()) + logger := logging.NewDefaultLoggerFactory().NewLogger("ice") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + active := newActiveTCPConn(ctx, "127.0.0.1:0", remote, logger) + require.NotNil(t, active) + + acceptCh := make(chan net.Conn, 1) + go func() { + conn, acceptErr := ln.Accept() + if acceptErr == nil { + acceptCh <- conn + } + }() + + require.Eventually(t, func() bool { + return active.conn.Load() != nil || active.closed.Load() + }, 2*time.Second, 10*time.Millisecond) + + connVal := active.conn.Load() + if connVal == nil { + t.Skip("tcp dial not permitted in this environment") + } + clientConn, ok := connVal.(net.Conn) + require.True(t, ok) + + readDeadline := time.Now().Add(50 * time.Millisecond) + writeDeadline := readDeadline.Add(50 * time.Millisecond) + allDeadline := writeDeadline.Add(50 * time.Millisecond) + + require.NoError(t, active.SetReadDeadline(readDeadline)) + require.NoError(t, active.SetWriteDeadline(writeDeadline)) + require.NoError(t, active.SetDeadline(allDeadline)) + + _ = active.Close() + _ = clientConn.Close() + select { + case srvConn := <-acceptCh: + _ = srvConn.Close() + default: + } +} diff --git a/candidate_base.go b/candidate_base.go index 40aeab5..a0a493b 100644 --- a/candidate_base.go +++ b/candidate_base.go @@ -87,6 +87,16 @@ func (c *candidateBase) Value(any) any { return nil } +// setWriteDeadline is used by upper layers to push write deadlines down to the +// underlying packet connection. +func (c *candidateBase) setWriteDeadline(t time.Time) error { + if c.conn == nil { + return nil + } + + return c.conn.SetWriteDeadline(t) +} + // ID returns Candidate ID. func (c *candidateBase) ID() string { return c.id diff --git a/tcp_packet_conn.go b/tcp_packet_conn.go index 7540783..29e8033 100644 --- a/tcp_packet_conn.go +++ b/tcp_packet_conn.go @@ -332,16 +332,46 @@ func (t *tcpPacketConn) LocalAddr() net.Addr { return t.params.LocalAddr } -func (t *tcpPacketConn) SetDeadline(time.Time) error { - return nil +func (t *tcpPacketConn) SetDeadline(d time.Time) error { + t.mu.Lock() + defer t.mu.Unlock() + + var err error + for _, conn := range t.conns { + if setErr := conn.SetDeadline(d); err == nil && setErr != nil { + err = setErr + } + } + + return err } -func (t *tcpPacketConn) SetReadDeadline(time.Time) error { - return nil +func (t *tcpPacketConn) SetReadDeadline(d time.Time) error { + t.mu.Lock() + defer t.mu.Unlock() + + var err error + for _, conn := range t.conns { + if setErr := conn.SetReadDeadline(d); err == nil && setErr != nil { + err = setErr + } + } + + return err } -func (t *tcpPacketConn) SetWriteDeadline(time.Time) error { - return nil +func (t *tcpPacketConn) SetWriteDeadline(d time.Time) error { + t.mu.Lock() + defer t.mu.Unlock() + + var err error + for _, conn := range t.conns { + if setErr := conn.SetWriteDeadline(d); err == nil && setErr != nil { + err = setErr + } + } + + return err } func (t *tcpPacketConn) CloseChannel() <-chan struct{} { diff --git a/tcp_packet_conn_test.go b/tcp_packet_conn_test.go index 6c1761a..ec879e7 100644 --- a/tcp_packet_conn_test.go +++ b/tcp_packet_conn_test.go @@ -20,6 +20,37 @@ import ( "github.com/stretchr/testify/require" ) +type deadlineConn struct { + readDeadline time.Time + writeDeadline time.Time + deadline time.Time + lAddr net.Addr + rAddr net.Addr +} + +func (d *deadlineConn) Read([]byte) (int, error) { return 0, io.EOF } +func (d *deadlineConn) Write([]byte) (int, error) { return 0, io.EOF } +func (d *deadlineConn) Close() error { return nil } +func (d *deadlineConn) LocalAddr() net.Addr { return d.lAddr } +func (d *deadlineConn) RemoteAddr() net.Addr { return d.rAddr } +func (d *deadlineConn) SetDeadline(t time.Time) error { + d.deadline = t + + return nil +} + +func (d *deadlineConn) SetReadDeadline(t time.Time) error { + d.readDeadline = t + + return nil +} + +func (d *deadlineConn) SetWriteDeadline(t time.Time) error { + d.writeDeadline = t + + return nil +} + func TestBufferedConn_Write_ErrorAfterClose(t *testing.T) { defer test.CheckRoutines(t)() @@ -216,6 +247,7 @@ func TestTCPPacketConn_WriteTo_ErrorBranch_WithProvidedMock(t *testing.T) { func TestTCPPacketConn_SetDeadlines(t *testing.T) { logger := logging.NewDefaultLoggerFactory().NewLogger("ice") addr := &net.TCPAddr{IP: net.IP{127, 0, 0, 1}, Port: 12345} + remoteAddr := &net.TCPAddr{IP: net.IP{127, 0, 0, 1}, Port: 23456} tpc := newTCPPacketConn(tcpPacketParams{ ReadBuffer: 8, @@ -224,8 +256,22 @@ func TestTCPPacketConn_SetDeadlines(t *testing.T) { WriteBuffer: 0, AliveDuration: 0, }) - require.NoError(t, tpc.SetReadDeadline(time.Now().Add(200*time.Millisecond))) - require.NoError(t, tpc.SetWriteDeadline(time.Now().Add(200*time.Millisecond))) + observer := &deadlineConn{lAddr: addr, rAddr: remoteAddr} + tpc.mu.Lock() + tpc.conns[observer.RemoteAddr().String()] = observer + tpc.mu.Unlock() + + readDeadline := time.Now().Add(200 * time.Millisecond) + writeDeadline := readDeadline.Add(200 * time.Millisecond) + combinedDeadline := writeDeadline.Add(200 * time.Millisecond) + + require.NoError(t, tpc.SetReadDeadline(readDeadline)) + require.NoError(t, tpc.SetWriteDeadline(writeDeadline)) + require.NoError(t, tpc.SetDeadline(combinedDeadline)) + + require.Equal(t, readDeadline, observer.readDeadline) + require.Equal(t, writeDeadline, observer.writeDeadline) + require.Equal(t, combinedDeadline, observer.deadline) require.NoError(t, tpc.Close()) require.NoError(t, tpc.SetReadDeadline(time.Now().Add(200*time.Millisecond))) diff --git a/transport.go b/transport.go index 31e1286..5425a0f 100644 --- a/transport.go +++ b/transport.go @@ -139,17 +139,33 @@ func (c *Conn) RemoteAddr() net.Addr { return pair.Remote.addr() } -// SetDeadline is a stub. -func (c *Conn) SetDeadline(time.Time) error { - return nil +// SetDeadline sets both read and write deadlines on the underlying ICE connection. +func (c *Conn) SetDeadline(t time.Time) error { + if err := c.SetReadDeadline(t); err != nil { + return err + } + + return c.SetWriteDeadline(t) } -// SetReadDeadline is a stub. -func (c *Conn) SetReadDeadline(time.Time) error { - return nil +// SetReadDeadline sets the read deadline on the packet buffer used for application data. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.agent.buf.SetReadDeadline(t) } -// SetWriteDeadline is a stub. -func (c *Conn) SetWriteDeadline(time.Time) error { +// SetWriteDeadline sets the write deadline on the currently selected local candidate connection. +// The deadline applies to the selected candidate pair and will affect all traffic over that pair. +func (c *Conn) SetWriteDeadline(t time.Time) error { + pair := c.agent.getSelectedPair() + if pair == nil || pair.Local == nil { + return nil + } + + if d, ok := pair.Local.(interface { + setWriteDeadline(time.Time) error + }); ok { + return d.setWriteDeadline(t) + } + return nil } diff --git a/transport_test.go b/transport_test.go index b16f54b..e14eea1 100644 --- a/transport_test.go +++ b/transport_test.go @@ -14,11 +14,51 @@ import ( "testing" "time" + "github.com/pion/ice/v4/internal/taskloop" "github.com/pion/stun/v3" + "github.com/pion/transport/v3/packetio" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/require" ) +type deadlineCandidate struct { + candidateBase +} + +type deadlinePacketConn struct { + writeDeadline time.Time +} + +func (d *deadlinePacketConn) ReadFrom([]byte) (n int, addr net.Addr, err error) { + return 0, nil, nil +} + +func (d *deadlinePacketConn) WriteTo([]byte, net.Addr) (n int, err error) { + return 0, nil +} + +func (d *deadlinePacketConn) Close() error { + return nil +} + +func (d *deadlinePacketConn) LocalAddr() net.Addr { + return nil +} + +func (d *deadlinePacketConn) SetDeadline(time.Time) error { + return nil +} + +func (d *deadlinePacketConn) SetReadDeadline(time.Time) error { + return nil +} + +func (d *deadlinePacketConn) SetWriteDeadline(t time.Time) error { + d.writeDeadline = t + + return nil +} + func TestStressDuplex(t *testing.T) { // Check for leaking routines defer test.CheckRoutines(t)() @@ -102,6 +142,38 @@ func TestReadClosed(t *testing.T) { require.Error(t, err) } +func TestConnDeadlines(t *testing.T) { + defer test.CheckRoutines(t)() + + loop := taskloop.New(func() {}) + defer loop.Close() + + buf := packetio.NewBuffer() + pc := &deadlinePacketConn{} + candidate := &deadlineCandidate{} + candidate.conn = pc + + agent := &Agent{ + buf: buf, + loop: loop, + } + agent.selectedPair.Store(&CandidatePair{Local: candidate}) + + conn := &Conn{agent: agent} + + writeDeadline := time.Now().Add(100 * time.Millisecond) + require.NoError(t, conn.SetWriteDeadline(writeDeadline)) + require.WithinDuration(t, writeDeadline, pc.writeDeadline, time.Millisecond) + + readDeadline := time.Now().Add(-1 * time.Millisecond) + require.NoError(t, conn.SetDeadline(readDeadline)) + + _, err := conn.Read(make([]byte, 1)) + var netErr interface{ Timeout() bool } + require.ErrorAs(t, err, &netErr) + require.True(t, netErr.Timeout()) +} + func stressDuplex(t *testing.T) { t.Helper()