diff --git a/tcp_mux_multi_test.go b/tcp_mux_multi_test.go index a73252c..7fb0a09 100644 --- a/tcp_mux_multi_test.go +++ b/tcp_mux_multi_test.go @@ -7,6 +7,7 @@ package ice import ( + "errors" "io" "net" "testing" @@ -126,3 +127,142 @@ func TestMultiTCPMux_NoDeadlockWhenClosingUnusedPacketConn(t *testing.T) { require.Nil(t, conn, "should receive nil because mux is closed") require.Equal(t, io.ErrClosedPipe, err, "should receive error because mux is closed") } + +func TestMultiTCPMux_GetConnByUfrag_NoMuxes(t *testing.T) { + multi := NewMultiTCPMuxDefault() // no muxes + + pc, err := multi.GetConnByUfrag("ufrag", false, net.IP{127, 0, 0, 1}) + require.Nil(t, pc) + require.ErrorIs(t, err, errNoTCPMuxAvailable) +} + +func TestMultiTCPMux_GetConnByUfrag_FromAnyMux(t *testing.T) { + logger := logging.NewDefaultLoggerFactory().NewLogger("ice") + + l1, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IP{127, 0, 0, 1}, Port: 0}) + require.NoError(t, err) + defer func() { + _ = l1.Close() + }() + + mux1 := NewTCPMuxDefault(TCPMuxParams{ + Listener: l1, + Logger: logger, + ReadBufferSize: 8, + }) + + l2, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IP{127, 0, 0, 1}, Port: 0}) + require.NoError(t, err) + defer func() { + _ = l2.Close() + }() + + mux2 := NewTCPMuxDefault(TCPMuxParams{ + Listener: l2, + Logger: logger, + ReadBufferSize: 8, + }) + + multi := NewMultiTCPMuxDefault(mux1, mux2) + defer func() { _ = multi.Close() }() + + pc, err := multi.GetConnByUfrag("myufrag", false, net.IP{127, 0, 0, 1}) + require.NoError(t, err) + require.NotNil(t, pc) + + pcAddr, ok := pc.LocalAddr().(*net.TCPAddr) + require.True(t, ok, "packet conn addr should be *net.TCPAddr") + + m1Addr, ok := mux1.LocalAddr().(*net.TCPAddr) + require.True(t, ok, "mux1 local addr should be *net.TCPAddr") + + m2Addr, ok := mux2.LocalAddr().(*net.TCPAddr) + require.True(t, ok, "mux2 local addr should be *net.TCPAddr") + + isFromMux1 := pcAddr.Port == m1Addr.Port && pcAddr.IP.Equal(m1Addr.IP) + isFromMux2 := pcAddr.Port == m2Addr.Port && pcAddr.IP.Equal(m2Addr.IP) + require.True(t, isFromMux1 || isFromMux2, "conn must come from one of the underlying muxes") +} + +func TestMultiTCPMux_GetAllConns_NoMuxes(t *testing.T) { + multi := NewMultiTCPMuxDefault() // no underlying TCPMux instances + + conns, err := multi.GetAllConns("ufrag", false, net.IP{127, 0, 0, 1}) + + require.Nil(t, conns) + require.ErrorIs(t, err, errNoTCPMuxAvailable) +} + +var ( + errTCPMuxCloseBoom = errors.New("tcp mux close boom") + errTCPMuxCloseFirst = errors.New("first tcp mux close failed") + errTCPMuxCloseSecond = errors.New("second tcp mux close failed") +) + +type closeErrTCPMux struct { + TCPMux + ret error +} + +func (w *closeErrTCPMux) Close() error { + _ = w.TCPMux.Close() + + return w.ret +} + +func TestMultiTCPMux_Close_PropagatesError_FromWrappedMux(t *testing.T) { + logger := logging.NewDefaultLoggerFactory().NewLogger("ice") + + // first mux: normal close (nil) + l1, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IP{127, 0, 0, 1}, Port: 0}) + require.NoError(t, err) + mux1 := NewTCPMuxDefault(TCPMuxParams{ + Listener: l1, + Logger: logger, + ReadBufferSize: 8, + }) + + // second mux: Close() returns injected error + l2, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IP{127, 0, 0, 1}, Port: 0}) + require.NoError(t, err) + mux2Real := NewTCPMuxDefault(TCPMuxParams{ + Listener: l2, + Logger: logger, + ReadBufferSize: 8, + }) + mux2 := &closeErrTCPMux{TCPMux: mux2Real, ret: errTCPMuxCloseBoom} + + multi := NewMultiTCPMuxDefault(mux1, mux2) + + got := multi.Close() + require.ErrorIs(t, got, errTCPMuxCloseBoom) +} + +func TestMultiTCPMux_Close_LastErrorWins_FromWrappedMuxes(t *testing.T) { + logger := logging.NewDefaultLoggerFactory().NewLogger("ice") + + // first mux: error1 + la, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IP{127, 0, 0, 1}, Port: 0}) + require.NoError(t, err) + mux1Real := NewTCPMuxDefault(TCPMuxParams{ + Listener: la, + Logger: logger, + ReadBufferSize: 8, + }) + mux1 := &closeErrTCPMux{TCPMux: mux1Real, ret: errTCPMuxCloseFirst} + + // second mux: error2 (last error should be returned) + lb, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IP{127, 0, 0, 1}, Port: 0}) + require.NoError(t, err) + mux2Real := NewTCPMuxDefault(TCPMuxParams{ + Listener: lb, + Logger: logger, + ReadBufferSize: 8, + }) + mux2 := &closeErrTCPMux{TCPMux: mux2Real, ret: errTCPMuxCloseSecond} + + multi := NewMultiTCPMuxDefault(mux1, mux2) + + got := multi.Close() + require.ErrorIs(t, got, errTCPMuxCloseSecond) +}