mirror of
https://github.com/libp2p/go-libp2p.git
synced 2025-09-26 20:21:26 +08:00
core: add ErrPeerIDMismatch error type to replace ad-hoc errors (#2451)
* feat: add ErrPeerIDMismatch error type to replace ad-hoc errors * test: add tests demonstrating the ability to discover a peer's peerID during security negotiation * noise: add tests for ErrPeerIDMismatch * tls: add error assertions for ErrPeerIDMismatch --------- Co-authored-by: Marten Seemann <martenseemann@gmail.com>
This commit is contained in:
@@ -3,6 +3,7 @@ package sec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
@@ -29,3 +30,14 @@ type SecureTransport interface {
|
||||
// ID is the protocol ID of the security protocol.
|
||||
ID() protocol.ID
|
||||
}
|
||||
|
||||
type ErrPeerIDMismatch struct {
|
||||
Expected peer.ID
|
||||
Actual peer.ID
|
||||
}
|
||||
|
||||
func (e ErrPeerIDMismatch) Error() string {
|
||||
return fmt.Sprintf("peer id mismatch: expected %s, but remote key matches %s", e.Expected, e.Actual)
|
||||
}
|
||||
|
||||
var _ error = (*ErrPeerIDMismatch)(nil)
|
||||
|
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/crypto"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/core/sec"
|
||||
"github.com/libp2p/go-libp2p/internal/sha256"
|
||||
"github.com/libp2p/go-libp2p/p2p/security/noise/pb"
|
||||
|
||||
@@ -276,7 +277,7 @@ func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStati
|
||||
|
||||
// check the peer ID if enabled
|
||||
if s.checkPeerID && s.remoteID != id {
|
||||
return nil, fmt.Errorf("peer id mismatch: expected %s, but remote key matches %s", s.remoteID.Pretty(), id.Pretty())
|
||||
return nil, sec.ErrPeerIDMismatch{Expected: s.remoteID, Actual: id}
|
||||
}
|
||||
|
||||
// verify payload is signed by asserted remote libp2p key.
|
||||
|
@@ -212,7 +212,10 @@ func TestPeerIDMismatchOutboundFailsHandshake(t *testing.T) {
|
||||
|
||||
initErr := <-errChan
|
||||
require.Error(t, initErr, "expected initiator to fail with peer ID mismatch error")
|
||||
require.Contains(t, initErr.Error(), "but remote key matches")
|
||||
var mismatchErr sec.ErrPeerIDMismatch
|
||||
require.ErrorAs(t, initErr, &mismatchErr)
|
||||
require.Equal(t, peer.ID("a-random-peer-id"), mismatchErr.Expected)
|
||||
require.Equal(t, respTransport.localID, mismatchErr.Actual)
|
||||
}
|
||||
|
||||
func TestPeerIDMismatchInboundFailsHandshake(t *testing.T) {
|
||||
@@ -231,6 +234,10 @@ func TestPeerIDMismatchInboundFailsHandshake(t *testing.T) {
|
||||
|
||||
_, err := respTransport.SecureInbound(context.Background(), resp, "a-random-peer-id")
|
||||
require.Error(t, err, "expected responder to fail with peer ID mismatch error")
|
||||
var mismatchErr sec.ErrPeerIDMismatch
|
||||
require.ErrorAs(t, err, &mismatchErr)
|
||||
require.Equal(t, peer.ID("a-random-peer-id"), mismatchErr.Expected)
|
||||
require.Equal(t, initTransport.localID, mismatchErr.Actual)
|
||||
<-done
|
||||
}
|
||||
|
||||
|
@@ -18,6 +18,7 @@ import (
|
||||
|
||||
ic "github.com/libp2p/go-libp2p/core/crypto"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/core/sec"
|
||||
)
|
||||
|
||||
const certValidityPeriod = 100 * 365 * 24 * time.Hour // ~100 years
|
||||
@@ -129,7 +130,7 @@ func (i *Identity) ConfigForPeer(remote peer.ID) (*tls.Config, <-chan ic.PubKey)
|
||||
if err != nil {
|
||||
peerID = peer.ID(fmt.Sprintf("(not determined: %s)", err.Error()))
|
||||
}
|
||||
return fmt.Errorf("peer IDs don't match: expected %s, got %s", remote, peerID)
|
||||
return sec.ErrPeerIDMismatch{Expected: remote, Actual: peerID}
|
||||
}
|
||||
keyCh <- pubKey
|
||||
return nil
|
||||
|
@@ -376,7 +376,10 @@ func TestPeerIDMismatch(t *testing.T) {
|
||||
thirdPartyID, _ := createPeer(t)
|
||||
_, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, thirdPartyID)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "peer IDs don't match")
|
||||
var mismatchErr sec.ErrPeerIDMismatch
|
||||
require.ErrorAs(t, err, &mismatchErr)
|
||||
require.Equal(t, thirdPartyID, mismatchErr.Expected)
|
||||
require.Equal(t, serverID, mismatchErr.Actual)
|
||||
|
||||
var serverErr error
|
||||
select {
|
||||
@@ -392,8 +395,8 @@ func TestPeerIDMismatch(t *testing.T) {
|
||||
clientInsecureConn, serverInsecureConn := connect(t)
|
||||
|
||||
errChan := make(chan error)
|
||||
thirdPartyID, _ := createPeer(t)
|
||||
go func() {
|
||||
thirdPartyID, _ := createPeer(t)
|
||||
// expect the wrong peer ID
|
||||
_, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, thirdPartyID)
|
||||
errChan <- err
|
||||
@@ -412,7 +415,10 @@ func TestPeerIDMismatch(t *testing.T) {
|
||||
t.Fatal("expected handshake to return on the server side")
|
||||
}
|
||||
require.Error(t, serverErr)
|
||||
require.Contains(t, serverErr.Error(), "peer IDs don't match")
|
||||
var mismatchErr sec.ErrPeerIDMismatch
|
||||
require.ErrorAs(t, serverErr, &mismatchErr)
|
||||
require.Equal(t, thirdPartyID, mismatchErr.Expected)
|
||||
require.Equal(t, clientTransport.localPeer, mismatchErr.Actual)
|
||||
})
|
||||
}
|
||||
|
||||
|
@@ -20,11 +20,14 @@ import (
|
||||
"github.com/libp2p/go-libp2p/core/host"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/core/sec"
|
||||
rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager"
|
||||
"github.com/libp2p/go-libp2p/p2p/muxer/yamux"
|
||||
"github.com/libp2p/go-libp2p/p2p/net/swarm"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
|
||||
"github.com/libp2p/go-libp2p/p2p/security/noise"
|
||||
tls "github.com/libp2p/go-libp2p/p2p/security/tls"
|
||||
"github.com/multiformats/go-multiaddr"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -607,3 +610,76 @@ func TestStreamReadDeadline(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiscoverPeerIDFromSecurityNegotiation(t *testing.T) {
|
||||
// extracts the peerID of the dialed peer from the error
|
||||
extractPeerIDFromError := func(inputErr error) (peer.ID, error) {
|
||||
var dialErr *swarm.DialError
|
||||
if !errors.As(inputErr, &dialErr) {
|
||||
return "", inputErr
|
||||
}
|
||||
innerErr := dialErr.DialErrors[0].Cause
|
||||
|
||||
var peerIDMismatchErr sec.ErrPeerIDMismatch
|
||||
if errors.As(innerErr, &peerIDMismatchErr) {
|
||||
return peerIDMismatchErr.Actual, nil
|
||||
}
|
||||
|
||||
return "", inputErr
|
||||
}
|
||||
|
||||
// runs a test to verify we can extract the peer ID from a target with just its address
|
||||
runTest := func(t *testing.T, h host.Host) {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Use a bogus peer ID so that when we connect to the target we get an error telling
|
||||
// us the targets real peer ID
|
||||
bogusPeerId, err := peer.Decode("QmadAdJ3f63JyNs65X7HHzqDwV53ynvCcKtNFvdNaz3nhk")
|
||||
if err != nil {
|
||||
t.Fatal("the hard coded bogus peerID is invalid")
|
||||
}
|
||||
|
||||
ai := &peer.AddrInfo{
|
||||
ID: bogusPeerId,
|
||||
Addrs: []multiaddr.Multiaddr{h.Addrs()[0]},
|
||||
}
|
||||
|
||||
testHost, err := libp2p.New()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Try connecting with the bogus peer ID
|
||||
if err := testHost.Connect(ctx, *ai); err != nil {
|
||||
// Extract the actual peer ID from the error
|
||||
newPeerId, err := extractPeerIDFromError(err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ai.ID = newPeerId
|
||||
|
||||
// Make sure the new ID is what we expected
|
||||
if ai.ID != h.ID() {
|
||||
t.Fatalf("peerID mismatch: expected %s, got %s", h.ID(), ai.ID)
|
||||
}
|
||||
|
||||
// and just to double-check try connecting again to make sure it works
|
||||
if err := testHost.Connect(ctx, *ai); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Fatal("somehow we successfully connected to a bogus peerID!")
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range transportsToTest {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
h := tc.HostGenerator(t, TransportTestCaseOpts{})
|
||||
defer h.Close()
|
||||
|
||||
runTest(t, h)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user