From 6107dea9a082afa82f228f970cb1bb089da4cb60 Mon Sep 17 00:00:00 2001 From: Alessandro Ros Date: Wed, 17 Sep 2025 21:30:11 +0200 Subject: [PATCH] add RTSP-over-WebSocket (#891) (#898) --- README.md | 4 +- client.go | 15 +- client_test.go | 100 ++++++++++++- ...nt_http_tunnel.go => client_tunnel_http.go | 22 +-- client_tunnel_websocket.go | 81 +++++++++++ conn_transport.go | 1 + go.mod | 1 + go.sum | 2 + internal/teste2e/client_vs_server_test.go | 30 +++- server.go | 2 +- server_conn.go | 13 +- server_conn_reader.go | 106 +++++++++----- server_session.go | 2 +- server_test.go | 50 ++++++- ...er_http_tunnel.go => server_tunnel_http.go | 0 server_tunnel_websocket.go | 133 ++++++++++++++++++ 16 files changed, 493 insertions(+), 69 deletions(-) rename client_http_tunnel.go => client_tunnel_http.go (85%) create mode 100644 client_tunnel_websocket.go rename server_http_tunnel.go => server_tunnel_http.go (100%) create mode 100644 server_tunnel_websocket.go diff --git a/README.md b/README.md index a11fbd85..a265d6f8 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Features: * Client * Support secure protocol variants (RTSPS, TLS, SRTP, MIKEY) - * Support RTSP-over-HTTP, RTSP-over-HTTPS + * Support tunneling (RTSP-over-HTTP, RTSP-over-WebSocket) * Query servers about available media streams * Read media streams from a server ("play") * Read streams with the UDP, UDP-multicast or TCP transport protocol @@ -30,7 +30,7 @@ Features: * Pause without disconnecting from the server * Server * Support secure protocol variants (RTSPS, TLS, SRTP, MIKEY) - * Support RTSP-over-HTTP, RTSP-over-HTTPS + * Support tunneling (RTSP-over-HTTP, RTSP-over-WebSocket) * Handle requests from clients * Validate client credentials * Read media streams from clients ("record") diff --git a/client.go b/client.go index d58607a5..bec686b0 100644 --- a/client.go +++ b/client.go @@ -1137,13 +1137,22 @@ func (c *Client) connOpen() error { var nconn net.Conn - if c.Tunnel == TunnelHTTP { + switch c.Tunnel { + case TunnelHTTP: var err error - nconn, err = newClientHTTPTunnel(dialCtx, c.DialContext, addr, tlsConfig) + nconn, err = newClientTunnelHTTP(dialCtx, c.DialContext, addr, tlsConfig) if err != nil { return err } - } else { + + case TunnelWebSocket: + var err error + nconn, err = newClientTunnelWebSocket(dialCtx, c.DialContext, addr, tlsConfig) + if err != nil { + return err + } + + default: var err error nconn, err = c.DialContext(dialCtx, "tcp", addr) if err != nil { diff --git a/client_test.go b/client_test.go index 308f6a34..56b3f835 100644 --- a/client_test.go +++ b/client_test.go @@ -3,6 +3,7 @@ package gortsplib import ( "bufio" "bytes" + "context" "crypto/tls" "net" "net/http" @@ -593,7 +594,7 @@ func TestClientRelativeContentBase(t *testing.T) { require.Equal(t, "rtsp://localhost:8554/relative-content-base", desc.BaseURL.String()) } -func TestClientHTTPTunnel(t *testing.T) { +func TestClientTunnelHTTP(t *testing.T) { for _, ca := range []string{"http", "https"} { t.Run(ca, func(t *testing.T) { var l net.Listener @@ -768,8 +769,103 @@ func TestClientHTTPTunnel(t *testing.T) { require.NoError(t, err) defer c.Close() - _, _, err = c.Describe(u) + _, res, err := c.Describe(u) require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + }) + } +} + +func TestClientTunnelWebSocket(t *testing.T) { + for _, ca := range []string{"ws", "wss"} { + t.Run(ca, func(t *testing.T) { + var scheme string + if ca == "ws" { + scheme = "rtsp" + } else { + scheme = "rtsps" + } + + s := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, r.Header.Get("Sec-WebSocket-Protocol"), "rtsp.onvif.org") + + wconn, err := upgrader.Upgrade(w, r, nil) + require.NoError(t, err) + defer wconn.Close() //nolint:errcheck + + conn := conn.NewConn(bufio.NewReader(&wsReader{wc: wconn}), &wsWriter{wc: wconn}) + + req, err2 := conn.ReadRequest() + require.NoError(t, err2) + require.Equal(t, base.Options, req.Method) + + err2 = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Public": base.HeaderValue{strings.Join([]string{ + string(base.Describe), + }, ", ")}, + }, + }) + require.NoError(t, err2) + + req, err2 = conn.ReadRequest() + require.NoError(t, err2) + require.Equal(t, base.Describe, req.Method) + require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL) + + medias := []*description.Media{testH264Media} + + err2 = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Content-Type": base.HeaderValue{"application/sdp; charset=utf-8"}, + "Content-Base": base.HeaderValue{"/relative-content-base"}, + }, + Body: mediasToSDP(medias), + }) + require.NoError(t, err2) + }), + } + + var ln net.Listener + + if ca == "ws" { + var err error + ln, err = net.Listen("tcp", "localhost:8554") + require.NoError(t, err) + } else { + cert, err := tls.X509KeyPair(serverCert, serverKey) + require.NoError(t, err) + + ln, err = tls.Listen("tcp", "localhost:8554", &tls.Config{Certificates: []tls.Certificate{cert}}) + require.NoError(t, err) + defer ln.Close() + } + + go s.Serve(ln) + defer s.Shutdown(context.Background()) + + u, err := base.ParseURL(scheme + "://localhost:8554/teststream") + require.NoError(t, err) + + c := Client{ + Scheme: u.Scheme, + Host: u.Host, + Tunnel: TunnelWebSocket, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + } + + err = c.Start() + require.NoError(t, err) + defer c.Close() + + _, res, err := c.Describe(u) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) }) } } diff --git a/client_http_tunnel.go b/client_tunnel_http.go similarity index 85% rename from client_http_tunnel.go rename to client_tunnel_http.go index e9732ed4..1a1476ec 100644 --- a/client_http_tunnel.go +++ b/client_tunnel_http.go @@ -14,53 +14,53 @@ import ( "github.com/google/uuid" ) -type clientHTTPTunnel struct { +type clientTunnelHTTP struct { readChan net.Conn readBuf *bufio.Reader writeChan net.Conn } -func (c *clientHTTPTunnel) Read(p []byte) (n int, err error) { +func (c *clientTunnelHTTP) Read(p []byte) (n int, err error) { return c.readBuf.Read(p) } -func (c *clientHTTPTunnel) Write(p []byte) (n int, err error) { +func (c *clientTunnelHTTP) Write(p []byte) (n int, err error) { return c.writeChan.Write([]byte(base64.StdEncoding.EncodeToString(p))) } -func (c *clientHTTPTunnel) Close() error { +func (c *clientTunnelHTTP) Close() error { c.readChan.Close() c.writeChan.Close() return nil } -func (c *clientHTTPTunnel) LocalAddr() net.Addr { +func (c *clientTunnelHTTP) LocalAddr() net.Addr { return c.readChan.LocalAddr() } -func (c *clientHTTPTunnel) RemoteAddr() net.Addr { +func (c *clientTunnelHTTP) RemoteAddr() net.Addr { return c.readChan.RemoteAddr() } -func (c *clientHTTPTunnel) SetDeadline(_ time.Time) error { +func (c *clientTunnelHTTP) SetDeadline(_ time.Time) error { panic("unimplemented") } -func (c *clientHTTPTunnel) SetReadDeadline(t time.Time) error { +func (c *clientTunnelHTTP) SetReadDeadline(t time.Time) error { return c.readChan.SetReadDeadline(t) } -func (c *clientHTTPTunnel) SetWriteDeadline(t time.Time) error { +func (c *clientTunnelHTTP) SetWriteDeadline(t time.Time) error { return c.writeChan.SetWriteDeadline(t) } -func newClientHTTPTunnel( +func newClientTunnelHTTP( ctx context.Context, dialContext func(ctx context.Context, network, address string) (net.Conn, error), addr string, tlsConfig *tls.Config, ) (net.Conn, error) { - c := &clientHTTPTunnel{} + c := &clientTunnelHTTP{} var err error c.readChan, err = dialContext(ctx, "tcp", addr) diff --git a/client_tunnel_websocket.go b/client_tunnel_websocket.go new file mode 100644 index 00000000..56267b03 --- /dev/null +++ b/client_tunnel_websocket.go @@ -0,0 +1,81 @@ +package gortsplib + +import ( + "context" + "crypto/tls" + "io" + "net" + "time" + + "github.com/gorilla/websocket" +) + +type clientTunnelWebSocket struct { + wconn *websocket.Conn + r io.Reader + w io.Writer +} + +func (tu *clientTunnelWebSocket) Read(b []byte) (int, error) { + return tu.r.Read(b) +} + +func (tu *clientTunnelWebSocket) Write(b []byte) (int, error) { + return tu.w.Write(b) +} + +func (tu *clientTunnelWebSocket) Close() error { + return tu.wconn.Close() +} + +func (tu *clientTunnelWebSocket) LocalAddr() net.Addr { + return tu.wconn.LocalAddr() +} + +func (tu *clientTunnelWebSocket) RemoteAddr() net.Addr { + return tu.wconn.RemoteAddr() +} + +func (tu *clientTunnelWebSocket) SetDeadline(_ time.Time) error { + return nil +} + +func (tu *clientTunnelWebSocket) SetReadDeadline(t time.Time) error { + return tu.wconn.SetReadDeadline(t) +} + +func (tu *clientTunnelWebSocket) SetWriteDeadline(t time.Time) error { + return tu.wconn.SetWriteDeadline(t) +} + +func newClientTunnelWebSocket( + ctx context.Context, + dialContext func(ctx context.Context, network, address string) (net.Conn, error), + addr string, + tlsConfig *tls.Config, +) (net.Conn, error) { + c := &clientTunnelWebSocket{} + + var ur string + if tlsConfig != nil { + ur = "wss" + } else { + ur = "ws" + } + ur += "://" + addr + "/" + + var err error + c.wconn, _, err = (&websocket.Dialer{ + NetDialContext: dialContext, + TLSClientConfig: tlsConfig, + Subprotocols: []string{"rtsp.onvif.org"}, + }).DialContext(ctx, ur, nil) //nolint:bodyclose + if err != nil { + return nil, err + } + + c.r = &wsReader{wc: c.wconn} + c.w = &wsWriter{wc: c.wconn} + + return c, nil +} diff --git a/conn_transport.go b/conn_transport.go index 1129ab96..e5f37710 100644 --- a/conn_transport.go +++ b/conn_transport.go @@ -7,6 +7,7 @@ type Tunnel int const ( TunnelNone Tunnel = iota TunnelHTTP + TunnelWebSocket ) // ConnTransport contains details about the transport of a connection. diff --git a/go.mod b/go.mod index 35c3b10b..81a42ccf 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.24.0 require ( github.com/bluenviron/mediacommon/v2 v2.4.2 github.com/google/uuid v1.6.0 + github.com/gorilla/websocket v1.5.3 github.com/pion/rtcp v1.2.15 github.com/pion/rtp v1.8.22 github.com/pion/sdp/v3 v3.0.16 diff --git a/go.sum b/go.sum index 165e938d..98b385bb 100644 --- a/go.sum +++ b/go.sum @@ -9,6 +9,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= diff --git a/internal/teste2e/client_vs_server_test.go b/internal/teste2e/client_vs_server_test.go index 60fe2179..808b0ce0 100644 --- a/internal/teste2e/client_vs_server_test.go +++ b/internal/teste2e/client_vs_server_test.go @@ -132,6 +132,22 @@ func TestClientVsServer(t *testing.T) { readerProto: "tcp", readerTunnel: "http", }, + { + publisherScheme: "rtsp", + publisherProto: "tcp", + publisherTunnel: "websocket", + readerScheme: "rtsp", + readerProto: "udp", + readerTunnel: "none", + }, + { + publisherScheme: "rtsp", + publisherProto: "tcp", + publisherTunnel: "none", + readerScheme: "rtsp", + readerProto: "tcp", + readerTunnel: "websocket", + }, } { t.Run(strings.Join([]string{ ca.publisherScheme, @@ -166,9 +182,12 @@ func TestClientVsServer(t *testing.T) { } var publisherTunnel gortsplib.Tunnel - if ca.publisherTunnel == "http" { + switch ca.publisherTunnel { + case "http": publisherTunnel = gortsplib.TunnelHTTP - } else { + case "websocket": + publisherTunnel = gortsplib.TunnelWebSocket + default: publisherTunnel = gortsplib.TunnelNone } @@ -192,9 +211,12 @@ func TestClientVsServer(t *testing.T) { time.Sleep(1 * time.Second) var readerTunnel gortsplib.Tunnel - if ca.readerTunnel == "http" { + switch ca.readerTunnel { + case "http": readerTunnel = gortsplib.TunnelHTTP - } else { + case "websocket": + readerTunnel = gortsplib.TunnelWebSocket + default: readerTunnel = gortsplib.TunnelNone } diff --git a/server.go b/server.go index 1d8147a2..83387af4 100644 --- a/server.go +++ b/server.go @@ -420,7 +420,7 @@ func (s *Server) runInner() error { sc := &ServerConn{ s: s, nconn: nconn, - isHTTP: true, + tunnel: TunnelHTTP, } sc.initialize() s.conns[sc] = struct{}{} diff --git a/server_conn.go b/server_conn.go index a8d5e7f5..a42306be 100644 --- a/server_conn.go +++ b/server_conn.go @@ -205,7 +205,7 @@ type readReq struct { type ServerConn struct { s *Server nconn net.Conn - isHTTP bool + tunnel Tunnel ctx context.Context ctxCancel func() @@ -231,7 +231,7 @@ type ServerConn struct { func (sc *ServerConn) initialize() { ctx, ctxCancel := context.WithCancel(sc.s.ctx) - if sc.s.TLSConfig != nil && !sc.isHTTP { + if sc.s.TLSConfig != nil && sc.tunnel == TunnelNone { sc.nconn = tls.Server(sc.nconn, sc.s.TLSConfig) } @@ -278,13 +278,10 @@ func (sc *ServerConn) Session() *ServerSession { // Transport returns transport details. func (sc *ServerConn) Transport() *ConnTransport { + sc.propsMutex.RLock() + defer sc.propsMutex.RUnlock() return &ConnTransport{ - Tunnel: func() Tunnel { - if sc.isHTTP { - return TunnelHTTP - } - return TunnelNone - }(), + Tunnel: sc.tunnel, } } diff --git a/server_conn_reader.go b/server_conn_reader.go index 25b0776c..21d8ed90 100644 --- a/server_conn_reader.go +++ b/server_conn_reader.go @@ -13,8 +13,28 @@ import ( "github.com/bluenviron/gortsplib/v5/pkg/conn" "github.com/bluenviron/gortsplib/v5/pkg/liberrors" "github.com/bluenviron/mediacommon/v2/pkg/rewindablereader" + "github.com/gorilla/websocket" ) +func isHTTPTunnel(req *http.Request) bool { + return ((req.Method == http.MethodGet && req.Header.Get("Accept") == "application/x-rtsp-tunnelled") || + (req.Method == http.MethodPost && req.Header.Get("Content-Type") == "application/x-rtsp-tunnelled")) && + req.Header.Get("X-Sessioncookie") != "" +} + +func isWebSocketTunnel(req *http.Request) bool { + return req.Method == http.MethodGet && + req.Header.Get("Connection") == "Upgrade" && + req.Header.Get("Upgrade") == "websocket" && + req.Header.Get("Sec-WebSocket-Protocol") == "rtsp.onvif.org" +} + +var upgrader = websocket.Upgrader{ + CheckOrigin: func(_ *http.Request) bool { + return true + }, +} + func makeReadWriter(r io.Reader, w io.Writer) io.ReadWriter { return struct { io.Reader @@ -65,9 +85,9 @@ func (cr *serverConnReader) run() { func (cr *serverConnReader) runInner() error { var rw io.ReadWriter = cr.sc.bc - if !cr.sc.isHTTP { + if cr.sc.tunnel == TunnelNone { var err error - rw, err = cr.upgradeToHTTP(rw) + rw, err = cr.handleTunneling(rw) if err != nil { return err } @@ -94,7 +114,7 @@ func (cr *serverConnReader) runInner() error { } } -func (cr *serverConnReader) upgradeToHTTP(in io.ReadWriter) (io.ReadWriter, error) { +func (cr *serverConnReader) handleTunneling(in io.ReadWriter) (io.ReadWriter, error) { rr := &rewindablereader.Reader{R: in} buf := make([]byte, 4) @@ -114,10 +134,53 @@ func (cr *serverConnReader) upgradeToHTTP(in io.ReadWriter) (io.ReadWriter, erro return nil, err } - if (req.Method != http.MethodGet && req.Method != http.MethodPost) || - (req.Method == http.MethodGet && req.Header.Get("Accept") != "application/x-rtsp-tunnelled") || - (req.Method == http.MethodPost && req.Header.Get("Content-Type") != "application/x-rtsp-tunnelled") || - req.Header.Get("X-Sessioncookie") == "" { + switch { + case isHTTPTunnel(req): + h := http.Header{} + h.Set("Cache-Control", "no-cache") + h.Set("Connection", "close") + h.Set("Content-Type", "application/x-rtsp-tunnelled") + h.Set("Pragma", "no-cache") + res := http.Response{ + StatusCode: http.StatusOK, + ProtoMajor: 1, + ProtoMinor: req.ProtoMinor, + Header: h, + ContentLength: -1, + } + var buf2 bytes.Buffer + res.Write(&buf2) //nolint:errcheck + cr.sc.nconn.SetWriteDeadline(time.Now().Add(cr.sc.s.WriteTimeout)) + _, err = in.Write(buf2.Bytes()) + if err != nil { + return nil, err + } + + cr.sc.httpReadBuf = buf + + err = cr.sc.s.handleHTTPChannel(sessionHandleHTTPChannelReq{ + sc: cr.sc, + write: (req.Method == http.MethodPost), + tunnelID: req.Header.Get("X-Sessioncookie"), + }) + return nil, err + + case isWebSocketTunnel(req): + resw := &wsResponseWriter{r: cr.sc.nconn, buf: buf, w: in, req: req} + resw.initialize() + var wconn *websocket.Conn + wconn, err = upgrader.Upgrade(resw, req, nil) + if err != nil { + return nil, err + } + + cr.sc.propsMutex.Lock() + cr.sc.tunnel = TunnelWebSocket + cr.sc.propsMutex.Unlock() + + return makeReadWriter(&wsReader{wc: wconn}, &wsWriter{wc: wconn}), nil + + default: res := http.Response{ StatusCode: http.StatusBadRequest, ProtoMajor: req.ProtoMajor, @@ -134,35 +197,6 @@ func (cr *serverConnReader) upgradeToHTTP(in io.ReadWriter) (io.ReadWriter, erro return nil, fmt.Errorf("invalid HTTP request") } - - h := http.Header{} - h.Set("Cache-Control", "no-cache") - h.Set("Connection", "close") - h.Set("Content-Type", "application/x-rtsp-tunnelled") - h.Set("Pragma", "no-cache") - res := http.Response{ - StatusCode: http.StatusOK, - ProtoMajor: 1, - ProtoMinor: req.ProtoMinor, - Header: h, - ContentLength: -1, - } - var buf2 bytes.Buffer - res.Write(&buf2) //nolint:errcheck - cr.sc.nconn.SetWriteDeadline(time.Now().Add(cr.sc.s.WriteTimeout)) - _, err = in.Write(buf2.Bytes()) - if err != nil { - return nil, err - } - - cr.sc.httpReadBuf = buf - - err = cr.sc.s.handleHTTPChannel(sessionHandleHTTPChannelReq{ - sc: cr.sc, - write: (req.Method == http.MethodPost), - tunnelID: req.Header.Get("X-Sessioncookie"), - }) - return nil, err } return makeReadWriter(rr, in), nil diff --git a/server_session.go b/server_session.go index adaf504a..fc927de4 100644 --- a/server_session.go +++ b/server_session.go @@ -196,7 +196,7 @@ func isTransportSupported(sc *ServerConn, tr *headers.Transport) bool { } // prevent using UDP with tunneling - if sc.isHTTP { + if sc.tunnel != TunnelNone { return false } diff --git a/server_test.go b/server_test.go index 6172c180..57dad2a3 100644 --- a/server_test.go +++ b/server_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/gorilla/websocket" "github.com/stretchr/testify/require" "github.com/bluenviron/gortsplib/v5/pkg/auth" @@ -1295,7 +1296,7 @@ func TestServerStreamErrorNoServer(t *testing.T) { require.Error(t, err) } -func TestServerHTTPTunnel(t *testing.T) { +func TestServerTunnelHTTP(t *testing.T) { for _, ca := range []string{"http", "https"} { t.Run(ca, func(t *testing.T) { done := make(chan struct{}) @@ -1426,3 +1427,50 @@ func TestServerHTTPTunnel(t *testing.T) { }) } } + +func TestServerTunnelWebSocket(t *testing.T) { + for _, ca := range []string{"ws", "wss"} { + t.Run(ca, func(t *testing.T) { + s := &Server{ + Handler: &testServerHandler{ + onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) { + return &base.Response{ + StatusCode: base.StatusNotFound, + }, nil, nil + }, + }, + RTSPAddress: "localhost:8554", + } + + if ca == "wss" { + cert, err := tls.X509KeyPair(serverCert, serverKey) + require.NoError(t, err) + s.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}} + } + + err := s.Start() + require.NoError(t, err) + defer s.Close() + + h := http.Header{} + h.Set("Sec-WebSocket-Protocol", "rtsp.onvif.org") + c, _, err := (&websocket.Dialer{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }).Dial(ca+"://localhost:8554", h) //nolint:bodyclose + require.NoError(t, err) + defer c.Close() //nolint:errcheck + + conn := conn.NewConn(bufio.NewReader(&wsReader{wc: c}), &wsWriter{wc: c}) + + rres, err := writeReqReadRes(conn, base.Request{ + Method: base.Describe, + URL: mustParseURL("rtsp://localhost:8554/teststream?param=value"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + }, + }) + require.NoError(t, err) + require.Equal(t, base.StatusNotFound, rres.StatusCode) + }) + } +} diff --git a/server_http_tunnel.go b/server_tunnel_http.go similarity index 100% rename from server_http_tunnel.go rename to server_tunnel_http.go diff --git a/server_tunnel_websocket.go b/server_tunnel_websocket.go new file mode 100644 index 00000000..62b8790c --- /dev/null +++ b/server_tunnel_websocket.go @@ -0,0 +1,133 @@ +package gortsplib + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net" + "net/http" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type wsNetConn struct { + r io.Reader + buf *bufio.Reader + w io.Writer +} + +func (c *wsNetConn) Read(b []byte) (n int, err error) { + return c.r.Read(b) +} + +func (c *wsNetConn) Write(b []byte) (n int, err error) { + return c.w.Write(b) +} + +func (c *wsNetConn) Close() error { + panic("unimplemented") +} + +func (c *wsNetConn) LocalAddr() net.Addr { + panic("unimplemented") +} + +func (c *wsNetConn) RemoteAddr() net.Addr { + panic("unimplemented") +} + +func (c *wsNetConn) SetDeadline(_ time.Time) error { + return nil +} + +func (c *wsNetConn) SetReadDeadline(_ time.Time) error { + return nil +} + +func (c *wsNetConn) SetWriteDeadline(_ time.Time) error { + return nil +} + +type wsResponseWriter struct { + r io.Reader + buf *bufio.Reader + w io.Writer + req *http.Request + + h http.Header +} + +func (w *wsResponseWriter) initialize() { + w.h = make(http.Header) +} + +func (w *wsResponseWriter) Header() http.Header { + return w.h +} + +func (w *wsResponseWriter) Write(p []byte) (int, error) { + return w.w.Write(p) +} + +func (w *wsResponseWriter) WriteHeader(statusCode int) { + res := http.Response{ + StatusCode: statusCode, + ProtoMajor: w.req.ProtoMajor, + ProtoMinor: w.req.ProtoMinor, + Header: w.h, + Request: w.req, + } + var buf2 bytes.Buffer + res.Write(&buf2) //nolint:errcheck + w.w.Write(buf2.Bytes()) +} + +func (w *wsResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return &wsNetConn{r: w.r, buf: w.buf, w: w.w}, bufio.NewReadWriter(w.buf, bufio.NewWriter(w.w)), nil +} + +type wsReader struct { + wc *websocket.Conn + + buf []byte +} + +func (r *wsReader) Read(p []byte) (int, error) { + if len(r.buf) == 0 { + var msgType int + var err error + msgType, r.buf, err = r.wc.ReadMessage() + if err != nil { + return 0, err + } + + if msgType != websocket.BinaryMessage { + return 0, fmt.Errorf("unxpected message type %v", msgType) + } + } + + n := copy(p, r.buf) + r.buf = r.buf[n:] + + return n, nil +} + +type wsWriter struct { + wc *websocket.Conn + + mutex sync.Mutex +} + +func (w *wsWriter) Write(p []byte) (int, error) { + w.mutex.Lock() + defer w.mutex.Unlock() + + err := w.wc.WriteMessage(websocket.BinaryMessage, p) + if err != nil { + return 0, err + } + return len(p), nil +}