mirror of
https://github.com/libp2p/go-libp2p.git
synced 2025-09-26 20:21:26 +08:00
Implement error codes spec (#2927)
This commit is contained in:
@@ -2,15 +2,60 @@ package network
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
ic "github.com/libp2p/go-libp2p/core/crypto"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/core/protocol"
|
||||
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
)
|
||||
|
||||
type ConnErrorCode uint32
|
||||
|
||||
type ConnError struct {
|
||||
Remote bool
|
||||
ErrorCode ConnErrorCode
|
||||
TransportError error
|
||||
}
|
||||
|
||||
func (c *ConnError) Error() string {
|
||||
side := "local"
|
||||
if c.Remote {
|
||||
side = "remote"
|
||||
}
|
||||
if c.TransportError != nil {
|
||||
return fmt.Sprintf("connection closed (%s): code: 0x%x: transport error: %s", side, c.ErrorCode, c.TransportError)
|
||||
}
|
||||
return fmt.Sprintf("connection closed (%s): code: 0x%x", side, c.ErrorCode)
|
||||
}
|
||||
|
||||
func (c *ConnError) Is(target error) bool {
|
||||
if tce, ok := target.(*ConnError); ok {
|
||||
return tce.ErrorCode == c.ErrorCode && tce.Remote == c.Remote
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *ConnError) Unwrap() []error {
|
||||
return []error{ErrReset, c.TransportError}
|
||||
}
|
||||
|
||||
const (
|
||||
ConnNoError ConnErrorCode = 0
|
||||
ConnProtocolNegotiationFailed ConnErrorCode = 0x1000
|
||||
ConnResourceLimitExceeded ConnErrorCode = 0x1001
|
||||
ConnRateLimited ConnErrorCode = 0x1002
|
||||
ConnProtocolViolation ConnErrorCode = 0x1003
|
||||
ConnSupplanted ConnErrorCode = 0x1004
|
||||
ConnGarbageCollected ConnErrorCode = 0x1005
|
||||
ConnShutdown ConnErrorCode = 0x1006
|
||||
ConnGated ConnErrorCode = 0x1007
|
||||
ConnCodeOutOfRange ConnErrorCode = 0x1008
|
||||
)
|
||||
|
||||
// Conn is a connection to a remote peer. It multiplexes streams.
|
||||
// Usually there is no need to use a Conn directly, but it may
|
||||
// be useful to get information about the peer on the other side:
|
||||
@@ -24,6 +69,11 @@ type Conn interface {
|
||||
ConnStat
|
||||
ConnScoper
|
||||
|
||||
// CloseWithError closes the connection with errCode. The errCode is sent to the
|
||||
// peer on a best effort basis. For transports that do not support sending error
|
||||
// codes on connection close, the behavior is identical to calling Close.
|
||||
CloseWithError(errCode ConnErrorCode) error
|
||||
|
||||
// ID returns an identifier that uniquely identifies this Conn within this
|
||||
// host, during this run. Connection IDs may repeat across restarts.
|
||||
ID() string
|
||||
|
@@ -3,6 +3,7 @@ package network
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
@@ -11,6 +12,49 @@ import (
|
||||
// ErrReset is returned when reading or writing on a reset stream.
|
||||
var ErrReset = errors.New("stream reset")
|
||||
|
||||
type StreamErrorCode uint32
|
||||
|
||||
type StreamError struct {
|
||||
ErrorCode StreamErrorCode
|
||||
Remote bool
|
||||
TransportError error
|
||||
}
|
||||
|
||||
func (s *StreamError) Error() string {
|
||||
side := "local"
|
||||
if s.Remote {
|
||||
side = "remote"
|
||||
}
|
||||
if s.TransportError != nil {
|
||||
return fmt.Sprintf("stream reset (%s): code: 0x%x: transport error: %s", side, s.ErrorCode, s.TransportError)
|
||||
}
|
||||
return fmt.Sprintf("stream reset (%s): code: 0x%x", side, s.ErrorCode)
|
||||
}
|
||||
|
||||
func (s *StreamError) Is(target error) bool {
|
||||
if tse, ok := target.(*StreamError); ok {
|
||||
return tse.ErrorCode == s.ErrorCode && tse.Remote == s.Remote
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *StreamError) Unwrap() []error {
|
||||
return []error{ErrReset, s.TransportError}
|
||||
}
|
||||
|
||||
const (
|
||||
StreamNoError StreamErrorCode = 0
|
||||
StreamProtocolNegotiationFailed StreamErrorCode = 0x1001
|
||||
StreamResourceLimitExceeded StreamErrorCode = 0x1002
|
||||
StreamRateLimited StreamErrorCode = 0x1003
|
||||
StreamProtocolViolation StreamErrorCode = 0x1004
|
||||
StreamSupplanted StreamErrorCode = 0x1005
|
||||
StreamGarbageCollected StreamErrorCode = 0x1006
|
||||
StreamShutdown StreamErrorCode = 0x1007
|
||||
StreamGated StreamErrorCode = 0x1008
|
||||
StreamCodeOutOfRange StreamErrorCode = 0x1009
|
||||
)
|
||||
|
||||
// MuxedStream is a bidirectional io pipe within a connection.
|
||||
type MuxedStream interface {
|
||||
io.Reader
|
||||
@@ -56,6 +100,11 @@ type MuxedStream interface {
|
||||
// side to hang up and go away.
|
||||
Reset() error
|
||||
|
||||
// ResetWithError aborts both ends of the stream with `errCode`. `errCode` is sent
|
||||
// to the peer on a best effort basis. For transports that do not support sending
|
||||
// error codes to remote peer, the behavior is identical to calling Reset
|
||||
ResetWithError(errCode StreamErrorCode) error
|
||||
|
||||
SetDeadline(time.Time) error
|
||||
SetReadDeadline(time.Time) error
|
||||
SetWriteDeadline(time.Time) error
|
||||
@@ -75,6 +124,10 @@ type MuxedConn interface {
|
||||
// Close closes the stream muxer and the the underlying net.Conn.
|
||||
io.Closer
|
||||
|
||||
// CloseWithError closes the connection with errCode. The errCode is sent
|
||||
// to the peer.
|
||||
CloseWithError(errCode ConnErrorCode) error
|
||||
|
||||
// IsClosed returns whether a connection is fully closed, so it can
|
||||
// be garbage collected.
|
||||
IsClosed() bool
|
||||
|
@@ -27,4 +27,8 @@ type Stream interface {
|
||||
|
||||
// Scope returns the user's view of this stream's resource scope
|
||||
Scope() StreamScope
|
||||
|
||||
// ResetWithError closes both ends of the stream with errCode. The errCode is sent
|
||||
// to the peer.
|
||||
ResetWithError(errCode StreamErrorCode) error
|
||||
}
|
||||
|
2
go.mod
2
go.mod
@@ -30,7 +30,7 @@ require (
|
||||
github.com/libp2p/go-nat v0.2.0
|
||||
github.com/libp2p/go-netroute v0.2.2
|
||||
github.com/libp2p/go-reuseport v0.4.0
|
||||
github.com/libp2p/go-yamux/v4 v4.0.2
|
||||
github.com/libp2p/go-yamux/v5 v5.0.0
|
||||
github.com/libp2p/zeroconf/v2 v2.2.0
|
||||
github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd
|
||||
github.com/mikioh/tcpinfo v0.0.0-20190314235526-30a79bb1804b
|
||||
|
4
go.sum
4
go.sum
@@ -193,8 +193,8 @@ github.com/libp2p/go-netroute v0.2.2 h1:Dejd8cQ47Qx2kRABg6lPwknU7+nBnFRpko45/fFP
|
||||
github.com/libp2p/go-netroute v0.2.2/go.mod h1:Rntq6jUAH0l9Gg17w5bFGhcC9a+vk4KNXs6s7IljKYE=
|
||||
github.com/libp2p/go-reuseport v0.4.0 h1:nR5KU7hD0WxXCJbmw7r2rhRYruNRl2koHw8fQscQm2s=
|
||||
github.com/libp2p/go-reuseport v0.4.0/go.mod h1:ZtI03j/wO5hZVDFo2jKywN6bYKWLOy8Se6DrI2E1cLU=
|
||||
github.com/libp2p/go-yamux/v4 v4.0.2 h1:nrLh89LN/LEiqcFiqdKDRHjGstN300C1269K/EX0CPU=
|
||||
github.com/libp2p/go-yamux/v4 v4.0.2/go.mod h1:C808cCRgOs1iBwY4S71T5oxgMxgLmqUw56qh4AeBW2o=
|
||||
github.com/libp2p/go-yamux/v5 v5.0.0 h1:2djUh96d3Jiac/JpGkKs4TO49YhsfLopAoryfPmf+Po=
|
||||
github.com/libp2p/go-yamux/v5 v5.0.0/go.mod h1:en+3cdX51U0ZslwRdRLrvQsdayFt3TSUKvBGErzpWbU=
|
||||
github.com/libp2p/zeroconf/v2 v2.2.0 h1:Cup06Jv6u81HLhIj1KasuNM/RHHrJ8T7wOTS4+Tv53Q=
|
||||
github.com/libp2p/zeroconf/v2 v2.2.0/go.mod h1:fuJqLnUwZTshS3U/bMRJ3+ow/v9oid1n0DmyYyNO1Xs=
|
||||
github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI=
|
||||
|
@@ -464,7 +464,7 @@ func (h *BasicHost) newStreamHandler(s network.Stream) {
|
||||
} else {
|
||||
log.Debugf("protocol mux failed: %s (took %s, id:%s, remote peer:%s, remote addr:%v)", err, took, s.ID(), s.Conn().RemotePeer(), s.Conn().RemoteMultiaddr())
|
||||
}
|
||||
s.Reset()
|
||||
s.ResetWithError(network.StreamProtocolNegotiationFailed)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -478,7 +478,7 @@ func (h *BasicHost) newStreamHandler(s network.Stream) {
|
||||
|
||||
if err := s.SetProtocol(protoID); err != nil {
|
||||
log.Debugf("error setting stream protocol: %s", err)
|
||||
s.Reset()
|
||||
s.ResetWithError(network.StreamResourceLimitExceeded)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -717,7 +717,7 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I
|
||||
}
|
||||
defer func() {
|
||||
if strErr != nil && s != nil {
|
||||
s.Reset()
|
||||
s.ResetWithError(network.StreamProtocolNegotiationFailed)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -761,13 +761,14 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I
|
||||
return nil, fmt.Errorf("failed to negotiate protocol: %w", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
s.Reset()
|
||||
s.ResetWithError(network.StreamProtocolNegotiationFailed)
|
||||
// wait for `SelectOneOf` to error out because of resetting the stream.
|
||||
<-errCh
|
||||
return nil, fmt.Errorf("failed to negotiate protocol: %w", ctx.Err())
|
||||
}
|
||||
|
||||
if err := s.SetProtocol(selected); err != nil {
|
||||
s.ResetWithError(network.StreamResourceLimitExceeded)
|
||||
return nil, err
|
||||
}
|
||||
_ = h.Peerstore().AddProtocols(p, selected) // adding the protocol to the peerstore isn't critical
|
||||
|
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
crand "crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
mrand "math/rand"
|
||||
@@ -462,7 +463,7 @@ func SubtestStreamReset(t *testing.T, tr network.Multiplexer) {
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
|
||||
_, err = s.Write([]byte("foo"))
|
||||
if err != network.ErrReset {
|
||||
if !errors.Is(err, network.ErrReset) {
|
||||
t.Error("should have been stream reset")
|
||||
}
|
||||
s.Close()
|
||||
|
@@ -5,7 +5,7 @@ import (
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
|
||||
"github.com/libp2p/go-yamux/v4"
|
||||
"github.com/libp2p/go-yamux/v5"
|
||||
)
|
||||
|
||||
// conn implements mux.MuxedConn over yamux.Session.
|
||||
@@ -23,6 +23,10 @@ func (c *conn) Close() error {
|
||||
return c.yamux().Close()
|
||||
}
|
||||
|
||||
func (c *conn) CloseWithError(errCode network.ConnErrorCode) error {
|
||||
return c.yamux().CloseWithError(uint32(errCode))
|
||||
}
|
||||
|
||||
// IsClosed checks if yamux.Session is in closed state.
|
||||
func (c *conn) IsClosed() bool {
|
||||
return c.yamux().IsClosed()
|
||||
@@ -32,7 +36,7 @@ func (c *conn) IsClosed() bool {
|
||||
func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
|
||||
s, err := c.yamux().OpenStream(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, parseError(err)
|
||||
}
|
||||
|
||||
return (*stream)(s), nil
|
||||
@@ -41,7 +45,7 @@ func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
|
||||
// AcceptStream accepts a stream opened by the other side.
|
||||
func (c *conn) AcceptStream() (network.MuxedStream, error) {
|
||||
s, err := c.yamux().AcceptStream()
|
||||
return (*stream)(s), err
|
||||
return (*stream)(s), parseError(err)
|
||||
}
|
||||
|
||||
func (c *conn) yamux() *yamux.Session {
|
||||
|
@@ -1,11 +1,13 @@
|
||||
package yamux
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
|
||||
"github.com/libp2p/go-yamux/v4"
|
||||
"github.com/libp2p/go-yamux/v5"
|
||||
)
|
||||
|
||||
// stream implements mux.MuxedStream over yamux.Stream.
|
||||
@@ -13,22 +15,32 @@ type stream yamux.Stream
|
||||
|
||||
var _ network.MuxedStream = &stream{}
|
||||
|
||||
func parseError(err error) error {
|
||||
if err == nil {
|
||||
return err
|
||||
}
|
||||
se := &yamux.StreamError{}
|
||||
if errors.As(err, &se) {
|
||||
return &network.StreamError{Remote: se.Remote, ErrorCode: network.StreamErrorCode(se.ErrorCode), TransportError: err}
|
||||
}
|
||||
ce := &yamux.GoAwayError{}
|
||||
if errors.As(err, &ce) {
|
||||
return &network.ConnError{Remote: ce.Remote, ErrorCode: network.ConnErrorCode(ce.ErrorCode), TransportError: err}
|
||||
}
|
||||
if errors.Is(err, yamux.ErrStreamReset) {
|
||||
return fmt.Errorf("%w: %w", network.ErrReset, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *stream) Read(b []byte) (n int, err error) {
|
||||
n, err = s.yamux().Read(b)
|
||||
if err == yamux.ErrStreamReset {
|
||||
err = network.ErrReset
|
||||
}
|
||||
|
||||
return n, err
|
||||
return n, parseError(err)
|
||||
}
|
||||
|
||||
func (s *stream) Write(b []byte) (n int, err error) {
|
||||
n, err = s.yamux().Write(b)
|
||||
if err == yamux.ErrStreamReset {
|
||||
err = network.ErrReset
|
||||
}
|
||||
|
||||
return n, err
|
||||
return n, parseError(err)
|
||||
}
|
||||
|
||||
func (s *stream) Close() error {
|
||||
@@ -39,6 +51,10 @@ func (s *stream) Reset() error {
|
||||
return s.yamux().Reset()
|
||||
}
|
||||
|
||||
func (s *stream) ResetWithError(errCode network.StreamErrorCode) error {
|
||||
return s.yamux().ResetWithError(uint32(errCode))
|
||||
}
|
||||
|
||||
func (s *stream) CloseRead() error {
|
||||
return s.yamux().CloseRead()
|
||||
}
|
||||
|
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
|
||||
"github.com/libp2p/go-yamux/v4"
|
||||
"github.com/libp2p/go-yamux/v5"
|
||||
)
|
||||
|
||||
var DefaultTransport *Transport
|
||||
|
@@ -175,7 +175,8 @@ func (cm *BasicConnMgr) memoryEmergency() {
|
||||
// Trim connections without paying attention to the silence period.
|
||||
for _, c := range cm.getConnsToCloseEmergency(target) {
|
||||
log.Infow("low on memory. closing conn", "peer", c.RemotePeer())
|
||||
c.Close()
|
||||
|
||||
c.CloseWithError(network.ConnGarbageCollected)
|
||||
}
|
||||
|
||||
// finally, update the last trim time.
|
||||
@@ -388,7 +389,7 @@ func (cm *BasicConnMgr) trim() {
|
||||
// do the actual trim.
|
||||
for _, c := range cm.getConnsToClose() {
|
||||
log.Debugw("closing conn", "peer", c.RemotePeer())
|
||||
c.Close()
|
||||
c.CloseWithError(network.ConnGarbageCollected)
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -11,8 +11,11 @@ import (
|
||||
"github.com/libp2p/go-libp2p/core/crypto"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/core/peerstore"
|
||||
tu "github.com/libp2p/go-libp2p/core/test"
|
||||
|
||||
swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing"
|
||||
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -33,6 +36,14 @@ func (c *tconn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *tconn) CloseWithError(code network.ConnErrorCode) error {
|
||||
atomic.StoreUint32(&c.closed, 1)
|
||||
if c.disconnectNotify != nil {
|
||||
c.disconnectNotify(nil, c)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *tconn) isClosed() bool {
|
||||
return atomic.LoadUint32(&c.closed) == 1
|
||||
}
|
||||
@@ -794,6 +805,7 @@ type mockConn struct {
|
||||
}
|
||||
|
||||
func (m mockConn) Close() error { panic("implement me") }
|
||||
func (m mockConn) CloseWithError(errCode network.ConnErrorCode) error { panic("implement me") }
|
||||
func (m mockConn) LocalPeer() peer.ID { panic("implement me") }
|
||||
func (m mockConn) RemotePeer() peer.ID { panic("implement me") }
|
||||
func (m mockConn) RemotePublicKey() crypto.PubKey { panic("implement me") }
|
||||
@@ -986,3 +998,79 @@ type testLimitGetter struct {
|
||||
func (g testLimitGetter) GetConnLimit() int {
|
||||
return g.limit
|
||||
}
|
||||
|
||||
func TestErrorCode(t *testing.T) {
|
||||
sw1, sw2, sw3 := swarmt.GenSwarm(t), swarmt.GenSwarm(t), swarmt.GenSwarm(t)
|
||||
defer sw1.Close()
|
||||
defer sw2.Close()
|
||||
defer sw3.Close()
|
||||
|
||||
cm, err := NewConnManager(1, 1, WithGracePeriod(0), WithSilencePeriod(10))
|
||||
require.NoError(t, err)
|
||||
defer cm.Close()
|
||||
|
||||
sw1.Peerstore().AddAddrs(sw2.LocalPeer(), sw2.ListenAddresses(), peerstore.PermanentAddrTTL)
|
||||
sw1.Peerstore().AddAddrs(sw3.LocalPeer(), sw3.ListenAddresses(), peerstore.PermanentAddrTTL)
|
||||
|
||||
c12, err := sw1.DialPeer(context.Background(), sw2.LocalPeer())
|
||||
require.NoError(t, err)
|
||||
|
||||
var c21 network.Conn
|
||||
require.Eventually(t, func() bool {
|
||||
conns := sw2.ConnsToPeer(sw1.LocalPeer())
|
||||
if len(conns) == 0 {
|
||||
return false
|
||||
}
|
||||
c21 = conns[0]
|
||||
return true
|
||||
}, 10*time.Second, 100*time.Millisecond)
|
||||
|
||||
c13, err := sw1.DialPeer(context.Background(), sw3.LocalPeer())
|
||||
require.NoError(t, err)
|
||||
|
||||
var c31 network.Conn
|
||||
require.Eventually(t, func() bool {
|
||||
conns := sw3.ConnsToPeer(sw1.LocalPeer())
|
||||
if len(conns) == 0 {
|
||||
return false
|
||||
}
|
||||
c31 = conns[0]
|
||||
return true
|
||||
}, 10*time.Second, 100*time.Millisecond)
|
||||
|
||||
not := cm.Notifee()
|
||||
not.Connected(sw1, c12)
|
||||
not.Connected(sw1, c13)
|
||||
|
||||
cm.TrimOpenConns(context.Background())
|
||||
|
||||
require.True(t, c12.IsClosed() || c13.IsClosed())
|
||||
var c, cr network.Conn
|
||||
if c12.IsClosed() {
|
||||
c = c12
|
||||
require.Eventually(t, func() bool {
|
||||
conns := sw2.ConnsToPeer(sw1.LocalPeer())
|
||||
if len(conns) == 0 {
|
||||
cr = c21
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}, 5*time.Second, 100*time.Millisecond)
|
||||
} else {
|
||||
c = c13
|
||||
require.Eventually(t, func() bool {
|
||||
conns := sw3.ConnsToPeer(sw1.LocalPeer())
|
||||
if len(conns) == 0 {
|
||||
cr = c31
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}, 5*time.Second, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
_, err = c.NewStream(context.Background())
|
||||
require.ErrorIs(t, err, &network.ConnError{ErrorCode: network.ConnGarbageCollected, Remote: false})
|
||||
|
||||
_, err = cr.NewStream(context.Background())
|
||||
require.ErrorIs(t, err, &network.ConnError{ErrorCode: network.ConnGarbageCollected, Remote: true})
|
||||
}
|
||||
|
@@ -185,3 +185,7 @@ func (c *conn) Stat() network.ConnStats {
|
||||
func (c *conn) Scope() network.ConnScope {
|
||||
return &network.NullScope{}
|
||||
}
|
||||
|
||||
func (c *conn) CloseWithError(_ network.ConnErrorCode) error {
|
||||
return c.Close()
|
||||
}
|
||||
|
@@ -144,6 +144,24 @@ func (s *stream) Reset() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetWithError resets the stream. It ignores the provided error code.
|
||||
// TODO: Implement error code support.
|
||||
func (s *stream) ResetWithError(_ network.StreamErrorCode) error {
|
||||
// Cancel any pending reads/writes with an error.
|
||||
|
||||
s.write.CloseWithError(network.ErrReset)
|
||||
s.read.CloseWithError(network.ErrReset)
|
||||
|
||||
select {
|
||||
case s.reset <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
<-s.closed
|
||||
|
||||
// No meaningful error case here.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stream) teardown() {
|
||||
// at this point, no streams are writing.
|
||||
s.conn.removeStream(s)
|
||||
|
@@ -385,8 +385,7 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn,
|
||||
// If we do this in the Upgrader, we will not be able to do this.
|
||||
if s.gater != nil {
|
||||
if allow, _ := s.gater.InterceptUpgraded(c); !allow {
|
||||
// TODO Send disconnect with reason here
|
||||
err := tc.Close()
|
||||
err := tc.CloseWithError(network.ConnGated)
|
||||
if err != nil {
|
||||
log.Warnf("failed to close connection with peer %s and addr %s; err: %s", p, addr, err)
|
||||
}
|
||||
@@ -845,6 +844,14 @@ func (c *connWithMetrics) Close() error {
|
||||
return c.closeErr
|
||||
}
|
||||
|
||||
func (c *connWithMetrics) CloseWithError(errCode network.ConnErrorCode) error {
|
||||
c.once.Do(func() {
|
||||
c.metricsTracer.ClosedConnection(c.dir, time.Since(c.opened), c.ConnState(), c.LocalMultiaddr())
|
||||
c.closeErr = c.CapableConn.CloseWithError(errCode)
|
||||
})
|
||||
return c.closeErr
|
||||
}
|
||||
|
||||
func (c *connWithMetrics) Stat() network.ConnStats {
|
||||
if cs, ok := c.CapableConn.(network.ConnStat); ok {
|
||||
return cs.Stat()
|
||||
|
@@ -58,11 +58,20 @@ func (c *Conn) ID() string {
|
||||
// open notifications must finish before we can fire off the close
|
||||
// notifications).
|
||||
func (c *Conn) Close() error {
|
||||
c.closeOnce.Do(c.doClose)
|
||||
c.closeOnce.Do(func() {
|
||||
c.doClose(0)
|
||||
})
|
||||
return c.err
|
||||
}
|
||||
|
||||
func (c *Conn) doClose() {
|
||||
func (c *Conn) CloseWithError(errCode network.ConnErrorCode) error {
|
||||
c.closeOnce.Do(func() {
|
||||
c.doClose(errCode)
|
||||
})
|
||||
return c.err
|
||||
}
|
||||
|
||||
func (c *Conn) doClose(errCode network.ConnErrorCode) {
|
||||
c.swarm.removeConn(c)
|
||||
|
||||
// Prevent new streams from opening.
|
||||
@@ -71,7 +80,11 @@ func (c *Conn) doClose() {
|
||||
c.streams.m = nil
|
||||
c.streams.Unlock()
|
||||
|
||||
c.err = c.conn.Close()
|
||||
if errCode != 0 {
|
||||
c.err = c.conn.CloseWithError(errCode)
|
||||
} else {
|
||||
c.err = c.conn.Close()
|
||||
}
|
||||
|
||||
// Send the connectedness event after closing the connection.
|
||||
// This ensures that both remote connection close and local connection
|
||||
@@ -121,7 +134,7 @@ func (c *Conn) start() {
|
||||
}
|
||||
scope, err := c.swarm.ResourceManager().OpenStream(c.RemotePeer(), network.DirInbound)
|
||||
if err != nil {
|
||||
ts.Reset()
|
||||
ts.ResetWithError(network.StreamResourceLimitExceeded)
|
||||
continue
|
||||
}
|
||||
c.swarm.refs.Add(1)
|
||||
|
@@ -91,6 +91,12 @@ func (s *Stream) Reset() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Stream) ResetWithError(errCode network.StreamErrorCode) error {
|
||||
err := s.stream.ResetWithError(errCode)
|
||||
s.closeAndRemoveStream()
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Stream) closeAndRemoveStream() {
|
||||
s.closeMx.Lock()
|
||||
defer s.closeMx.Unlock()
|
||||
|
@@ -538,7 +538,7 @@ func TestResourceManagerAcceptStream(t *testing.T) {
|
||||
if err == nil {
|
||||
_, err = str.Read([]byte{0})
|
||||
}
|
||||
require.EqualError(t, err, "stream reset")
|
||||
require.ErrorContains(t, err, "stream reset")
|
||||
}
|
||||
|
||||
func TestListenCloseCount(t *testing.T) {
|
||||
|
@@ -63,3 +63,8 @@ func (t *transportConn) ConnState() network.ConnectionState {
|
||||
UsedEarlyMuxerNegotiation: t.usedEarlyMuxerNegotiation,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *transportConn) CloseWithError(errCode network.ConnErrorCode) error {
|
||||
defer t.scope.Done()
|
||||
return t.MuxedConn.CloseWithError(errCode)
|
||||
}
|
||||
|
@@ -162,7 +162,7 @@ func (l *listener) handleIncoming() {
|
||||
// if we stop accepting connections for some reason,
|
||||
// we'll eventually close all the open ones
|
||||
// instead of hanging onto them.
|
||||
conn.Close()
|
||||
conn.CloseWithError(network.ConnRateLimited)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
@@ -267,12 +268,12 @@ func TestRelayLimitTime(t *testing.T) {
|
||||
if n > 0 {
|
||||
t.Fatalf("expected to write 0 bytes, wrote %d", n)
|
||||
}
|
||||
if err != network.ErrReset {
|
||||
if !errors.Is(err, network.ErrReset) {
|
||||
t.Fatalf("expected reset, but got %s", err)
|
||||
}
|
||||
|
||||
err = <-rch
|
||||
if err != network.ErrReset {
|
||||
if !errors.Is(err, network.ErrReset) {
|
||||
t.Fatalf("expected reset, but got %s", err)
|
||||
}
|
||||
}
|
||||
@@ -300,7 +301,7 @@ func TestRelayLimitData(t *testing.T) {
|
||||
}
|
||||
|
||||
n, err := s.Read(buf)
|
||||
if err != network.ErrReset {
|
||||
if !errors.Is(err, network.ErrReset) {
|
||||
t.Fatalf("expected reset but got %s", err)
|
||||
}
|
||||
rch <- n
|
||||
|
@@ -36,6 +36,7 @@ import (
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -866,3 +867,170 @@ func TestConnClosedWhenRemoteCloses(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorCodes(t *testing.T) {
|
||||
assertStreamErrors := func(s network.Stream, expectedError error) {
|
||||
buf := make([]byte, 10)
|
||||
_, err := s.Read(buf)
|
||||
require.ErrorIs(t, err, expectedError)
|
||||
|
||||
_, err = s.Write(buf)
|
||||
require.ErrorIs(t, err, expectedError)
|
||||
}
|
||||
|
||||
for _, tc := range transportsToTest {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
server := tc.HostGenerator(t, TransportTestCaseOpts{})
|
||||
client := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true})
|
||||
defer server.Close()
|
||||
defer client.Close()
|
||||
|
||||
client.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL)
|
||||
|
||||
// setup stream handler
|
||||
remoteStreamQ := make(chan network.Stream)
|
||||
server.SetStreamHandler("/test", func(s network.Stream) {
|
||||
b := make([]byte, 10)
|
||||
n, err := s.Read(b)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
_, err = s.Write(b[:n])
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
remoteStreamQ <- s
|
||||
})
|
||||
|
||||
// pingPong writes and reads "hello" on the stream
|
||||
pingPong := func(s network.Stream) {
|
||||
buf := []byte("hello")
|
||||
_, err := s.Write(buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, buf, []byte("hello"))
|
||||
}
|
||||
|
||||
t.Run("StreamResetWithError", func(t *testing.T) {
|
||||
if tc.Name == "WebTransport" {
|
||||
t.Skipf("skipping: %s, not implemented", tc.Name)
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s, err := client.NewStream(ctx, server.ID(), "/test")
|
||||
require.NoError(t, err)
|
||||
pingPong(s)
|
||||
|
||||
remoteStream := <-remoteStreamQ
|
||||
defer remoteStream.Reset()
|
||||
|
||||
err = s.ResetWithError(42)
|
||||
require.NoError(t, err)
|
||||
assertStreamErrors(s, &network.StreamError{
|
||||
ErrorCode: 42,
|
||||
Remote: false,
|
||||
})
|
||||
|
||||
assertStreamErrors(remoteStream, &network.StreamError{
|
||||
ErrorCode: 42,
|
||||
Remote: true,
|
||||
})
|
||||
})
|
||||
t.Run("StreamResetWithErrorByRemote", func(t *testing.T) {
|
||||
if tc.Name == "WebTransport" {
|
||||
t.Skipf("skipping: %s, not implemented", tc.Name)
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s, err := client.NewStream(ctx, server.ID(), "/test")
|
||||
require.NoError(t, err)
|
||||
pingPong(s)
|
||||
|
||||
remoteStream := <-remoteStreamQ
|
||||
|
||||
err = remoteStream.ResetWithError(42)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertStreamErrors(s, &network.StreamError{
|
||||
ErrorCode: 42,
|
||||
Remote: true,
|
||||
})
|
||||
|
||||
assertStreamErrors(remoteStream, &network.StreamError{
|
||||
ErrorCode: 42,
|
||||
Remote: false,
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("StreamResetByConnCloseWithError", func(t *testing.T) {
|
||||
if tc.Name == "WebTransport" || tc.Name == "WebRTC" {
|
||||
t.Skipf("skipping: %s, not implemented", tc.Name)
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s, err := client.NewStream(ctx, server.ID(), "/test")
|
||||
require.NoError(t, err)
|
||||
pingPong(s)
|
||||
|
||||
remoteStream := <-remoteStreamQ
|
||||
defer remoteStream.Reset()
|
||||
|
||||
err = s.Conn().CloseWithError(42)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertStreamErrors(s, &network.ConnError{
|
||||
ErrorCode: 42,
|
||||
Remote: false,
|
||||
})
|
||||
|
||||
assertStreamErrors(remoteStream, &network.ConnError{
|
||||
ErrorCode: 42,
|
||||
Remote: true,
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("NewStreamErrorByConnCloseWithError", func(t *testing.T) {
|
||||
if tc.Name == "WebTransport" || tc.Name == "WebRTC" {
|
||||
t.Skipf("skipping: %s, not implemented", tc.Name)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s, err := client.NewStream(ctx, server.ID(), "/test")
|
||||
require.NoError(t, err)
|
||||
pingPong(s)
|
||||
|
||||
err = s.Conn().CloseWithError(42)
|
||||
require.NoError(t, err)
|
||||
|
||||
remoteStream := <-remoteStreamQ
|
||||
defer remoteStream.Reset()
|
||||
|
||||
localErr := &network.ConnError{
|
||||
ErrorCode: 42,
|
||||
Remote: false,
|
||||
}
|
||||
|
||||
remoteErr := &network.ConnError{
|
||||
ErrorCode: 42,
|
||||
Remote: true,
|
||||
}
|
||||
|
||||
// assert these first to ensure that remote has closed the connection
|
||||
assertStreamErrors(remoteStream, remoteErr)
|
||||
|
||||
_, err = s.Conn().NewStream(ctx)
|
||||
require.ErrorIs(t, err, localErr)
|
||||
|
||||
_, err = remoteStream.Conn().NewStream(ctx)
|
||||
require.ErrorIs(t, err, remoteErr)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -34,6 +34,13 @@ func (c *conn) Close() error {
|
||||
return c.closeWithError(0, "")
|
||||
}
|
||||
|
||||
// CloseWithError closes the connection
|
||||
// It must be called even if the peer closed the connection in order for
|
||||
// garbage collection to properly work in this package.
|
||||
func (c *conn) CloseWithError(errCode network.ConnErrorCode) error {
|
||||
return c.closeWithError(quic.ApplicationErrorCode(errCode), "")
|
||||
}
|
||||
|
||||
func (c *conn) closeWithError(errCode quic.ApplicationErrorCode, errString string) error {
|
||||
c.transport.removeConn(c.quicConn)
|
||||
err := c.quicConn.CloseWithError(errCode, errString)
|
||||
@@ -53,13 +60,19 @@ func (c *conn) allowWindowIncrease(size uint64) bool {
|
||||
// OpenStream creates a new stream.
|
||||
func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
|
||||
qstr, err := c.quicConn.OpenStreamSync(ctx)
|
||||
return &stream{Stream: qstr}, err
|
||||
if err != nil {
|
||||
return nil, parseStreamError(err)
|
||||
}
|
||||
return &stream{Stream: qstr}, nil
|
||||
}
|
||||
|
||||
// AcceptStream accepts a stream opened by the other side.
|
||||
func (c *conn) AcceptStream() (network.MuxedStream, error) {
|
||||
qstr, err := c.quicConn.AcceptStream(context.Background())
|
||||
return &stream{Stream: qstr}, err
|
||||
if err != nil {
|
||||
return nil, parseStreamError(err)
|
||||
}
|
||||
return &stream{Stream: qstr}, nil
|
||||
}
|
||||
|
||||
// LocalPeer returns our peer ID
|
||||
|
@@ -270,6 +270,9 @@ func TestStreams(t *testing.T) {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
testStreams(t, tc)
|
||||
})
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
testStreamsErrorCode(t, tc)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -305,6 +308,45 @@ func testStreams(t *testing.T, tc *connTestCase) {
|
||||
require.Equal(t, data, []byte("foobar"))
|
||||
}
|
||||
|
||||
func testStreamsErrorCode(t *testing.T, tc *connTestCase) {
|
||||
serverID, serverKey := createPeer(t)
|
||||
_, clientKey := createPeer(t)
|
||||
|
||||
serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
defer serverTransport.(io.Closer).Close()
|
||||
ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1")
|
||||
defer ln.Close()
|
||||
|
||||
clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
defer clientTransport.(io.Closer).Close()
|
||||
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
serverConn, err := ln.Accept()
|
||||
require.NoError(t, err)
|
||||
defer serverConn.Close()
|
||||
|
||||
str, err := conn.OpenStream(context.Background())
|
||||
require.NoError(t, err)
|
||||
err = str.ResetWithError(42)
|
||||
require.NoError(t, err)
|
||||
|
||||
sstr, err := serverConn.AcceptStream()
|
||||
require.NoError(t, err)
|
||||
_, err = io.ReadAll(sstr)
|
||||
require.Error(t, err)
|
||||
se := &network.StreamError{}
|
||||
if errors.As(err, &se) {
|
||||
require.Equal(t, se.ErrorCode, network.StreamErrorCode(42))
|
||||
require.True(t, se.Remote)
|
||||
} else {
|
||||
t.Fatalf("expected error to be of network.StreamError type, got %T, %v", err, err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestHandshakeFailPeerIDMismatch(t *testing.T) {
|
||||
for _, tc := range connTestCases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
|
@@ -11,7 +11,6 @@ import (
|
||||
tpt "github.com/libp2p/go-libp2p/core/transport"
|
||||
p2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
|
||||
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
|
||||
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
@@ -54,12 +53,12 @@ func (l *listener) Accept() (tpt.CapableConn, error) {
|
||||
c, err := l.wrapConn(qconn)
|
||||
if err != nil {
|
||||
log.Debugf("failed to setup connection: %s", err)
|
||||
qconn.CloseWithError(1, "")
|
||||
qconn.CloseWithError(quic.ApplicationErrorCode(network.ConnResourceLimitExceeded), "")
|
||||
continue
|
||||
}
|
||||
l.transport.addConn(qconn, c)
|
||||
if l.transport.gater != nil && !(l.transport.gater.InterceptAccept(c) && l.transport.gater.InterceptSecured(network.DirInbound, c.remotePeerID, c)) {
|
||||
c.closeWithError(errorCodeConnectionGating, "connection gated")
|
||||
c.closeWithError(quic.ApplicationErrorCode(network.ConnGated), "connection gated")
|
||||
continue
|
||||
}
|
||||
|
||||
|
@@ -159,10 +159,11 @@ func TestCleanupConnWhenBlocked(t *testing.T) {
|
||||
s.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
b := [1]byte{}
|
||||
_, err = s.Read(b[:])
|
||||
if err != nil && errors.As(err, &quicErr) {
|
||||
connError := &network.ConnError{}
|
||||
if err != nil && errors.As(err, &connError) {
|
||||
// We hit our expected application error
|
||||
return
|
||||
}
|
||||
|
||||
t.Fatalf("expected application error, got %v", err)
|
||||
t.Fatalf("expected network.ConnError, got %v", err)
|
||||
}
|
||||
|
@@ -2,6 +2,7 @@ package libp2pquic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
|
||||
@@ -18,24 +19,49 @@ type stream struct {
|
||||
|
||||
var _ network.MuxedStream = &stream{}
|
||||
|
||||
func (s *stream) Read(b []byte) (n int, err error) {
|
||||
var streamErr *quic.StreamError
|
||||
|
||||
n, err = s.Stream.Read(b)
|
||||
if err != nil && errors.As(err, &streamErr) {
|
||||
err = network.ErrReset
|
||||
func parseStreamError(err error) error {
|
||||
if err == nil {
|
||||
return err
|
||||
}
|
||||
return n, err
|
||||
se := &quic.StreamError{}
|
||||
if errors.As(err, &se) {
|
||||
var code network.StreamErrorCode
|
||||
if se.ErrorCode > math.MaxUint32 {
|
||||
code = network.StreamCodeOutOfRange
|
||||
} else {
|
||||
code = network.StreamErrorCode(se.ErrorCode)
|
||||
}
|
||||
err = &network.StreamError{
|
||||
ErrorCode: code,
|
||||
Remote: se.Remote,
|
||||
TransportError: se,
|
||||
}
|
||||
}
|
||||
ae := &quic.ApplicationError{}
|
||||
if errors.As(err, &ae) {
|
||||
var code network.ConnErrorCode
|
||||
if ae.ErrorCode > math.MaxUint32 {
|
||||
code = network.ConnCodeOutOfRange
|
||||
} else {
|
||||
code = network.ConnErrorCode(ae.ErrorCode)
|
||||
}
|
||||
err = &network.ConnError{
|
||||
ErrorCode: code,
|
||||
Remote: ae.Remote,
|
||||
TransportError: ae,
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *stream) Read(b []byte) (n int, err error) {
|
||||
n, err = s.Stream.Read(b)
|
||||
return n, parseStreamError(err)
|
||||
}
|
||||
|
||||
func (s *stream) Write(b []byte) (n int, err error) {
|
||||
var streamErr *quic.StreamError
|
||||
|
||||
n, err = s.Stream.Write(b)
|
||||
if err != nil && errors.As(err, &streamErr) {
|
||||
err = network.ErrReset
|
||||
}
|
||||
return n, err
|
||||
return n, parseStreamError(err)
|
||||
}
|
||||
|
||||
func (s *stream) Reset() error {
|
||||
@@ -44,6 +70,12 @@ func (s *stream) Reset() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stream) ResetWithError(errCode network.StreamErrorCode) error {
|
||||
s.Stream.CancelRead(quic.StreamErrorCode(errCode))
|
||||
s.Stream.CancelWrite(quic.StreamErrorCode(errCode))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stream) Close() error {
|
||||
s.Stream.CancelRead(reset)
|
||||
return s.Stream.Close()
|
||||
|
@@ -34,8 +34,6 @@ var ErrHolePunching = errors.New("hole punching attempted; no active dial")
|
||||
|
||||
var HolePunchTimeout = 5 * time.Second
|
||||
|
||||
const errorCodeConnectionGating = 0x47415445 // GATE in ASCII
|
||||
|
||||
// The Transport implements the tpt.Transport interface for QUIC connections.
|
||||
type transport struct {
|
||||
privKey ic.PrivKey
|
||||
@@ -169,7 +167,7 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee
|
||||
remoteMultiaddr: raddr,
|
||||
}
|
||||
if t.gater != nil && !t.gater.InterceptSecured(network.DirOutbound, p, c) {
|
||||
pconn.CloseWithError(errorCodeConnectionGating, "connection gated")
|
||||
pconn.CloseWithError(quic.ApplicationErrorCode(network.ConnGated), "connection gated")
|
||||
return nil, fmt.Errorf("secured connection gated")
|
||||
}
|
||||
t.addConn(pconn, c)
|
||||
|
@@ -3,6 +3,7 @@ package libp2pquic
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
tpt "github.com/libp2p/go-libp2p/core/transport"
|
||||
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
|
||||
|
||||
@@ -142,8 +143,8 @@ func (r *acceptLoopRunner) innerAccept(l *listener, expectedVersion quic.Version
|
||||
select {
|
||||
case ch <- acceptVal{conn: conn}:
|
||||
default:
|
||||
conn.CloseWithError(network.ConnRateLimited)
|
||||
// accept queue filled up, drop the connection
|
||||
conn.Close()
|
||||
log.Warn("Accept queue filled. Dropping connection.")
|
||||
}
|
||||
|
||||
|
@@ -10,6 +10,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/transport"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
"github.com/quic-go/quic-go"
|
||||
@@ -212,7 +213,7 @@ func (l *listener) Close() error {
|
||||
close(l.queue)
|
||||
// drain the queue
|
||||
for conn := range l.queue {
|
||||
conn.CloseWithError(1, "closing")
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(network.ConnShutdown), "closing")
|
||||
}
|
||||
})
|
||||
return nil
|
||||
|
@@ -132,6 +132,12 @@ func (c *connection) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseWithError closes the connection ignoring the error code. As there's no way to signal
|
||||
// the remote peer on closing the underlying peerconnection, we ignore the error code.
|
||||
func (c *connection) CloseWithError(_ network.ConnErrorCode) error {
|
||||
return c.Close()
|
||||
}
|
||||
|
||||
// closeWithError is used to Close the connection when the underlying DTLS connection fails
|
||||
func (c *connection) closeWithError(err error) {
|
||||
c.closeOnce.Do(func() {
|
||||
|
@@ -95,6 +95,7 @@ type Message struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Flag *Message_Flag `protobuf:"varint,1,opt,name=flag,enum=Message_Flag" json:"flag,omitempty"`
|
||||
Message []byte `protobuf:"bytes,2,opt,name=message" json:"message,omitempty"`
|
||||
ErrorCode *uint32 `protobuf:"varint,3,opt,name=errorCode" json:"errorCode,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -143,24 +144,32 @@ func (x *Message) GetMessage() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *Message) GetErrorCode() uint32 {
|
||||
if x != nil && x.ErrorCode != nil {
|
||||
return *x.ErrorCode
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
var File_p2p_transport_webrtc_pb_message_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_p2p_transport_webrtc_pb_message_proto_rawDesc = string([]byte{
|
||||
0x0a, 0x25, 0x70, 0x32, 0x70, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f,
|
||||
0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2f, 0x70, 0x62, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67,
|
||||
0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x81, 0x01, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73,
|
||||
0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x9f, 0x01, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73,
|
||||
0x61, 0x67, 0x65, 0x12, 0x21, 0x0a, 0x04, 0x66, 0x6c, 0x61, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28,
|
||||
0x0e, 0x32, 0x0d, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2e, 0x46, 0x6c, 0x61, 0x67,
|
||||
0x52, 0x04, 0x66, 0x6c, 0x61, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67,
|
||||
0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
|
||||
0x22, 0x39, 0x0a, 0x04, 0x46, 0x6c, 0x61, 0x67, 0x12, 0x07, 0x0a, 0x03, 0x46, 0x49, 0x4e, 0x10,
|
||||
0x00, 0x12, 0x10, 0x0a, 0x0c, 0x53, 0x54, 0x4f, 0x50, 0x5f, 0x53, 0x45, 0x4e, 0x44, 0x49, 0x4e,
|
||||
0x47, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x52, 0x45, 0x53, 0x45, 0x54, 0x10, 0x02, 0x12, 0x0b,
|
||||
0x0a, 0x07, 0x46, 0x49, 0x4e, 0x5f, 0x41, 0x43, 0x4b, 0x10, 0x03, 0x42, 0x35, 0x5a, 0x33, 0x67,
|
||||
0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70,
|
||||
0x2f, 0x67, 0x6f, 0x2d, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2f, 0x70, 0x32, 0x70, 0x2f, 0x74,
|
||||
0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2f,
|
||||
0x70, 0x62,
|
||||
0x12, 0x1c, 0x0a, 0x09, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, 0x03, 0x20,
|
||||
0x01, 0x28, 0x0d, 0x52, 0x09, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x22, 0x39,
|
||||
0x0a, 0x04, 0x46, 0x6c, 0x61, 0x67, 0x12, 0x07, 0x0a, 0x03, 0x46, 0x49, 0x4e, 0x10, 0x00, 0x12,
|
||||
0x10, 0x0a, 0x0c, 0x53, 0x54, 0x4f, 0x50, 0x5f, 0x53, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10,
|
||||
0x01, 0x12, 0x09, 0x0a, 0x05, 0x52, 0x45, 0x53, 0x45, 0x54, 0x10, 0x02, 0x12, 0x0b, 0x0a, 0x07,
|
||||
0x46, 0x49, 0x4e, 0x5f, 0x41, 0x43, 0x4b, 0x10, 0x03, 0x42, 0x35, 0x5a, 0x33, 0x67, 0x69, 0x74,
|
||||
0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2f, 0x67,
|
||||
0x6f, 0x2d, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2f, 0x70, 0x32, 0x70, 0x2f, 0x74, 0x72, 0x61,
|
||||
0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2f, 0x70, 0x62,
|
||||
})
|
||||
|
||||
var (
|
||||
|
@@ -21,4 +21,6 @@ message Message {
|
||||
optional Flag flag=1;
|
||||
|
||||
optional bytes message = 2;
|
||||
|
||||
optional uint32 errorCode = 3;
|
||||
}
|
||||
|
@@ -69,8 +69,9 @@ type stream struct {
|
||||
|
||||
// readerMx ensures that only a single goroutine reads from the reader. Read is not threadsafe
|
||||
// But we may need to read from reader for control messages from a different goroutine.
|
||||
readerMx sync.Mutex
|
||||
reader pbio.Reader
|
||||
readerMx sync.Mutex
|
||||
reader pbio.Reader
|
||||
readError error
|
||||
|
||||
// this buffer is limited up to a single message. Reason we need it
|
||||
// is because a reader might read a message midway, and so we need a
|
||||
@@ -82,6 +83,7 @@ type stream struct {
|
||||
writeStateChanged chan struct{}
|
||||
sendState sendState
|
||||
writeDeadline time.Time
|
||||
writeError error
|
||||
|
||||
controlMessageReaderOnce sync.Once
|
||||
// controlMessageReaderEndTime is the end time for reading FIN_ACK from the control
|
||||
@@ -146,6 +148,10 @@ func (s *stream) Close() error {
|
||||
}
|
||||
|
||||
func (s *stream) Reset() error {
|
||||
return s.ResetWithError(0)
|
||||
}
|
||||
|
||||
func (s *stream) ResetWithError(errCode network.StreamErrorCode) error {
|
||||
s.mx.Lock()
|
||||
isClosed := s.closeForShutdownErr != nil
|
||||
s.mx.Unlock()
|
||||
@@ -154,8 +160,8 @@ func (s *stream) Reset() error {
|
||||
}
|
||||
|
||||
defer s.cleanup()
|
||||
cancelWriteErr := s.cancelWrite()
|
||||
closeReadErr := s.CloseRead()
|
||||
cancelWriteErr := s.cancelWrite(errCode)
|
||||
closeReadErr := s.closeRead(errCode, false)
|
||||
s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour))
|
||||
return errors.Join(closeReadErr, cancelWriteErr)
|
||||
}
|
||||
@@ -175,19 +181,20 @@ func (s *stream) SetDeadline(t time.Time) error {
|
||||
return s.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
// processIncomingFlag process the flag on an incoming message
|
||||
// processIncomingFlag processes the flag(FIN/RST/etc) on msg.
|
||||
// It needs to be called while the mutex is locked.
|
||||
func (s *stream) processIncomingFlag(flag *pb.Message_Flag) {
|
||||
if flag == nil {
|
||||
func (s *stream) processIncomingFlag(msg *pb.Message) {
|
||||
if msg.Flag == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch *flag {
|
||||
switch msg.GetFlag() {
|
||||
case pb.Message_STOP_SENDING:
|
||||
// We must process STOP_SENDING after sending a FIN(sendStateDataSent). Remote peer
|
||||
// may not send a FIN_ACK once it has sent a STOP_SENDING
|
||||
if s.sendState == sendStateSending || s.sendState == sendStateDataSent {
|
||||
s.sendState = sendStateReset
|
||||
s.writeError = &network.StreamError{Remote: true, ErrorCode: network.StreamErrorCode(msg.GetErrorCode())}
|
||||
}
|
||||
s.notifyWriteStateChanged()
|
||||
case pb.Message_FIN_ACK:
|
||||
@@ -206,6 +213,11 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) {
|
||||
case pb.Message_RESET:
|
||||
if s.receiveState == receiveStateReceiving {
|
||||
s.receiveState = receiveStateReset
|
||||
s.readError = &network.StreamError{Remote: true, ErrorCode: network.StreamErrorCode(msg.GetErrorCode())}
|
||||
}
|
||||
if s.sendState == sendStateSending || s.sendState == sendStateDataSent {
|
||||
s.sendState = sendStateReset
|
||||
s.writeError = &network.StreamError{Remote: true, ErrorCode: network.StreamErrorCode(msg.GetErrorCode())}
|
||||
}
|
||||
s.spawnControlMessageReader()
|
||||
}
|
||||
@@ -235,7 +247,7 @@ func (s *stream) spawnControlMessageReader() {
|
||||
s.readerMx.Unlock()
|
||||
|
||||
if s.nextMessage != nil {
|
||||
s.processIncomingFlag(s.nextMessage.Flag)
|
||||
s.processIncomingFlag(s.nextMessage)
|
||||
s.nextMessage = nil
|
||||
}
|
||||
var msg pb.Message
|
||||
@@ -266,7 +278,7 @@ func (s *stream) spawnControlMessageReader() {
|
||||
}
|
||||
return
|
||||
}
|
||||
s.processIncomingFlag(msg.Flag)
|
||||
s.processIncomingFlag(&msg)
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
@@ -22,7 +22,7 @@ func (s *stream) Read(b []byte) (int, error) {
|
||||
case receiveStateDataRead:
|
||||
return 0, io.EOF
|
||||
case receiveStateReset:
|
||||
return 0, network.ErrReset
|
||||
return 0, s.readError
|
||||
}
|
||||
|
||||
if len(b) == 0 {
|
||||
@@ -52,10 +52,11 @@ func (s *stream) Read(b []byte) (int, error) {
|
||||
// datachannel. For these implementations a stream reset will be observed as an
|
||||
// abrupt closing of the datachannel.
|
||||
s.receiveState = receiveStateReset
|
||||
return 0, network.ErrReset
|
||||
s.readError = &network.StreamError{Remote: true}
|
||||
return 0, s.readError
|
||||
}
|
||||
if s.receiveState == receiveStateReset {
|
||||
return 0, network.ErrReset
|
||||
return 0, s.readError
|
||||
}
|
||||
if s.receiveState == receiveStateDataRead {
|
||||
return 0, io.EOF
|
||||
@@ -73,7 +74,7 @@ func (s *stream) Read(b []byte) (int, error) {
|
||||
}
|
||||
|
||||
// process flags on the message after reading all the data
|
||||
s.processIncomingFlag(s.nextMessage.Flag)
|
||||
s.processIncomingFlag(s.nextMessage)
|
||||
s.nextMessage = nil
|
||||
if s.closeForShutdownErr != nil {
|
||||
return read, s.closeForShutdownErr
|
||||
@@ -82,7 +83,7 @@ func (s *stream) Read(b []byte) (int, error) {
|
||||
case receiveStateDataRead:
|
||||
return read, io.EOF
|
||||
case receiveStateReset:
|
||||
return read, network.ErrReset
|
||||
return read, s.readError
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -101,12 +102,18 @@ func (s *stream) setDataChannelReadDeadline(t time.Time) error {
|
||||
}
|
||||
|
||||
func (s *stream) CloseRead() error {
|
||||
return s.closeRead(0, false)
|
||||
}
|
||||
|
||||
func (s *stream) closeRead(errCode network.StreamErrorCode, remote bool) error {
|
||||
s.mx.Lock()
|
||||
defer s.mx.Unlock()
|
||||
var err error
|
||||
if s.receiveState == receiveStateReceiving && s.closeForShutdownErr == nil {
|
||||
err = s.writer.WriteMsg(&pb.Message{Flag: pb.Message_STOP_SENDING.Enum()})
|
||||
code := uint32(errCode)
|
||||
err = s.writer.WriteMsg(&pb.Message{Flag: pb.Message_STOP_SENDING.Enum(), ErrorCode: &code})
|
||||
s.receiveState = receiveStateReset
|
||||
s.readError = &network.StreamError{Remote: remote, ErrorCode: errCode}
|
||||
}
|
||||
s.spawnControlMessageReader()
|
||||
return err
|
||||
|
@@ -24,7 +24,7 @@ func (s *stream) Write(b []byte) (int, error) {
|
||||
}
|
||||
switch s.sendState {
|
||||
case sendStateReset:
|
||||
return 0, network.ErrReset
|
||||
return 0, s.writeError
|
||||
case sendStateDataSent, sendStateDataReceived:
|
||||
return 0, errWriteAfterClose
|
||||
}
|
||||
@@ -48,7 +48,7 @@ func (s *stream) Write(b []byte) (int, error) {
|
||||
}
|
||||
switch s.sendState {
|
||||
case sendStateReset:
|
||||
return n, network.ErrReset
|
||||
return n, s.writeError
|
||||
case sendStateDataSent, sendStateDataReceived:
|
||||
return n, errWriteAfterClose
|
||||
}
|
||||
@@ -119,7 +119,7 @@ func (s *stream) availableSendSpace() int {
|
||||
return availableSpace
|
||||
}
|
||||
|
||||
func (s *stream) cancelWrite() error {
|
||||
func (s *stream) cancelWrite(errCode network.StreamErrorCode) error {
|
||||
s.mx.Lock()
|
||||
defer s.mx.Unlock()
|
||||
|
||||
@@ -129,10 +129,12 @@ func (s *stream) cancelWrite() error {
|
||||
return nil
|
||||
}
|
||||
s.sendState = sendStateReset
|
||||
s.writeError = &network.StreamError{Remote: false, ErrorCode: errCode}
|
||||
// Remove reference to this stream from data channel
|
||||
s.dataChannel.OnBufferedAmountLow(nil)
|
||||
s.notifyWriteStateChanged()
|
||||
return s.writer.WriteMsg(&pb.Message{Flag: pb.Message_RESET.Enum()})
|
||||
code := uint32(errCode)
|
||||
return s.writer.WriteMsg(&pb.Message{Flag: pb.Message_RESET.Enum(), ErrorCode: &code})
|
||||
}
|
||||
|
||||
func (s *stream) CloseWrite() error {
|
||||
|
@@ -78,6 +78,10 @@ func (c *conn) Close() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) CloseWithError(_ network.ConnErrorCode) error {
|
||||
return c.Close()
|
||||
}
|
||||
|
||||
func (c *conn) IsClosed() bool { return c.session.Context().Err() != nil }
|
||||
func (c *conn) Scope() network.ConnScope { return c.scope }
|
||||
func (c *conn) Transport() tpt.Transport { return c.transport }
|
||||
|
@@ -56,6 +56,17 @@ func (s *stream) Reset() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetWithError resets the stream ignoring the error code. Error codes aren't
|
||||
// specified for WebTransport as the current implementation of WebTransport in
|
||||
// browsers(https://www.ietf.org/archive/id/draft-kinnear-webtransport-http2-02.html)
|
||||
// only supports 1 byte error codes. For more details, see
|
||||
// https://github.com/libp2p/specs/blob/4eca305185c7aef219e936bef76c48b1ab0a8b43/error-codes/README.md?plain=1#L84
|
||||
func (s *stream) ResetWithError(_ network.StreamErrorCode) error {
|
||||
s.Stream.CancelRead(reset)
|
||||
s.Stream.CancelWrite(reset)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stream) Close() error {
|
||||
s.Stream.CancelRead(reset)
|
||||
return s.Stream.Close()
|
||||
|
@@ -43,7 +43,7 @@ require (
|
||||
github.com/libp2p/go-nat v0.2.0 // indirect
|
||||
github.com/libp2p/go-netroute v0.2.2 // indirect
|
||||
github.com/libp2p/go-reuseport v0.4.0 // indirect
|
||||
github.com/libp2p/go-yamux/v4 v4.0.2 // indirect
|
||||
github.com/libp2p/go-yamux/v5 v5.0.0 // indirect
|
||||
github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/miekg/dns v1.1.63 // indirect
|
||||
|
@@ -149,8 +149,8 @@ github.com/libp2p/go-netroute v0.2.2 h1:Dejd8cQ47Qx2kRABg6lPwknU7+nBnFRpko45/fFP
|
||||
github.com/libp2p/go-netroute v0.2.2/go.mod h1:Rntq6jUAH0l9Gg17w5bFGhcC9a+vk4KNXs6s7IljKYE=
|
||||
github.com/libp2p/go-reuseport v0.4.0 h1:nR5KU7hD0WxXCJbmw7r2rhRYruNRl2koHw8fQscQm2s=
|
||||
github.com/libp2p/go-reuseport v0.4.0/go.mod h1:ZtI03j/wO5hZVDFo2jKywN6bYKWLOy8Se6DrI2E1cLU=
|
||||
github.com/libp2p/go-yamux/v4 v4.0.2 h1:nrLh89LN/LEiqcFiqdKDRHjGstN300C1269K/EX0CPU=
|
||||
github.com/libp2p/go-yamux/v4 v4.0.2/go.mod h1:C808cCRgOs1iBwY4S71T5oxgMxgLmqUw56qh4AeBW2o=
|
||||
github.com/libp2p/go-yamux/v5 v5.0.0 h1:2djUh96d3Jiac/JpGkKs4TO49YhsfLopAoryfPmf+Po=
|
||||
github.com/libp2p/go-yamux/v5 v5.0.0/go.mod h1:en+3cdX51U0ZslwRdRLrvQsdayFt3TSUKvBGErzpWbU=
|
||||
github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI=
|
||||
github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
|
||||
github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd h1:br0buuQ854V8u83wA0rVZ8ttrq5CpaPZdvrK0LP2lOk=
|
||||
|
Reference in New Issue
Block a user