Add options to send custom WebSocket headers on connect (#1919)

Example:

```go
	headers := make(http.Header)
	headers.Set("X-Client-ID", "go-example-client")
	headers.Set("X-Custom-Header", "static-value")

	// Connect to NATS server via WebSocket with custom headers
	nc, err := nats.Connect("ws://localhost:8080",
		nats.WebSocketConnectionHeaders(headers),
	)
```

Co-authored-by: Waldemar Quevedo <wally@nats.io>
This commit is contained in:
Saurabh Kumar Ojha
2025-10-10 11:28:30 +05:30
committed by GitHub
parent 1610417578
commit 2ab8185635
4 changed files with 269 additions and 0 deletions

40
nats.go
View File

@@ -151,6 +151,7 @@ var (
ErrMaxAccountConnectionsExceeded = errors.New("nats: maximum account active connections exceeded")
ErrConnectionNotTLS = errors.New("nats: connection is not tls")
ErrMaxSubscriptionsExceeded = errors.New("nats: server maximum subscriptions exceeded")
ErrWebSocketHeadersAlreadySet = errors.New("nats: websocket connection headers already set")
)
// GetDefaultOptions returns default configuration options for the client.
@@ -250,6 +251,9 @@ type UserInfoCB func() (string, string)
// whole list of URLs and failed to reconnect.
type ReconnectDelayHandler func(attempts int) time.Duration
// WebSocketHeadersHandler is an optional callback handler for generating token used for WebSocket connections.
type WebSocketHeadersHandler func() (http.Header, error)
// asyncCB is used to preserve order for async callbacks.
type asyncCB struct {
f func()
@@ -524,6 +528,12 @@ type Options struct {
// from SubscribeSync if the server returns a permissions error for a subscription.
// Defaults to false.
PermissionErrOnSubscribe bool
// WebSocketConnectionHeaders is an optional http request headers to be sent with the WebSocket request.
WebSocketConnectionHeaders http.Header
// WebSocketConnectionHeadersHandler is an optional callback handler for generating token used for WebSocket connections.
WebSocketConnectionHeadersHandler WebSocketHeadersHandler
}
const (
@@ -1472,6 +1482,36 @@ func TLSHandshakeFirst() Option {
}
}
// WebSocketConnectionHeaders sets a fixed set of HTTP headers that will be
// sent during the WebSocket connection handshake.
// This option is mutually exclusive with WebSocketConnectionHeadersHandler;
// if a headers handler has already been configured, it returns
// ErrWebSocketHeadersAlreadySet.
func WebSocketConnectionHeaders(headers http.Header) Option {
return func(o *Options) error {
if o.WebSocketConnectionHeadersHandler != nil {
return ErrWebSocketHeadersAlreadySet
}
o.WebSocketConnectionHeaders = headers
return nil
}
}
// WebSocketConnectionHeadersHandler registers a callback used to supply HTTP
// headers for the WebSocket connection handshake.
// This option is mutually exclusive with WebSocketConnectionHeaders; if
// non-empty static headers have already been configured, it returns
// ErrWebSocketHeadersAlreadySet.
func WebSocketConnectionHeadersHandler(cb WebSocketHeadersHandler) Option {
return func(o *Options) error {
if len(o.WebSocketConnectionHeaders) != 0 {
return ErrWebSocketHeadersAlreadySet
}
o.WebSocketConnectionHeadersHandler = cb
return nil
}
}
// Handler processing
// SetDisconnectHandler will set the disconnect event handler.

View File

@@ -21,6 +21,7 @@ import (
"fmt"
"math/rand"
"net"
"net/http"
"runtime"
"strings"
"sync"
@@ -611,3 +612,88 @@ func TestWSNoDeadlockOnAuthFailure(t *testing.T) {
tm.Stop()
}
func TestWsWithCustomHeaders(t *testing.T) {
sopts := testWSGetDefaultOptions(t, false)
s := RunServerWithOptions(sopts)
defer s.Shutdown()
staticHeader := make(http.Header, 0)
staticHeader.Set("Authorization", "Bearer Random Token")
headerProvider := func() (http.Header, error) {
return staticHeader, nil
}
for _, test := range []struct {
name string
connectionOptions []nats.Option
wantErr bool
}{
{
name: "Failure 1: Both headers and handler present",
connectionOptions: []nats.Option{
nats.WebSocketConnectionHeadersHandler(headerProvider),
nats.WebSocketConnectionHeaders(staticHeader),
},
wantErr: true,
},
{
name: "Success 1: Headers present as static headers",
connectionOptions: []nats.Option{
nats.WebSocketConnectionHeaders(staticHeader),
},
wantErr: false,
},
{
name: "Success 2: Header supplied through handler",
connectionOptions: []nats.Option{
nats.WebSocketConnectionHeadersHandler(headerProvider),
},
wantErr: false,
},
} {
t.Run(test.name, func(t *testing.T) {
url := fmt.Sprintf("ws://127.0.0.1:%d", sopts.Websocket.Port)
nc, err := nats.Connect(url, test.connectionOptions...)
if err != nil && test.wantErr {
return
}
if err != nil && !test.wantErr {
t.Fatalf("Did not expect error, found error: %v", err)
}
defer nc.Close()
sub, err := nc.SubscribeSync("foo")
if err != nil {
t.Fatalf("Error on subscribe: %v", err)
}
msgs := make([][]byte, 100)
for i := 0; i < len(msgs); i++ {
msg := make([]byte, 100)
for j := 0; j < len(msg); j++ {
msg[j] = 'A'
}
msgs[i] = msg
}
for i, msg := range msgs {
if err := nc.Publish("foo", msg); err != nil {
t.Fatalf("Error on publish: %v", err)
}
// Make sure that compression/masking does not touch user data
if !bytes.Equal(msgs[i], msg) {
t.Fatalf("User content has been changed: %v, got %v", msgs[i], msg)
}
}
for i := 0; i < len(msgs); i++ {
msg, err := sub.NextMsg(time.Second)
if err != nil {
t.Fatalf("Error getting next message (%d): %v", i+1, err)
}
if !bytes.Equal(msgs[i], msg.Data) {
t.Fatalf("Expected message (%d): %v, got %v", i+1, msgs[i], msg)
}
}
})
}
}

22
ws.go
View File

@@ -610,6 +610,9 @@ func (nc *Conn) wsInitHandshake(u *url.URL) error {
if compress {
req.Header.Add("Sec-WebSocket-Extensions", wsPMCReqHeaderValue)
}
if err := nc.wsUpdateConnectionHeaders(req); err != nil {
return err
}
if err := req.Write(nc.conn); err != nil {
return err
}
@@ -728,6 +731,25 @@ func (nc *Conn) wsEnqueueControlMsg(needsLock bool, frameType wsOpCode, payload
nc.bw.flush()
}
func (nc *Conn) wsUpdateConnectionHeaders(req *http.Request) error {
var headers http.Header
var err error
if nc.Opts.WebSocketConnectionHeadersHandler != nil {
headers, err = nc.Opts.WebSocketConnectionHeadersHandler()
if err != nil {
return err
}
} else {
headers = nc.Opts.WebSocketConnectionHeaders
}
for key, values := range headers {
for _, val := range values {
req.Header.Add(key, val)
}
}
return nil
}
func wsPMCExtensionSupport(header http.Header) (bool, bool) {
for _, extensionList := range header["Sec-Websocket-Extensions"] {
extensions := strings.Split(extensionList, ",")

View File

@@ -14,6 +14,7 @@
package nats
import (
"bufio"
"bytes"
"context"
"fmt"
@@ -608,3 +609,123 @@ func TestWSProxyPath(t *testing.T) {
})
}
}
// --- helpers ---
func startHeaderCatcher(t *testing.T) (addr string, got chan []string, closer func()) {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
got = make(chan []string, 1)
go func() {
conn, err := ln.Accept()
if err != nil {
// surface nothing; test will timeout
return
}
defer conn.Close()
r := bufio.NewReader(conn)
var lines []string
for {
s, err := r.ReadString('\n')
if err != nil {
break
}
s = strings.TrimRight(s, "\r\n")
if s == "" { // end of HTTP headers
break
}
lines = append(lines, s)
}
got <- lines
}()
return ln.Addr().String(), got, func() { _ = ln.Close() }
}
func hasHeaderValue(headers []string, name, want string) bool {
prefix := strings.ToLower(name) + ":"
for _, h := range headers {
if !strings.HasPrefix(strings.ToLower(h), prefix) {
continue
}
val := strings.TrimSpace(strings.SplitN(h, ":", 2)[1])
for _, part := range strings.Split(val, ",") {
if strings.EqualFold(strings.TrimSpace(part), want) {
return true
}
}
}
return false
}
func TestWSHeaders_StaticAppliedOnHandshake(t *testing.T) {
addr, got, closeLn := startHeaderCatcher(t)
defer closeLn()
static := make(http.Header)
static.Set("Authorization", "Bearer Random Token")
static.Add("X-Multi", "v1")
static.Add("X-Multi", "v2")
// Intentionally connect to our fake server; it won't complete the upgrade.
opts := GetDefaultOptions()
opts.WebSocketConnectionHeaders = static
opts.Url = "ws://" + addr
_, err := opts.Connect()
if err == nil {
t.Fatalf("expected connect to fail because server does not reply")
}
var headers []string
select {
case headers = <-got:
case <-time.After(2 * time.Second):
t.Fatal("did not capture headers in time")
}
if !hasHeaderValue(headers, "Authorization", "Bearer Random Token") {
t.Fatalf("Authorization header missing: %v", headers)
}
if !hasHeaderValue(headers, "X-Multi", "v1") || !hasHeaderValue(headers, "X-Multi", "v2") {
t.Fatalf("X-Multi headers missing/combined incorrectly: %v", headers)
}
}
func TestWSHeaders_HandlerAppliedOnHandshake(t *testing.T) {
addr, got, closeLn := startHeaderCatcher(t)
defer closeLn()
provider := func() (http.Header, error) {
h := make(http.Header)
h.Set("Authorization", "Bearer FromHandler")
h.Add("X-Multi", "h1")
h.Add("X-Multi", "h2")
return h, nil
}
opts := GetDefaultOptions()
opts.WebSocketConnectionHeadersHandler = provider
opts.Url = "ws://" + addr
_, err := opts.Connect()
if err == nil {
t.Fatalf("expected connect to fail because server does not reply")
}
var headers []string
select {
case headers = <-got:
case <-time.After(2 * time.Second):
t.Fatal("did not capture headers in time")
}
if !hasHeaderValue(headers, "Authorization", "Bearer FromHandler") {
t.Fatalf("Authorization header missing: %v", headers)
}
if !hasHeaderValue(headers, "X-Multi", "h1") || !hasHeaderValue(headers, "X-Multi", "h2") {
t.Fatalf("X-Multi headers missing/combined incorrectly: %v", headers)
}
}