diff --git a/http/middleware/session/session.go b/http/middleware/session/session.go index c09bc16e..b17dc296 100644 --- a/http/middleware/session/session.go +++ b/http/middleware/session/session.go @@ -2,6 +2,7 @@ package session import ( "bytes" + "fmt" "io" "net/http" "net/url" @@ -87,45 +88,11 @@ func NewWithConfig(config Config) echo.MiddlewareFunc { req := c.Request() path := req.URL.Path + referrer := req.Header.Get("Referer") - data := map[string]interface{}{} - - e := util.DefaultContext[interface{}](c, "session", nil) - if e != nil { - var ok bool - data, ok = e.(map[string]interface{}) - if !ok { - return api.Err(http.StatusForbidden, "", "invalid session data") - } - - if match, ok := data["match"].(string); ok { - if ok, err := glob.Match(match, path, '/'); !ok { - if err != nil { - return api.Err(http.StatusForbidden, "", "no match for '%s' in %s: %s", match, path, err.Error()) - } - - return api.Err(http.StatusForbidden, "", "no match for '%s' in %s", match, path) - } - } - - referrer := req.Header.Get("Referer") - if u, err := url.Parse(referrer); err == nil { - referrer = u.Host - } - - if remote, ok := data["remote"].([]string); ok && len(remote) != 0 { - match := false - for _, r := range remote { - if ok, _ := glob.Match(r, referrer, '.'); ok { - match = true - break - } - } - - if !match { - return api.Err(http.StatusForbidden, "", "remote not allowed") - } - } + data, err := verifySession(util.DefaultContext[interface{}](c, "session", nil), path, referrer) + if err != nil { + return api.Err(http.StatusForbidden, "", "invalid session data") } data["name"] = ctxuser @@ -148,6 +115,64 @@ func NewWithConfig(config Config) echo.MiddlewareFunc { } } +func verifySession(raw interface{}, path, referrer string) (map[string]interface{}, error) { + data := map[string]interface{}{} + + if raw == nil { + return data, nil + } + + var ok bool + data, ok = raw.(map[string]interface{}) + if !ok { + return data, fmt.Errorf("invalid session data") + } + + if match, ok := data["match"].(string); ok { + if ok, err := glob.Match(match, path, '/'); !ok { + if err != nil { + return data, fmt.Errorf("no match for '%s' in %s: %s", match, path, err.Error()) + } + + return data, fmt.Errorf("no match for '%s' in %s", match, path) + } + } + + if u, err := url.Parse(referrer); err == nil { + referrer = u.Host + } + + if remote, ok := data["remote"].([]interface{}); ok && len(remote) != 0 { + if len(referrer) == 0 { + return data, fmt.Errorf("remote not allowed") + } + + remotes := []string{} + for _, r := range remote { + v, ok := r.(string) + if !ok { + continue + } + + remotes = append(remotes, v) + } + + match := false + for _, r := range remotes { + if ok, _ := glob.Match(r, referrer, '.'); ok { + match = true + break + } + } + + if !match { + return data, fmt.Errorf("remote not allowed") + } + } + + return data, nil +} + func headerSize(header http.Header) int64 { var buffer bytes.Buffer diff --git a/http/middleware/session/session_test.go b/http/middleware/session/session_test.go new file mode 100644 index 00000000..0bf2c82e --- /dev/null +++ b/http/middleware/session/session_test.go @@ -0,0 +1,135 @@ +package session + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestVerifySession(t *testing.T) { + jsondata := []byte(`{ + "match": "/memfs/6faad99a-c440-4df1-9344-963869718d8d/**", + "remote": [ + "foo.example.com" + ] + }`) + + var rawdata interface{} + + err := json.Unmarshal(jsondata, &rawdata) + require.NoError(t, err) + + _, err = verifySession(rawdata, "/memfs/6faad99a-c440-4df1-9344-963869718d8d/main.m3u8", "http://foo.example.com") + require.NoError(t, err) + + _, err = verifySession(rawdata, "/memfs/6faad99a-c440-4df1-9344-963869718d8d/main.m3u8", "http://bar.example.com") + require.Error(t, err) + + _, err = verifySession(rawdata, "/memfs/6faad99a-c440-4df1-0000-963869718d8d/main.m3u8", "http://foo.example.com") + require.Error(t, err) + + _, err = verifySession(rawdata, "/memfs/6faad99a-c440-4df1-9344-963869718d8d/main.m3u8", "") + require.Error(t, err) +} + +func TestVerifySessionNoRemote(t *testing.T) { + jsondata := []byte(`{ + "match": "/memfs/6faad99a-c440-4df1-9344-963869718d8d/**", + "remote": [] + }`) + + var rawdata interface{} + + err := json.Unmarshal(jsondata, &rawdata) + require.NoError(t, err) + + _, err = verifySession(rawdata, "/memfs/6faad99a-c440-4df1-9344-963869718d8d/main.m3u8", "http://cm.example.com") + require.NoError(t, err) + + _, err = verifySession(rawdata, "/memfs/6faad99a-c440-4df1-9344-963869718d8d/main.m3u8", "") + require.NoError(t, err) + + jsondata = []byte(`{ + "match": "/memfs/6faad99a-c440-4df1-9344-963869718d8d/**" + }`) + + err = json.Unmarshal(jsondata, &rawdata) + require.NoError(t, err) + + _, err = verifySession(rawdata, "/memfs/6faad99a-c440-4df1-9344-963869718d8d/main.m3u8", "http://cm.example.com") + require.NoError(t, err) + + _, err = verifySession(rawdata, "/memfs/6faad99a-c440-4df1-9344-963869718d8d/main.m3u8", "") + require.NoError(t, err) +} + +func TestVerifySessionWildcardRemote(t *testing.T) { + jsondata := []byte(`{ + "match": "/memfs/6faad99a-c440-4df1-9344-963869718d8d/**", + "remote": [ + "*.example.com" + ] + }`) + + var rawdata interface{} + + err := json.Unmarshal(jsondata, &rawdata) + require.NoError(t, err) + + _, err = verifySession(rawdata, "/memfs/6faad99a-c440-4df1-9344-963869718d8d/main.m3u8", "http://foo.example.com") + require.NoError(t, err) + + _, err = verifySession(rawdata, "/memfs/6faad99a-c440-4df1-9344-963869718d8d/main.m3u8", "http://bar.example.com") + require.NoError(t, err) + + _, err = verifySession(rawdata, "/memfs/6faad99a-c440-4df1-9344-963869718d8d/main.m3u8", "http://sub.bar.example.com") + require.Error(t, err) +} + +func TestVerifySessionSuperWildcardRemote(t *testing.T) { + jsondata := []byte(`{ + "match": "/memfs/6faad99a-c440-4df1-9344-963869718d8d/**", + "remote": [ + "**.example.com" + ] + }`) + + var rawdata interface{} + + err := json.Unmarshal(jsondata, &rawdata) + require.NoError(t, err) + + _, err = verifySession(rawdata, "/memfs/6faad99a-c440-4df1-9344-963869718d8d/main.m3u8", "http://foo.example.com") + require.NoError(t, err) + + _, err = verifySession(rawdata, "/memfs/6faad99a-c440-4df1-9344-963869718d8d/main.m3u8", "http://bar.example.com") + require.NoError(t, err) + + _, err = verifySession(rawdata, "/memfs/6faad99a-c440-4df1-9344-963869718d8d/main.m3u8", "http://sub.bar.example.com") + require.NoError(t, err) +} + +func TestVerifySessionMultipleRemote(t *testing.T) { + jsondata := []byte(`{ + "match": "/memfs/6faad99a-c440-4df1-9344-963869718d8d/**", + "remote": [ + "foo.example.com", + "bar.otherdomain.com" + ] + }`) + + var rawdata interface{} + + err := json.Unmarshal(jsondata, &rawdata) + require.NoError(t, err) + + _, err = verifySession(rawdata, "/memfs/6faad99a-c440-4df1-9344-963869718d8d/main.m3u8", "http://foo.example.com") + require.NoError(t, err) + + _, err = verifySession(rawdata, "/memfs/6faad99a-c440-4df1-9344-963869718d8d/main.m3u8", "http://bar.otherdomain.com") + require.NoError(t, err) + + _, err = verifySession(rawdata, "/memfs/6faad99a-c440-4df1-9344-963869718d8d/main.m3u8", "http://bar.example.com") + require.Error(t, err) +}