Add SkipTLSHandshake small interface to CustomDialer

Signed-off-by: Waldemar Quevedo <wally@nats.io>

Co-authored-by: Piotr Piotrowski <piotr@synadia.com>
This commit is contained in:
Waldemar Quevedo
2022-11-29 12:29:28 -08:00
parent c9b2fd86f7
commit 65b787031b
4 changed files with 102 additions and 19 deletions

View File

@@ -17,6 +17,7 @@ import (
"context" "context"
"fmt" "fmt"
"log" "log"
"net"
"time" "time"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
@@ -44,6 +45,40 @@ func ExampleConnect() {
nc.Close() nc.Close()
} }
type skipTLSDialer struct {
dialer *net.Dialer
skipTLS bool
}
func (sd *skipTLSDialer) Dial(network, address string) (net.Conn, error) {
return sd.dialer.Dial(network, address)
}
func (sd *skipTLSDialer) SkipTLSHandshake() bool {
return sd.skipTLS
}
func ExampleCustomDialer() {
// Given the following CustomDialer implementation:
//
// type skipTLSDialer struct {
// dialer *net.Dialer
// skipTLS bool
// }
//
// func (sd *skipTLSDialer) Dial(network, address string) (net.Conn, error) {
// return sd.dialer.Dial(network, address)
// }
//
// func (sd *skipTLSDialer) SkipTLSHandshake() bool {
// return true
// }
//
sd := &skipTLSDialer{dialer: &net.Dialer{Timeout: 2 * time.Second}, skipTLS: true}
nc, _ := nats.Connect("demo.nats.io", nats.SetCustomDialer(sd))
defer nc.Close()
}
// This Example shows an asynchronous subscriber. // This Example shows an asynchronous subscriber.
func ExampleConn_Subscribe() { func ExampleConn_Subscribe() {
nc, _ := nats.Connect(nats.DefaultURL) nc, _ := nats.Connect(nats.DefaultURL)

28
nats.go
View File

@@ -247,8 +247,9 @@ type asyncCallbacksHandler struct {
// Option is a function on the options for a connection. // Option is a function on the options for a connection.
type Option func(*Options) error type Option func(*Options) error
// CustomDialer can be used to specify any dialer, not necessarily // CustomDialer can be used to specify any dialer, not necessarily a
// a *net.Dialer. // *net.Dialer. A CustomDialer may also implement `SkipTLSHandshake() bool`
// in order to skip the TLS handshake in case not required.
type CustomDialer interface { type CustomDialer interface {
Dial(network, address string) (net.Conn, error) Dial(network, address string) (net.Conn, error)
} }
@@ -303,10 +304,6 @@ type Options struct {
// transports. // transports.
TLSConfig *tls.Config TLSConfig *tls.Config
// SkipTLSWrapper does not upgrade the connection to TLS and is
// meant to be used if the custom dialer does handle TLS itself
SkipTLSWrapper bool
// AllowReconnect enables reconnection logic to be used when we // AllowReconnect enables reconnection logic to be used when we
// encounter a disconnect from the current server. // encounter a disconnect from the current server.
AllowReconnect bool AllowReconnect bool
@@ -1189,16 +1186,6 @@ func SetCustomDialer(dialer CustomDialer) Option {
} }
} }
// SetSkipTLSWrapper is an Option to be used with the CustomDialer which
// will not wrap the connection with TLS. Use it if the CustomDialer did
// already handle TLS
func SetSkipTLSWrapper(skip bool) Option {
return func(o *Options) error {
o.SkipTLSWrapper = skip
return nil
}
}
// UseOldRequestStyle is an Option to force usage of the old Request style. // UseOldRequestStyle is an Option to force usage of the old Request style.
func UseOldRequestStyle() Option { func UseOldRequestStyle() Option {
return func(o *Options) error { return func(o *Options) error {
@@ -1906,12 +1893,19 @@ func (nc *Conn) createConn() (err error) {
return nil return nil
} }
type skipTLSDialer interface {
SkipTLSHandshake() bool
}
// makeTLSConn will wrap an existing Conn using TLS // makeTLSConn will wrap an existing Conn using TLS
func (nc *Conn) makeTLSConn() error { func (nc *Conn) makeTLSConn() error {
if nc.Opts.SkipTLSWrapper { if nc.Opts.CustomDialer != nil {
// we do nothing when asked to skip the TLS wrapper // we do nothing when asked to skip the TLS wrapper
sd, ok := nc.Opts.CustomDialer.(skipTLSDialer)
if ok && sd.SkipTLSHandshake() {
return nil return nil
} }
}
// Allow the user to configure their own tls.Config structure. // Allow the user to configure their own tls.Config structure.
var tlsCopy *tls.Config var tlsCopy *tls.Config
if nc.Opts.TLSConfig != nil { if nc.Opts.TLSConfig != nil {

View File

@@ -31,7 +31,7 @@ import (
type ( type (
// Service is an interface for sevice management. // Service is an interface for service management.
// It exposes methods to stop/reset a service, as well as get information on a service. // It exposes methods to stop/reset a service, as well as get information on a service.
Service interface { Service interface {
ID() string ID() string

View File

@@ -868,6 +868,60 @@ func TestWSWithTLS(t *testing.T) {
} }
} }
type testSkipTLSDialer struct {
dialer *net.Dialer
skipTLS bool
}
func (sd *testSkipTLSDialer) Dial(network, address string) (net.Conn, error) {
return sd.dialer.Dial(network, address)
}
func (sd *testSkipTLSDialer) SkipTLSHandshake() bool {
return sd.skipTLS
}
func TestWSWithTLSCustomDialer(t *testing.T) {
sopts := testWSGetDefaultOptions(t, true)
s := RunServerWithOptions(sopts)
defer s.Shutdown()
sd := &testSkipTLSDialer{
dialer: &net.Dialer{
Timeout: 2 * time.Second,
},
skipTLS: true,
}
// Connect with CustomDialer that fails since TLSHandshake is disabled.
copts := make([]Option, 0)
copts = append(copts, Secure(&tls.Config{InsecureSkipVerify: true}))
copts = append(copts, SetCustomDialer(sd))
_, err := Connect(fmt.Sprintf("wss://localhost:%d", sopts.Websocket.Port), copts...)
if err == nil {
t.Fatalf("Expected error on connect: %v", err)
}
if err.Error() != `invalid websocket connection` {
t.Logf("Expected invalid websocket connection: %v", err)
}
// Retry with the dialer.
copts = make([]Option, 0)
sd = &testSkipTLSDialer{
dialer: &net.Dialer{
Timeout: 2 * time.Second,
},
skipTLS: false,
}
copts = append(copts, Secure(&tls.Config{InsecureSkipVerify: true}))
copts = append(copts, SetCustomDialer(sd))
nc, err := Connect(fmt.Sprintf("wss://localhost:%d", sopts.Websocket.Port), copts...)
if err != nil {
t.Fatalf("Unexpected error on connect: %v", err)
}
defer nc.Close()
}
func TestWSTlsNoConfig(t *testing.T) { func TestWSTlsNoConfig(t *testing.T) {
opts := GetDefaultOptions() opts := GetDefaultOptions()
opts.Servers = []string{"wss://localhost:443"} opts.Servers = []string{"wss://localhost:443"}