Add DTLS Handshake hooks to SettingEngine

This commit is contained in:
theodorsm
2024-08-17 14:10:14 +02:00
committed by Sean DuBois
parent 4a97b7d67e
commit 64a837f688
3 changed files with 67 additions and 11 deletions

View File

@@ -355,6 +355,9 @@ func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error { //nolint:
dtlsConfig.ClientCAs = t.api.settingEngine.dtls.clientCAs dtlsConfig.ClientCAs = t.api.settingEngine.dtls.clientCAs
dtlsConfig.RootCAs = t.api.settingEngine.dtls.rootCAs dtlsConfig.RootCAs = t.api.settingEngine.dtls.rootCAs
dtlsConfig.KeyLogWriter = t.api.settingEngine.dtls.keyLogWriter dtlsConfig.KeyLogWriter = t.api.settingEngine.dtls.keyLogWriter
dtlsConfig.ClientHelloMessageHook = t.api.settingEngine.dtls.clientHelloMessageHook
dtlsConfig.ServerHelloMessageHook = t.api.settingEngine.dtls.serverHelloMessageHook
dtlsConfig.CertificateRequestMessageHook = t.api.settingEngine.dtls.certificateRequestMessageHook
// Connect as DTLS Client/Server, function is blocking and we // Connect as DTLS Client/Server, function is blocking and we
// must not hold the DTLSTransport lock // must not hold the DTLSTransport lock

View File

@@ -15,6 +15,7 @@ import (
"github.com/pion/dtls/v3" "github.com/pion/dtls/v3"
dtlsElliptic "github.com/pion/dtls/v3/pkg/crypto/elliptic" dtlsElliptic "github.com/pion/dtls/v3/pkg/crypto/elliptic"
"github.com/pion/dtls/v3/pkg/protocol/handshake"
"github.com/pion/ice/v4" "github.com/pion/ice/v4"
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/stun/v3" "github.com/pion/stun/v3"
@@ -63,17 +64,20 @@ type SettingEngine struct {
SRTCP *uint SRTCP *uint
} }
dtls struct { dtls struct {
insecureSkipHelloVerify bool insecureSkipHelloVerify bool
disableInsecureSkipVerify bool disableInsecureSkipVerify bool
retransmissionInterval time.Duration retransmissionInterval time.Duration
ellipticCurves []dtlsElliptic.Curve ellipticCurves []dtlsElliptic.Curve
connectContextMaker func() (context.Context, func()) connectContextMaker func() (context.Context, func())
extendedMasterSecret dtls.ExtendedMasterSecretType extendedMasterSecret dtls.ExtendedMasterSecretType
clientAuth *dtls.ClientAuthType clientAuth *dtls.ClientAuthType
clientCAs *x509.CertPool clientCAs *x509.CertPool
rootCAs *x509.CertPool rootCAs *x509.CertPool
keyLogWriter io.Writer keyLogWriter io.Writer
customCipherSuites func() []dtls.CipherSuite customCipherSuites func() []dtls.CipherSuite
clientHelloMessageHook func(handshake.MessageClientHello) handshake.Message
serverHelloMessageHook func(handshake.MessageServerHello) handshake.Message
certificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message
} }
sctp struct { sctp struct {
maxReceiveBufferSize uint32 maxReceiveBufferSize uint32
@@ -455,6 +459,24 @@ func (e *SettingEngine) SetDTLSCustomerCipherSuites(customCipherSuites func() []
e.dtls.customCipherSuites = customCipherSuites e.dtls.customCipherSuites = customCipherSuites
} }
// SetDTLSClientHelloMessageHook if not nil, is called when a DTLS Client Hello message is sent
// from a client. The returned handshake message replaces the original message.
func (e *SettingEngine) SetDTLSClientHelloMessageHook(hook func(handshake.MessageClientHello) handshake.Message) {
e.dtls.clientHelloMessageHook = hook
}
// SetDTLSServerHelloMessageHook if not nil, is called when a DTLS Server Hello message is sent
// from a client. The returned handshake message replaces the original message.
func (e *SettingEngine) SetDTLSServerHelloMessageHook(hook func(handshake.MessageServerHello) handshake.Message) {
e.dtls.serverHelloMessageHook = hook
}
// SetDTLSCertificateRequestMessageHook if not nil, is called when a DTLS Certificate Request message is sent
// from a client. The returned handshake message replaces the original message.
func (e *SettingEngine) SetDTLSCertificateRequestMessageHook(hook func(handshake.MessageCertificateRequest) handshake.Message) {
e.dtls.certificateRequestMessageHook = hook
}
// SetSCTPRTOMax sets the maximum retransmission timeout. // SetSCTPRTOMax sets the maximum retransmission timeout.
// Leave this 0 for the default timeout. // Leave this 0 for the default timeout.
func (e *SettingEngine) SetSCTPRTOMax(rtoMax time.Duration) { func (e *SettingEngine) SetSCTPRTOMax(rtoMax time.Duration) {

View File

@@ -13,6 +13,7 @@ import (
"time" "time"
"github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/crypto/elliptic"
"github.com/pion/dtls/v3/pkg/protocol/handshake"
"github.com/pion/ice/v4" "github.com/pion/ice/v4"
"github.com/pion/stun/v3" "github.com/pion/stun/v3"
"github.com/pion/transport/v3/test" "github.com/pion/transport/v3/test"
@@ -309,3 +310,33 @@ func TestSetICEBindingRequestHandler(t *testing.T) {
<-seenICEControlling.Done() <-seenICEControlling.Done()
closePairNow(t, pcOffer, pcAnswer) closePairNow(t, pcOffer, pcAnswer)
} }
func TestSetHooks(t *testing.T) {
s := SettingEngine{}
if s.dtls.clientHelloMessageHook != nil ||
s.dtls.serverHelloMessageHook != nil ||
s.dtls.certificateRequestMessageHook != nil {
t.Fatalf("SettingEngine defaults aren't as expected.")
}
s.SetDTLSClientHelloMessageHook(func(msg handshake.MessageClientHello) handshake.Message {
return &msg
})
s.SetDTLSServerHelloMessageHook(func(msg handshake.MessageServerHello) handshake.Message {
return &msg
})
s.SetDTLSCertificateRequestMessageHook(func(msg handshake.MessageCertificateRequest) handshake.Message {
return &msg
})
if s.dtls.clientHelloMessageHook == nil {
t.Errorf("Failed to set DTLS Client Hello Hook")
}
if s.dtls.serverHelloMessageHook == nil {
t.Errorf("Failed to set DTLS Server Hello Hook")
}
if s.dtls.certificateRequestMessageHook == nil {
t.Errorf("Failed to set DTLS Certificate Request Hook")
}
}