core: add ErrPeerIDMismatch error type to replace ad-hoc errors (#2451)

* feat: add ErrPeerIDMismatch error type to replace ad-hoc errors

* test: add tests demonstrating the ability to discover a peer's peerID during security negotiation

* noise: add tests for ErrPeerIDMismatch

* tls: add error assertions for ErrPeerIDMismatch

---------

Co-authored-by: Marten Seemann <martenseemann@gmail.com>
This commit is contained in:
Adin Schmahmann
2023-08-24 22:44:23 -04:00
committed by GitHub
parent fea268babb
commit 4005fe67fa
6 changed files with 109 additions and 6 deletions

View File

@@ -3,6 +3,7 @@ package sec
import (
"context"
"fmt"
"net"
"github.com/libp2p/go-libp2p/core/network"
@@ -29,3 +30,14 @@ type SecureTransport interface {
// ID is the protocol ID of the security protocol.
ID() protocol.ID
}
type ErrPeerIDMismatch struct {
Expected peer.ID
Actual peer.ID
}
func (e ErrPeerIDMismatch) Error() string {
return fmt.Sprintf("peer id mismatch: expected %s, but remote key matches %s", e.Expected, e.Actual)
}
var _ error = (*ErrPeerIDMismatch)(nil)

View File

@@ -12,6 +12,7 @@ import (
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/libp2p/go-libp2p/internal/sha256"
"github.com/libp2p/go-libp2p/p2p/security/noise/pb"
@@ -276,7 +277,7 @@ func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStati
// check the peer ID if enabled
if s.checkPeerID && s.remoteID != id {
return nil, fmt.Errorf("peer id mismatch: expected %s, but remote key matches %s", s.remoteID.Pretty(), id.Pretty())
return nil, sec.ErrPeerIDMismatch{Expected: s.remoteID, Actual: id}
}
// verify payload is signed by asserted remote libp2p key.

View File

@@ -212,7 +212,10 @@ func TestPeerIDMismatchOutboundFailsHandshake(t *testing.T) {
initErr := <-errChan
require.Error(t, initErr, "expected initiator to fail with peer ID mismatch error")
require.Contains(t, initErr.Error(), "but remote key matches")
var mismatchErr sec.ErrPeerIDMismatch
require.ErrorAs(t, initErr, &mismatchErr)
require.Equal(t, peer.ID("a-random-peer-id"), mismatchErr.Expected)
require.Equal(t, respTransport.localID, mismatchErr.Actual)
}
func TestPeerIDMismatchInboundFailsHandshake(t *testing.T) {
@@ -231,6 +234,10 @@ func TestPeerIDMismatchInboundFailsHandshake(t *testing.T) {
_, err := respTransport.SecureInbound(context.Background(), resp, "a-random-peer-id")
require.Error(t, err, "expected responder to fail with peer ID mismatch error")
var mismatchErr sec.ErrPeerIDMismatch
require.ErrorAs(t, err, &mismatchErr)
require.Equal(t, peer.ID("a-random-peer-id"), mismatchErr.Expected)
require.Equal(t, initTransport.localID, mismatchErr.Actual)
<-done
}

View File

@@ -18,6 +18,7 @@ import (
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/sec"
)
const certValidityPeriod = 100 * 365 * 24 * time.Hour // ~100 years
@@ -129,7 +130,7 @@ func (i *Identity) ConfigForPeer(remote peer.ID) (*tls.Config, <-chan ic.PubKey)
if err != nil {
peerID = peer.ID(fmt.Sprintf("(not determined: %s)", err.Error()))
}
return fmt.Errorf("peer IDs don't match: expected %s, got %s", remote, peerID)
return sec.ErrPeerIDMismatch{Expected: remote, Actual: peerID}
}
keyCh <- pubKey
return nil

View File

@@ -376,7 +376,10 @@ func TestPeerIDMismatch(t *testing.T) {
thirdPartyID, _ := createPeer(t)
_, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, thirdPartyID)
require.Error(t, err)
require.Contains(t, err.Error(), "peer IDs don't match")
var mismatchErr sec.ErrPeerIDMismatch
require.ErrorAs(t, err, &mismatchErr)
require.Equal(t, thirdPartyID, mismatchErr.Expected)
require.Equal(t, serverID, mismatchErr.Actual)
var serverErr error
select {
@@ -392,8 +395,8 @@ func TestPeerIDMismatch(t *testing.T) {
clientInsecureConn, serverInsecureConn := connect(t)
errChan := make(chan error)
thirdPartyID, _ := createPeer(t)
go func() {
thirdPartyID, _ := createPeer(t)
// expect the wrong peer ID
_, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, thirdPartyID)
errChan <- err
@@ -412,7 +415,10 @@ func TestPeerIDMismatch(t *testing.T) {
t.Fatal("expected handshake to return on the server side")
}
require.Error(t, serverErr)
require.Contains(t, serverErr.Error(), "peer IDs don't match")
var mismatchErr sec.ErrPeerIDMismatch
require.ErrorAs(t, serverErr, &mismatchErr)
require.Equal(t, thirdPartyID, mismatchErr.Expected)
require.Equal(t, clientTransport.localPeer, mismatchErr.Actual)
})
}

View File

@@ -20,11 +20,14 @@ import (
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/sec"
rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager"
"github.com/libp2p/go-libp2p/p2p/muxer/yamux"
"github.com/libp2p/go-libp2p/p2p/net/swarm"
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
"github.com/libp2p/go-libp2p/p2p/security/noise"
tls "github.com/libp2p/go-libp2p/p2p/security/tls"
"github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/require"
)
@@ -607,3 +610,76 @@ func TestStreamReadDeadline(t *testing.T) {
})
}
}
func TestDiscoverPeerIDFromSecurityNegotiation(t *testing.T) {
// extracts the peerID of the dialed peer from the error
extractPeerIDFromError := func(inputErr error) (peer.ID, error) {
var dialErr *swarm.DialError
if !errors.As(inputErr, &dialErr) {
return "", inputErr
}
innerErr := dialErr.DialErrors[0].Cause
var peerIDMismatchErr sec.ErrPeerIDMismatch
if errors.As(innerErr, &peerIDMismatchErr) {
return peerIDMismatchErr.Actual, nil
}
return "", inputErr
}
// runs a test to verify we can extract the peer ID from a target with just its address
runTest := func(t *testing.T, h host.Host) {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Use a bogus peer ID so that when we connect to the target we get an error telling
// us the targets real peer ID
bogusPeerId, err := peer.Decode("QmadAdJ3f63JyNs65X7HHzqDwV53ynvCcKtNFvdNaz3nhk")
if err != nil {
t.Fatal("the hard coded bogus peerID is invalid")
}
ai := &peer.AddrInfo{
ID: bogusPeerId,
Addrs: []multiaddr.Multiaddr{h.Addrs()[0]},
}
testHost, err := libp2p.New()
if err != nil {
t.Fatal(err)
}
// Try connecting with the bogus peer ID
if err := testHost.Connect(ctx, *ai); err != nil {
// Extract the actual peer ID from the error
newPeerId, err := extractPeerIDFromError(err)
if err != nil {
t.Fatal(err)
}
ai.ID = newPeerId
// Make sure the new ID is what we expected
if ai.ID != h.ID() {
t.Fatalf("peerID mismatch: expected %s, got %s", h.ID(), ai.ID)
}
// and just to double-check try connecting again to make sure it works
if err := testHost.Connect(ctx, *ai); err != nil {
t.Fatal(err)
}
} else {
t.Fatal("somehow we successfully connected to a bogus peerID!")
}
}
for _, tc := range transportsToTest {
t.Run(tc.Name, func(t *testing.T) {
h := tc.HostGenerator(t, TransportTestCaseOpts{})
defer h.Close()
runTest(t, h)
})
}
}