mirror of
https://github.com/libp2p/go-libp2p.git
synced 2025-10-27 02:01:39 +08:00
websocket: Don't limit message sizes in the websocket reader (#2193)
* Don't limit message sizes in the websocket reader * Remove fmt.println * Update p2p/transport/websocket/listener.go Co-authored-by: Marten Seemann <martenseemann@gmail.com> * Update p2p/transport/websocket/listener.go Co-authored-by: Marten Seemann <martenseemann@gmail.com> --------- Co-authored-by: Marten Seemann <martenseemann@gmail.com>
This commit is contained in:
66
p2p/test/websocket/websocket_test.go
Normal file
66
p2p/test/websocket/websocket_test.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package websocket_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/libp2p/go-libp2p"
|
||||||
|
"github.com/libp2p/go-libp2p/core/network"
|
||||||
|
"github.com/libp2p/go-libp2p/core/peer"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestReadLimit(t *testing.T) {
|
||||||
|
h1, err := libp2p.New(libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0/ws"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer h1.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
h2, err := libp2p.New(libp2p.NoListenAddrs)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer h2.Close()
|
||||||
|
|
||||||
|
err = h2.Connect(ctx, peer.AddrInfo{ID: h1.ID(), Addrs: h1.Addrs()})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
buf := make([]byte, 256<<10)
|
||||||
|
buf2 := make([]byte, 256<<10)
|
||||||
|
copyBuf := make([]byte, 8<<10)
|
||||||
|
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
// TODO perf would be perfect here, but not yet merged.
|
||||||
|
h1.SetStreamHandler("/big-blocks", func(s network.Stream) {
|
||||||
|
defer s.Close()
|
||||||
|
_, err := io.CopyBuffer(io.Discard, s, copyBuf)
|
||||||
|
if err != nil {
|
||||||
|
errCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err = s.Write(buf)
|
||||||
|
if err != nil {
|
||||||
|
errCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
errCh <- nil
|
||||||
|
})
|
||||||
|
|
||||||
|
allocs := testing.AllocsPerRun(100, func() {
|
||||||
|
s, err := h2.NewStream(ctx, h1.ID(), "/big-blocks")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer s.Close()
|
||||||
|
_, err = s.Write(buf2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, s.CloseWrite())
|
||||||
|
|
||||||
|
_, err = io.ReadFull(s, buf2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = s.Read([]byte{0})
|
||||||
|
require.ErrorIs(t, err, io.EOF)
|
||||||
|
require.NoError(t, <-errCh)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Make sure we aren't doing some crazy allocs when transferring big blocks
|
||||||
|
require.Less(t, allocs, 8*1024.0)
|
||||||
|
}
|
||||||
@@ -31,7 +31,6 @@ func (c conn) Read(b []byte) (int, error) {
|
|||||||
if err == nil && n == 0 && c.readAttempts < maxReadAttempts {
|
if err == nil && n == 0 && c.readAttempts < maxReadAttempts {
|
||||||
c.readAttempts++
|
c.readAttempts++
|
||||||
// Nothing happened, let's read again. We reached the end of the frame
|
// Nothing happened, let's read again. We reached the end of the frame
|
||||||
// we have
|
|
||||||
// (https://github.com/nhooyr/websocket/blob/master/netconn.go#L118).
|
// (https://github.com/nhooyr/websocket/blob/master/netconn.go#L118).
|
||||||
// The next read will block until we get
|
// The next read will block until we get
|
||||||
// the next frame. We limit here to avoid looping in case of a bunch of
|
// the next frame. We limit here to avoid looping in case of a bunch of
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -109,6 +110,10 @@ func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set an arbitrarily large read limit since we don't actually want to limit the message size here.
|
||||||
|
// See https://github.com/nhooyr/websocket/issues/382 for details.
|
||||||
|
c.SetReadLimit(math.MaxInt64 - 1) // -1 because the library adds a byte for the fin frame
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case l.incoming <- conn{
|
case l.incoming <- conn{
|
||||||
Conn: ws.NetConn(context.Background(), c, ws.MessageBinary),
|
Conn: ws.NetConn(context.Background(), c, ws.MessageBinary),
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -250,6 +251,8 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma
|
|||||||
return nil, fmt.Errorf("failed to get local address")
|
return nil, fmt.Errorf("failed to get local address")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set an arbitrarily large read limit since we don't actually want to limit the message size here.
|
||||||
|
wscon.SetReadLimit(math.MaxInt64 - 1) // -1 because the library adds a byte for the fin frame
|
||||||
mnc, err := manet.WrapNetConn(
|
mnc, err := manet.WrapNetConn(
|
||||||
conn{
|
conn{
|
||||||
Conn: ws.NetConn(context.Background(), wscon, ws.MessageBinary),
|
Conn: ws.NetConn(context.Background(), wscon, ws.MessageBinary),
|
||||||
|
|||||||
Reference in New Issue
Block a user