mirror of
https://github.com/nats-io/nats.go.git
synced 2025-09-27 04:46:01 +08:00
Add SkipTLSHandshake small interface to CustomDialer
Signed-off-by: Waldemar Quevedo <wally@nats.io> Co-authored-by: Piotr Piotrowski <piotr@synadia.com>
This commit is contained in:
@@ -17,6 +17,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
@@ -44,6 +45,40 @@ func ExampleConnect() {
|
||||
nc.Close()
|
||||
}
|
||||
|
||||
type skipTLSDialer struct {
|
||||
dialer *net.Dialer
|
||||
skipTLS bool
|
||||
}
|
||||
|
||||
func (sd *skipTLSDialer) Dial(network, address string) (net.Conn, error) {
|
||||
return sd.dialer.Dial(network, address)
|
||||
}
|
||||
|
||||
func (sd *skipTLSDialer) SkipTLSHandshake() bool {
|
||||
return sd.skipTLS
|
||||
}
|
||||
|
||||
func ExampleCustomDialer() {
|
||||
// Given the following CustomDialer implementation:
|
||||
//
|
||||
// type skipTLSDialer struct {
|
||||
// dialer *net.Dialer
|
||||
// skipTLS bool
|
||||
// }
|
||||
//
|
||||
// func (sd *skipTLSDialer) Dial(network, address string) (net.Conn, error) {
|
||||
// return sd.dialer.Dial(network, address)
|
||||
// }
|
||||
//
|
||||
// func (sd *skipTLSDialer) SkipTLSHandshake() bool {
|
||||
// return true
|
||||
// }
|
||||
//
|
||||
sd := &skipTLSDialer{dialer: &net.Dialer{Timeout: 2 * time.Second}, skipTLS: true}
|
||||
nc, _ := nats.Connect("demo.nats.io", nats.SetCustomDialer(sd))
|
||||
defer nc.Close()
|
||||
}
|
||||
|
||||
// This Example shows an asynchronous subscriber.
|
||||
func ExampleConn_Subscribe() {
|
||||
nc, _ := nats.Connect(nats.DefaultURL)
|
||||
|
30
nats.go
30
nats.go
@@ -247,8 +247,9 @@ type asyncCallbacksHandler struct {
|
||||
// Option is a function on the options for a connection.
|
||||
type Option func(*Options) error
|
||||
|
||||
// CustomDialer can be used to specify any dialer, not necessarily
|
||||
// a *net.Dialer.
|
||||
// CustomDialer can be used to specify any dialer, not necessarily a
|
||||
// *net.Dialer. A CustomDialer may also implement `SkipTLSHandshake() bool`
|
||||
// in order to skip the TLS handshake in case not required.
|
||||
type CustomDialer interface {
|
||||
Dial(network, address string) (net.Conn, error)
|
||||
}
|
||||
@@ -303,10 +304,6 @@ type Options struct {
|
||||
// transports.
|
||||
TLSConfig *tls.Config
|
||||
|
||||
// SkipTLSWrapper does not upgrade the connection to TLS and is
|
||||
// meant to be used if the custom dialer does handle TLS itself
|
||||
SkipTLSWrapper bool
|
||||
|
||||
// AllowReconnect enables reconnection logic to be used when we
|
||||
// encounter a disconnect from the current server.
|
||||
AllowReconnect bool
|
||||
@@ -1189,16 +1186,6 @@ func SetCustomDialer(dialer CustomDialer) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// SetSkipTLSWrapper is an Option to be used with the CustomDialer which
|
||||
// will not wrap the connection with TLS. Use it if the CustomDialer did
|
||||
// already handle TLS
|
||||
func SetSkipTLSWrapper(skip bool) Option {
|
||||
return func(o *Options) error {
|
||||
o.SkipTLSWrapper = skip
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// UseOldRequestStyle is an Option to force usage of the old Request style.
|
||||
func UseOldRequestStyle() Option {
|
||||
return func(o *Options) error {
|
||||
@@ -1906,11 +1893,18 @@ func (nc *Conn) createConn() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
type skipTLSDialer interface {
|
||||
SkipTLSHandshake() bool
|
||||
}
|
||||
|
||||
// makeTLSConn will wrap an existing Conn using TLS
|
||||
func (nc *Conn) makeTLSConn() error {
|
||||
if nc.Opts.SkipTLSWrapper {
|
||||
if nc.Opts.CustomDialer != nil {
|
||||
// we do nothing when asked to skip the TLS wrapper
|
||||
return nil
|
||||
sd, ok := nc.Opts.CustomDialer.(skipTLSDialer)
|
||||
if ok && sd.SkipTLSHandshake() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// Allow the user to configure their own tls.Config structure.
|
||||
var tlsCopy *tls.Config
|
||||
|
@@ -31,7 +31,7 @@ import (
|
||||
|
||||
type (
|
||||
|
||||
// Service is an interface for sevice management.
|
||||
// Service is an interface for service management.
|
||||
// It exposes methods to stop/reset a service, as well as get information on a service.
|
||||
Service interface {
|
||||
ID() string
|
||||
|
54
ws_test.go
54
ws_test.go
@@ -868,6 +868,60 @@ func TestWSWithTLS(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type testSkipTLSDialer struct {
|
||||
dialer *net.Dialer
|
||||
skipTLS bool
|
||||
}
|
||||
|
||||
func (sd *testSkipTLSDialer) Dial(network, address string) (net.Conn, error) {
|
||||
return sd.dialer.Dial(network, address)
|
||||
}
|
||||
|
||||
func (sd *testSkipTLSDialer) SkipTLSHandshake() bool {
|
||||
return sd.skipTLS
|
||||
}
|
||||
|
||||
func TestWSWithTLSCustomDialer(t *testing.T) {
|
||||
sopts := testWSGetDefaultOptions(t, true)
|
||||
s := RunServerWithOptions(sopts)
|
||||
defer s.Shutdown()
|
||||
|
||||
sd := &testSkipTLSDialer{
|
||||
dialer: &net.Dialer{
|
||||
Timeout: 2 * time.Second,
|
||||
},
|
||||
skipTLS: true,
|
||||
}
|
||||
|
||||
// Connect with CustomDialer that fails since TLSHandshake is disabled.
|
||||
copts := make([]Option, 0)
|
||||
copts = append(copts, Secure(&tls.Config{InsecureSkipVerify: true}))
|
||||
copts = append(copts, SetCustomDialer(sd))
|
||||
_, err := Connect(fmt.Sprintf("wss://localhost:%d", sopts.Websocket.Port), copts...)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error on connect: %v", err)
|
||||
}
|
||||
if err.Error() != `invalid websocket connection` {
|
||||
t.Logf("Expected invalid websocket connection: %v", err)
|
||||
}
|
||||
|
||||
// Retry with the dialer.
|
||||
copts = make([]Option, 0)
|
||||
sd = &testSkipTLSDialer{
|
||||
dialer: &net.Dialer{
|
||||
Timeout: 2 * time.Second,
|
||||
},
|
||||
skipTLS: false,
|
||||
}
|
||||
copts = append(copts, Secure(&tls.Config{InsecureSkipVerify: true}))
|
||||
copts = append(copts, SetCustomDialer(sd))
|
||||
nc, err := Connect(fmt.Sprintf("wss://localhost:%d", sopts.Websocket.Port), copts...)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error on connect: %v", err)
|
||||
}
|
||||
defer nc.Close()
|
||||
}
|
||||
|
||||
func TestWSTlsNoConfig(t *testing.T) {
|
||||
opts := GetDefaultOptions()
|
||||
opts.Servers = []string{"wss://localhost:443"}
|
||||
|
Reference in New Issue
Block a user