Add capability to perform websocket fallback

This commit is contained in:
Jacob
2024-02-16 07:01:04 -05:00
parent 60124d8450
commit ea261ca186
3 changed files with 50 additions and 13 deletions

16
dial.go
View File

@@ -1,17 +1,22 @@
package rtcnet
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"net"
"sync"
"time"
"github.com/pion/webrtc/v3"
)
func Dial(address string, tlsConfig *tls.Config, ordered bool, iceServers []string) (*Conn, error) {
wSock, err := dialWebsocket(address, tlsConfig)
dialCtx, cancel := context.WithTimeout(context.Background(), 5 * time.Second) // TODO: pass in timeout
defer cancel()
wSock, err := dialWebsocket(address, tlsConfig, dialCtx)
if err != nil {
return nil, err
}
@@ -85,7 +90,9 @@ func Dial(address string, tlsConfig *tls.Config, ordered bool, iceServers []stri
logger.Error().
Err(err).
Msg("Failed to read from websocket")
conn.pushErrorData(err)
// TODO: We don't want this to cause an error, if it closed for normal reasons. Else we do want it to cause an error
// conn.pushErrorData(err)
return
}
@@ -254,6 +261,11 @@ func Dial(address string, tlsConfig *tls.Config, ordered bool, iceServers []stri
// Wait until the webrtc connection is finished getting setup
select {
case <-dialCtx.Done():
logger.Error().
Err(err).
Msg("Dial: context Done")
return nil, dialCtx.Err()
case err := <-conn.errorChan:
logger.Error().
Err(err).

View File

@@ -15,11 +15,12 @@ type ListenConfig struct {
TlsConfig *tls.Config
OriginPatterns []string
IceServers []string
// AllowWebsocketFallback bool // TODO: Restriction?
}
type Listener struct {
wsListener *websocketListener
pendingAccepts chan *Conn // TODO - should this get buffered?
pendingAccepts chan net.Conn // TODO - should this get buffered?
pendingAcceptErrors chan error // TODO - should this get buffered?
closed atomic.Bool
iceServers []string
@@ -33,7 +34,7 @@ func NewListener(address string, config ListenConfig) (*Listener, error) {
rtcListener := &Listener{
wsListener: wsl,
pendingAccepts: make(chan *Conn),
pendingAccepts: make(chan net.Conn),
pendingAcceptErrors: make(chan error),
iceServers: config.IceServers,
}
@@ -50,8 +51,13 @@ func NewListener(address string, config ListenConfig) (*Listener, error) {
return // If closed then just exit
}
// Try and negotiate a webrtc connection for the websocket connection
go rtcListener.attemptWebRtcNegotiation(wsConn)
fallback, isFallback := wsConn.(wsFallback)
if isFallback {
rtcListener.pendingAccepts <- fallback.Conn
} else {
// Try and negotiate a webrtc connection for the websocket connection
go rtcListener.attemptWebRtcNegotiation(wsConn)
}
}
}()

33
ws.go
View File

@@ -13,9 +13,8 @@ import (
)
// Returns a connected socket or fails with an error
func dialWebsocket(address string, tlsConfig *tls.Config) (net.Conn, error) {
// TODO: make timeout configurable
ctx, _ := context.WithTimeout(context.Background(), 10 * time.Second)
func dialWebsocket(address string, tlsConfig *tls.Config, ctx context.Context) (net.Conn, error) {
// ctx, _ := context.WithTimeout(context.Background(), 10 * time.Second)
url := "wss://" + address
wsConn, err := dialWs(ctx, url, tlsConfig)
@@ -23,6 +22,7 @@ func dialWebsocket(address string, tlsConfig *tls.Config) (net.Conn, error) {
return nil, err
}
// Note: The entire websocket net.Conn lifetime is managed by the context too
// ctx, cancel := context.WithCancel(context.Background())
conn := websocket.NetConn(ctx, wsConn, websocket.MessageBinary)
@@ -85,6 +85,10 @@ func newWebsocketListener(address string, config ListenConfig) (*websocketListen
return wsl, nil
}
type wsFallback struct {
net.Conn
}
func (l *websocketListener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
wsConn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
OriginPatterns: l.originPatterns,
@@ -95,11 +99,26 @@ func (l *websocketListener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// Build the net.Conn and push to the channel
ctx := context.Background() // TODO: configurable context?
conn := websocket.NetConn(ctx, wsConn, websocket.MessageBinary)
fallback := false
if r.URL != nil {
if r.URL.Path == "/wss" {
logger.Warn().Msg("Dialer requested wss fallback socket!")
fallback = true
}
}
l.pendingAccepts <- conn
// Build the net.Conn and push to the channel
if fallback {
ctx := context.Background() // Note: This has to be background if it is a fallback path
conn := websocket.NetConn(ctx, wsConn, websocket.MessageBinary)
conn = wsFallback{conn}
l.pendingAccepts <- conn
} else {
// TODO: make timeout configurable?
ctx, _ := context.WithTimeout(context.Background(), 30 * time.Second)
conn := websocket.NetConn(ctx, wsConn, websocket.MessageBinary)
l.pendingAccepts <- conn
}
}
func (l *websocketListener) Accept() (net.Conn, error) {