mirror of
https://github.com/nats-io/nats.go.git
synced 2025-12-24 13:08:06 +08:00
Add options to send custom WebSocket headers on connect (#1919)
Example:
```go
headers := make(http.Header)
headers.Set("X-Client-ID", "go-example-client")
headers.Set("X-Custom-Header", "static-value")
// Connect to NATS server via WebSocket with custom headers
nc, err := nats.Connect("ws://localhost:8080",
nats.WebSocketConnectionHeaders(headers),
)
```
Co-authored-by: Waldemar Quevedo <wally@nats.io>
This commit is contained in:
committed by
GitHub
parent
1610417578
commit
2ab8185635
40
nats.go
40
nats.go
@@ -151,6 +151,7 @@ var (
|
||||
ErrMaxAccountConnectionsExceeded = errors.New("nats: maximum account active connections exceeded")
|
||||
ErrConnectionNotTLS = errors.New("nats: connection is not tls")
|
||||
ErrMaxSubscriptionsExceeded = errors.New("nats: server maximum subscriptions exceeded")
|
||||
ErrWebSocketHeadersAlreadySet = errors.New("nats: websocket connection headers already set")
|
||||
)
|
||||
|
||||
// GetDefaultOptions returns default configuration options for the client.
|
||||
@@ -250,6 +251,9 @@ type UserInfoCB func() (string, string)
|
||||
// whole list of URLs and failed to reconnect.
|
||||
type ReconnectDelayHandler func(attempts int) time.Duration
|
||||
|
||||
// WebSocketHeadersHandler is an optional callback handler for generating token used for WebSocket connections.
|
||||
type WebSocketHeadersHandler func() (http.Header, error)
|
||||
|
||||
// asyncCB is used to preserve order for async callbacks.
|
||||
type asyncCB struct {
|
||||
f func()
|
||||
@@ -524,6 +528,12 @@ type Options struct {
|
||||
// from SubscribeSync if the server returns a permissions error for a subscription.
|
||||
// Defaults to false.
|
||||
PermissionErrOnSubscribe bool
|
||||
|
||||
// WebSocketConnectionHeaders is an optional http request headers to be sent with the WebSocket request.
|
||||
WebSocketConnectionHeaders http.Header
|
||||
|
||||
// WebSocketConnectionHeadersHandler is an optional callback handler for generating token used for WebSocket connections.
|
||||
WebSocketConnectionHeadersHandler WebSocketHeadersHandler
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -1472,6 +1482,36 @@ func TLSHandshakeFirst() Option {
|
||||
}
|
||||
}
|
||||
|
||||
// WebSocketConnectionHeaders sets a fixed set of HTTP headers that will be
|
||||
// sent during the WebSocket connection handshake.
|
||||
// This option is mutually exclusive with WebSocketConnectionHeadersHandler;
|
||||
// if a headers handler has already been configured, it returns
|
||||
// ErrWebSocketHeadersAlreadySet.
|
||||
func WebSocketConnectionHeaders(headers http.Header) Option {
|
||||
return func(o *Options) error {
|
||||
if o.WebSocketConnectionHeadersHandler != nil {
|
||||
return ErrWebSocketHeadersAlreadySet
|
||||
}
|
||||
o.WebSocketConnectionHeaders = headers
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WebSocketConnectionHeadersHandler registers a callback used to supply HTTP
|
||||
// headers for the WebSocket connection handshake.
|
||||
// This option is mutually exclusive with WebSocketConnectionHeaders; if
|
||||
// non-empty static headers have already been configured, it returns
|
||||
// ErrWebSocketHeadersAlreadySet.
|
||||
func WebSocketConnectionHeadersHandler(cb WebSocketHeadersHandler) Option {
|
||||
return func(o *Options) error {
|
||||
if len(o.WebSocketConnectionHeaders) != 0 {
|
||||
return ErrWebSocketHeadersAlreadySet
|
||||
}
|
||||
o.WebSocketConnectionHeadersHandler = cb
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Handler processing
|
||||
|
||||
// SetDisconnectHandler will set the disconnect event handler.
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -611,3 +612,88 @@ func TestWSNoDeadlockOnAuthFailure(t *testing.T) {
|
||||
|
||||
tm.Stop()
|
||||
}
|
||||
|
||||
func TestWsWithCustomHeaders(t *testing.T) {
|
||||
sopts := testWSGetDefaultOptions(t, false)
|
||||
s := RunServerWithOptions(sopts)
|
||||
defer s.Shutdown()
|
||||
|
||||
staticHeader := make(http.Header, 0)
|
||||
staticHeader.Set("Authorization", "Bearer Random Token")
|
||||
headerProvider := func() (http.Header, error) {
|
||||
return staticHeader, nil
|
||||
}
|
||||
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
connectionOptions []nats.Option
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Failure 1: Both headers and handler present",
|
||||
connectionOptions: []nats.Option{
|
||||
nats.WebSocketConnectionHeadersHandler(headerProvider),
|
||||
nats.WebSocketConnectionHeaders(staticHeader),
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Success 1: Headers present as static headers",
|
||||
connectionOptions: []nats.Option{
|
||||
nats.WebSocketConnectionHeaders(staticHeader),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Success 2: Header supplied through handler",
|
||||
connectionOptions: []nats.Option{
|
||||
nats.WebSocketConnectionHeadersHandler(headerProvider),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
url := fmt.Sprintf("ws://127.0.0.1:%d", sopts.Websocket.Port)
|
||||
nc, err := nats.Connect(url, test.connectionOptions...)
|
||||
if err != nil && test.wantErr {
|
||||
return
|
||||
}
|
||||
if err != nil && !test.wantErr {
|
||||
t.Fatalf("Did not expect error, found error: %v", err)
|
||||
}
|
||||
defer nc.Close()
|
||||
sub, err := nc.SubscribeSync("foo")
|
||||
if err != nil {
|
||||
t.Fatalf("Error on subscribe: %v", err)
|
||||
}
|
||||
|
||||
msgs := make([][]byte, 100)
|
||||
for i := 0; i < len(msgs); i++ {
|
||||
msg := make([]byte, 100)
|
||||
for j := 0; j < len(msg); j++ {
|
||||
msg[j] = 'A'
|
||||
}
|
||||
msgs[i] = msg
|
||||
}
|
||||
for i, msg := range msgs {
|
||||
if err := nc.Publish("foo", msg); err != nil {
|
||||
t.Fatalf("Error on publish: %v", err)
|
||||
}
|
||||
// Make sure that compression/masking does not touch user data
|
||||
if !bytes.Equal(msgs[i], msg) {
|
||||
t.Fatalf("User content has been changed: %v, got %v", msgs[i], msg)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < len(msgs); i++ {
|
||||
msg, err := sub.NextMsg(time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting next message (%d): %v", i+1, err)
|
||||
}
|
||||
if !bytes.Equal(msgs[i], msg.Data) {
|
||||
t.Fatalf("Expected message (%d): %v, got %v", i+1, msgs[i], msg)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
22
ws.go
22
ws.go
@@ -610,6 +610,9 @@ func (nc *Conn) wsInitHandshake(u *url.URL) error {
|
||||
if compress {
|
||||
req.Header.Add("Sec-WebSocket-Extensions", wsPMCReqHeaderValue)
|
||||
}
|
||||
if err := nc.wsUpdateConnectionHeaders(req); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := req.Write(nc.conn); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -728,6 +731,25 @@ func (nc *Conn) wsEnqueueControlMsg(needsLock bool, frameType wsOpCode, payload
|
||||
nc.bw.flush()
|
||||
}
|
||||
|
||||
func (nc *Conn) wsUpdateConnectionHeaders(req *http.Request) error {
|
||||
var headers http.Header
|
||||
var err error
|
||||
if nc.Opts.WebSocketConnectionHeadersHandler != nil {
|
||||
headers, err = nc.Opts.WebSocketConnectionHeadersHandler()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
headers = nc.Opts.WebSocketConnectionHeaders
|
||||
}
|
||||
for key, values := range headers {
|
||||
for _, val := range values {
|
||||
req.Header.Add(key, val)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func wsPMCExtensionSupport(header http.Header) (bool, bool) {
|
||||
for _, extensionList := range header["Sec-Websocket-Extensions"] {
|
||||
extensions := strings.Split(extensionList, ",")
|
||||
|
||||
121
ws_test.go
121
ws_test.go
@@ -14,6 +14,7 @@
|
||||
package nats
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
@@ -608,3 +609,123 @@ func TestWSProxyPath(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func startHeaderCatcher(t *testing.T) (addr string, got chan []string, closer func()) {
|
||||
t.Helper()
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
got = make(chan []string, 1)
|
||||
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
// surface nothing; test will timeout
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
r := bufio.NewReader(conn)
|
||||
var lines []string
|
||||
for {
|
||||
s, err := r.ReadString('\n')
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
s = strings.TrimRight(s, "\r\n")
|
||||
if s == "" { // end of HTTP headers
|
||||
break
|
||||
}
|
||||
lines = append(lines, s)
|
||||
}
|
||||
got <- lines
|
||||
}()
|
||||
|
||||
return ln.Addr().String(), got, func() { _ = ln.Close() }
|
||||
}
|
||||
|
||||
func hasHeaderValue(headers []string, name, want string) bool {
|
||||
prefix := strings.ToLower(name) + ":"
|
||||
for _, h := range headers {
|
||||
if !strings.HasPrefix(strings.ToLower(h), prefix) {
|
||||
continue
|
||||
}
|
||||
val := strings.TrimSpace(strings.SplitN(h, ":", 2)[1])
|
||||
for _, part := range strings.Split(val, ",") {
|
||||
if strings.EqualFold(strings.TrimSpace(part), want) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestWSHeaders_StaticAppliedOnHandshake(t *testing.T) {
|
||||
addr, got, closeLn := startHeaderCatcher(t)
|
||||
defer closeLn()
|
||||
|
||||
static := make(http.Header)
|
||||
static.Set("Authorization", "Bearer Random Token")
|
||||
static.Add("X-Multi", "v1")
|
||||
static.Add("X-Multi", "v2")
|
||||
|
||||
// Intentionally connect to our fake server; it won't complete the upgrade.
|
||||
opts := GetDefaultOptions()
|
||||
opts.WebSocketConnectionHeaders = static
|
||||
opts.Url = "ws://" + addr
|
||||
_, err := opts.Connect()
|
||||
if err == nil {
|
||||
t.Fatalf("expected connect to fail because server does not reply")
|
||||
}
|
||||
|
||||
var headers []string
|
||||
select {
|
||||
case headers = <-got:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("did not capture headers in time")
|
||||
}
|
||||
|
||||
if !hasHeaderValue(headers, "Authorization", "Bearer Random Token") {
|
||||
t.Fatalf("Authorization header missing: %v", headers)
|
||||
}
|
||||
if !hasHeaderValue(headers, "X-Multi", "v1") || !hasHeaderValue(headers, "X-Multi", "v2") {
|
||||
t.Fatalf("X-Multi headers missing/combined incorrectly: %v", headers)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWSHeaders_HandlerAppliedOnHandshake(t *testing.T) {
|
||||
addr, got, closeLn := startHeaderCatcher(t)
|
||||
defer closeLn()
|
||||
|
||||
provider := func() (http.Header, error) {
|
||||
h := make(http.Header)
|
||||
h.Set("Authorization", "Bearer FromHandler")
|
||||
h.Add("X-Multi", "h1")
|
||||
h.Add("X-Multi", "h2")
|
||||
return h, nil
|
||||
}
|
||||
|
||||
opts := GetDefaultOptions()
|
||||
opts.WebSocketConnectionHeadersHandler = provider
|
||||
opts.Url = "ws://" + addr
|
||||
_, err := opts.Connect()
|
||||
if err == nil {
|
||||
t.Fatalf("expected connect to fail because server does not reply")
|
||||
}
|
||||
|
||||
var headers []string
|
||||
select {
|
||||
case headers = <-got:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("did not capture headers in time")
|
||||
}
|
||||
|
||||
if !hasHeaderValue(headers, "Authorization", "Bearer FromHandler") {
|
||||
t.Fatalf("Authorization header missing: %v", headers)
|
||||
}
|
||||
if !hasHeaderValue(headers, "X-Multi", "h1") || !hasHeaderValue(headers, "X-Multi", "h2") {
|
||||
t.Fatalf("X-Multi headers missing/combined incorrectly: %v", headers)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user