diff --git a/cmd/api/api.go b/cmd/api/api.go index 3158c128..55b935d1 100644 --- a/cmd/api/api.go +++ b/cmd/api/api.go @@ -5,7 +5,6 @@ import ( "github.com/AlexxIT/go2rtc/cmd/app" "github.com/AlexxIT/go2rtc/cmd/streams" "github.com/AlexxIT/go2rtc/pkg/streamer" - "github.com/gorilla/websocket" "github.com/rs/zerolog" "net" "net/http" @@ -133,7 +132,8 @@ func streamsHandler(w http.ResponseWriter, r *http.Request) { func apiWS(w http.ResponseWriter, r *http.Request) { ctx := new(Context) if err := ctx.Upgrade(w, r); err != nil { - log.Error().Err(err).Msg("[api.ws] upgrade") + origin := r.Header.Get("Origin") + log.Error().Err(err).Caller().Msgf("host=%s origin=%s", r.Host, origin) return } defer ctx.Close() diff --git a/cmd/api/ws.go b/cmd/api/ws.go index 09d90190..50a8ced9 100644 --- a/cmd/api/ws.go +++ b/cmd/api/ws.go @@ -4,6 +4,8 @@ import ( "github.com/AlexxIT/go2rtc/pkg/streamer" "github.com/gorilla/websocket" "net/http" + "net/url" + "strings" "sync" ) @@ -13,7 +15,30 @@ func initWS(origin string) { WriteBufferSize: 512000, } - if origin == "*" { + switch origin { + case "": + // same origin + ignore port + wsUp.CheckOrigin = func(r *http.Request) bool { + origin := r.Header["Origin"] + if len(origin) == 0 { + return true + } + o, err := url.Parse(origin[0]) + if err != nil { + return false + } + if o.Host == r.Host { + return true + } + log.Trace().Msgf("[api.ws] origin=%s, host=%s", o.Host, r.Host) + // https://github.com/AlexxIT/go2rtc/issues/118 + if i := strings.IndexByte(o.Host, ':'); i > 0 { + return o.Host[:i] == r.Host + } + return false + } + case "*": + // any origin wsUp.CheckOrigin = func(r *http.Request) bool { return true }