Add SkipTLSHandshake small interface to CustomDialer

Signed-off-by: Waldemar Quevedo <wally@nats.io>

Co-authored-by: Piotr Piotrowski <piotr@synadia.com>
This commit is contained in:
Waldemar Quevedo
2022-11-29 12:29:28 -08:00
parent c9b2fd86f7
commit 65b787031b
4 changed files with 102 additions and 19 deletions

View File

@@ -17,6 +17,7 @@ import (
"context"
"fmt"
"log"
"net"
"time"
"github.com/nats-io/nats.go"
@@ -44,6 +45,40 @@ func ExampleConnect() {
nc.Close()
}
type skipTLSDialer struct {
dialer *net.Dialer
skipTLS bool
}
func (sd *skipTLSDialer) Dial(network, address string) (net.Conn, error) {
return sd.dialer.Dial(network, address)
}
func (sd *skipTLSDialer) SkipTLSHandshake() bool {
return sd.skipTLS
}
func ExampleCustomDialer() {
// Given the following CustomDialer implementation:
//
// type skipTLSDialer struct {
// dialer *net.Dialer
// skipTLS bool
// }
//
// func (sd *skipTLSDialer) Dial(network, address string) (net.Conn, error) {
// return sd.dialer.Dial(network, address)
// }
//
// func (sd *skipTLSDialer) SkipTLSHandshake() bool {
// return true
// }
//
sd := &skipTLSDialer{dialer: &net.Dialer{Timeout: 2 * time.Second}, skipTLS: true}
nc, _ := nats.Connect("demo.nats.io", nats.SetCustomDialer(sd))
defer nc.Close()
}
// This Example shows an asynchronous subscriber.
func ExampleConn_Subscribe() {
nc, _ := nats.Connect(nats.DefaultURL)

30
nats.go
View File

@@ -247,8 +247,9 @@ type asyncCallbacksHandler struct {
// Option is a function on the options for a connection.
type Option func(*Options) error
// CustomDialer can be used to specify any dialer, not necessarily
// a *net.Dialer.
// CustomDialer can be used to specify any dialer, not necessarily a
// *net.Dialer. A CustomDialer may also implement `SkipTLSHandshake() bool`
// in order to skip the TLS handshake in case not required.
type CustomDialer interface {
Dial(network, address string) (net.Conn, error)
}
@@ -303,10 +304,6 @@ type Options struct {
// transports.
TLSConfig *tls.Config
// SkipTLSWrapper does not upgrade the connection to TLS and is
// meant to be used if the custom dialer does handle TLS itself
SkipTLSWrapper bool
// AllowReconnect enables reconnection logic to be used when we
// encounter a disconnect from the current server.
AllowReconnect bool
@@ -1189,16 +1186,6 @@ func SetCustomDialer(dialer CustomDialer) Option {
}
}
// SetSkipTLSWrapper is an Option to be used with the CustomDialer which
// will not wrap the connection with TLS. Use it if the CustomDialer did
// already handle TLS
func SetSkipTLSWrapper(skip bool) Option {
return func(o *Options) error {
o.SkipTLSWrapper = skip
return nil
}
}
// UseOldRequestStyle is an Option to force usage of the old Request style.
func UseOldRequestStyle() Option {
return func(o *Options) error {
@@ -1906,11 +1893,18 @@ func (nc *Conn) createConn() (err error) {
return nil
}
type skipTLSDialer interface {
SkipTLSHandshake() bool
}
// makeTLSConn will wrap an existing Conn using TLS
func (nc *Conn) makeTLSConn() error {
if nc.Opts.SkipTLSWrapper {
if nc.Opts.CustomDialer != nil {
// we do nothing when asked to skip the TLS wrapper
return nil
sd, ok := nc.Opts.CustomDialer.(skipTLSDialer)
if ok && sd.SkipTLSHandshake() {
return nil
}
}
// Allow the user to configure their own tls.Config structure.
var tlsCopy *tls.Config

View File

@@ -31,7 +31,7 @@ import (
type (
// Service is an interface for sevice management.
// Service is an interface for service management.
// It exposes methods to stop/reset a service, as well as get information on a service.
Service interface {
ID() string

View File

@@ -868,6 +868,60 @@ func TestWSWithTLS(t *testing.T) {
}
}
type testSkipTLSDialer struct {
dialer *net.Dialer
skipTLS bool
}
func (sd *testSkipTLSDialer) Dial(network, address string) (net.Conn, error) {
return sd.dialer.Dial(network, address)
}
func (sd *testSkipTLSDialer) SkipTLSHandshake() bool {
return sd.skipTLS
}
func TestWSWithTLSCustomDialer(t *testing.T) {
sopts := testWSGetDefaultOptions(t, true)
s := RunServerWithOptions(sopts)
defer s.Shutdown()
sd := &testSkipTLSDialer{
dialer: &net.Dialer{
Timeout: 2 * time.Second,
},
skipTLS: true,
}
// Connect with CustomDialer that fails since TLSHandshake is disabled.
copts := make([]Option, 0)
copts = append(copts, Secure(&tls.Config{InsecureSkipVerify: true}))
copts = append(copts, SetCustomDialer(sd))
_, err := Connect(fmt.Sprintf("wss://localhost:%d", sopts.Websocket.Port), copts...)
if err == nil {
t.Fatalf("Expected error on connect: %v", err)
}
if err.Error() != `invalid websocket connection` {
t.Logf("Expected invalid websocket connection: %v", err)
}
// Retry with the dialer.
copts = make([]Option, 0)
sd = &testSkipTLSDialer{
dialer: &net.Dialer{
Timeout: 2 * time.Second,
},
skipTLS: false,
}
copts = append(copts, Secure(&tls.Config{InsecureSkipVerify: true}))
copts = append(copts, SetCustomDialer(sd))
nc, err := Connect(fmt.Sprintf("wss://localhost:%d", sopts.Websocket.Port), copts...)
if err != nil {
t.Fatalf("Unexpected error on connect: %v", err)
}
defer nc.Close()
}
func TestWSTlsNoConfig(t *testing.T) {
opts := GetDefaultOptions()
opts.Servers = []string{"wss://localhost:443"}