Implement error codes spec (#2927)

This commit is contained in:
sukun
2025-02-09 21:39:18 +05:30
committed by GitHub
parent 4957d357e0
commit 02ab795c92
40 changed files with 673 additions and 93 deletions

View File

@@ -2,15 +2,60 @@ package network
import (
"context"
"fmt"
"io"
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
ma "github.com/multiformats/go-multiaddr"
)
type ConnErrorCode uint32
type ConnError struct {
Remote bool
ErrorCode ConnErrorCode
TransportError error
}
func (c *ConnError) Error() string {
side := "local"
if c.Remote {
side = "remote"
}
if c.TransportError != nil {
return fmt.Sprintf("connection closed (%s): code: 0x%x: transport error: %s", side, c.ErrorCode, c.TransportError)
}
return fmt.Sprintf("connection closed (%s): code: 0x%x", side, c.ErrorCode)
}
func (c *ConnError) Is(target error) bool {
if tce, ok := target.(*ConnError); ok {
return tce.ErrorCode == c.ErrorCode && tce.Remote == c.Remote
}
return false
}
func (c *ConnError) Unwrap() []error {
return []error{ErrReset, c.TransportError}
}
const (
ConnNoError ConnErrorCode = 0
ConnProtocolNegotiationFailed ConnErrorCode = 0x1000
ConnResourceLimitExceeded ConnErrorCode = 0x1001
ConnRateLimited ConnErrorCode = 0x1002
ConnProtocolViolation ConnErrorCode = 0x1003
ConnSupplanted ConnErrorCode = 0x1004
ConnGarbageCollected ConnErrorCode = 0x1005
ConnShutdown ConnErrorCode = 0x1006
ConnGated ConnErrorCode = 0x1007
ConnCodeOutOfRange ConnErrorCode = 0x1008
)
// Conn is a connection to a remote peer. It multiplexes streams.
// Usually there is no need to use a Conn directly, but it may
// be useful to get information about the peer on the other side:
@@ -24,6 +69,11 @@ type Conn interface {
ConnStat
ConnScoper
// CloseWithError closes the connection with errCode. The errCode is sent to the
// peer on a best effort basis. For transports that do not support sending error
// codes on connection close, the behavior is identical to calling Close.
CloseWithError(errCode ConnErrorCode) error
// ID returns an identifier that uniquely identifies this Conn within this
// host, during this run. Connection IDs may repeat across restarts.
ID() string

View File

@@ -3,6 +3,7 @@ package network
import (
"context"
"errors"
"fmt"
"io"
"net"
"time"
@@ -11,6 +12,49 @@ import (
// ErrReset is returned when reading or writing on a reset stream.
var ErrReset = errors.New("stream reset")
type StreamErrorCode uint32
type StreamError struct {
ErrorCode StreamErrorCode
Remote bool
TransportError error
}
func (s *StreamError) Error() string {
side := "local"
if s.Remote {
side = "remote"
}
if s.TransportError != nil {
return fmt.Sprintf("stream reset (%s): code: 0x%x: transport error: %s", side, s.ErrorCode, s.TransportError)
}
return fmt.Sprintf("stream reset (%s): code: 0x%x", side, s.ErrorCode)
}
func (s *StreamError) Is(target error) bool {
if tse, ok := target.(*StreamError); ok {
return tse.ErrorCode == s.ErrorCode && tse.Remote == s.Remote
}
return false
}
func (s *StreamError) Unwrap() []error {
return []error{ErrReset, s.TransportError}
}
const (
StreamNoError StreamErrorCode = 0
StreamProtocolNegotiationFailed StreamErrorCode = 0x1001
StreamResourceLimitExceeded StreamErrorCode = 0x1002
StreamRateLimited StreamErrorCode = 0x1003
StreamProtocolViolation StreamErrorCode = 0x1004
StreamSupplanted StreamErrorCode = 0x1005
StreamGarbageCollected StreamErrorCode = 0x1006
StreamShutdown StreamErrorCode = 0x1007
StreamGated StreamErrorCode = 0x1008
StreamCodeOutOfRange StreamErrorCode = 0x1009
)
// MuxedStream is a bidirectional io pipe within a connection.
type MuxedStream interface {
io.Reader
@@ -56,6 +100,11 @@ type MuxedStream interface {
// side to hang up and go away.
Reset() error
// ResetWithError aborts both ends of the stream with `errCode`. `errCode` is sent
// to the peer on a best effort basis. For transports that do not support sending
// error codes to remote peer, the behavior is identical to calling Reset
ResetWithError(errCode StreamErrorCode) error
SetDeadline(time.Time) error
SetReadDeadline(time.Time) error
SetWriteDeadline(time.Time) error
@@ -75,6 +124,10 @@ type MuxedConn interface {
// Close closes the stream muxer and the the underlying net.Conn.
io.Closer
// CloseWithError closes the connection with errCode. The errCode is sent
// to the peer.
CloseWithError(errCode ConnErrorCode) error
// IsClosed returns whether a connection is fully closed, so it can
// be garbage collected.
IsClosed() bool

View File

@@ -27,4 +27,8 @@ type Stream interface {
// Scope returns the user's view of this stream's resource scope
Scope() StreamScope
// ResetWithError closes both ends of the stream with errCode. The errCode is sent
// to the peer.
ResetWithError(errCode StreamErrorCode) error
}

2
go.mod
View File

@@ -30,7 +30,7 @@ require (
github.com/libp2p/go-nat v0.2.0
github.com/libp2p/go-netroute v0.2.2
github.com/libp2p/go-reuseport v0.4.0
github.com/libp2p/go-yamux/v4 v4.0.2
github.com/libp2p/go-yamux/v5 v5.0.0
github.com/libp2p/zeroconf/v2 v2.2.0
github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd
github.com/mikioh/tcpinfo v0.0.0-20190314235526-30a79bb1804b

4
go.sum
View File

@@ -193,8 +193,8 @@ github.com/libp2p/go-netroute v0.2.2 h1:Dejd8cQ47Qx2kRABg6lPwknU7+nBnFRpko45/fFP
github.com/libp2p/go-netroute v0.2.2/go.mod h1:Rntq6jUAH0l9Gg17w5bFGhcC9a+vk4KNXs6s7IljKYE=
github.com/libp2p/go-reuseport v0.4.0 h1:nR5KU7hD0WxXCJbmw7r2rhRYruNRl2koHw8fQscQm2s=
github.com/libp2p/go-reuseport v0.4.0/go.mod h1:ZtI03j/wO5hZVDFo2jKywN6bYKWLOy8Se6DrI2E1cLU=
github.com/libp2p/go-yamux/v4 v4.0.2 h1:nrLh89LN/LEiqcFiqdKDRHjGstN300C1269K/EX0CPU=
github.com/libp2p/go-yamux/v4 v4.0.2/go.mod h1:C808cCRgOs1iBwY4S71T5oxgMxgLmqUw56qh4AeBW2o=
github.com/libp2p/go-yamux/v5 v5.0.0 h1:2djUh96d3Jiac/JpGkKs4TO49YhsfLopAoryfPmf+Po=
github.com/libp2p/go-yamux/v5 v5.0.0/go.mod h1:en+3cdX51U0ZslwRdRLrvQsdayFt3TSUKvBGErzpWbU=
github.com/libp2p/zeroconf/v2 v2.2.0 h1:Cup06Jv6u81HLhIj1KasuNM/RHHrJ8T7wOTS4+Tv53Q=
github.com/libp2p/zeroconf/v2 v2.2.0/go.mod h1:fuJqLnUwZTshS3U/bMRJ3+ow/v9oid1n0DmyYyNO1Xs=
github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI=

View File

@@ -464,7 +464,7 @@ func (h *BasicHost) newStreamHandler(s network.Stream) {
} else {
log.Debugf("protocol mux failed: %s (took %s, id:%s, remote peer:%s, remote addr:%v)", err, took, s.ID(), s.Conn().RemotePeer(), s.Conn().RemoteMultiaddr())
}
s.Reset()
s.ResetWithError(network.StreamProtocolNegotiationFailed)
return
}
@@ -478,7 +478,7 @@ func (h *BasicHost) newStreamHandler(s network.Stream) {
if err := s.SetProtocol(protoID); err != nil {
log.Debugf("error setting stream protocol: %s", err)
s.Reset()
s.ResetWithError(network.StreamResourceLimitExceeded)
return
}
@@ -717,7 +717,7 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I
}
defer func() {
if strErr != nil && s != nil {
s.Reset()
s.ResetWithError(network.StreamProtocolNegotiationFailed)
}
}()
@@ -761,13 +761,14 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I
return nil, fmt.Errorf("failed to negotiate protocol: %w", err)
}
case <-ctx.Done():
s.Reset()
s.ResetWithError(network.StreamProtocolNegotiationFailed)
// wait for `SelectOneOf` to error out because of resetting the stream.
<-errCh
return nil, fmt.Errorf("failed to negotiate protocol: %w", ctx.Err())
}
if err := s.SetProtocol(selected); err != nil {
s.ResetWithError(network.StreamResourceLimitExceeded)
return nil, err
}
_ = h.Peerstore().AddProtocols(p, selected) // adding the protocol to the peerstore isn't critical

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"context"
crand "crypto/rand"
"errors"
"fmt"
"io"
mrand "math/rand"
@@ -462,7 +463,7 @@ func SubtestStreamReset(t *testing.T, tr network.Multiplexer) {
time.Sleep(time.Millisecond * 50)
_, err = s.Write([]byte("foo"))
if err != network.ErrReset {
if !errors.Is(err, network.ErrReset) {
t.Error("should have been stream reset")
}
s.Close()

View File

@@ -5,7 +5,7 @@ import (
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-yamux/v4"
"github.com/libp2p/go-yamux/v5"
)
// conn implements mux.MuxedConn over yamux.Session.
@@ -23,6 +23,10 @@ func (c *conn) Close() error {
return c.yamux().Close()
}
func (c *conn) CloseWithError(errCode network.ConnErrorCode) error {
return c.yamux().CloseWithError(uint32(errCode))
}
// IsClosed checks if yamux.Session is in closed state.
func (c *conn) IsClosed() bool {
return c.yamux().IsClosed()
@@ -32,7 +36,7 @@ func (c *conn) IsClosed() bool {
func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
s, err := c.yamux().OpenStream(ctx)
if err != nil {
return nil, err
return nil, parseError(err)
}
return (*stream)(s), nil
@@ -41,7 +45,7 @@ func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
// AcceptStream accepts a stream opened by the other side.
func (c *conn) AcceptStream() (network.MuxedStream, error) {
s, err := c.yamux().AcceptStream()
return (*stream)(s), err
return (*stream)(s), parseError(err)
}
func (c *conn) yamux() *yamux.Session {

View File

@@ -1,11 +1,13 @@
package yamux
import (
"errors"
"fmt"
"time"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-yamux/v4"
"github.com/libp2p/go-yamux/v5"
)
// stream implements mux.MuxedStream over yamux.Stream.
@@ -13,22 +15,32 @@ type stream yamux.Stream
var _ network.MuxedStream = &stream{}
func (s *stream) Read(b []byte) (n int, err error) {
n, err = s.yamux().Read(b)
if err == yamux.ErrStreamReset {
err = network.ErrReset
func parseError(err error) error {
if err == nil {
return err
}
se := &yamux.StreamError{}
if errors.As(err, &se) {
return &network.StreamError{Remote: se.Remote, ErrorCode: network.StreamErrorCode(se.ErrorCode), TransportError: err}
}
ce := &yamux.GoAwayError{}
if errors.As(err, &ce) {
return &network.ConnError{Remote: ce.Remote, ErrorCode: network.ConnErrorCode(ce.ErrorCode), TransportError: err}
}
if errors.Is(err, yamux.ErrStreamReset) {
return fmt.Errorf("%w: %w", network.ErrReset, err)
}
return err
}
return n, err
func (s *stream) Read(b []byte) (n int, err error) {
n, err = s.yamux().Read(b)
return n, parseError(err)
}
func (s *stream) Write(b []byte) (n int, err error) {
n, err = s.yamux().Write(b)
if err == yamux.ErrStreamReset {
err = network.ErrReset
}
return n, err
return n, parseError(err)
}
func (s *stream) Close() error {
@@ -39,6 +51,10 @@ func (s *stream) Reset() error {
return s.yamux().Reset()
}
func (s *stream) ResetWithError(errCode network.StreamErrorCode) error {
return s.yamux().ResetWithError(uint32(errCode))
}
func (s *stream) CloseRead() error {
return s.yamux().CloseRead()
}

View File

@@ -7,7 +7,7 @@ import (
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-yamux/v4"
"github.com/libp2p/go-yamux/v5"
)
var DefaultTransport *Transport

View File

@@ -175,7 +175,8 @@ func (cm *BasicConnMgr) memoryEmergency() {
// Trim connections without paying attention to the silence period.
for _, c := range cm.getConnsToCloseEmergency(target) {
log.Infow("low on memory. closing conn", "peer", c.RemotePeer())
c.Close()
c.CloseWithError(network.ConnGarbageCollected)
}
// finally, update the last trim time.
@@ -388,7 +389,7 @@ func (cm *BasicConnMgr) trim() {
// do the actual trim.
for _, c := range cm.getConnsToClose() {
log.Debugw("closing conn", "peer", c.RemotePeer())
c.Close()
c.CloseWithError(network.ConnGarbageCollected)
}
}

View File

@@ -11,8 +11,11 @@ import (
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
tu "github.com/libp2p/go-libp2p/core/test"
swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing"
ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/require"
)
@@ -33,6 +36,14 @@ func (c *tconn) Close() error {
return nil
}
func (c *tconn) CloseWithError(code network.ConnErrorCode) error {
atomic.StoreUint32(&c.closed, 1)
if c.disconnectNotify != nil {
c.disconnectNotify(nil, c)
}
return nil
}
func (c *tconn) isClosed() bool {
return atomic.LoadUint32(&c.closed) == 1
}
@@ -794,6 +805,7 @@ type mockConn struct {
}
func (m mockConn) Close() error { panic("implement me") }
func (m mockConn) CloseWithError(errCode network.ConnErrorCode) error { panic("implement me") }
func (m mockConn) LocalPeer() peer.ID { panic("implement me") }
func (m mockConn) RemotePeer() peer.ID { panic("implement me") }
func (m mockConn) RemotePublicKey() crypto.PubKey { panic("implement me") }
@@ -986,3 +998,79 @@ type testLimitGetter struct {
func (g testLimitGetter) GetConnLimit() int {
return g.limit
}
func TestErrorCode(t *testing.T) {
sw1, sw2, sw3 := swarmt.GenSwarm(t), swarmt.GenSwarm(t), swarmt.GenSwarm(t)
defer sw1.Close()
defer sw2.Close()
defer sw3.Close()
cm, err := NewConnManager(1, 1, WithGracePeriod(0), WithSilencePeriod(10))
require.NoError(t, err)
defer cm.Close()
sw1.Peerstore().AddAddrs(sw2.LocalPeer(), sw2.ListenAddresses(), peerstore.PermanentAddrTTL)
sw1.Peerstore().AddAddrs(sw3.LocalPeer(), sw3.ListenAddresses(), peerstore.PermanentAddrTTL)
c12, err := sw1.DialPeer(context.Background(), sw2.LocalPeer())
require.NoError(t, err)
var c21 network.Conn
require.Eventually(t, func() bool {
conns := sw2.ConnsToPeer(sw1.LocalPeer())
if len(conns) == 0 {
return false
}
c21 = conns[0]
return true
}, 10*time.Second, 100*time.Millisecond)
c13, err := sw1.DialPeer(context.Background(), sw3.LocalPeer())
require.NoError(t, err)
var c31 network.Conn
require.Eventually(t, func() bool {
conns := sw3.ConnsToPeer(sw1.LocalPeer())
if len(conns) == 0 {
return false
}
c31 = conns[0]
return true
}, 10*time.Second, 100*time.Millisecond)
not := cm.Notifee()
not.Connected(sw1, c12)
not.Connected(sw1, c13)
cm.TrimOpenConns(context.Background())
require.True(t, c12.IsClosed() || c13.IsClosed())
var c, cr network.Conn
if c12.IsClosed() {
c = c12
require.Eventually(t, func() bool {
conns := sw2.ConnsToPeer(sw1.LocalPeer())
if len(conns) == 0 {
cr = c21
return true
}
return false
}, 5*time.Second, 100*time.Millisecond)
} else {
c = c13
require.Eventually(t, func() bool {
conns := sw3.ConnsToPeer(sw1.LocalPeer())
if len(conns) == 0 {
cr = c31
return true
}
return false
}, 5*time.Second, 100*time.Millisecond)
}
_, err = c.NewStream(context.Background())
require.ErrorIs(t, err, &network.ConnError{ErrorCode: network.ConnGarbageCollected, Remote: false})
_, err = cr.NewStream(context.Background())
require.ErrorIs(t, err, &network.ConnError{ErrorCode: network.ConnGarbageCollected, Remote: true})
}

View File

@@ -185,3 +185,7 @@ func (c *conn) Stat() network.ConnStats {
func (c *conn) Scope() network.ConnScope {
return &network.NullScope{}
}
func (c *conn) CloseWithError(_ network.ConnErrorCode) error {
return c.Close()
}

View File

@@ -144,6 +144,24 @@ func (s *stream) Reset() error {
return nil
}
// ResetWithError resets the stream. It ignores the provided error code.
// TODO: Implement error code support.
func (s *stream) ResetWithError(_ network.StreamErrorCode) error {
// Cancel any pending reads/writes with an error.
s.write.CloseWithError(network.ErrReset)
s.read.CloseWithError(network.ErrReset)
select {
case s.reset <- struct{}{}:
default:
}
<-s.closed
// No meaningful error case here.
return nil
}
func (s *stream) teardown() {
// at this point, no streams are writing.
s.conn.removeStream(s)

View File

@@ -385,8 +385,7 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn,
// If we do this in the Upgrader, we will not be able to do this.
if s.gater != nil {
if allow, _ := s.gater.InterceptUpgraded(c); !allow {
// TODO Send disconnect with reason here
err := tc.Close()
err := tc.CloseWithError(network.ConnGated)
if err != nil {
log.Warnf("failed to close connection with peer %s and addr %s; err: %s", p, addr, err)
}
@@ -845,6 +844,14 @@ func (c *connWithMetrics) Close() error {
return c.closeErr
}
func (c *connWithMetrics) CloseWithError(errCode network.ConnErrorCode) error {
c.once.Do(func() {
c.metricsTracer.ClosedConnection(c.dir, time.Since(c.opened), c.ConnState(), c.LocalMultiaddr())
c.closeErr = c.CapableConn.CloseWithError(errCode)
})
return c.closeErr
}
func (c *connWithMetrics) Stat() network.ConnStats {
if cs, ok := c.CapableConn.(network.ConnStat); ok {
return cs.Stat()

View File

@@ -58,11 +58,20 @@ func (c *Conn) ID() string {
// open notifications must finish before we can fire off the close
// notifications).
func (c *Conn) Close() error {
c.closeOnce.Do(c.doClose)
c.closeOnce.Do(func() {
c.doClose(0)
})
return c.err
}
func (c *Conn) doClose() {
func (c *Conn) CloseWithError(errCode network.ConnErrorCode) error {
c.closeOnce.Do(func() {
c.doClose(errCode)
})
return c.err
}
func (c *Conn) doClose(errCode network.ConnErrorCode) {
c.swarm.removeConn(c)
// Prevent new streams from opening.
@@ -71,7 +80,11 @@ func (c *Conn) doClose() {
c.streams.m = nil
c.streams.Unlock()
if errCode != 0 {
c.err = c.conn.CloseWithError(errCode)
} else {
c.err = c.conn.Close()
}
// Send the connectedness event after closing the connection.
// This ensures that both remote connection close and local connection
@@ -121,7 +134,7 @@ func (c *Conn) start() {
}
scope, err := c.swarm.ResourceManager().OpenStream(c.RemotePeer(), network.DirInbound)
if err != nil {
ts.Reset()
ts.ResetWithError(network.StreamResourceLimitExceeded)
continue
}
c.swarm.refs.Add(1)

View File

@@ -91,6 +91,12 @@ func (s *Stream) Reset() error {
return err
}
func (s *Stream) ResetWithError(errCode network.StreamErrorCode) error {
err := s.stream.ResetWithError(errCode)
s.closeAndRemoveStream()
return err
}
func (s *Stream) closeAndRemoveStream() {
s.closeMx.Lock()
defer s.closeMx.Unlock()

View File

@@ -538,7 +538,7 @@ func TestResourceManagerAcceptStream(t *testing.T) {
if err == nil {
_, err = str.Read([]byte{0})
}
require.EqualError(t, err, "stream reset")
require.ErrorContains(t, err, "stream reset")
}
func TestListenCloseCount(t *testing.T) {

View File

@@ -63,3 +63,8 @@ func (t *transportConn) ConnState() network.ConnectionState {
UsedEarlyMuxerNegotiation: t.usedEarlyMuxerNegotiation,
}
}
func (t *transportConn) CloseWithError(errCode network.ConnErrorCode) error {
defer t.scope.Done()
return t.MuxedConn.CloseWithError(errCode)
}

View File

@@ -162,7 +162,7 @@ func (l *listener) handleIncoming() {
// if we stop accepting connections for some reason,
// we'll eventually close all the open ones
// instead of hanging onto them.
conn.Close()
conn.CloseWithError(network.ConnRateLimited)
}
}()
}

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"crypto/rand"
"errors"
"fmt"
"io"
"testing"
@@ -267,12 +268,12 @@ func TestRelayLimitTime(t *testing.T) {
if n > 0 {
t.Fatalf("expected to write 0 bytes, wrote %d", n)
}
if err != network.ErrReset {
if !errors.Is(err, network.ErrReset) {
t.Fatalf("expected reset, but got %s", err)
}
err = <-rch
if err != network.ErrReset {
if !errors.Is(err, network.ErrReset) {
t.Fatalf("expected reset, but got %s", err)
}
}
@@ -300,7 +301,7 @@ func TestRelayLimitData(t *testing.T) {
}
n, err := s.Read(buf)
if err != network.ErrReset {
if !errors.Is(err, network.ErrReset) {
t.Fatalf("expected reset but got %s", err)
}
rch <- n

View File

@@ -36,6 +36,7 @@ import (
"go.uber.org/mock/gomock"
ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -866,3 +867,170 @@ func TestConnClosedWhenRemoteCloses(t *testing.T) {
})
}
}
func TestErrorCodes(t *testing.T) {
assertStreamErrors := func(s network.Stream, expectedError error) {
buf := make([]byte, 10)
_, err := s.Read(buf)
require.ErrorIs(t, err, expectedError)
_, err = s.Write(buf)
require.ErrorIs(t, err, expectedError)
}
for _, tc := range transportsToTest {
t.Run(tc.Name, func(t *testing.T) {
server := tc.HostGenerator(t, TransportTestCaseOpts{})
client := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true})
defer server.Close()
defer client.Close()
client.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL)
// setup stream handler
remoteStreamQ := make(chan network.Stream)
server.SetStreamHandler("/test", func(s network.Stream) {
b := make([]byte, 10)
n, err := s.Read(b)
if !assert.NoError(t, err) {
return
}
_, err = s.Write(b[:n])
if !assert.NoError(t, err) {
return
}
remoteStreamQ <- s
})
// pingPong writes and reads "hello" on the stream
pingPong := func(s network.Stream) {
buf := []byte("hello")
_, err := s.Write(buf)
require.NoError(t, err)
_, err = s.Read(buf)
require.NoError(t, err)
require.Equal(t, buf, []byte("hello"))
}
t.Run("StreamResetWithError", func(t *testing.T) {
if tc.Name == "WebTransport" {
t.Skipf("skipping: %s, not implemented", tc.Name)
return
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s, err := client.NewStream(ctx, server.ID(), "/test")
require.NoError(t, err)
pingPong(s)
remoteStream := <-remoteStreamQ
defer remoteStream.Reset()
err = s.ResetWithError(42)
require.NoError(t, err)
assertStreamErrors(s, &network.StreamError{
ErrorCode: 42,
Remote: false,
})
assertStreamErrors(remoteStream, &network.StreamError{
ErrorCode: 42,
Remote: true,
})
})
t.Run("StreamResetWithErrorByRemote", func(t *testing.T) {
if tc.Name == "WebTransport" {
t.Skipf("skipping: %s, not implemented", tc.Name)
return
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s, err := client.NewStream(ctx, server.ID(), "/test")
require.NoError(t, err)
pingPong(s)
remoteStream := <-remoteStreamQ
err = remoteStream.ResetWithError(42)
require.NoError(t, err)
assertStreamErrors(s, &network.StreamError{
ErrorCode: 42,
Remote: true,
})
assertStreamErrors(remoteStream, &network.StreamError{
ErrorCode: 42,
Remote: false,
})
})
t.Run("StreamResetByConnCloseWithError", func(t *testing.T) {
if tc.Name == "WebTransport" || tc.Name == "WebRTC" {
t.Skipf("skipping: %s, not implemented", tc.Name)
return
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s, err := client.NewStream(ctx, server.ID(), "/test")
require.NoError(t, err)
pingPong(s)
remoteStream := <-remoteStreamQ
defer remoteStream.Reset()
err = s.Conn().CloseWithError(42)
require.NoError(t, err)
assertStreamErrors(s, &network.ConnError{
ErrorCode: 42,
Remote: false,
})
assertStreamErrors(remoteStream, &network.ConnError{
ErrorCode: 42,
Remote: true,
})
})
t.Run("NewStreamErrorByConnCloseWithError", func(t *testing.T) {
if tc.Name == "WebTransport" || tc.Name == "WebRTC" {
t.Skipf("skipping: %s, not implemented", tc.Name)
return
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s, err := client.NewStream(ctx, server.ID(), "/test")
require.NoError(t, err)
pingPong(s)
err = s.Conn().CloseWithError(42)
require.NoError(t, err)
remoteStream := <-remoteStreamQ
defer remoteStream.Reset()
localErr := &network.ConnError{
ErrorCode: 42,
Remote: false,
}
remoteErr := &network.ConnError{
ErrorCode: 42,
Remote: true,
}
// assert these first to ensure that remote has closed the connection
assertStreamErrors(remoteStream, remoteErr)
_, err = s.Conn().NewStream(ctx)
require.ErrorIs(t, err, localErr)
_, err = remoteStream.Conn().NewStream(ctx)
require.ErrorIs(t, err, remoteErr)
})
})
}
}

View File

@@ -34,6 +34,13 @@ func (c *conn) Close() error {
return c.closeWithError(0, "")
}
// CloseWithError closes the connection
// It must be called even if the peer closed the connection in order for
// garbage collection to properly work in this package.
func (c *conn) CloseWithError(errCode network.ConnErrorCode) error {
return c.closeWithError(quic.ApplicationErrorCode(errCode), "")
}
func (c *conn) closeWithError(errCode quic.ApplicationErrorCode, errString string) error {
c.transport.removeConn(c.quicConn)
err := c.quicConn.CloseWithError(errCode, errString)
@@ -53,13 +60,19 @@ func (c *conn) allowWindowIncrease(size uint64) bool {
// OpenStream creates a new stream.
func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
qstr, err := c.quicConn.OpenStreamSync(ctx)
return &stream{Stream: qstr}, err
if err != nil {
return nil, parseStreamError(err)
}
return &stream{Stream: qstr}, nil
}
// AcceptStream accepts a stream opened by the other side.
func (c *conn) AcceptStream() (network.MuxedStream, error) {
qstr, err := c.quicConn.AcceptStream(context.Background())
return &stream{Stream: qstr}, err
if err != nil {
return nil, parseStreamError(err)
}
return &stream{Stream: qstr}, nil
}
// LocalPeer returns our peer ID

View File

@@ -270,6 +270,9 @@ func TestStreams(t *testing.T) {
t.Run(tc.Name, func(t *testing.T) {
testStreams(t, tc)
})
t.Run(tc.Name, func(t *testing.T) {
testStreamsErrorCode(t, tc)
})
}
}
@@ -305,6 +308,45 @@ func testStreams(t *testing.T, tc *connTestCase) {
require.Equal(t, data, []byte("foobar"))
}
func testStreamsErrorCode(t *testing.T, tc *connTestCase) {
serverID, serverKey := createPeer(t)
_, clientKey := createPeer(t)
serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil)
require.NoError(t, err)
defer serverTransport.(io.Closer).Close()
ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1")
defer ln.Close()
clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil)
require.NoError(t, err)
defer clientTransport.(io.Closer).Close()
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
require.NoError(t, err)
defer conn.Close()
serverConn, err := ln.Accept()
require.NoError(t, err)
defer serverConn.Close()
str, err := conn.OpenStream(context.Background())
require.NoError(t, err)
err = str.ResetWithError(42)
require.NoError(t, err)
sstr, err := serverConn.AcceptStream()
require.NoError(t, err)
_, err = io.ReadAll(sstr)
require.Error(t, err)
se := &network.StreamError{}
if errors.As(err, &se) {
require.Equal(t, se.ErrorCode, network.StreamErrorCode(42))
require.True(t, se.Remote)
} else {
t.Fatalf("expected error to be of network.StreamError type, got %T, %v", err, err)
}
}
func TestHandshakeFailPeerIDMismatch(t *testing.T) {
for _, tc := range connTestCases {
t.Run(tc.Name, func(t *testing.T) {

View File

@@ -11,7 +11,6 @@ import (
tpt "github.com/libp2p/go-libp2p/core/transport"
p2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
ma "github.com/multiformats/go-multiaddr"
"github.com/quic-go/quic-go"
)
@@ -54,12 +53,12 @@ func (l *listener) Accept() (tpt.CapableConn, error) {
c, err := l.wrapConn(qconn)
if err != nil {
log.Debugf("failed to setup connection: %s", err)
qconn.CloseWithError(1, "")
qconn.CloseWithError(quic.ApplicationErrorCode(network.ConnResourceLimitExceeded), "")
continue
}
l.transport.addConn(qconn, c)
if l.transport.gater != nil && !(l.transport.gater.InterceptAccept(c) && l.transport.gater.InterceptSecured(network.DirInbound, c.remotePeerID, c)) {
c.closeWithError(errorCodeConnectionGating, "connection gated")
c.closeWithError(quic.ApplicationErrorCode(network.ConnGated), "connection gated")
continue
}

View File

@@ -159,10 +159,11 @@ func TestCleanupConnWhenBlocked(t *testing.T) {
s.SetReadDeadline(time.Now().Add(10 * time.Second))
b := [1]byte{}
_, err = s.Read(b[:])
if err != nil && errors.As(err, &quicErr) {
connError := &network.ConnError{}
if err != nil && errors.As(err, &connError) {
// We hit our expected application error
return
}
t.Fatalf("expected application error, got %v", err)
t.Fatalf("expected network.ConnError, got %v", err)
}

View File

@@ -2,6 +2,7 @@ package libp2pquic
import (
"errors"
"math"
"github.com/libp2p/go-libp2p/core/network"
@@ -18,24 +19,49 @@ type stream struct {
var _ network.MuxedStream = &stream{}
func (s *stream) Read(b []byte) (n int, err error) {
var streamErr *quic.StreamError
n, err = s.Stream.Read(b)
if err != nil && errors.As(err, &streamErr) {
err = network.ErrReset
func parseStreamError(err error) error {
if err == nil {
return err
}
return n, err
se := &quic.StreamError{}
if errors.As(err, &se) {
var code network.StreamErrorCode
if se.ErrorCode > math.MaxUint32 {
code = network.StreamCodeOutOfRange
} else {
code = network.StreamErrorCode(se.ErrorCode)
}
err = &network.StreamError{
ErrorCode: code,
Remote: se.Remote,
TransportError: se,
}
}
ae := &quic.ApplicationError{}
if errors.As(err, &ae) {
var code network.ConnErrorCode
if ae.ErrorCode > math.MaxUint32 {
code = network.ConnCodeOutOfRange
} else {
code = network.ConnErrorCode(ae.ErrorCode)
}
err = &network.ConnError{
ErrorCode: code,
Remote: ae.Remote,
TransportError: ae,
}
}
return err
}
func (s *stream) Read(b []byte) (n int, err error) {
n, err = s.Stream.Read(b)
return n, parseStreamError(err)
}
func (s *stream) Write(b []byte) (n int, err error) {
var streamErr *quic.StreamError
n, err = s.Stream.Write(b)
if err != nil && errors.As(err, &streamErr) {
err = network.ErrReset
}
return n, err
return n, parseStreamError(err)
}
func (s *stream) Reset() error {
@@ -44,6 +70,12 @@ func (s *stream) Reset() error {
return nil
}
func (s *stream) ResetWithError(errCode network.StreamErrorCode) error {
s.Stream.CancelRead(quic.StreamErrorCode(errCode))
s.Stream.CancelWrite(quic.StreamErrorCode(errCode))
return nil
}
func (s *stream) Close() error {
s.Stream.CancelRead(reset)
return s.Stream.Close()

View File

@@ -34,8 +34,6 @@ var ErrHolePunching = errors.New("hole punching attempted; no active dial")
var HolePunchTimeout = 5 * time.Second
const errorCodeConnectionGating = 0x47415445 // GATE in ASCII
// The Transport implements the tpt.Transport interface for QUIC connections.
type transport struct {
privKey ic.PrivKey
@@ -169,7 +167,7 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee
remoteMultiaddr: raddr,
}
if t.gater != nil && !t.gater.InterceptSecured(network.DirOutbound, p, c) {
pconn.CloseWithError(errorCodeConnectionGating, "connection gated")
pconn.CloseWithError(quic.ApplicationErrorCode(network.ConnGated), "connection gated")
return nil, fmt.Errorf("secured connection gated")
}
t.addConn(pconn, c)

View File

@@ -3,6 +3,7 @@ package libp2pquic
import (
"sync"
"github.com/libp2p/go-libp2p/core/network"
tpt "github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
@@ -142,8 +143,8 @@ func (r *acceptLoopRunner) innerAccept(l *listener, expectedVersion quic.Version
select {
case ch <- acceptVal{conn: conn}:
default:
conn.CloseWithError(network.ConnRateLimited)
// accept queue filled up, drop the connection
conn.Close()
log.Warn("Accept queue filled. Dropping connection.")
}

View File

@@ -10,6 +10,7 @@ import (
"strings"
"sync"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/transport"
ma "github.com/multiformats/go-multiaddr"
"github.com/quic-go/quic-go"
@@ -212,7 +213,7 @@ func (l *listener) Close() error {
close(l.queue)
// drain the queue
for conn := range l.queue {
conn.CloseWithError(1, "closing")
conn.CloseWithError(quic.ApplicationErrorCode(network.ConnShutdown), "closing")
}
})
return nil

View File

@@ -132,6 +132,12 @@ func (c *connection) Close() error {
return nil
}
// CloseWithError closes the connection ignoring the error code. As there's no way to signal
// the remote peer on closing the underlying peerconnection, we ignore the error code.
func (c *connection) CloseWithError(_ network.ConnErrorCode) error {
return c.Close()
}
// closeWithError is used to Close the connection when the underlying DTLS connection fails
func (c *connection) closeWithError(err error) {
c.closeOnce.Do(func() {

View File

@@ -95,6 +95,7 @@ type Message struct {
state protoimpl.MessageState `protogen:"open.v1"`
Flag *Message_Flag `protobuf:"varint,1,opt,name=flag,enum=Message_Flag" json:"flag,omitempty"`
Message []byte `protobuf:"bytes,2,opt,name=message" json:"message,omitempty"`
ErrorCode *uint32 `protobuf:"varint,3,opt,name=errorCode" json:"errorCode,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -143,24 +144,32 @@ func (x *Message) GetMessage() []byte {
return nil
}
func (x *Message) GetErrorCode() uint32 {
if x != nil && x.ErrorCode != nil {
return *x.ErrorCode
}
return 0
}
var File_p2p_transport_webrtc_pb_message_proto protoreflect.FileDescriptor
var file_p2p_transport_webrtc_pb_message_proto_rawDesc = string([]byte{
0x0a, 0x25, 0x70, 0x32, 0x70, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f,
0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2f, 0x70, 0x62, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67,
0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x81, 0x01, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73,
0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x9f, 0x01, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73,
0x61, 0x67, 0x65, 0x12, 0x21, 0x0a, 0x04, 0x66, 0x6c, 0x61, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28,
0x0e, 0x32, 0x0d, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2e, 0x46, 0x6c, 0x61, 0x67,
0x52, 0x04, 0x66, 0x6c, 0x61, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67,
0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
0x22, 0x39, 0x0a, 0x04, 0x46, 0x6c, 0x61, 0x67, 0x12, 0x07, 0x0a, 0x03, 0x46, 0x49, 0x4e, 0x10,
0x00, 0x12, 0x10, 0x0a, 0x0c, 0x53, 0x54, 0x4f, 0x50, 0x5f, 0x53, 0x45, 0x4e, 0x44, 0x49, 0x4e,
0x47, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x52, 0x45, 0x53, 0x45, 0x54, 0x10, 0x02, 0x12, 0x0b,
0x0a, 0x07, 0x46, 0x49, 0x4e, 0x5f, 0x41, 0x43, 0x4b, 0x10, 0x03, 0x42, 0x35, 0x5a, 0x33, 0x67,
0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70,
0x2f, 0x67, 0x6f, 0x2d, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2f, 0x70, 0x32, 0x70, 0x2f, 0x74,
0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2f,
0x70, 0x62,
0x12, 0x1c, 0x0a, 0x09, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, 0x03, 0x20,
0x01, 0x28, 0x0d, 0x52, 0x09, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x22, 0x39,
0x0a, 0x04, 0x46, 0x6c, 0x61, 0x67, 0x12, 0x07, 0x0a, 0x03, 0x46, 0x49, 0x4e, 0x10, 0x00, 0x12,
0x10, 0x0a, 0x0c, 0x53, 0x54, 0x4f, 0x50, 0x5f, 0x53, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10,
0x01, 0x12, 0x09, 0x0a, 0x05, 0x52, 0x45, 0x53, 0x45, 0x54, 0x10, 0x02, 0x12, 0x0b, 0x0a, 0x07,
0x46, 0x49, 0x4e, 0x5f, 0x41, 0x43, 0x4b, 0x10, 0x03, 0x42, 0x35, 0x5a, 0x33, 0x67, 0x69, 0x74,
0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2f, 0x67,
0x6f, 0x2d, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2f, 0x70, 0x32, 0x70, 0x2f, 0x74, 0x72, 0x61,
0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2f, 0x70, 0x62,
})
var (

View File

@@ -21,4 +21,6 @@ message Message {
optional Flag flag=1;
optional bytes message = 2;
optional uint32 errorCode = 3;
}

View File

@@ -71,6 +71,7 @@ type stream struct {
// But we may need to read from reader for control messages from a different goroutine.
readerMx sync.Mutex
reader pbio.Reader
readError error
// this buffer is limited up to a single message. Reason we need it
// is because a reader might read a message midway, and so we need a
@@ -82,6 +83,7 @@ type stream struct {
writeStateChanged chan struct{}
sendState sendState
writeDeadline time.Time
writeError error
controlMessageReaderOnce sync.Once
// controlMessageReaderEndTime is the end time for reading FIN_ACK from the control
@@ -146,6 +148,10 @@ func (s *stream) Close() error {
}
func (s *stream) Reset() error {
return s.ResetWithError(0)
}
func (s *stream) ResetWithError(errCode network.StreamErrorCode) error {
s.mx.Lock()
isClosed := s.closeForShutdownErr != nil
s.mx.Unlock()
@@ -154,8 +160,8 @@ func (s *stream) Reset() error {
}
defer s.cleanup()
cancelWriteErr := s.cancelWrite()
closeReadErr := s.CloseRead()
cancelWriteErr := s.cancelWrite(errCode)
closeReadErr := s.closeRead(errCode, false)
s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour))
return errors.Join(closeReadErr, cancelWriteErr)
}
@@ -175,19 +181,20 @@ func (s *stream) SetDeadline(t time.Time) error {
return s.SetWriteDeadline(t)
}
// processIncomingFlag process the flag on an incoming message
// processIncomingFlag processes the flag(FIN/RST/etc) on msg.
// It needs to be called while the mutex is locked.
func (s *stream) processIncomingFlag(flag *pb.Message_Flag) {
if flag == nil {
func (s *stream) processIncomingFlag(msg *pb.Message) {
if msg.Flag == nil {
return
}
switch *flag {
switch msg.GetFlag() {
case pb.Message_STOP_SENDING:
// We must process STOP_SENDING after sending a FIN(sendStateDataSent). Remote peer
// may not send a FIN_ACK once it has sent a STOP_SENDING
if s.sendState == sendStateSending || s.sendState == sendStateDataSent {
s.sendState = sendStateReset
s.writeError = &network.StreamError{Remote: true, ErrorCode: network.StreamErrorCode(msg.GetErrorCode())}
}
s.notifyWriteStateChanged()
case pb.Message_FIN_ACK:
@@ -206,6 +213,11 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) {
case pb.Message_RESET:
if s.receiveState == receiveStateReceiving {
s.receiveState = receiveStateReset
s.readError = &network.StreamError{Remote: true, ErrorCode: network.StreamErrorCode(msg.GetErrorCode())}
}
if s.sendState == sendStateSending || s.sendState == sendStateDataSent {
s.sendState = sendStateReset
s.writeError = &network.StreamError{Remote: true, ErrorCode: network.StreamErrorCode(msg.GetErrorCode())}
}
s.spawnControlMessageReader()
}
@@ -235,7 +247,7 @@ func (s *stream) spawnControlMessageReader() {
s.readerMx.Unlock()
if s.nextMessage != nil {
s.processIncomingFlag(s.nextMessage.Flag)
s.processIncomingFlag(s.nextMessage)
s.nextMessage = nil
}
var msg pb.Message
@@ -266,7 +278,7 @@ func (s *stream) spawnControlMessageReader() {
}
return
}
s.processIncomingFlag(msg.Flag)
s.processIncomingFlag(&msg)
}
}()
})

View File

@@ -22,7 +22,7 @@ func (s *stream) Read(b []byte) (int, error) {
case receiveStateDataRead:
return 0, io.EOF
case receiveStateReset:
return 0, network.ErrReset
return 0, s.readError
}
if len(b) == 0 {
@@ -52,10 +52,11 @@ func (s *stream) Read(b []byte) (int, error) {
// datachannel. For these implementations a stream reset will be observed as an
// abrupt closing of the datachannel.
s.receiveState = receiveStateReset
return 0, network.ErrReset
s.readError = &network.StreamError{Remote: true}
return 0, s.readError
}
if s.receiveState == receiveStateReset {
return 0, network.ErrReset
return 0, s.readError
}
if s.receiveState == receiveStateDataRead {
return 0, io.EOF
@@ -73,7 +74,7 @@ func (s *stream) Read(b []byte) (int, error) {
}
// process flags on the message after reading all the data
s.processIncomingFlag(s.nextMessage.Flag)
s.processIncomingFlag(s.nextMessage)
s.nextMessage = nil
if s.closeForShutdownErr != nil {
return read, s.closeForShutdownErr
@@ -82,7 +83,7 @@ func (s *stream) Read(b []byte) (int, error) {
case receiveStateDataRead:
return read, io.EOF
case receiveStateReset:
return read, network.ErrReset
return read, s.readError
}
}
}
@@ -101,12 +102,18 @@ func (s *stream) setDataChannelReadDeadline(t time.Time) error {
}
func (s *stream) CloseRead() error {
return s.closeRead(0, false)
}
func (s *stream) closeRead(errCode network.StreamErrorCode, remote bool) error {
s.mx.Lock()
defer s.mx.Unlock()
var err error
if s.receiveState == receiveStateReceiving && s.closeForShutdownErr == nil {
err = s.writer.WriteMsg(&pb.Message{Flag: pb.Message_STOP_SENDING.Enum()})
code := uint32(errCode)
err = s.writer.WriteMsg(&pb.Message{Flag: pb.Message_STOP_SENDING.Enum(), ErrorCode: &code})
s.receiveState = receiveStateReset
s.readError = &network.StreamError{Remote: remote, ErrorCode: errCode}
}
s.spawnControlMessageReader()
return err

View File

@@ -24,7 +24,7 @@ func (s *stream) Write(b []byte) (int, error) {
}
switch s.sendState {
case sendStateReset:
return 0, network.ErrReset
return 0, s.writeError
case sendStateDataSent, sendStateDataReceived:
return 0, errWriteAfterClose
}
@@ -48,7 +48,7 @@ func (s *stream) Write(b []byte) (int, error) {
}
switch s.sendState {
case sendStateReset:
return n, network.ErrReset
return n, s.writeError
case sendStateDataSent, sendStateDataReceived:
return n, errWriteAfterClose
}
@@ -119,7 +119,7 @@ func (s *stream) availableSendSpace() int {
return availableSpace
}
func (s *stream) cancelWrite() error {
func (s *stream) cancelWrite(errCode network.StreamErrorCode) error {
s.mx.Lock()
defer s.mx.Unlock()
@@ -129,10 +129,12 @@ func (s *stream) cancelWrite() error {
return nil
}
s.sendState = sendStateReset
s.writeError = &network.StreamError{Remote: false, ErrorCode: errCode}
// Remove reference to this stream from data channel
s.dataChannel.OnBufferedAmountLow(nil)
s.notifyWriteStateChanged()
return s.writer.WriteMsg(&pb.Message{Flag: pb.Message_RESET.Enum()})
code := uint32(errCode)
return s.writer.WriteMsg(&pb.Message{Flag: pb.Message_RESET.Enum(), ErrorCode: &code})
}
func (s *stream) CloseWrite() error {

View File

@@ -78,6 +78,10 @@ func (c *conn) Close() error {
return err
}
func (c *conn) CloseWithError(_ network.ConnErrorCode) error {
return c.Close()
}
func (c *conn) IsClosed() bool { return c.session.Context().Err() != nil }
func (c *conn) Scope() network.ConnScope { return c.scope }
func (c *conn) Transport() tpt.Transport { return c.transport }

View File

@@ -56,6 +56,17 @@ func (s *stream) Reset() error {
return nil
}
// ResetWithError resets the stream ignoring the error code. Error codes aren't
// specified for WebTransport as the current implementation of WebTransport in
// browsers(https://www.ietf.org/archive/id/draft-kinnear-webtransport-http2-02.html)
// only supports 1 byte error codes. For more details, see
// https://github.com/libp2p/specs/blob/4eca305185c7aef219e936bef76c48b1ab0a8b43/error-codes/README.md?plain=1#L84
func (s *stream) ResetWithError(_ network.StreamErrorCode) error {
s.Stream.CancelRead(reset)
s.Stream.CancelWrite(reset)
return nil
}
func (s *stream) Close() error {
s.Stream.CancelRead(reset)
return s.Stream.Close()

View File

@@ -43,7 +43,7 @@ require (
github.com/libp2p/go-nat v0.2.0 // indirect
github.com/libp2p/go-netroute v0.2.2 // indirect
github.com/libp2p/go-reuseport v0.4.0 // indirect
github.com/libp2p/go-yamux/v4 v4.0.2 // indirect
github.com/libp2p/go-yamux/v5 v5.0.0 // indirect
github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/miekg/dns v1.1.63 // indirect

View File

@@ -149,8 +149,8 @@ github.com/libp2p/go-netroute v0.2.2 h1:Dejd8cQ47Qx2kRABg6lPwknU7+nBnFRpko45/fFP
github.com/libp2p/go-netroute v0.2.2/go.mod h1:Rntq6jUAH0l9Gg17w5bFGhcC9a+vk4KNXs6s7IljKYE=
github.com/libp2p/go-reuseport v0.4.0 h1:nR5KU7hD0WxXCJbmw7r2rhRYruNRl2koHw8fQscQm2s=
github.com/libp2p/go-reuseport v0.4.0/go.mod h1:ZtI03j/wO5hZVDFo2jKywN6bYKWLOy8Se6DrI2E1cLU=
github.com/libp2p/go-yamux/v4 v4.0.2 h1:nrLh89LN/LEiqcFiqdKDRHjGstN300C1269K/EX0CPU=
github.com/libp2p/go-yamux/v4 v4.0.2/go.mod h1:C808cCRgOs1iBwY4S71T5oxgMxgLmqUw56qh4AeBW2o=
github.com/libp2p/go-yamux/v5 v5.0.0 h1:2djUh96d3Jiac/JpGkKs4TO49YhsfLopAoryfPmf+Po=
github.com/libp2p/go-yamux/v5 v5.0.0/go.mod h1:en+3cdX51U0ZslwRdRLrvQsdayFt3TSUKvBGErzpWbU=
github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI=
github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd h1:br0buuQ854V8u83wA0rVZ8ttrq5CpaPZdvrK0LP2lOk=