diff --git a/app/api/api.go b/app/api/api.go index 457a5e9a..f4de261f 100644 --- a/app/api/api.go +++ b/app/api/api.go @@ -34,6 +34,7 @@ import ( "github.com/datarhei/core/v16/restream" restreamapp "github.com/datarhei/core/v16/restream/app" "github.com/datarhei/core/v16/restream/replace" + "github.com/datarhei/core/v16/restream/rewrite" restreamstore "github.com/datarhei/core/v16/restream/store" "github.com/datarhei/core/v16/rtmp" "github.com/datarhei/core/v16/service" @@ -440,6 +441,9 @@ func (a *api) start() error { return fmt.Errorf("iam: %w", err) } + // Create default policies for anonymous users in order to mimic + // the behaviour before IAM + iam.RemovePolicy("$anon", "$none", "", "") iam.RemovePolicy("$localhost", "$none", "", "") @@ -465,6 +469,14 @@ func (a *api) start() error { iam.AddPolicy("$localhost", "$none", "fs:/memfs/**", "GET|HEAD|OPTIONS|POST|PUT|DELETE") } + if cfg.RTMP.Enable && len(cfg.RTMP.Token) == 0 { + iam.AddPolicy("$anon", "$none", "rtmp:/**", "PUBLISH|PLAY") + } + + if cfg.SRT.Enable && len(cfg.SRT.Token) == 0 { + iam.AddPolicy("$anon", "$none", "srt:**", "PUBLISH|PLAY") + } + a.iam = iam } @@ -592,6 +604,35 @@ func (a *api) start() error { a.ffmpeg = ffmpeg + var rw rewrite.Rewriter + + { + baseAddress := func(address string) string { + var base string + host, port, _ := gonet.SplitHostPort(address) + if len(host) == 0 { + base = "localhost:" + port + } else { + base = address + } + + return base + } + + httpBase := baseAddress(cfg.Address) + rtmpBase := baseAddress(cfg.RTMP.Address) + cfg.RTMP.App + srtBase := baseAddress(cfg.SRT.Address) + + rw, err = rewrite.New(rewrite.Config{ + HTTPBase: "http://" + httpBase, + RTMPBase: "rtmp://" + rtmpBase, + SRTBase: "srt://" + srtBase, + }) + if err != nil { + return fmt.Errorf("unable to create url rewriter: %w", err) + } + } + a.replacer = replace.New() { @@ -627,8 +668,8 @@ func (a *api) start() error { } template += "/{name}" - if len(cfg.RTMP.Token) != 0 { - template += "?token=" + cfg.RTMP.Token + if identity, _ := a.iam.GetIdentity(config.Owner); identity != nil { + template += "/" + identity.GetServiceToken() } return template @@ -643,14 +684,14 @@ func (a *api) start() error { template := "srt://" + host + ":" + port + "?mode=caller&transtype=live&latency={latency}&streamid={name}" if section == "output" { template += ",mode:publish" - } else { - template += ",mode:request" } - if len(cfg.SRT.Token) != 0 { - template += ",token:" + cfg.SRT.Token + + if identity, _ := a.iam.GetIdentity(config.Owner); identity != nil { + template += ",token:" + identity.GetServiceToken() } + if len(cfg.SRT.Passphrase) != 0 { - template += "&passphrase=" + cfg.SRT.Passphrase + template += "&passphrase=" + url.QueryEscape(cfg.SRT.Passphrase) } return template @@ -693,8 +734,10 @@ func (a *api) start() error { Store: store, Filesystems: filesystems, Replace: a.replacer, + Rewrite: rw, FFmpeg: a.ffmpeg, MaxProcesses: cfg.FFmpeg.MaxProcesses, + IAM: a.iam, Logger: a.log.logger.core.WithComponent("Process"), }) @@ -703,49 +746,7 @@ func (a *api) start() error { } a.restream = restream - /* - var httpjwt jwt.JWT - if cfg.API.Auth.Enable { - secret := rand.String(32) - if len(cfg.API.Auth.JWT.Secret) != 0 { - secret = cfg.API.Auth.Username + cfg.API.Auth.Password + cfg.API.Auth.JWT.Secret - } - - var err error - httpjwt, err = jwt.New(jwt.Config{ - Realm: app.Name, - Secret: secret, - SkipLocalhost: cfg.API.Auth.DisableLocalhost, - }) - - if err != nil { - return fmt.Errorf("unable to create JWT provider: %w", err) - } - - if validator, err := jwt.NewLocalValidator(a.iam); err == nil { - if err := httpjwt.AddValidator(app.Name, validator); err != nil { - return fmt.Errorf("unable to add local JWT validator: %w", err) - } - } else { - return fmt.Errorf("unable to create local JWT validator: %w", err) - } - - if cfg.API.Auth.Auth0.Enable { - for _, t := range cfg.API.Auth.Auth0.Tenants { - if validator, err := jwt.NewAuth0Validator(a.iam); err == nil { - if err := httpjwt.AddValidator("https://"+t.Domain+"/", validator); err != nil { - return fmt.Errorf("unable to add Auth0 JWT validator: %w", err) - } - } else { - return fmt.Errorf("unable to create Auth0 JWT validator: %w", err) - } - } - } - } - - a.httpjwt = httpjwt - */ metrics, err := monitor.NewHistory(monitor.HistoryConfig{ Enable: cfg.Metrics.Enable, Timerange: time.Duration(cfg.Metrics.Range) * time.Second, diff --git a/app/casbin/adapter.go b/app/casbin/adapter.go deleted file mode 100644 index c00daafb..00000000 --- a/app/casbin/adapter.go +++ /dev/null @@ -1,502 +0,0 @@ -package main - -import ( - "encoding/json" - "fmt" - "os" - "path/filepath" - "strings" - "sync" - - "github.com/casbin/casbin/v2/model" - "github.com/casbin/casbin/v2/persist" -) - -// Adapter is the file adapter for Casbin. -// It can load policy from file or save policy to file. -type adapter struct { - filePath string - groups []Group - lock sync.Mutex -} - -func NewAdapter(filePath string) persist.Adapter { - return &adapter{filePath: filePath} -} - -// Adapter -func (a *adapter) LoadPolicy(model model.Model) error { - a.lock.Lock() - defer a.lock.Unlock() - - if a.filePath == "" { - return fmt.Errorf("invalid file path, file path cannot be empty") - } - - /* - logger := &log.DefaultLogger{} - logger.EnableLog(true) - - model.SetLogger(logger) - */ - - return a.loadPolicyFile(model) -} - -func (a *adapter) loadPolicyFile(model model.Model) error { - if _, err := os.Stat(a.filePath); os.IsNotExist(err) { - a.groups = []Group{} - return nil - } - - data, err := os.ReadFile(a.filePath) - if err != nil { - return err - } - - groups := []Group{} - - err = json.Unmarshal(data, &groups) - if err != nil { - return err - } - - rule := [5]string{} - for _, group := range groups { - rule[0] = "p" - rule[2] = group.Name - for name, roles := range group.Roles { - rule[1] = "role:" + name - for _, role := range roles { - rule[3] = role.Resource - rule[4] = role.Actions - - if err := a.importPolicy(model, rule[0:5]); err != nil { - return err - } - } - } - - for _, policy := range group.Policies { - rule[1] = policy.Username - rule[3] = policy.Resource - rule[4] = policy.Actions - - if err := a.importPolicy(model, rule[0:5]); err != nil { - return err - } - } - - rule[0] = "g" - rule[3] = group.Name - - for _, ug := range group.UserRoles { - rule[1] = ug.Username - rule[2] = "role:" + ug.Role - - if err := a.importPolicy(model, rule[0:4]); err != nil { - return err - } - } - } - - a.groups = groups - - return nil -} - -func (a *adapter) importPolicy(model model.Model, rule []string) error { - copiedRule := make([]string, len(rule)) - copy(copiedRule, rule) - - ok, err := model.HasPolicyEx(copiedRule[0], copiedRule[0], copiedRule[1:]) - if err != nil { - return err - } - if ok { - return nil // skip duplicated policy - } - - model.AddPolicy(copiedRule[0], copiedRule[0], copiedRule[1:]) - - return nil -} - -// Adapter -func (a *adapter) SavePolicy(model model.Model) error { - a.lock.Lock() - defer a.lock.Unlock() - - return a.savePolicyFile() -} - -func (a *adapter) savePolicyFile() error { - if a.filePath == "" { - return fmt.Errorf("invalid file path, file path cannot be empty") - } - - jsondata, err := json.MarshalIndent(a.groups, "", " ") - if err != nil { - return err - } - - dir, filename := filepath.Split(a.filePath) - - tmpfile, err := os.CreateTemp(dir, filename) - if err != nil { - return err - } - - defer os.Remove(tmpfile.Name()) - - if _, err := tmpfile.Write(jsondata); err != nil { - return err - } - - if err := tmpfile.Close(); err != nil { - return err - } - - if err := os.Rename(tmpfile.Name(), a.filePath); err != nil { - return err - } - - return nil -} - -// Adapter (auto-save) -func (a *adapter) AddPolicy(sec, ptype string, rule []string) error { - a.lock.Lock() - defer a.lock.Unlock() - - err := a.addPolicy(ptype, rule) - if err != nil { - return err - } - - return a.savePolicyFile() -} - -// BatchAdapter (auto-save) -func (a *adapter) AddPolicies(sec string, ptype string, rules [][]string) error { - a.lock.Lock() - defer a.lock.Unlock() - - for _, rule := range rules { - err := a.addPolicy(ptype, rule) - if err != nil { - return err - } - } - - return a.savePolicyFile() -} - -func (a *adapter) addPolicy(ptype string, rule []string) error { - ok, err := a.hasPolicy(ptype, rule) - if err != nil { - return err - } - - if ok { - // the policy is already there, nothing to add - return nil - } - - username := "" - role := "" - domain := "" - resource := "" - actions := "" - - if ptype == "p" { - username = rule[0] - domain = rule[1] - resource = rule[2] - actions = rule[3] - } else if ptype == "g" { - username = rule[0] - role = rule[1] - domain = rule[2] - } else { - return fmt.Errorf("unknown ptype: %s", ptype) - } - - var group *Group = nil - for i := range a.groups { - if a.groups[i].Name == domain { - group = &a.groups[i] - } - } - - if group == nil { - g := Group{ - Name: domain, - } - - a.groups = append(a.groups, g) - group = &g - } - - if ptype == "p" { - if strings.HasPrefix(username, "role:") { - if group.Roles == nil { - group.Roles = make(map[string][]Role) - } - - role := strings.TrimPrefix(username, "role:") - group.Roles[role] = append(group.Roles[role], Role{ - Resource: resource, - Actions: actions, - }) - } else { - group.Policies = append(group.Policies, Policy{ - Username: rule[0], - Role: Role{ - Resource: resource, - Actions: actions, - }, - }) - } - } else { - group.UserRoles = append(group.UserRoles, MapUserRole{ - Username: username, - Role: strings.TrimPrefix(role, "role:"), - }) - } - - return nil -} - -func (a *adapter) hasPolicy(ptype string, rule []string) (bool, error) { - var username string - var role string - var domain string - var resource string - var actions string - - if ptype == "p" { - if len(rule) != 4 { - return false, fmt.Errorf("invalid rule length. must be 'user/role, domain, resource, actions'") - } - - username = rule[0] - domain = rule[1] - resource = rule[2] - actions = rule[3] - } else if ptype == "g" { - username = rule[0] - role = rule[1] - domain = rule[2] - } else { - return false, fmt.Errorf("unknown ptype: %s", ptype) - } - - var group *Group = nil - for _, g := range a.groups { - if g.Name == domain { - group = &g - break - } - } - - if group == nil { - // if we can't find any group with that name, then the policy doesn't exist - return false, nil - } - - if ptype == "p" { - isRole := false - if strings.HasPrefix(username, "role:") { - isRole = true - username = strings.TrimPrefix(username, "role:") - } - - if isRole { - roles, ok := group.Roles[username] - if !ok { - // unknown role, policy doesn't exist - return false, nil - } - - for _, role := range roles { - if role.Resource == resource && role.Actions == actions { - return true, nil - } - } - } else { - for _, p := range group.Policies { - if p.Username == username && p.Resource == resource && p.Actions == actions { - return true, nil - } - } - } - } else { - role = strings.TrimPrefix(role, "role:") - for _, user := range group.UserRoles { - if user.Username == username && user.Role == role { - return true, nil - } - } - } - - return false, nil -} - -// Adapter (auto-save) -func (a *adapter) RemovePolicy(sec string, ptype string, rule []string) error { - a.lock.Lock() - defer a.lock.Unlock() - - err := a.removePolicy(ptype, rule) - if err != nil { - return err - } - - return a.savePolicyFile() -} - -// BatchAdapter (auto-save) -func (a *adapter) RemovePolicies(sec string, ptype string, rules [][]string) error { - a.lock.Lock() - defer a.lock.Unlock() - - for _, rule := range rules { - err := a.removePolicy(ptype, rule) - if err != nil { - return err - } - } - - return a.savePolicyFile() -} - -func (a *adapter) removePolicy(ptype string, rule []string) error { - ok, err := a.hasPolicy(ptype, rule) - if err != nil { - return err - } - - if !ok { - // the policy is not there, nothing to remove - return nil - } - - username := "" - role := "" - domain := "" - resource := "" - actions := "" - - if ptype == "p" { - username = rule[0] - domain = rule[1] - resource = rule[2] - actions = rule[3] - } else if ptype == "g" { - username = rule[0] - role = rule[1] - domain = rule[2] - } else { - return fmt.Errorf("unknown ptype: %s", ptype) - } - - var group *Group = nil - for i := range a.groups { - if a.groups[i].Name == domain { - group = &a.groups[i] - } - } - - if ptype == "p" { - isRole := false - if strings.HasPrefix(username, "role:") { - isRole = true - username = strings.TrimPrefix(username, "role:") - } - - if isRole { - roles := group.Roles[username] - - newRoles := []Role{} - - for _, role := range roles { - if role.Resource == resource && role.Actions == actions { - continue - } - - newRoles = append(newRoles, role) - } - - group.Roles[username] = newRoles - } else { - policies := []Policy{} - - for _, p := range group.Policies { - if p.Username == username && p.Resource == resource && p.Actions == actions { - continue - } - - policies = append(policies, p) - } - - group.Policies = policies - } - } else { - role = strings.TrimPrefix(role, "role:") - - users := []MapUserRole{} - - for _, user := range group.UserRoles { - if user.Username == username && user.Role == role { - continue - } - - users = append(users, user) - } - - group.UserRoles = users - } - - return nil -} - -// Adapter -func (a *adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error { - return fmt.Errorf("not implemented") -} - -func (a *adapter) GetAllGroupNames() []string { - a.lock.Lock() - defer a.lock.Unlock() - - groups := []string{} - - for _, group := range a.groups { - groups = append(groups, group.Name) - } - - return groups -} - -type Group struct { - Name string `json:"name"` - Roles map[string][]Role `json:"roles"` - UserRoles []MapUserRole `json:"userroles"` - Policies []Policy `json:"policies"` -} - -type Role struct { - Resource string `json:"resource"` - Actions string `json:"actions"` -} - -type MapUserRole struct { - Username string `json:"username"` - Role string `json:"role"` -} - -type Policy struct { - Username string `json:"username"` - Role -} diff --git a/app/casbin/casbin b/app/casbin/casbin deleted file mode 100755 index 0008b014..00000000 Binary files a/app/casbin/casbin and /dev/null differ diff --git a/app/casbin/main.go b/app/casbin/main.go deleted file mode 100644 index 8ffc9e88..00000000 --- a/app/casbin/main.go +++ /dev/null @@ -1,215 +0,0 @@ -package main - -import ( - "flag" - "fmt" - "os" - "strings" - - "github.com/casbin/casbin/v2" - "github.com/casbin/casbin/v2/model" - "github.com/gobwas/glob" -) - -func main() { - var subject string - var domain string - var object string - var action string - - flag.StringVar(&subject, "subject", "$anon", "subject of this request") - flag.StringVar(&domain, "domain", "$none", "domain of this request") - flag.StringVar(&object, "object", "", "object of this request") - flag.StringVar(&action, "action", "", "action of this request") - - flag.Parse() - - m := model.NewModel() - m.AddDef("r", "r", "sub, dom, obj, act") - m.AddDef("p", "p", "sub, dom, obj, act") - m.AddDef("g", "g", "_, _, _") - m.AddDef("e", "e", "some(where (p.eft == allow))") - m.AddDef("m", "m", `g(r.sub, p.sub, r.dom) && r.dom == p.dom && ResourceMatch(r.obj, r.dom, p.obj) && ActionMatch(r.act, p.act) || r.sub == "$superuser"`) - - a := NewAdapter("./policy.json") - - e, err := casbin.NewEnforcer(m, a) - if err != nil { - fmt.Printf("error: %s\n", err) - os.Exit(1) - } - - e.AddFunction("ResourceMatch", ResourceMatchFunc) - e.AddFunction("ActionMatch", ActionMatchFunc) - - /* - if err := addGroup(e, "foobar"); err != nil { - fmt.Printf("error: %s\n", err) - os.Exit(1) - } - - if err := addGroupUser(e, "foobar", "franz", "admin"); err != nil { - fmt.Printf("error: %s\n", err) - os.Exit(1) - } - - if err := addGroupUser(e, "foobar", "$anon", "anonymous"); err != nil { - fmt.Printf("error: %s\n", err) - os.Exit(1) - } - - e.RemovePolicy("bob", "igelcamp", "processid:*", "COMMAND") - e.AddPolicy("bob", "igelcamp", "processid:bob-*", "COMMAND") - */ - ok, reason, err := e.EnforceEx(subject, domain, object, action) - if err != nil { - fmt.Printf("error: %s\n", err) - os.Exit(1) - } - - if ok { - fmt.Printf("OK: %v\n", reason) - } else { - fmt.Printf("not OK\n") - } -} - -func ResourceMatch(request, domain, policy string) bool { - reqPrefix, reqResource := getPrefix(request) - polPrefix, polResource := getPrefix(policy) - - if reqPrefix != polPrefix { - return false - } - - fmt.Printf("prefix: %s\n", reqPrefix) - fmt.Printf("requested resource: %s\n", reqResource) - fmt.Printf("requested domain: %s\n", domain) - fmt.Printf("policy resource: %s\n", polResource) - - var match bool - var err error - - if reqPrefix == "processid" { - match, err = Match(polResource, reqResource) - if err != nil { - return false - } - } else if reqPrefix == "api" { - match, err = Match(polResource, reqResource, rune('/')) - if err != nil { - return false - } - } else if reqPrefix == "fs" { - match, err = Match(polResource, reqResource, rune('/')) - if err != nil { - return false - } - } else if reqPrefix == "rtmp" { - match, err = Match(polResource, reqResource) - if err != nil { - return false - } - } else if reqPrefix == "srt" { - match, err = Match(polResource, reqResource) - if err != nil { - return false - } - } else { - match, err = Match(polResource, reqResource) - if err != nil { - return false - } - } - - fmt.Printf("match: %v\n", match) - - return match -} - -func ResourceMatchFunc(args ...interface{}) (interface{}, error) { - name1 := args[0].(string) - name2 := args[1].(string) - name3 := args[2].(string) - - return (bool)(ResourceMatch(name1, name2, name3)), nil -} - -func ActionMatch(request string, policy string) bool { - request = strings.ToUpper(request) - actions := strings.Split(strings.ToUpper(policy), "|") - if len(actions) == 0 { - return false - } - - for _, a := range actions { - if request == a { - return true - } - } - - return false -} - -func ActionMatchFunc(args ...interface{}) (interface{}, error) { - name1 := args[0].(string) - name2 := args[1].(string) - - return (bool)(ActionMatch(name1, name2)), nil -} - -func getPrefix(s string) (string, string) { - splits := strings.SplitN(s, ":", 2) - - if len(splits) == 0 { - return "", "" - } - - if len(splits) == 1 { - return "", splits[0] - } - - return splits[0], splits[1] -} - -func Match(pattern, name string, separators ...rune) (bool, error) { - g, err := glob.Compile(pattern, separators...) - if err != nil { - return false, err - } - - return g.Match(name), nil -} - -func addGroup(e *casbin.Enforcer, name string) error { - rules := [][]string{} - - rules = append(rules, []string{"role:admin", name, "api:/process/**", "GET|POST|PUT|DELETE"}) - rules = append(rules, []string{"role:admin", name, "processid:*", "CONFIG|PROGRESS|REPORT|METADATA|COMMAND"}) - rules = append(rules, []string{"role:admin", name, "rtmp:" + name + "/*", "PUBLISH|PLAY"}) - rules = append(rules, []string{"role:admin", name, "srt:" + name + "/*", "PUBLISH|PLAY"}) - rules = append(rules, []string{"role:admin", name, "fs:/" + name + "/**", "GET|POST|PUT|DELETE"}) - rules = append(rules, []string{"role:admin", name, "fs:/memfs/" + name + "/**", "GET|POST|PUT|DELETE"}) - - rules = append(rules, []string{"role:user", name, "api:/process/**", "GET"}) - rules = append(rules, []string{"role:user", name, "processid:*", "PROGRESS"}) - rules = append(rules, []string{"role:user", name, "rtmp:" + name + "/*", "PLAY"}) - rules = append(rules, []string{"role:user", name, "srt:" + name + "/*", "PLAY"}) - rules = append(rules, []string{"role:user", name, "fs:/" + name + "/**", "GET"}) - rules = append(rules, []string{"role:user", name, "fs:/memfs/" + name + "/**", "GET"}) - - rules = append(rules, []string{"role:anonymous", name, "rtmp:" + name + "/*", "PLAY"}) - rules = append(rules, []string{"role:anonymous", name, "srt:" + name + "/*", "PLAY"}) - rules = append(rules, []string{"role:anonymous", name, "fs:/" + name + "/**", "GET"}) - rules = append(rules, []string{"role:anonymous", name, "fs:/memfs/" + name + "/**", "GET"}) - - _, err := e.AddPolicies(rules) - - return err -} - -func addGroupUser(e *casbin.Enforcer, group, username, role string) error { - _, err := e.AddGroupingPolicy(username, "role:"+role, group) - - return err -} diff --git a/app/casbin/policy.csv b/app/casbin/policy.csv deleted file mode 100644 index c4d50a0e..00000000 --- a/app/casbin/policy.csv +++ /dev/null @@ -1,17 +0,0 @@ -p, admin, igelcamp, api:/process/**, GET|POST|PUT|DELETE -p, admin, igelcamp, processid:*, GET|POST|PUT|DELETE -p, admin, igelcamp, rtmp:*, PUBLISH|PLAY -p, admin, igelcamp, srt:*, PUBLISH|PLAY -p, admin, igelcamp, fs:/igelcamp/**, GET|POST|PUT|DELETE -p, admin, igelcamp, fs:/memfs/igelcamp/**, GET|POST|PUT|DELETE -p, user, igelcamp, api:/process/**, GET -p, user, igelcamp, processid:*, GET -p, user, igelcamp, rtmp:*, PLAY -p, user, igelcamp, srt:*, PLAY -p, user, igelcamp, fs:/igelcamp/**, GET -p, user, igelcamp, fs:/memfs/igelcamp/**, GET -p, anonymous, $none, fs:/*, GET - -g, alice, admin, igelcamp -g, alice, anonymous, $none -g, bob, user, igelcamp \ No newline at end of file diff --git a/app/casbin/policy.json b/app/casbin/policy.json deleted file mode 100644 index 08a2c830..00000000 --- a/app/casbin/policy.json +++ /dev/null @@ -1,206 +0,0 @@ -[ - { - "name": "igelcamp", - "roles": { - "admin": [ - { - "resource": "api:/process/**", - "actions": "GET|POST|PUT|DELETE" - }, - { - "resource": "processid:*", - "actions": "CONFIG|PROGRESS|REPORT|METADATA|COMMAND" - }, - { - "resource": "rtmp:igelcamp/*", - "actions": "PUBLISH|PLAY" - }, - { - "resource": "srt:igelcamp/*", - "actions": "PUBLISH|PLAY" - }, - { - "resource": "fs:/igelcamp/**", - "actions": "GET|POST|PUT|DELETE" - }, - { - "resource": "fs:/memfs/igelcamp/**", - "actions": "GET|POST|PUT|DELETE" - } - ], - "anonymous": [ - { - "resource": "rtmp:igelcamp/*", - "actions": "PLAY" - }, - { - "resource": "srt:igelcamp/*", - "actions": "PLAY" - }, - { - "resource": "fs:/igelcamp/**", - "actions": "GET" - }, - { - "resource": "fs:/memfs/igelcamp/**", - "actions": "GET" - } - ], - "user": [ - { - "resource": "api:/process/**", - "actions": "GET" - }, - { - "resource": "processid:*", - "actions": "PROGRESS" - }, - { - "resource": "rtmp:igelcamp/*", - "actions": "PLAY" - }, - { - "resource": "srt:igelcamp/*", - "actions": "PLAY" - }, - { - "resource": "fs:/igelcamp/**", - "actions": "GET" - }, - { - "resource": "fs:/memfs/igelcamp/**", - "actions": "GET" - } - ] - }, - "userroles": [ - { - "username": "alice", - "role": "admin" - }, - { - "username": "bob", - "role": "user" - }, - { - "username": "$anon", - "role": "anonymous" - } - ], - "policies": [ - { - "username": "bob", - "resource": "processid:bob-*", - "actions": "COMMAND" - } - ] - }, - { - "name": "$none", - "roles": { - "anonymous": [ - { - "resource": "fs:/*", - "actions": "GET" - } - ] - }, - "userroles": [ - { - "username": "$anon", - "role": "anonymous" - }, - { - "username": "alice", - "role": "anonymous" - }, - { - "username": "bob", - "role": "anonymous" - } - ], - "policies": null - }, - { - "name": "foobar", - "roles": { - "admin": [ - { - "resource": "processid:*", - "actions": "CONFIG|PROGRESS|REPORT|METADATA|COMMAND" - }, - { - "resource": "rtmp:foobar/*", - "actions": "PUBLISH|PLAY" - }, - { - "resource": "srt:foobar/*", - "actions": "PUBLISH|PLAY" - }, - { - "resource": "fs:/foobar/**", - "actions": "GET|POST|PUT|DELETE" - }, - { - "resource": "fs:/memfs/foobar/**", - "actions": "GET|POST|PUT|DELETE" - } - ], - "anonymous": [ - { - "resource": "rtmp:foobar/*", - "actions": "PLAY" - }, - { - "resource": "srt:foobar/*", - "actions": "PLAY" - }, - { - "resource": "fs:/foobar/**", - "actions": "GET" - }, - { - "resource": "fs:/memfs/foobar/**", - "actions": "GET" - } - ], - "user": [ - { - "resource": "api:/process/**", - "actions": "GET" - }, - { - "resource": "processid:*", - "actions": "PROGRESS" - }, - { - "resource": "rtmp:foobar/*", - "actions": "PLAY" - }, - { - "resource": "srt:foobar/*", - "actions": "PLAY" - }, - { - "resource": "fs:/foobar/**", - "actions": "GET" - }, - { - "resource": "fs:/memfs/foobar/**", - "actions": "GET" - } - ] - }, - "userroles": [ - { - "username": "franz", - "role": "admin" - }, - { - "username": "$anon", - "role": "anonymous" - } - ], - "policies": null - } -] \ No newline at end of file diff --git a/app/casbin/users.json b/app/casbin/users.json deleted file mode 100644 index c33e0b30..00000000 --- a/app/casbin/users.json +++ /dev/null @@ -1,93 +0,0 @@ -[ - { - "name": "alice", - "superuser": false, - "auth": { - "api": { - "userpass": { - "enable": true, - "username": "foo", - "password": "bar" - }, - "auth0": { - "enable": true, - "user": "google|42", - "tenant": "tenant1" - } - }, - "http": { - "basic": { - "enable": true, - "username": "bar", - "password": "baz" - } - }, - "rtmp": { - "enable": true, - "token": "abc123" - }, - "srt": { - "enable": true, - "token": "xyz987" - } - } - }, - { - "name": "bob", - "superuser": true, - "auth": { - "api": { - "userpass": { - "enable": true, - "username": "foo", - "password": "baz" - }, - "auth0": { - "enable": true, - "user": "github|88", - "tenant": "tenant2" - } - }, - "http": { - "basic": { - "enable": true, - "username": "boz", - "password": "bok" - } - }, - "rtmp": { - "enable": true, - "token": "abc456" - }, - "srt": { - "enable": true, - "token": "xyz654" - } - } - }, - { - "name": "$anon", - "superuser": false, - "auth": { - "api": { - "userpass": { - "enable": false - }, - "auth0": { - "enable": false - } - }, - "http": { - "basic": { - "enable": false - } - }, - "rtmp": { - "enable": false - }, - "srt": { - "enable": false - } - } - } -] \ No newline at end of file diff --git a/app/import/import.go b/app/import/import.go index 5899c350..58817554 100644 --- a/app/import/import.go +++ b/app/import/import.go @@ -17,6 +17,7 @@ import ( "github.com/datarhei/core/v16/encoding/json" "github.com/datarhei/core/v16/ffmpeg" "github.com/datarhei/core/v16/ffmpeg/skills" + "github.com/datarhei/core/v16/iam" "github.com/datarhei/core/v16/io/fs" "github.com/datarhei/core/v16/restream" "github.com/datarhei/core/v16/restream/app" @@ -502,6 +503,7 @@ func importV1(fs fs.Filesystem, path string, cfg importConfig) (store.StoreData, } r := store.NewStoreData() + r.Version = 4 jsondata, err := fs.ReadFile(path) if err != nil { @@ -1428,17 +1430,30 @@ func probeInput(binary string, config app.Config) app.Probe { return app.Probe{} } + iam, _ := iam.NewIAM(iam.Config{ + FS: dummyfs, + Superuser: iam.User{ + Name: "foobar", + }, + JWTRealm: "", + JWTSecret: "", + Logger: nil, + }) + + iam.AddPolicy("$anon", "$none", "process:*", "CREATE|GET|DELETE|PROBE") + rs, err := restream.New(restream.Config{ FFmpeg: ffmpeg, Store: store, + IAM: iam, }) if err != nil { return app.Probe{} } rs.AddProcess(&config) - probe := rs.Probe(config.ID) - rs.DeleteProcess(config.ID) + probe := rs.Probe(config.ID, "", "") + rs.DeleteProcess(config.ID, "", "") return probe } diff --git a/http/api/process.go b/http/api/process.go index e217b455..df56fad8 100644 --- a/http/api/process.go +++ b/http/api/process.go @@ -44,6 +44,7 @@ type ProcessConfigLimits struct { // ProcessConfig represents the configuration of an ffmpeg process type ProcessConfig struct { ID string `json:"id"` + Group string `json:"group"` Type string `json:"type" validate:"oneof='ffmpeg' ''" jsonschema:"enum=ffmpeg,enum="` Reference string `json:"reference"` Input []ProcessConfigIO `json:"input" validate:"required"` @@ -60,6 +61,7 @@ type ProcessConfig struct { func (cfg *ProcessConfig) Marshal() *app.Config { p := &app.Config{ ID: cfg.ID, + Group: cfg.Group, Reference: cfg.Reference, Options: cfg.Options, Reconnect: cfg.Reconnect, @@ -139,6 +141,7 @@ func (cfg *ProcessConfig) Unmarshal(c *app.Config) { } cfg.ID = c.ID + cfg.Group = c.Group cfg.Reference = c.Reference cfg.Type = "ffmpeg" cfg.Reconnect = c.Reconnect diff --git a/http/graph/graph/graph.go b/http/graph/graph/graph.go index 332d2fdf..0fadb572 100644 --- a/http/graph/graph/graph.go +++ b/http/graph/graph/graph.go @@ -127,8 +127,10 @@ type ComplexityRoot struct { Process struct { Config func(childComplexity int) int CreatedAt func(childComplexity int) int + Group func(childComplexity int) int ID func(childComplexity int) int Metadata func(childComplexity int) int + Owner func(childComplexity int) int Reference func(childComplexity int) int Report func(childComplexity int) int State func(childComplexity int) int @@ -137,11 +139,13 @@ type ComplexityRoot struct { ProcessConfig struct { Autostart func(childComplexity int) int + Group func(childComplexity int) int ID func(childComplexity int) int Input func(childComplexity int) int Limits func(childComplexity int) int Options func(childComplexity int) int Output func(childComplexity int) int + Owner func(childComplexity int) int Reconnect func(childComplexity int) int ReconnectDelaySeconds func(childComplexity int) int Reference func(childComplexity int) int @@ -236,10 +240,10 @@ type ComplexityRoot struct { Log func(childComplexity int) int Metrics func(childComplexity int, query models.MetricsInput) int Ping func(childComplexity int) int - PlayoutStatus func(childComplexity int, id string, input string) int - Probe func(childComplexity int, id string) int - Process func(childComplexity int, id string) int - Processes func(childComplexity int) int + PlayoutStatus func(childComplexity int, id string, group *string, input string) int + Probe func(childComplexity int, id string, group *string) int + Process func(childComplexity int, id string, group *string) int + Processes func(childComplexity int, idpattern *string, refpattern *string, group *string) int } RawAVstream struct { @@ -283,10 +287,10 @@ type QueryResolver interface { About(ctx context.Context) (*models.About, error) Log(ctx context.Context) ([]string, error) Metrics(ctx context.Context, query models.MetricsInput) (*models.Metrics, error) - PlayoutStatus(ctx context.Context, id string, input string) (*models.RawAVstream, error) - Processes(ctx context.Context) ([]*models.Process, error) - Process(ctx context.Context, id string) (*models.Process, error) - Probe(ctx context.Context, id string) (*models.Probe, error) + PlayoutStatus(ctx context.Context, id string, group *string, input string) (*models.RawAVstream, error) + Processes(ctx context.Context, idpattern *string, refpattern *string, group *string) ([]*models.Process, error) + Process(ctx context.Context, id string, group *string) (*models.Process, error) + Probe(ctx context.Context, id string, group *string) (*models.Probe, error) } type executableSchema struct { @@ -675,6 +679,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Process.CreatedAt(childComplexity), true + case "Process.group": + if e.complexity.Process.Group == nil { + break + } + + return e.complexity.Process.Group(childComplexity), true + case "Process.id": if e.complexity.Process.ID == nil { break @@ -689,6 +700,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Process.Metadata(childComplexity), true + case "Process.owner": + if e.complexity.Process.Owner == nil { + break + } + + return e.complexity.Process.Owner(childComplexity), true + case "Process.reference": if e.complexity.Process.Reference == nil { break @@ -724,6 +742,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.ProcessConfig.Autostart(childComplexity), true + case "ProcessConfig.group": + if e.complexity.ProcessConfig.Group == nil { + break + } + + return e.complexity.ProcessConfig.Group(childComplexity), true + case "ProcessConfig.id": if e.complexity.ProcessConfig.ID == nil { break @@ -759,6 +784,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.ProcessConfig.Output(childComplexity), true + case "ProcessConfig.owner": + if e.complexity.ProcessConfig.Owner == nil { + break + } + + return e.complexity.ProcessConfig.Owner(childComplexity), true + case "ProcessConfig.reconnect": if e.complexity.ProcessConfig.Reconnect == nil { break @@ -1243,7 +1275,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } - return e.complexity.Query.PlayoutStatus(childComplexity, args["id"].(string), args["input"].(string)), true + return e.complexity.Query.PlayoutStatus(childComplexity, args["id"].(string), args["group"].(*string), args["input"].(string)), true case "Query.probe": if e.complexity.Query.Probe == nil { @@ -1255,7 +1287,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } - return e.complexity.Query.Probe(childComplexity, args["id"].(string)), true + return e.complexity.Query.Probe(childComplexity, args["id"].(string), args["group"].(*string)), true case "Query.process": if e.complexity.Query.Process == nil { @@ -1267,14 +1299,19 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } - return e.complexity.Query.Process(childComplexity, args["id"].(string)), true + return e.complexity.Query.Process(childComplexity, args["id"].(string), args["group"].(*string)), true case "Query.processes": if e.complexity.Query.Processes == nil { break } - return e.complexity.Query.Processes(childComplexity), true + args, err := ec.field_Query_processes_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Query.Processes(childComplexity, args["idpattern"].(*string), args["refpattern"].(*string), args["group"].(*string)), true case "RawAVstream.aqueue": if e.complexity.RawAVstream.Aqueue == nil { @@ -1561,7 +1598,7 @@ type Metric { } `, BuiltIn: false}, {Name: "../playout.graphqls", Input: `extend type Query { - playoutStatus(id: ID!, input: ID!): RawAVstream + playoutStatus(id: ID!, group: String, input: ID!): RawAVstream } type RawAVstreamIO { @@ -1597,9 +1634,9 @@ type RawAVstream { } `, BuiltIn: false}, {Name: "../process.graphqls", Input: `extend type Query { - processes: [Process!]! - process(id: ID!): Process - probe(id: ID!): Probe! + processes(idpattern: String, refpattern: String, group: String): [Process!]! + process(id: ID!, group: String): Process + probe(id: ID!, group: String): Probe! } type ProcessConfigIO { @@ -1616,6 +1653,8 @@ type ProcessConfigLimits { type ProcessConfig { id: String! + owner: String! + group: String! type: String! reference: String! input: [ProcessConfigIO!]! @@ -1666,6 +1705,8 @@ type ProcessReport implements IProcessReportHistoryEntry { type Process { id: String! + owner: String! + group: String! type: String! reference: String! created_at: Time! @@ -1840,15 +1881,24 @@ func (ec *executionContext) field_Query_playoutStatus_args(ctx context.Context, } } args["id"] = arg0 - var arg1 string - if tmp, ok := rawArgs["input"]; ok { - ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("input")) - arg1, err = ec.unmarshalNID2string(ctx, tmp) + var arg1 *string + if tmp, ok := rawArgs["group"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("group")) + arg1, err = ec.unmarshalOString2ᚖstring(ctx, tmp) if err != nil { return nil, err } } - args["input"] = arg1 + args["group"] = arg1 + var arg2 string + if tmp, ok := rawArgs["input"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("input")) + arg2, err = ec.unmarshalNID2string(ctx, tmp) + if err != nil { + return nil, err + } + } + args["input"] = arg2 return args, nil } @@ -1864,6 +1914,15 @@ func (ec *executionContext) field_Query_probe_args(ctx context.Context, rawArgs } } args["id"] = arg0 + var arg1 *string + if tmp, ok := rawArgs["group"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("group")) + arg1, err = ec.unmarshalOString2ᚖstring(ctx, tmp) + if err != nil { + return nil, err + } + } + args["group"] = arg1 return args, nil } @@ -1879,6 +1938,48 @@ func (ec *executionContext) field_Query_process_args(ctx context.Context, rawArg } } args["id"] = arg0 + var arg1 *string + if tmp, ok := rawArgs["group"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("group")) + arg1, err = ec.unmarshalOString2ᚖstring(ctx, tmp) + if err != nil { + return nil, err + } + } + args["group"] = arg1 + return args, nil +} + +func (ec *executionContext) field_Query_processes_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 *string + if tmp, ok := rawArgs["idpattern"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("idpattern")) + arg0, err = ec.unmarshalOString2ᚖstring(ctx, tmp) + if err != nil { + return nil, err + } + } + args["idpattern"] = arg0 + var arg1 *string + if tmp, ok := rawArgs["refpattern"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("refpattern")) + arg1, err = ec.unmarshalOString2ᚖstring(ctx, tmp) + if err != nil { + return nil, err + } + } + args["refpattern"] = arg1 + var arg2 *string + if tmp, ok := rawArgs["group"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("group")) + arg2, err = ec.unmarshalOString2ᚖstring(ctx, tmp) + if err != nil { + return nil, err + } + } + args["group"] = arg2 return args, nil } @@ -4275,6 +4376,94 @@ func (ec *executionContext) fieldContext_Process_id(ctx context.Context, field g return fc, nil } +func (ec *executionContext) _Process_owner(ctx context.Context, field graphql.CollectedField, obj *models.Process) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Process_owner(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Owner, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNString2string(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Process_owner(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Process", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _Process_group(ctx context.Context, field graphql.CollectedField, obj *models.Process) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Process_group(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Group, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNString2string(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Process_group(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Process", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _Process_type(ctx context.Context, field graphql.CollectedField, obj *models.Process) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Process_type(ctx, field) if err != nil { @@ -4448,6 +4637,10 @@ func (ec *executionContext) fieldContext_Process_config(ctx context.Context, fie switch field.Name { case "id": return ec.fieldContext_ProcessConfig_id(ctx, field) + case "owner": + return ec.fieldContext_ProcessConfig_owner(ctx, field) + case "group": + return ec.fieldContext_ProcessConfig_group(ctx, field) case "type": return ec.fieldContext_ProcessConfig_type(ctx, field) case "reference": @@ -4678,6 +4871,94 @@ func (ec *executionContext) fieldContext_ProcessConfig_id(ctx context.Context, f return fc, nil } +func (ec *executionContext) _ProcessConfig_owner(ctx context.Context, field graphql.CollectedField, obj *models.ProcessConfig) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_ProcessConfig_owner(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Owner, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNString2string(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_ProcessConfig_owner(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "ProcessConfig", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _ProcessConfig_group(ctx context.Context, field graphql.CollectedField, obj *models.ProcessConfig) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_ProcessConfig_group(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Group, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNString2string(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_ProcessConfig_group(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "ProcessConfig", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _ProcessConfig_type(ctx context.Context, field graphql.CollectedField, obj *models.ProcessConfig) (ret graphql.Marshaler) { fc, err := ec.fieldContext_ProcessConfig_type(ctx, field) if err != nil { @@ -8071,7 +8352,7 @@ func (ec *executionContext) _Query_playoutStatus(ctx context.Context, field grap }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Query().PlayoutStatus(rctx, fc.Args["id"].(string), fc.Args["input"].(string)) + return ec.resolvers.Query().PlayoutStatus(rctx, fc.Args["id"].(string), fc.Args["group"].(*string), fc.Args["input"].(string)) }) if err != nil { ec.Error(ctx, err) @@ -8155,7 +8436,7 @@ func (ec *executionContext) _Query_processes(ctx context.Context, field graphql. }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Query().Processes(rctx) + return ec.resolvers.Query().Processes(rctx, fc.Args["idpattern"].(*string), fc.Args["refpattern"].(*string), fc.Args["group"].(*string)) }) if err != nil { ec.Error(ctx, err) @@ -8182,6 +8463,10 @@ func (ec *executionContext) fieldContext_Query_processes(ctx context.Context, fi switch field.Name { case "id": return ec.fieldContext_Process_id(ctx, field) + case "owner": + return ec.fieldContext_Process_owner(ctx, field) + case "group": + return ec.fieldContext_Process_group(ctx, field) case "type": return ec.fieldContext_Process_type(ctx, field) case "reference": @@ -8200,6 +8485,17 @@ func (ec *executionContext) fieldContext_Query_processes(ctx context.Context, fi return nil, fmt.Errorf("no field named %q was found under type Process", field.Name) }, } + defer func() { + if r := recover(); r != nil { + err = ec.Recover(ctx, r) + ec.Error(ctx, err) + } + }() + ctx = graphql.WithFieldContext(ctx, fc) + if fc.Args, err = ec.field_Query_processes_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { + ec.Error(ctx, err) + return + } return fc, nil } @@ -8217,7 +8513,7 @@ func (ec *executionContext) _Query_process(ctx context.Context, field graphql.Co }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Query().Process(rctx, fc.Args["id"].(string)) + return ec.resolvers.Query().Process(rctx, fc.Args["id"].(string), fc.Args["group"].(*string)) }) if err != nil { ec.Error(ctx, err) @@ -8241,6 +8537,10 @@ func (ec *executionContext) fieldContext_Query_process(ctx context.Context, fiel switch field.Name { case "id": return ec.fieldContext_Process_id(ctx, field) + case "owner": + return ec.fieldContext_Process_owner(ctx, field) + case "group": + return ec.fieldContext_Process_group(ctx, field) case "type": return ec.fieldContext_Process_type(ctx, field) case "reference": @@ -8287,7 +8587,7 @@ func (ec *executionContext) _Query_probe(ctx context.Context, field graphql.Coll }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Query().Probe(rctx, fc.Args["id"].(string)) + return ec.resolvers.Query().Probe(rctx, fc.Args["id"].(string), fc.Args["group"].(*string)) }) if err != nil { ec.Error(ctx, err) @@ -11282,7 +11582,12 @@ func (ec *executionContext) unmarshalInputMetricInput(ctx context.Context, obj i asMap[k] = v } - for k, v := range asMap { + fieldsInOrder := [...]string{"name", "labels"} + for _, k := range fieldsInOrder { + v, ok := asMap[k] + if !ok { + continue + } switch k { case "name": var err error @@ -11313,7 +11618,12 @@ func (ec *executionContext) unmarshalInputMetricsInput(ctx context.Context, obj asMap[k] = v } - for k, v := range asMap { + fieldsInOrder := [...]string{"timerange_seconds", "interval_seconds", "metrics"} + for _, k := range fieldsInOrder { + v, ok := asMap[k] + if !ok { + continue + } switch k { case "timerange_seconds": var err error @@ -11938,6 +12248,20 @@ func (ec *executionContext) _Process(ctx context.Context, sel ast.SelectionSet, out.Values[i] = ec._Process_id(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } + case "owner": + + out.Values[i] = ec._Process_owner(ctx, field, obj) + + if out.Values[i] == graphql.Null { + invalids++ + } + case "group": + + out.Values[i] = ec._Process_group(ctx, field, obj) + if out.Values[i] == graphql.Null { invalids++ } @@ -12012,6 +12336,20 @@ func (ec *executionContext) _ProcessConfig(ctx context.Context, sel ast.Selectio out.Values[i] = ec._ProcessConfig_id(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } + case "owner": + + out.Values[i] = ec._ProcessConfig_owner(ctx, field, obj) + + if out.Values[i] == graphql.Null { + invalids++ + } + case "group": + + out.Values[i] = ec._ProcessConfig_group(ctx, field, obj) + if out.Values[i] == graphql.Null { invalids++ } diff --git a/http/graph/models/models_gen.go b/http/graph/models/models_gen.go index 7cda1413..eaf97ea2 100644 --- a/http/graph/models/models_gen.go +++ b/http/graph/models/models_gen.go @@ -13,6 +13,9 @@ import ( type IProcessReportHistoryEntry interface { IsIProcessReportHistoryEntry() + GetCreatedAt() time.Time + GetPrelude() []string + GetLog() []*ProcessReportLogEntry } type AVStream struct { @@ -102,6 +105,8 @@ type ProbeIo struct { type Process struct { ID string `json:"id"` + Owner string `json:"owner"` + Group string `json:"group"` Type string `json:"type"` Reference string `json:"reference"` CreatedAt time.Time `json:"created_at"` @@ -113,6 +118,8 @@ type Process struct { type ProcessConfig struct { ID string `json:"id"` + Owner string `json:"owner"` + Group string `json:"group"` Type string `json:"type"` Reference string `json:"reference"` Input []*ProcessConfigIo `json:"input"` @@ -145,6 +152,27 @@ type ProcessReport struct { } func (ProcessReport) IsIProcessReportHistoryEntry() {} +func (this ProcessReport) GetCreatedAt() time.Time { return this.CreatedAt } +func (this ProcessReport) GetPrelude() []string { + if this.Prelude == nil { + return nil + } + interfaceSlice := make([]string, 0, len(this.Prelude)) + for _, concrete := range this.Prelude { + interfaceSlice = append(interfaceSlice, concrete) + } + return interfaceSlice +} +func (this ProcessReport) GetLog() []*ProcessReportLogEntry { + if this.Log == nil { + return nil + } + interfaceSlice := make([]*ProcessReportLogEntry, 0, len(this.Log)) + for _, concrete := range this.Log { + interfaceSlice = append(interfaceSlice, concrete) + } + return interfaceSlice +} type ProcessReportHistoryEntry struct { CreatedAt time.Time `json:"created_at"` @@ -153,6 +181,27 @@ type ProcessReportHistoryEntry struct { } func (ProcessReportHistoryEntry) IsIProcessReportHistoryEntry() {} +func (this ProcessReportHistoryEntry) GetCreatedAt() time.Time { return this.CreatedAt } +func (this ProcessReportHistoryEntry) GetPrelude() []string { + if this.Prelude == nil { + return nil + } + interfaceSlice := make([]string, 0, len(this.Prelude)) + for _, concrete := range this.Prelude { + interfaceSlice = append(interfaceSlice, concrete) + } + return interfaceSlice +} +func (this ProcessReportHistoryEntry) GetLog() []*ProcessReportLogEntry { + if this.Log == nil { + return nil + } + interfaceSlice := make([]*ProcessReportLogEntry, 0, len(this.Log)) + for _, concrete := range this.Log { + interfaceSlice = append(interfaceSlice, concrete) + } + return interfaceSlice +} type ProcessReportLogEntry struct { Timestamp time.Time `json:"timestamp"` diff --git a/http/graph/playout.graphqls b/http/graph/playout.graphqls index 1ac70973..3747b715 100644 --- a/http/graph/playout.graphqls +++ b/http/graph/playout.graphqls @@ -1,5 +1,5 @@ extend type Query { - playoutStatus(id: ID!, input: ID!): RawAVstream + playoutStatus(id: ID!, group: String, input: ID!): RawAVstream } type RawAVstreamIO { diff --git a/http/graph/process.graphqls b/http/graph/process.graphqls index 1e4fff7b..66a6a75f 100644 --- a/http/graph/process.graphqls +++ b/http/graph/process.graphqls @@ -1,7 +1,7 @@ extend type Query { - processes: [Process!]! - process(id: ID!): Process - probe(id: ID!): Probe! + processes(idpattern: String, refpattern: String, group: String): [Process!]! + process(id: ID!, group: String): Process + probe(id: ID!, group: String): Probe! } type ProcessConfigIO { @@ -18,6 +18,8 @@ type ProcessConfigLimits { type ProcessConfig { id: String! + owner: String! + group: String! type: String! reference: String! input: [ProcessConfigIO!]! @@ -68,6 +70,8 @@ type ProcessReport implements IProcessReportHistoryEntry { type Process { id: String! + owner: String! + group: String! type: String! reference: String! created_at: Time! diff --git a/http/graph/resolver/about.resolvers.go b/http/graph/resolver/about.resolvers.go index 37453620..d62771b5 100644 --- a/http/graph/resolver/about.resolvers.go +++ b/http/graph/resolver/about.resolvers.go @@ -12,6 +12,7 @@ import ( "github.com/datarhei/core/v16/http/graph/scalars" ) +// About is the resolver for the about field. func (r *queryResolver) About(ctx context.Context) (*models.About, error) { createdAt := r.Restream.CreatedAt() diff --git a/http/graph/resolver/log.resolvers.go b/http/graph/resolver/log.resolvers.go index 50006348..18aa2281 100644 --- a/http/graph/resolver/log.resolvers.go +++ b/http/graph/resolver/log.resolvers.go @@ -10,6 +10,7 @@ import ( "github.com/datarhei/core/v16/log" ) +// Log is the resolver for the log field. func (r *queryResolver) Log(ctx context.Context) ([]string, error) { if r.LogBuffer == nil { r.LogBuffer = log.NewBufferWriter(log.Lsilent, 1) diff --git a/http/graph/resolver/metrics.resolvers.go b/http/graph/resolver/metrics.resolvers.go index 6a78567e..e9ae7d3e 100644 --- a/http/graph/resolver/metrics.resolvers.go +++ b/http/graph/resolver/metrics.resolvers.go @@ -12,6 +12,7 @@ import ( "github.com/datarhei/core/v16/monitor/metric" ) +// Metrics is the resolver for the metrics field. func (r *queryResolver) Metrics(ctx context.Context, query models.MetricsInput) (*models.Metrics, error) { patterns := []metric.Pattern{} diff --git a/http/graph/resolver/playout.resolvers.go b/http/graph/resolver/playout.resolvers.go index 8ec88ba1..71e03387 100644 --- a/http/graph/resolver/playout.resolvers.go +++ b/http/graph/resolver/playout.resolvers.go @@ -13,8 +13,11 @@ import ( "github.com/datarhei/core/v16/playout" ) -func (r *queryResolver) PlayoutStatus(ctx context.Context, id string, input string) (*models.RawAVstream, error) { - addr, err := r.Restream.GetPlayout(id, input) +// PlayoutStatus is the resolver for the playoutStatus field. +func (r *queryResolver) PlayoutStatus(ctx context.Context, id string, group *string, input string) (*models.RawAVstream, error) { + user, _ := ctx.Value("user").(string) + + addr, err := r.Restream.GetPlayout(id, user, *group, input) if err != nil { return nil, fmt.Errorf("unknown process or input: %w", err) } diff --git a/http/graph/resolver/process.resolvers.go b/http/graph/resolver/process.resolvers.go index d8f1ee33..deab6797 100644 --- a/http/graph/resolver/process.resolvers.go +++ b/http/graph/resolver/process.resolvers.go @@ -9,13 +9,15 @@ import ( "github.com/datarhei/core/v16/http/graph/models" ) -func (r *queryResolver) Processes(ctx context.Context) ([]*models.Process, error) { - ids := r.Restream.GetProcessIDs("", "") +// Processes is the resolver for the processes field. +func (r *queryResolver) Processes(ctx context.Context, idpattern *string, refpattern *string, group *string) ([]*models.Process, error) { + user, _ := ctx.Value("user").(string) + ids := r.Restream.GetProcessIDs(*idpattern, *refpattern, user, *group) procs := []*models.Process{} for _, id := range ids { - p, err := r.getProcess(id) + p, err := r.getProcess(id, user, *group) if err != nil { return nil, err } @@ -26,12 +28,18 @@ func (r *queryResolver) Processes(ctx context.Context) ([]*models.Process, error return procs, nil } -func (r *queryResolver) Process(ctx context.Context, id string) (*models.Process, error) { - return r.getProcess(id) +// Process is the resolver for the process field. +func (r *queryResolver) Process(ctx context.Context, id string, group *string) (*models.Process, error) { + user, _ := ctx.Value("user").(string) + + return r.getProcess(id, user, *group) } -func (r *queryResolver) Probe(ctx context.Context, id string) (*models.Probe, error) { - probe := r.Restream.Probe(id) +// Probe is the resolver for the probe field. +func (r *queryResolver) Probe(ctx context.Context, id string, group *string) (*models.Probe, error) { + user, _ := ctx.Value("user").(string) + + probe := r.Restream.Probe(id, user, *group) p := &models.Probe{} p.UnmarshalRestream(probe) diff --git a/http/graph/resolver/resolver.go b/http/graph/resolver/resolver.go index 8705fe80..c833a2cb 100644 --- a/http/graph/resolver/resolver.go +++ b/http/graph/resolver/resolver.go @@ -22,23 +22,23 @@ type Resolver struct { LogBuffer log.BufferWriter } -func (r *queryResolver) getProcess(id string) (*models.Process, error) { - process, err := r.Restream.GetProcess(id) +func (r *queryResolver) getProcess(id, user, group string) (*models.Process, error) { + process, err := r.Restream.GetProcess(id, user, group) if err != nil { return nil, err } - state, err := r.Restream.GetProcessState(id) + state, err := r.Restream.GetProcessState(id, user, group) if err != nil { return nil, err } - report, err := r.Restream.GetProcessLog(id) + report, err := r.Restream.GetProcessLog(id, user, group) if err != nil { return nil, err } - m, err := r.Restream.GetProcessMetadata(id, "") + m, err := r.Restream.GetProcessMetadata(id, user, group, "") if err != nil { return nil, err } diff --git a/http/graph/resolver/schema.resolvers.go b/http/graph/resolver/schema.resolvers.go index 60cebcea..89a0b922 100644 --- a/http/graph/resolver/schema.resolvers.go +++ b/http/graph/resolver/schema.resolvers.go @@ -9,10 +9,12 @@ import ( "github.com/datarhei/core/v16/http/graph/graph" ) +// Ping is the resolver for the ping field. func (r *mutationResolver) Ping(ctx context.Context) (string, error) { return "pong", nil } +// Ping is the resolver for the ping field. func (r *queryResolver) Ping(ctx context.Context) (string, error) { return "pong", nil } diff --git a/http/handler/api/graph.go b/http/handler/api/graph.go index 9d9a0a13..64bd5440 100644 --- a/http/handler/api/graph.go +++ b/http/handler/api/graph.go @@ -1,6 +1,7 @@ package api import ( + "context" "net/http" "github.com/datarhei/core/v16/http/graph/graph" @@ -18,7 +19,7 @@ type GraphHandler struct { playgroundHandler http.HandlerFunc } -// NewRestream return a new Restream type. You have to provide a valid Restreamer instance. +// NewGraph return a new GraphHandler type. You have to provide a valid Restreamer instance. func NewGraph(resolver resolver.Resolver, path string) *GraphHandler { g := &GraphHandler{ resolver: resolver, @@ -43,7 +44,12 @@ func NewGraph(resolver resolver.Resolver, path string) *GraphHandler { // @Security ApiKeyAuth // @Router /api/graph/query [post] func (g *GraphHandler) Query(c echo.Context) error { - g.queryHandler.ServeHTTP(c.Response(), c.Request()) + user, _ := c.Get("user").(string) + + r := c.Request() + ctx := context.WithValue(r.Context(), "user", user) + + g.queryHandler.ServeHTTP(c.Response(), r.WithContext(ctx)) return nil } diff --git a/http/handler/api/playout.go b/http/handler/api/playout.go index cc073001..3c30daff 100644 --- a/http/handler/api/playout.go +++ b/http/handler/api/playout.go @@ -44,8 +44,10 @@ func NewPlayout(restream restream.Restreamer) *PlayoutHandler { func (h *PlayoutHandler) Status(c echo.Context) error { id := util.PathParam(c, "id") inputid := util.PathParam(c, "inputid") + user := util.DefaultContext(c, "user", "") + group := util.DefaultQuery(c, "group", "") - addr, err := h.restream.GetPlayout(id, inputid) + addr, err := h.restream.GetPlayout(id, user, group, inputid) if err != nil { return api.Err(http.StatusNotFound, "Unknown process or input", "%s", err) } @@ -102,8 +104,10 @@ func (h *PlayoutHandler) Keyframe(c echo.Context) error { id := util.PathParam(c, "id") inputid := util.PathParam(c, "inputid") name := util.PathWildcardParam(c) + user := util.DefaultContext(c, "user", "") + group := util.DefaultQuery(c, "group", "") - addr, err := h.restream.GetPlayout(id, inputid) + addr, err := h.restream.GetPlayout(id, user, group, inputid) if err != nil { return api.Err(http.StatusNotFound, "Unknown process or input", "%s", err) } @@ -149,8 +153,10 @@ func (h *PlayoutHandler) Keyframe(c echo.Context) error { func (h *PlayoutHandler) EncodeErrorframe(c echo.Context) error { id := util.PathParam(c, "id") inputid := util.PathParam(c, "inputid") + user := util.DefaultContext(c, "user", "") + group := util.DefaultQuery(c, "group", "") - addr, err := h.restream.GetPlayout(id, inputid) + addr, err := h.restream.GetPlayout(id, user, group, inputid) if err != nil { return api.Err(http.StatusNotFound, "Unknown process or input", "%s", err) } @@ -193,8 +199,10 @@ func (h *PlayoutHandler) EncodeErrorframe(c echo.Context) error { func (h *PlayoutHandler) SetErrorframe(c echo.Context) error { id := util.PathParam(c, "id") inputid := util.PathParam(c, "inputid") + user := util.DefaultContext(c, "user", "") + group := util.DefaultQuery(c, "group", "") - addr, err := h.restream.GetPlayout(id, inputid) + addr, err := h.restream.GetPlayout(id, user, group, inputid) if err != nil { return api.Err(http.StatusNotFound, "Unknown process or input", "%s", err) } @@ -238,8 +246,10 @@ func (h *PlayoutHandler) SetErrorframe(c echo.Context) error { func (h *PlayoutHandler) ReopenInput(c echo.Context) error { id := util.PathParam(c, "id") inputid := util.PathParam(c, "inputid") + user := util.DefaultContext(c, "user", "") + group := util.DefaultQuery(c, "group", "") - addr, err := h.restream.GetPlayout(id, inputid) + addr, err := h.restream.GetPlayout(id, user, group, inputid) if err != nil { return api.Err(http.StatusNotFound, "Unknown process or input", "%s", err) } @@ -281,8 +291,10 @@ func (h *PlayoutHandler) ReopenInput(c echo.Context) error { func (h *PlayoutHandler) SetStream(c echo.Context) error { id := util.PathParam(c, "id") inputid := util.PathParam(c, "inputid") + user := util.DefaultContext(c, "user", "") + group := util.DefaultQuery(c, "group", "") - addr, err := h.restream.GetPlayout(id, inputid) + addr, err := h.restream.GetPlayout(id, user, group, inputid) if err != nil { return api.Err(http.StatusNotFound, "Unknown process or input", "%s", err) } diff --git a/http/handler/api/restream.go b/http/handler/api/restream.go index ac750c23..9d473808 100644 --- a/http/handler/api/restream.go +++ b/http/handler/api/restream.go @@ -37,6 +37,8 @@ func NewRestream(restream restream.Restreamer) *RestreamHandler { // @Security ApiKeyAuth // @Router /api/v3/process [post] func (h *RestreamHandler) Add(c echo.Context) error { + user := util.DefaultContext(c, "user", "") + process := api.ProcessConfig{ ID: shortuuid.New(), Type: "ffmpeg", @@ -56,6 +58,7 @@ func (h *RestreamHandler) Add(c echo.Context) error { } config := process.Marshal() + config.Owner = user if err := h.restream.AddProcess(config); err != nil { return api.Err(http.StatusBadRequest, "Invalid process config", "%s", err.Error()) @@ -210,6 +213,7 @@ func (h *RestreamHandler) Update(c echo.Context) error { } config := process.Marshal() + config.Owner = user if err := h.restream.UpdateProcess(id, user, group, config); err != nil { if err == restream.ErrUnknownProcess { diff --git a/http/handler/api/restream_test.go b/http/handler/api/restream_test.go index 516db9ce..c06d1797 100644 --- a/http/handler/api/restream_test.go +++ b/http/handler/api/restream_test.go @@ -8,9 +8,9 @@ import ( "github.com/datarhei/core/v16/http/api" "github.com/datarhei/core/v16/http/mock" - "github.com/stretchr/testify/require" "github.com/labstack/echo/v4" + "github.com/stretchr/testify/require" ) type Response struct { diff --git a/http/handler/api/widget.go b/http/handler/api/widget.go index bb4688f2..a6a6550e 100644 --- a/http/handler/api/widget.go +++ b/http/handler/api/widget.go @@ -43,17 +43,18 @@ func NewWidget(config WidgetConfig) *WidgetHandler { // @Router /api/v3/widget/process/{id} [get] func (w *WidgetHandler) Get(c echo.Context) error { id := util.PathParam(c, "id") + group := util.DefaultQuery(c, "group", "") if w.restream == nil { return api.Err(http.StatusNotFound, "Unknown process ID") } - process, err := w.restream.GetProcess(id) + process, err := w.restream.GetProcess(id, "", group) if err != nil { return api.Err(http.StatusNotFound, "Unknown process ID", "%s", err) } - state, err := w.restream.GetProcessState(id) + state, err := w.restream.GetProcessState(id, "", group) if err != nil { return api.Err(http.StatusNotFound, "Unknown process ID", "%s", err) } diff --git a/http/middleware/iam/iam.go b/http/middleware/iam/iam.go index 0bfa05a8..e06eb924 100644 --- a/http/middleware/iam/iam.go +++ b/http/middleware/iam/iam.go @@ -177,17 +177,13 @@ func NewWithConfig(config Config) echo.MiddlewareFunc { c.Set("user", username) - if identity != nil && identity.IsSuperuser() { - username = "$superuser" - } - if len(domain) == 0 { domain = "$none" } action := c.Request().Method - if ok, _ := config.IAM.Enforce(username, domain, resource, action); !ok { + if !config.IAM.Enforce(username, domain, resource, action) { return api.Err(http.StatusForbidden, "Forbidden", "access denied") } diff --git a/http/mock/mock.go b/http/mock/mock.go index 621204a7..926340e3 100644 --- a/http/mock/mock.go +++ b/http/mock/mock.go @@ -16,6 +16,7 @@ import ( "github.com/datarhei/core/v16/http/api" "github.com/datarhei/core/v16/http/errorhandler" "github.com/datarhei/core/v16/http/validator" + "github.com/datarhei/core/v16/iam" "github.com/datarhei/core/v16/internal/testhelper" "github.com/datarhei/core/v16/io/fs" "github.com/datarhei/core/v16/restream" @@ -52,9 +53,23 @@ func DummyRestreamer(pathPrefix string) (restream.Restreamer, error) { return nil, err } + iam, err := iam.NewIAM(iam.Config{ + FS: memfs, + Superuser: iam.User{ + Name: "foobar", + }, + JWTRealm: "", + JWTSecret: "", + Logger: nil, + }) + + iam.AddPolicy("$anon", "$none", "api:/**", "ANY") + iam.AddPolicy("$anon", "$none", "fs:/**", "ANY") + rs, err := restream.New(restream.Config{ Store: store, FFmpeg: ffmpeg, + IAM: iam, }) if err != nil { return nil, err diff --git a/iam/access.go b/iam/access.go index f5f9269f..13c57bcf 100644 --- a/iam/access.go +++ b/iam/access.go @@ -108,20 +108,7 @@ func (am *access) HasGroup(name string) bool { } func (am *access) Enforce(name, domain, resource, action string) (bool, string) { - l := am.logger.Debug().WithFields(log.Fields{ - "subject": name, - "domain": domain, - "resource": resource, - "action": action, - }) - ok, rule, _ := am.enforcer.EnforceEx(name, domain, resource, action) - if !ok { - l.Log("no match") - } else { - l.WithField("rule", strings.Join(rule, ", ")).Log("match") - } - return ok, strings.Join(rule, ", ") } diff --git a/iam/casbin.go b/iam/functions.go similarity index 79% rename from iam/casbin.go rename to iam/functions.go index 9a402d2b..26261904 100644 --- a/iam/casbin.go +++ b/iam/functions.go @@ -13,21 +13,11 @@ func resourceMatch(request, domain, policy string) bool { if reqPrefix != polPrefix { return false } - /* - fmt.Printf("prefix: %s\n", reqPrefix) - fmt.Printf("requested resource: %s\n", reqResource) - fmt.Printf("requested domain: %s\n", domain) - fmt.Printf("policy resource: %s\n", polResource) - */ + var match bool var err error - if reqPrefix == "processid" { - match, err = globMatch(polResource, reqResource) - if err != nil { - return false - } - } else if reqPrefix == "api" { + if reqPrefix == "api" { match, err = globMatch(polResource, reqResource, rune('/')) if err != nil { return false @@ -38,12 +28,12 @@ func resourceMatch(request, domain, policy string) bool { return false } } else if reqPrefix == "rtmp" { - match, err = globMatch(polResource, reqResource) + match, err = globMatch(polResource, reqResource, rune('/')) if err != nil { return false } } else if reqPrefix == "srt" { - match, err = globMatch(polResource, reqResource) + match, err = globMatch(polResource, reqResource, rune('/')) if err != nil { return false } @@ -54,8 +44,6 @@ func resourceMatch(request, domain, policy string) bool { } } - //fmt.Printf("match: %v\n", match) - return match } @@ -74,6 +62,10 @@ func actionMatch(request string, policy string) bool { return false } + if len(actions) == 1 && actions[0] == "ANY" { + return true + } + for _, a := range actions { if request == a { return true diff --git a/iam/iam.go b/iam/iam.go index ae57a47a..806066fc 100644 --- a/iam/iam.go +++ b/iam/iam.go @@ -6,7 +6,7 @@ import ( ) type IAM interface { - Enforce(user, domain, resource, action string) (bool, string) + Enforce(user, domain, resource, action string) bool IsDomain(domain string) bool AddPolicy(username, domain, resource, actions string) bool @@ -26,6 +26,8 @@ type IAM interface { type iam struct { im IdentityManager am AccessManager + + logger log.Logger } type Config struct { @@ -56,10 +58,17 @@ func NewIAM(config Config) (IAM, error) { return nil, err } - return &iam{ - im: im, - am: am, - }, nil + iam := &iam{ + im: im, + am: am, + logger: config.Logger, + } + + if iam.logger == nil { + iam.logger = log.New("") + } + + return iam, nil } func (i *iam) Close() { @@ -67,12 +76,38 @@ func (i *iam) Close() { i.im = nil i.am = nil - - return } -func (i *iam) Enforce(user, domain, resource, action string) (bool, string) { - return i.am.Enforce(user, domain, resource, action) +func (i *iam) Enforce(user, domain, resource, action string) bool { + superuser := false + + if identity, err := i.im.GetVerifier(user); err == nil { + if identity.IsSuperuser() { + superuser = true + } + } + + l := i.logger.Debug().WithFields(log.Fields{ + "subject": user, + "domain": domain, + "resource": resource, + "action": action, + "superuser": superuser, + }) + + if superuser { + user = "$superuser" + } + + ok, rule := i.am.Enforce(user, domain, resource, action) + + if !ok { + l.Log("no match") + } else { + l.WithField("rule", rule).Log("match") + } + + return ok } func (i *iam) GetIdentity(name string) (IdentityVerifier, error) { diff --git a/iam/identity.go b/iam/identity.go index 39c80ff8..2fa9feef 100644 --- a/iam/identity.go +++ b/iam/identity.go @@ -86,6 +86,23 @@ func (u *User) marshalIdentity() *identity { return i } +type IdentityVerifier interface { + Name() string + + VerifyJWT(jwt string) (bool, error) + + VerifyAPIPassword(password string) (bool, error) + VerifyAPIAuth0(jwt string) (bool, error) + + VerifyServiceBasicAuth(password string) (bool, error) + VerifyServiceToken(token string) (bool, error) + + GetServiceBasicAuth() string + GetServiceToken() string + + IsSuperuser() bool +} + type identity struct { user User @@ -269,6 +286,21 @@ func (i *identity) VerifyServiceBasicAuth(password string) (bool, error) { return i.user.Auth.Services.Basic.Password == password, nil } +func (i *identity) GetServiceBasicAuth() string { + i.lock.RLock() + defer i.lock.RUnlock() + + if !i.isValid() { + return "" + } + + if !i.user.Auth.Services.Basic.Enable { + return "" + } + + return i.user.Auth.Services.Basic.Password +} + func (i *identity) VerifyServiceToken(token string) (bool, error) { i.lock.RLock() defer i.lock.RUnlock() @@ -286,6 +318,21 @@ func (i *identity) VerifyServiceToken(token string) (bool, error) { return false, nil } +func (i *identity) GetServiceToken() string { + i.lock.RLock() + defer i.lock.RUnlock() + + if !i.isValid() { + return "" + } + + if len(i.user.Auth.Services.Token) == 0 { + return "" + } + + return i.Name() + ":" + i.user.Auth.Services.Token[0] +} + func (i *identity) isValid() bool { return i.valid } @@ -297,24 +344,9 @@ func (i *identity) IsSuperuser() bool { return i.user.Superuser } -type IdentityVerifier interface { - Name() string - - VerifyJWT(jwt string) (bool, error) - - VerifyAPIPassword(password string) (bool, error) - VerifyAPIAuth0(jwt string) (bool, error) - - VerifyServiceBasicAuth(password string) (bool, error) - VerifyServiceToken(token string) (bool, error) - - IsSuperuser() bool -} - type IdentityManager interface { Create(identity User) error Remove(name string) error - Get(name string) (User, error) GetVerifier(name string) (IdentityVerifier, error) GetVerifierByAuth0(name string) (IdentityVerifier, error) GetDefaultVerifier() (IdentityVerifier, error) @@ -404,8 +436,6 @@ func (im *identityManager) Close() { } im.tenants = map[string]*auth0Tenant{} - - return } func (im *identityManager) Create(u User) error { @@ -485,7 +515,7 @@ func (im *identityManager) getIdentity(name string) (*identity, error) { if im.root.user.Name == name { identity = im.root } else { - identity, _ = im.identities[name] + identity = im.identities[name] } @@ -499,18 +529,6 @@ func (im *identityManager) getIdentity(name string) (*identity, error) { return identity, nil } -func (im *identityManager) Get(name string) (User, error) { - im.lock.RLock() - defer im.lock.RUnlock() - - identity, err := im.getIdentity(name) - if err != nil { - return User{}, fmt.Errorf("not found") - } - - return identity.user, nil -} - func (im *identityManager) GetVerifier(name string) (IdentityVerifier, error) { im.lock.RLock() defer im.lock.RUnlock() diff --git a/iam/identity_test.go b/iam/identity_test.go new file mode 100644 index 00000000..7c04b17b --- /dev/null +++ b/iam/identity_test.go @@ -0,0 +1 @@ +package iam diff --git a/net/url/url.go b/net/url/url.go index d99b9ebd..1d13e38c 100644 --- a/net/url/url.go +++ b/net/url/url.go @@ -18,8 +18,6 @@ func Validate(address string) error { // Parse parses an URL into its components. Returns a net/url.URL or // an error if the URL couldn't be parsed. func Parse(address string) (*url.URL, error) { - address = reScheme.ReplaceAllString(address, "//") - u, err := url.Parse(address) return u, err diff --git a/restream/restream.go b/restream/restream.go index 847225cf..6ad4aba4 100644 --- a/restream/restream.go +++ b/restream/restream.go @@ -24,6 +24,7 @@ import ( "github.com/datarhei/core/v16/restream/app" rfs "github.com/datarhei/core/v16/restream/fs" "github.com/datarhei/core/v16/restream/replace" + "github.com/datarhei/core/v16/restream/rewrite" "github.com/datarhei/core/v16/restream/store" "github.com/Masterminds/semver/v3" @@ -67,6 +68,7 @@ type Config struct { Store store.Store Filesystems []fs.Filesystem Replace replace.Replacer + Rewrite rewrite.Rewriter FFmpeg ffmpeg.FFmpeg MaxProcesses int64 Logger log.Logger @@ -112,6 +114,7 @@ type restream struct { stopObserver context.CancelFunc } replace replace.Replacer + rewrite rewrite.Rewriter tasks map[string]*task logger log.Logger metadata map[string]interface{} @@ -132,6 +135,7 @@ func New(config Config) (Restreamer, error) { createdAt: time.Now(), store: config.Store, replace: config.Replace, + rewrite: config.Rewrite, logger: config.Logger, iam: config.IAM, } @@ -418,16 +422,21 @@ func (r *restream) save() { func (r *restream) enforce(name, group, processid, action string) bool { if len(name) == 0 { - name = "$anon" + // This is for backwards compatibility. Existing processes don't have an owner. + // All processes that will be added later will have an owner ($anon, ...). + identity, err := r.iam.GetDefaultIdentity() + if err != nil { + name = "$anon" + } else { + name = identity.Name() + } } if len(group) == 0 { group = "$none" } - ok, _ := r.iam.Enforce(name, group, "process:"+processid, action) - - return ok + return r.iam.Enforce(name, group, "process:"+processid, action) } func (r *restream) ID() string { @@ -878,37 +887,138 @@ func (r *restream) resolveAddresses(tasks map[string]*task, config *app.Config) } func (r *restream) resolveAddress(tasks map[string]*task, id, address string) (string, error) { - re := regexp.MustCompile(`^#(.+):output=(.+)`) - - if len(address) == 0 { - return address, fmt.Errorf("empty address") + matches, err := parseAddressReference(address) + if err != nil { + return address, err } - if address[0] != '#' { + // Address is not a reference + if _, ok := matches["address"]; ok { return address, nil } - matches := re.FindStringSubmatch(address) - if matches == nil { - return address, fmt.Errorf("invalid format (%s)", address) + if matches["id"] == id { + return address, fmt.Errorf("self-reference is not allowed (%s)", address) } - if matches[1] == id { - return address, fmt.Errorf("self-reference not possible (%s)", address) - } + var t *task = nil - task, ok := tasks[matches[1]] - if !ok { - return address, fmt.Errorf("unknown process '%s' (%s)", matches[1], address) - } - - for _, x := range task.config.Output { - if x.ID == matches[2] { - return x.Address, nil + for _, tsk := range tasks { + if tsk.id == matches["id"] && tsk.group == matches["group"] { + t = tsk + break } } - return address, fmt.Errorf("the process '%s' has no outputs with the ID '%s' (%s)", matches[1], matches[2], address) + if t == nil { + return address, fmt.Errorf("unknown process '%s' in group '%s' (%s)", matches["id"], matches["group"], address) + } + + identity, _ := r.iam.GetIdentity(t.config.Owner) + + teeOptions := regexp.MustCompile(`^\[[^\]]*\]`) + + for _, x := range t.config.Output { + if x.ID != matches["output"] { + continue + } + + // Check for non-tee output + if !strings.Contains(x.Address, "|") && !strings.HasPrefix(x.Address, "[") { + return r.rewrite.RewriteAddress(x.Address, identity, rewrite.READ), nil + } + + // Split tee output in its individual addresses + + addresses := strings.Split(x.Address, "|") + if len(addresses) == 0 { + return x.Address, nil + } + + // Remove tee options + for i, a := range addresses { + addresses[i] = teeOptions.ReplaceAllString(a, "") + } + + if len(matches["source"]) == 0 { + return r.rewrite.RewriteAddress(addresses[0], identity, rewrite.READ), nil + } + + for _, a := range addresses { + u, err := url.Parse(a) + if err != nil { + // Ignore invalid addresses + continue + } + + if matches["source"] == "hls" { + if (u.Scheme == "http" || u.Scheme == "https") && strings.HasSuffix(u.Path, ".m3u8") { + return r.rewrite.RewriteAddress(a, identity, rewrite.READ), nil + } + } else if matches["source"] == "rtmp" { + if u.Scheme == "rtmp" { + return r.rewrite.RewriteAddress(a, identity, rewrite.READ), nil + } + } else if matches["source"] == "srt" { + if u.Scheme == "srt" { + return r.rewrite.RewriteAddress(a, identity, rewrite.READ), nil + } + } + } + + // If none of the sources matched, return the first address + return r.rewrite.RewriteAddress(addresses[0], identity, rewrite.READ), nil + } + + return address, fmt.Errorf("the process '%s' in group '%s' has no outputs with the ID '%s' (%s)", matches["id"], matches["group"], matches["output"], address) +} + +func parseAddressReference(address string) (map[string]string, error) { + if len(address) == 0 { + return nil, fmt.Errorf("empty address") + } + + if address[0] != '#' { + return map[string]string{ + "address": address, + }, nil + } + + re := regexp.MustCompile(`:(output|group|source)=(.+)`) + + results := map[string]string{} + + idEnd := -1 + value := address + key := "" + + for { + matches := re.FindStringSubmatchIndex(value) + if matches == nil { + break + } + + if idEnd < 0 { + idEnd = matches[2] - 1 + } + + if len(key) != 0 { + results[key] = value[:matches[2]-1] + } + + key = value[matches[2]:matches[3]] + value = value[matches[4]:matches[5]] + + results[key] = value + } + + if idEnd < 0 { + return nil, fmt.Errorf("invalid format (%s)", address) + } + + results["id"] = address[1:idEnd] + + return results, nil } func (r *restream) UpdateProcess(id, user, group string, config *app.Config) error { @@ -1661,6 +1771,7 @@ func (r *restream) GetMetadata(key string) (interface{}, error) { func resolvePlaceholders(config *app.Config, r replace.Replacer) { vars := map[string]string{ "processid": config.ID, + "owner": config.Owner, "reference": config.Reference, "group": config.Group, } diff --git a/restream/restream_test.go b/restream/restream_test.go index ca256fe6..71a174e8 100644 --- a/restream/restream_test.go +++ b/restream/restream_test.go @@ -12,6 +12,7 @@ import ( "github.com/datarhei/core/v16/net" "github.com/datarhei/core/v16/restream/app" "github.com/datarhei/core/v16/restream/replace" + "github.com/datarhei/core/v16/restream/rewrite" "github.com/stretchr/testify/require" ) @@ -49,9 +50,15 @@ func getDummyRestreamer(portrange net.Portranger, validatorIn, validatorOut ffmp iam.AddPolicy("$anon", "$none", "process:*", "CREATE|GET|DELETE|UPDATE|COMMAND|PROBE|METADATA|PLAYOUT") + rewriter, err := rewrite.New(rewrite.Config{}) + if err != nil { + return nil, err + } + rs, err := New(Config{ FFmpeg: ffmpeg, Replace: replacer, + Rewrite: rewriter, IAM: iam, }) if err != nil { @@ -528,6 +535,39 @@ func TestPlayoutRange(t *testing.T) { require.Equal(t, "127.0.0.1:3000", addr, "the playout address should be 127.0.0.1:3000") } +func TestParseAddressReference(t *testing.T) { + matches, err := parseAddressReference("foobar") + require.NoError(t, err) + require.Equal(t, "foobar", matches["address"]) + + _, err = parseAddressReference("#foobar") + require.Error(t, err) + + _, err = parseAddressReference("#foobar:nothing=foo") + require.Error(t, err) + + matches, err = parseAddressReference("#foobar:output=foo") + require.NoError(t, err) + require.Equal(t, "foobar", matches["id"]) + require.Equal(t, "foo", matches["output"]) + + matches, err = parseAddressReference("#foobar:group=foo") + require.NoError(t, err) + require.Equal(t, "foobar", matches["id"]) + require.Equal(t, "foo", matches["group"]) + + matches, err = parseAddressReference("#foobar:nothing=foo:output=bar") + require.NoError(t, err) + require.Equal(t, "foobar:nothing=foo", matches["id"]) + require.Equal(t, "bar", matches["output"]) + + matches, err = parseAddressReference("#foobar:output=foo:group=bar") + require.NoError(t, err) + require.Equal(t, "foobar", matches["id"]) + require.Equal(t, "foo", matches["output"]) + require.Equal(t, "bar", matches["group"]) +} + func TestAddressReference(t *testing.T) { rs, err := getDummyRestreamer(nil, nil, nil, nil) require.NoError(t, err) @@ -559,6 +599,44 @@ func TestAddressReference(t *testing.T) { require.Equal(t, nil, err, "should resolve reference") } +func TestTeeAddressReference(t *testing.T) { + rs, err := getDummyRestreamer(nil, nil, nil, nil) + require.NoError(t, err) + + process1 := getDummyProcess() + process2 := getDummyProcess() + process3 := getDummyProcess() + process4 := getDummyProcess() + + process1.Output[0].Address = "[f=hls]http://example.com/live.m3u8|[f=flv]rtmp://example.com/live.stream?token=123" + process2.ID = "process2" + process3.ID = "process3" + process4.ID = "process4" + + rs.AddProcess(process1) + + process2.Input[0].Address = "#process:output=out" + + err = rs.AddProcess(process2) + require.Equal(t, nil, err, "should resolve reference") + + process3.Input[0].Address = "#process:output=out:source=hls" + + err = rs.AddProcess(process3) + require.Equal(t, nil, err, "should resolve reference") + + process4.Input[0].Address = "#process:output=out:source=rtmp" + + err = rs.AddProcess(process4) + require.Equal(t, nil, err, "should resolve reference") + + r := rs.(*restream) + + require.Equal(t, "http://example.com/live.m3u8", r.tasks["process2~"].config.Input[0].Address) + require.Equal(t, "http://example.com/live.m3u8", r.tasks["process3~"].config.Input[0].Address) + require.Equal(t, "rtmp://example.com/live.stream?token=123", r.tasks["process4~"].config.Input[0].Address) +} + func TestConfigValidation(t *testing.T) { rsi, err := getDummyRestreamer(nil, nil, nil, nil) require.NoError(t, err) @@ -863,5 +941,8 @@ func TestReplacer(t *testing.T) { StaleTimeout: 0, } - require.Equal(t, process, rs.tasks["314159265359"].config) + task, ok := rs.tasks["314159265359~"] + require.True(t, ok) + + require.Equal(t, process, task.config) } diff --git a/restream/rewrite/rewrite.go b/restream/rewrite/rewrite.go new file mode 100644 index 00000000..1112ce23 --- /dev/null +++ b/restream/rewrite/rewrite.go @@ -0,0 +1,156 @@ +// Package rewrite provides facilities for rewriting a local HLS, RTMP, and SRT address. +package rewrite + +import ( + "fmt" + "net/url" + + "github.com/datarhei/core/v16/iam" + "github.com/datarhei/core/v16/rtmp" + srturl "github.com/datarhei/core/v16/srt/url" +) + +type Access string + +var ( + READ Access = "read" + WRITE Access = "write" +) + +type Config struct { + HTTPBase string + RTMPBase string + SRTBase string +} + +// to a new identity, i.e. adjusting the credentials to the given identity. +type Rewriter interface { + RewriteAddress(address string, identity iam.IdentityVerifier, mode Access) string +} + +type rewrite struct { + httpBase string + rtmpBase string + srtBase string +} + +func New(config Config) (Rewriter, error) { + r := &rewrite{ + httpBase: config.HTTPBase, + rtmpBase: config.RTMPBase, + srtBase: config.SRTBase, + } + + return r, nil +} + +func (g *rewrite) RewriteAddress(address string, identity iam.IdentityVerifier, mode Access) string { + u, err := url.Parse(address) + if err != nil { + return address + } + + // Decide whether this is our local server + if !g.isLocal(u) { + return address + } + + if identity == nil { + return address + } + + if u.Scheme == "http" || u.Scheme == "https" { + return g.httpURL(u, mode, identity) + } else if u.Scheme == "rtmp" { + return g.rtmpURL(u, mode, identity) + } else if u.Scheme == "srt" { + return g.srtURL(u, mode, identity) + } + + return address +} + +func (g *rewrite) isLocal(u *url.URL) bool { + var base *url.URL + var err error + + if u.Scheme == "http" || u.Scheme == "https" { + base, err = url.Parse(g.httpBase) + } else if u.Scheme == "rtmp" { + base, err = url.Parse(g.rtmpBase) + } else if u.Scheme == "srt" { + base, err = url.Parse(g.srtBase) + } else { + err = fmt.Errorf("unsupported scheme") + } + + if err != nil { + return false + } + + hostname := u.Hostname() + port := u.Port() + + if base.Hostname() == "localhost" { + if hostname != "localhost" && hostname != "127.0.0.1" && hostname != "::1" { + return false + } + + hostname = "localhost" + } + + host := hostname + ":" + port + + return host == base.Host +} + +func (g *rewrite) httpURL(u *url.URL, mode Access, identity iam.IdentityVerifier) string { + password := identity.GetServiceBasicAuth() + + if len(password) == 0 { + u.User = nil + } else { + u.User = url.UserPassword(identity.Name(), password) + } + + return u.String() +} + +func (g *rewrite) rtmpURL(u *url.URL, mode Access, identity iam.IdentityVerifier) string { + token := identity.GetServiceToken() + + // Remove the existing token from the path + path, _ := rtmp.GetToken(u) + u.Path = path + + q := u.Query() + q.Set("token", token) + + u.RawQuery = q.Encode() + + return u.String() +} + +func (g *rewrite) srtURL(u *url.URL, mode Access, identity iam.IdentityVerifier) string { + token := identity.GetServiceToken() + + q := u.Query() + + streamInfo, err := srturl.ParseStreamId(q.Get("streamid")) + if err != nil { + return u.String() + } + + streamInfo.Token = token + + if mode == WRITE { + streamInfo.Mode = "publish" + } else { + streamInfo.Mode = "request" + } + + q.Set("streamid", streamInfo.String()) + u.RawQuery = q.Encode() + + return u.String() +} diff --git a/restream/rewrite/rewrite_test.go b/restream/rewrite/rewrite_test.go new file mode 100644 index 00000000..3e0e5fe9 --- /dev/null +++ b/restream/rewrite/rewrite_test.go @@ -0,0 +1,156 @@ +package rewrite + +import ( + "net/url" + "testing" + + "github.com/datarhei/core/v16/iam" + "github.com/datarhei/core/v16/io/fs" + + "github.com/stretchr/testify/require" +) + +func getIdentityManager(enableBasic bool) iam.IdentityManager { + dummyfs, _ := fs.NewMemFilesystem(fs.MemConfig{}) + + im, _ := iam.NewIdentityManager(iam.IdentityConfig{ + FS: dummyfs, + Superuser: iam.User{ + Name: "foobar", + Superuser: false, + Auth: iam.UserAuth{ + API: iam.UserAuthAPI{}, + Services: iam.UserAuthServices{ + Basic: iam.UserAuthPassword{ + Enable: enableBasic, + Password: "basicauthpassword", + }, + Token: []string{"servicetoken"}, + }, + }, + }, + JWTRealm: "", + JWTSecret: "", + Logger: nil, + }) + + return im +} + +func TestRewriteHTTP(t *testing.T) { + im := getIdentityManager(false) + + rewrite, err := New(Config{ + HTTPBase: "http://localhost:8080/", + }) + require.NoError(t, err) + require.NotNil(t, rewrite) + + identity, err := im.GetVerifier("foobar") + require.NoError(t, err) + require.NotNil(t, identity) + + samples := [][3]string{ + {"http://example.com/live/stream.m3u8", "read", "http://example.com/live/stream.m3u8"}, + {"http://example.com/live/stream.m3u8", "write", "http://example.com/live/stream.m3u8"}, + {"http://localhost:8181/live/stream.m3u8", "read", "http://localhost:8181/live/stream.m3u8"}, + {"http://localhost:8181/live/stream.m3u8", "write", "http://localhost:8181/live/stream.m3u8"}, + {"http://localhost:8080/live/stream.m3u8", "read", "http://localhost:8080/live/stream.m3u8"}, + {"http://localhost:8080/live/stream.m3u8", "write", "http://localhost:8080/live/stream.m3u8"}, + {"http://admin:pass@localhost:8080/live/stream.m3u8", "read", "http://localhost:8080/live/stream.m3u8"}, + {"http://admin:pass@localhost:8080/live/stream.m3u8", "write", "http://localhost:8080/live/stream.m3u8"}, + } + + for _, e := range samples { + rewritten := rewrite.RewriteAddress(e[0], identity, Access(e[1])) + require.Equal(t, e[2], rewritten, "%s %s", e[0], e[1]) + } +} + +func TestRewriteHTTPPassword(t *testing.T) { + im := getIdentityManager(true) + + rewrite, err := New(Config{ + HTTPBase: "http://localhost:8080/", + }) + require.NoError(t, err) + require.NotNil(t, rewrite) + + identity, err := im.GetVerifier("foobar") + require.NoError(t, err) + require.NotNil(t, identity) + + samples := [][3]string{ + {"http://example.com/live/stream.m3u8", "read", "http://example.com/live/stream.m3u8"}, + {"http://example.com/live/stream.m3u8", "write", "http://example.com/live/stream.m3u8"}, + {"http://localhost:8181/live/stream.m3u8", "read", "http://localhost:8181/live/stream.m3u8"}, + {"http://localhost:8181/live/stream.m3u8", "write", "http://localhost:8181/live/stream.m3u8"}, + {"http://localhost:8080/live/stream.m3u8", "read", "http://foobar:basicauthpassword@localhost:8080/live/stream.m3u8"}, + {"http://localhost:8080/live/stream.m3u8", "write", "http://foobar:basicauthpassword@localhost:8080/live/stream.m3u8"}, + {"http://admin:pass@localhost:8080/live/stream.m3u8", "read", "http://foobar:basicauthpassword@localhost:8080/live/stream.m3u8"}, + {"http://admin:pass@localhost:8080/live/stream.m3u8", "write", "http://foobar:basicauthpassword@localhost:8080/live/stream.m3u8"}, + } + + for _, e := range samples { + rewritten := rewrite.RewriteAddress(e[0], identity, Access(e[1])) + require.Equal(t, e[2], rewritten, "%s %s", e[0], e[1]) + } +} + +func TestRewriteRTMP(t *testing.T) { + im := getIdentityManager(false) + + rewrite, err := New(Config{ + RTMPBase: "rtmp://localhost:1935/live", + }) + require.NoError(t, err) + require.NotNil(t, rewrite) + + identity, err := im.GetVerifier("foobar") + require.NoError(t, err) + require.NotNil(t, identity) + + samples := [][3]string{ + {"rtmp://example.com/live/stream", "read", "rtmp://example.com/live/stream"}, + {"rtmp://example.com/live/stream", "write", "rtmp://example.com/live/stream"}, + {"rtmp://localhost:1936/live/stream/token", "read", "rtmp://localhost:1936/live/stream/token"}, + {"rtmp://localhost:1936/live/stream?token=token", "write", "rtmp://localhost:1936/live/stream?token=token"}, + {"rtmp://localhost:1935/live/stream?token=token", "read", "rtmp://localhost:1935/live/stream?token=" + url.QueryEscape("foobar:servicetoken")}, + {"rtmp://localhost:1935/live/stream/token", "write", "rtmp://localhost:1935/live/stream?token=" + url.QueryEscape("foobar:servicetoken")}, + } + + for _, e := range samples { + rewritten := rewrite.RewriteAddress(e[0], identity, Access(e[1])) + require.Equal(t, e[2], rewritten, "%s %s", e[0], e[1]) + } +} + +func TestRewriteSRT(t *testing.T) { + im := getIdentityManager(false) + + rewrite, err := New(Config{ + SRTBase: "srt://localhost:6000/", + }) + require.NoError(t, err) + require.NotNil(t, rewrite) + + identity, err := im.GetVerifier("foobar") + require.NoError(t, err) + require.NotNil(t, identity) + + samples := [][3]string{ + {"srt://example.com/?streamid=stream", "read", "srt://example.com/?streamid=stream"}, + {"srt://example.com/?streamid=stream", "write", "srt://example.com/?streamid=stream"}, + {"srt://localhost:1936/?streamid=live/stream", "read", "srt://localhost:1936/?streamid=live/stream"}, + {"srt://localhost:1936/?streamid=live/stream", "write", "srt://localhost:1936/?streamid=live/stream"}, + {"srt://localhost:6000/?streamid=live/stream,mode:publish,token:token", "read", "srt://localhost:6000/?streamid=" + url.QueryEscape("live/stream,token:foobar:servicetoken")}, + {"srt://localhost:6000/?streamid=live/stream,mode:publish,token:token", "write", "srt://localhost:6000/?streamid=" + url.QueryEscape("live/stream,mode:publish,token:foobar:servicetoken")}, + {"srt://localhost:6000/?streamid=" + url.QueryEscape("#!:r=live/stream,m=publish,token=token"), "read", "srt://localhost:6000/?streamid=" + url.QueryEscape("live/stream,token:foobar:servicetoken")}, + {"srt://localhost:6000/?streamid=" + url.QueryEscape("#!:r=live/stream,m=publish,token=token"), "write", "srt://localhost:6000/?streamid=" + url.QueryEscape("live/stream,mode:publish,token:foobar:servicetoken")}, + } + + for _, e := range samples { + rewritten := rewrite.RewriteAddress(e[0], identity, Access(e[1])) + require.Equal(t, e[2], rewritten, "%s %s", e[0], e[1]) + } +} diff --git a/restream/store/data.go b/restream/store/data.go index a93dbc7d..c644d806 100644 --- a/restream/store/data.go +++ b/restream/store/data.go @@ -16,7 +16,7 @@ type StoreData struct { func NewStoreData() StoreData { c := StoreData{ - Version: 4, + Version: version, } c.Process = make(map[string]*app.Process) diff --git a/restream/store/json.go b/restream/store/json.go index 36e5720e..52615c38 100644 --- a/restream/store/json.go +++ b/restream/store/json.go @@ -26,7 +26,10 @@ type jsonStore struct { lock sync.RWMutex } -var version uint64 = 4 +// version 4 -> 5: +// process groups have been added. the indices for the maps are only the process IDs in version 4. +// version 5 adds the group name as suffix to the process ID with a "~". +var version uint64 = 5 func NewJSON(config JSONConfig) (Store, error) { s := &jsonStore{ @@ -123,12 +126,29 @@ func (s *jsonStore) load(filepath string, version uint64) (StoreData, error) { return r, json.FormatError(jsondata, err) } - if db.Version != version { - return r, fmt.Errorf("unsupported version of the DB file (want: %d, have: %d)", version, db.Version) - } + if db.Version == 4 { + rold := NewStoreData() + if err = gojson.Unmarshal(jsondata, &rold); err != nil { + return r, json.FormatError(jsondata, err) + } - if err = gojson.Unmarshal(jsondata, &r); err != nil { - return r, json.FormatError(jsondata, err) + for id, p := range rold.Process { + r.Process[id+"~"] = p + } + + for key, p := range rold.Metadata.System { + r.Metadata.System[key] = p + } + + for id, p := range rold.Metadata.Process { + r.Metadata.Process[id+"~"] = p + } + } else if db.Version == version { + if err = gojson.Unmarshal(jsondata, &r); err != nil { + return r, json.FormatError(jsondata, err) + } + } else { + return r, fmt.Errorf("unsupported version of the DB file (want: %d, have: %d)", version, db.Version) } s.logger.WithField("file", filepath).Debug().Log("Read data") diff --git a/restream/store/json_test.go b/restream/store/json_test.go index 8b2c4698..c9e0c843 100644 --- a/restream/store/json_test.go +++ b/restream/store/json_test.go @@ -76,11 +76,11 @@ func TestNotExists(t *testing.T) { func TestStore(t *testing.T) { fs := getFS(t) - fs.Remove("./fixtures/v4_store.json") + fs.Remove("./fixtures/v5_store.json") store, err := NewJSON(JSONConfig{ Filesystem: fs, - Filepath: "./fixtures/v4_store.json", + Filepath: "./fixtures/v5_store.json", }) require.NoError(t, err) @@ -90,13 +90,14 @@ func TestStore(t *testing.T) { data.Metadata.System["somedata"] = "foobar" - store.Store(data) + err = store.Store(data) + require.NoError(t, err) data2, err := store.Load() require.NoError(t, err) require.Equal(t, data, data2) - fs.Remove("./fixtures/v4_store.json") + fs.Remove("./fixtures/v5_store.json") } func TestInvalidVersion(t *testing.T) { diff --git a/rtmp/channel.go b/rtmp/channel.go new file mode 100644 index 00000000..9ea68fe3 --- /dev/null +++ b/rtmp/channel.go @@ -0,0 +1,164 @@ +package rtmp + +import ( + "context" + "net" + "net/url" + "sync" + "time" + + "github.com/datarhei/core/v16/session" + "github.com/datarhei/joy4/av" + "github.com/datarhei/joy4/av/pubsub" + "github.com/datarhei/joy4/format/rtmp" +) + +type client struct { + conn connection + id string + createdAt time.Time + + txbytes uint64 + rxbytes uint64 + + collector session.Collector + + cancel context.CancelFunc +} + +func newClient(conn connection, id string, collector session.Collector) *client { + c := &client{ + conn: conn, + id: id, + createdAt: time.Now(), + + collector: collector, + } + + var ctx context.Context + ctx, c.cancel = context.WithCancel(context.Background()) + + go c.ticker(ctx) + + return c +} + +func (c *client) ticker(ctx context.Context) { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + txbytes := c.conn.TxBytes() + rxbytes := c.conn.RxBytes() + + c.collector.Ingress(c.id, int64(rxbytes-c.rxbytes)) + c.collector.Egress(c.id, int64(txbytes-c.txbytes)) + + c.txbytes = txbytes + c.rxbytes = rxbytes + } + } +} + +func (c *client) Close() { + c.cancel() + c.conn.Close() +} + +// channel represents a stream that is sent to the server +type channel struct { + // The packet queue for the stream + queue *pubsub.Queue + + // The metadata of the stream + streams []av.CodecData + + // Whether the stream has an audio track + hasAudio bool + + // Whether the stream has a video track + hasVideo bool + + collector session.Collector + path string + reference string + + publisher *client + subscriber map[string]*client + lock sync.RWMutex + + isProxy bool +} + +func newChannel(conn connection, u *url.URL, reference string, remote net.Addr, streams []av.CodecData, isProxy bool, collector session.Collector) *channel { + ch := &channel{ + path: u.Path, + reference: reference, + publisher: newClient(conn, u.Path, collector), + subscriber: make(map[string]*client), + collector: collector, + streams: streams, + queue: pubsub.NewQueue(), + isProxy: isProxy, + } + + ch.queue.WriteHeader(streams) + + addr := remote.String() + ip, _, _ := net.SplitHostPort(addr) + + if collector.IsCollectableIP(ip) { + collector.RegisterAndActivate(ch.path, ch.reference, "publish:"+ch.path, addr) + } + + return ch +} + +func (ch *channel) Close() { + if ch.publisher == nil { + return + } + + ch.publisher.Close() + ch.publisher = nil + + ch.queue.Close() +} + +func (ch *channel) AddSubscriber(conn *rtmp.Conn) string { + addr := conn.NetConn().RemoteAddr().String() + ip, _, _ := net.SplitHostPort(addr) + + client := newClient(conn, addr, ch.collector) + + if ch.collector.IsCollectableIP(ip) { + ch.collector.RegisterAndActivate(addr, ch.reference, "play:"+conn.URL.Path, addr) + } + + ch.lock.Lock() + ch.subscriber[addr] = client + ch.lock.Unlock() + + return addr +} + +func (ch *channel) RemoveSubscriber(id string) { + ch.lock.Lock() + defer ch.lock.Unlock() + + client := ch.subscriber[id] + if client != nil { + delete(ch.subscriber, id) + client.Close() + } + + // If this is a proxied channel and the last subscriber leaves, + // close the channel. + if len(ch.subscriber) == 0 && ch.isProxy { + ch.Close() + } +} diff --git a/rtmp/connection.go b/rtmp/connection.go new file mode 100644 index 00000000..43cee496 --- /dev/null +++ b/rtmp/connection.go @@ -0,0 +1,104 @@ +package rtmp + +import ( + "fmt" + + "github.com/datarhei/joy4/av" +) + +type connection interface { + av.MuxCloser + av.DemuxCloser + TxBytes() uint64 + RxBytes() uint64 +} + +// conn implements the connection interface +type conn struct { + muxer av.MuxCloser + demuxer av.DemuxCloser + + txbytes uint64 + rxbytes uint64 +} + +// Make sure that conn implements the connection interface +var _ connection = &conn{} + +func newConnectionFromDemuxer(m av.DemuxCloser) connection { + c := &conn{ + demuxer: m, + } + + return c +} + +func (c *conn) TxBytes() uint64 { + return c.txbytes +} + +func (c *conn) RxBytes() uint64 { + return c.rxbytes +} + +func (c *conn) ReadPacket() (av.Packet, error) { + if c.demuxer != nil { + p, err := c.demuxer.ReadPacket() + if err == nil { + c.rxbytes += uint64(len(p.Data)) + } + + return p, err + } + + return av.Packet{}, fmt.Errorf("no demuxer available") +} + +func (c *conn) Streams() ([]av.CodecData, error) { + if c.demuxer != nil { + return c.demuxer.Streams() + } + + return nil, fmt.Errorf("no demuxer available") +} + +func (c *conn) WritePacket(p av.Packet) error { + if c.muxer != nil { + err := c.muxer.WritePacket(p) + if err == nil { + c.txbytes += uint64(len(p.Data)) + } + + return err + } + + return fmt.Errorf("no muxer available") +} + +func (c *conn) WriteHeader(streams []av.CodecData) error { + if c.muxer != nil { + return c.muxer.WriteHeader(streams) + } + + return fmt.Errorf("no muxer available") +} + +func (c *conn) WriteTrailer() error { + if c.muxer != nil { + return c.muxer.WriteTrailer() + } + + return fmt.Errorf("no muxer available") +} + +func (c *conn) Close() error { + if c.muxer != nil { + return c.muxer.Close() + } + + if c.demuxer != nil { + return c.demuxer.Close() + } + + return nil +} diff --git a/rtmp/rtmp.go b/rtmp/rtmp.go index aa0b80ad..3331fdbf 100644 --- a/rtmp/rtmp.go +++ b/rtmp/rtmp.go @@ -2,7 +2,6 @@ package rtmp import ( - "context" "crypto/tls" "fmt" "net" @@ -10,7 +9,6 @@ import ( "path/filepath" "strings" "sync" - "time" "github.com/datarhei/core/v16/iam" "github.com/datarhei/core/v16/log" @@ -18,9 +16,7 @@ import ( "github.com/datarhei/joy4/av/avutil" "github.com/datarhei/joy4/av/pktque" - "github.com/datarhei/joy4/av/pubsub" "github.com/datarhei/joy4/format" - "github.com/datarhei/joy4/format/flv/flvio" "github.com/datarhei/joy4/format/rtmp" ) @@ -32,142 +28,6 @@ func init() { format.RegisterAll() } -type client struct { - conn *rtmp.Conn - id string - createdAt time.Time - - txbytes uint64 - rxbytes uint64 - - collector session.Collector - - cancel context.CancelFunc -} - -func newClient(conn *rtmp.Conn, id string, collector session.Collector) *client { - c := &client{ - conn: conn, - id: id, - createdAt: time.Now(), - - collector: collector, - } - - var ctx context.Context - ctx, c.cancel = context.WithCancel(context.Background()) - - go c.ticker(ctx) - - return c -} - -func (c *client) ticker(ctx context.Context) { - ticker := time.NewTicker(1 * time.Second) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - txbytes := c.conn.TxBytes() - rxbytes := c.conn.RxBytes() - - c.collector.Ingress(c.id, int64(rxbytes-c.rxbytes)) - c.collector.Egress(c.id, int64(txbytes-c.txbytes)) - - c.txbytes = txbytes - c.rxbytes = rxbytes - } - } -} - -func (c *client) Close() { - c.cancel() -} - -// channel represents a stream that is sent to the server -type channel struct { - // The packet queue for the stream - queue *pubsub.Queue - - // The metadata of the stream - metadata flvio.AMFMap - - // Whether the stream has an audio track - hasAudio bool - - // Whether the stream has a video track - hasVideo bool - - collector session.Collector - path string - reference string - - publisher *client - subscriber map[string]*client - lock sync.RWMutex -} - -func newChannel(conn *rtmp.Conn, reference string, collector session.Collector) *channel { - ch := &channel{ - path: conn.URL.Path, - reference: reference, - publisher: newClient(conn, conn.URL.Path, collector), - subscriber: make(map[string]*client), - collector: collector, - } - - addr := conn.NetConn().RemoteAddr().String() - ip, _, _ := net.SplitHostPort(addr) - - if collector.IsCollectableIP(ip) { - collector.RegisterAndActivate(ch.path, ch.reference, "publish:"+ch.path, addr) - } - - return ch -} - -func (ch *channel) Close() { - if ch.publisher == nil { - return - } - - ch.publisher.Close() - ch.publisher = nil - - ch.queue.Close() -} - -func (ch *channel) AddSubscriber(conn *rtmp.Conn) string { - addr := conn.NetConn().RemoteAddr().String() - ip, _, _ := net.SplitHostPort(addr) - - client := newClient(conn, addr, ch.collector) - - if ch.collector.IsCollectableIP(ip) { - ch.collector.RegisterAndActivate(addr, ch.reference, "play:"+ch.path, addr) - } - - ch.lock.Lock() - ch.subscriber[addr] = client - ch.lock.Unlock() - - return addr -} - -func (ch *channel) RemoveSubscriber(id string) { - ch.lock.Lock() - defer ch.lock.Unlock() - - client := ch.subscriber[id] - if client != nil { - delete(ch.subscriber, id) - client.Close() - } -} - // Config for a new RTMP server type Config struct { // Logger. Optional. @@ -333,17 +193,15 @@ func (s *server) log(who, action, path, message string, client net.Addr) { }).Log(message) } -// getToken returns the path and the token found in the URL. If the token +// GetToken returns the path and the token found in the URL. If the token // was part of the path, the token is removed from the path. The token in // the query string takes precedence. The token in the path is assumed to // be the last path element. -func getToken(u *url.URL) (string, string) { +func GetToken(u *url.URL) (string, string) { q := u.Query() - token := q.Get("token") - - if len(token) != 0 { + if q.Has("token") { // The token was in the query. Return the unmomdified path and the token - return u.Path, token + return u.Path, q.Get("token") } pathElements := strings.Split(u.EscapedPath(), "/") @@ -359,35 +217,24 @@ func getToken(u *url.URL) (string, string) { // handlePlay is called when a RTMP client wants to play a stream func (s *server) handlePlay(conn *rtmp.Conn) { - client := conn.NetConn().RemoteAddr() - defer conn.Close() - playPath, token := getToken(conn.URL) + remote := conn.NetConn().RemoteAddr() + playPath, token := GetToken(conn.URL) identity, err := s.findIdentityFromStreamKey(token) if err != nil { - s.logger.Debug().WithError(err).Log("no valid identity found") - s.log("PLAY", "FORBIDDEN", playPath, "invalid streamkey ("+token+")", client) + s.logger.Debug().WithError(err).Log("invalid streamkey") + s.log("PLAY", "FORBIDDEN", playPath, "invalid streamkey ("+token+")", remote) return } domain := s.findDomainFromPlaypath(playPath) resource := "rtmp:" + playPath - l := s.logger.Debug().WithFields(log.Fields{ - "name": identity.Name(), - "domain": domain, - "resource": resource, - "action": "PLAY", - }) - - if ok, rule := s.iam.Enforce(identity.Name(), domain, resource, "PLAY"); !ok { - l.Log("access denied") - s.log("PLAY", "FORBIDDEN", playPath, "invalid streamkey ("+token+")", client) + if !s.iam.Enforce(identity, domain, resource, "PLAY") { + s.log("PLAY", "FORBIDDEN", playPath, "access denied", remote) return - } else { - l.Log(rule) } /* @@ -415,10 +262,10 @@ func (s *server) handlePlay(conn *rtmp.Conn) { s.lock.RUnlock() if ch != nil { - // Set the metadata for the client - conn.SetMetaData(ch.metadata) + // Send the metadata to the client + conn.WriteHeader(ch.streams) - s.log("PLAY", "START", playPath, "", client) + s.log("PLAY", "START", conn.URL.Path, "", remote) // Get a cursor and apply filters cursor := ch.queue.Oldest() @@ -440,75 +287,68 @@ func (s *server) handlePlay(conn *rtmp.Conn) { id := ch.AddSubscriber(conn) - // Transfer the data + // Transfer the data, blocks until done avutil.CopyFile(conn, demuxer) ch.RemoveSubscriber(id) - s.log("PLAY", "STOP", playPath, "", client) + s.log("PLAY", "STOP", playPath, "", remote) } else { - s.log("PLAY", "NOTFOUND", playPath, "", client) + s.log("PLAY", "NOTFOUND", playPath, "", remote) } } // handlePublish is called when a RTMP client wants to publish a stream func (s *server) handlePublish(conn *rtmp.Conn) { - client := conn.NetConn().RemoteAddr() - defer conn.Close() - playPath, token := getToken(conn.URL) + remote := conn.NetConn().RemoteAddr() + playPath, token := GetToken(conn.URL) // Check the app patch if !strings.HasPrefix(playPath, s.app) { - s.log("PUBLISH", "FORBIDDEN", conn.URL.Path, "invalid app", client) + s.log("PUBLISH", "FORBIDDEN", conn.URL.Path, "invalid app", remote) return } identity, err := s.findIdentityFromStreamKey(token) if err != nil { - s.logger.Debug().WithError(err).Log("no valid identity found") - s.log("PUBLISH", "FORBIDDEN", playPath, "invalid streamkey ("+token+")", client) + s.logger.Debug().WithError(err).Log("invalid streamkey") + s.log("PUBLISH", "FORBIDDEN", playPath, "invalid streamkey ("+token+")", remote) return } domain := s.findDomainFromPlaypath(playPath) resource := "rtmp:" + playPath - l := s.logger.Debug().WithFields(log.Fields{ - "name": identity.Name(), - "domain": domain, - "resource": resource, - "action": "PUBLISH", - }) - - if ok, rule := s.iam.Enforce(identity.Name(), domain, "rtmp:"+playPath, "PUBLISH"); !ok { - l.Log("access denied") - s.log("PUBLISH", "FORBIDDEN", playPath, "invalid streamkey ("+token+")", client) + if !s.iam.Enforce(identity, domain, resource, "PUBLISH") { + s.log("PUBLISH", "FORBIDDEN", playPath, "access denied", remote) return - } else { - l.Log(rule) } - // Check the stream if it contains any valid/known streams - streams, _ := conn.Streams() + err = s.publish(conn, conn.URL, remote, false) + if err != nil { + s.logger.WithField("path", conn.URL.Path).WithError(err).Log("") + } +} + +func (s *server) publish(src connection, u *url.URL, remote net.Addr, isProxy bool) error { + // Check the streams if it contains any valid/known streams + streams, _ := src.Streams() if len(streams) == 0 { - s.log("PUBLISH", "INVALID", playPath, "no streams available", client) - return + s.log("PUBLISH", "INVALID", u.Path, "no streams available", remote) + return fmt.Errorf("no streams are available") } s.lock.Lock() - ch := s.channels[conn.URL.Path] + ch := s.channels[u.Path] if ch == nil { - reference := strings.TrimPrefix(strings.TrimSuffix(playPath, filepath.Ext(playPath)), s.app+"/") + reference := strings.TrimPrefix(strings.TrimSuffix(u.Path, filepath.Ext(u.Path)), s.app+"/") // Create a new channel - ch = newChannel(conn, reference, s.collector) - ch.metadata = conn.GetMetaData() - ch.queue = pubsub.NewQueue() - ch.queue.WriteHeader(streams) + ch = newChannel(src, u, reference, remote, streams, isProxy, s.collector) for _, stream := range streams { typ := stream.Type() @@ -521,7 +361,7 @@ func (s *server) handlePublish(conn *rtmp.Conn) { } } - s.channels[playPath] = ch + s.channels[u.Path] = ch } else { ch = nil } @@ -529,48 +369,58 @@ func (s *server) handlePublish(conn *rtmp.Conn) { s.lock.Unlock() if ch == nil { - s.log("PUBLISH", "CONFLICT", playPath, "already publishing", client) - return + s.log("PUBLISH", "CONFLICT", u.Path, "already publishing", remote) + return fmt.Errorf("already publishing") } - s.log("PUBLISH", "START", playPath, "", client) + s.log("PUBLISH", "START", u.Path, "", remote) for _, stream := range streams { - s.log("PUBLISH", "STREAM", playPath, stream.Type().String(), client) + s.log("PUBLISH", "STREAM", u.Path, stream.Type().String(), remote) } - // Ingest the data - avutil.CopyPackets(ch.queue, conn) + // Ingest the data, blocks until done + avutil.CopyPackets(ch.queue, src) s.lock.Lock() - delete(s.channels, playPath) + delete(s.channels, u.Path) s.lock.Unlock() ch.Close() - s.log("PUBLISH", "STOP", playPath, "", client) + s.log("PUBLISH", "STOP", u.Path, "", remote) + + return nil } -func (s *server) findIdentityFromStreamKey(key string) (iam.IdentityVerifier, error) { +func (s *server) findIdentityFromStreamKey(key string) (string, error) { + if len(key) == 0 { + return "$anon", nil + } + var identity iam.IdentityVerifier var err error + var token string + elements := strings.Split(key, ":") if len(elements) == 1 { identity, err = s.iam.GetDefaultIdentity() + token = elements[0] } else { identity, err = s.iam.GetIdentity(elements[0]) + token = elements[1] } if err != nil { - return nil, fmt.Errorf("invalid token: %w", err) + return "$anon", nil } - if ok, err := identity.VerifyServiceToken(elements[1]); !ok { - return nil, fmt.Errorf("invalid token: %w", err) + if ok, err := identity.VerifyServiceToken(token); !ok { + return "$anon", fmt.Errorf("invalid token: %w", err) } - return identity, nil + return identity.Name(), nil } func (s *server) findDomainFromPlaypath(path string) string { @@ -578,7 +428,7 @@ func (s *server) findDomainFromPlaypath(path string) string { elements := strings.Split(path, "/") if len(elements) == 1 { - return "" + return "$none" } domain := elements[0] @@ -587,5 +437,5 @@ func (s *server) findDomainFromPlaypath(path string) string { return domain } - return "" + return "$none" } diff --git a/rtmp/rtmp_test.go b/rtmp/rtmp_test.go index 20bb5274..37848afc 100644 --- a/rtmp/rtmp_test.go +++ b/rtmp/rtmp_test.go @@ -18,7 +18,7 @@ func TestToken(t *testing.T) { u, err := url.Parse(d[0]) require.NoError(t, err) - path, token := getToken(u) + path, token := GetToken(u) require.Equal(t, d[1], path, "url=%s", u.String()) require.Equal(t, d[2], token, "url=%s", u.String()) diff --git a/srt/channel.go b/srt/channel.go new file mode 100644 index 00000000..ee2fbad9 --- /dev/null +++ b/srt/channel.go @@ -0,0 +1,147 @@ +package srt + +import ( + "context" + "net" + "sync" + "time" + + "github.com/datarhei/core/v16/session" + srt "github.com/datarhei/gosrt" +) + +type client struct { + conn srt.Conn + id string + createdAt time.Time + + txbytes uint64 + rxbytes uint64 + + collector session.Collector + + cancel context.CancelFunc +} + +func newClient(conn srt.Conn, id string, collector session.Collector) *client { + c := &client{ + conn: conn, + id: id, + createdAt: time.Now(), + + collector: collector, + } + + var ctx context.Context + ctx, c.cancel = context.WithCancel(context.Background()) + + go c.ticker(ctx) + + return c +} + +func (c *client) ticker(ctx context.Context) { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + stats := &srt.Statistics{} + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + c.conn.Stats(stats) + + rxbytes := stats.Accumulated.ByteRecv + txbytes := stats.Accumulated.ByteSent + + c.collector.Ingress(c.id, int64(rxbytes-c.rxbytes)) + c.collector.Egress(c.id, int64(txbytes-c.txbytes)) + + c.txbytes = txbytes + c.rxbytes = rxbytes + } + } +} + +func (c *client) Close() { + c.cancel() + c.conn.Close() +} + +// channel represents a stream that is sent to the server +type channel struct { + pubsub srt.PubSub + collector session.Collector + path string + + publisher *client + subscriber map[string]*client + lock sync.RWMutex + + isProxy bool +} + +func newChannel(conn srt.Conn, resource string, isProxy bool, collector session.Collector) *channel { + ch := &channel{ + pubsub: srt.NewPubSub(srt.PubSubConfig{}), + path: resource, + publisher: newClient(conn, resource, collector), + subscriber: make(map[string]*client), + collector: collector, + isProxy: isProxy, + } + + addr := conn.RemoteAddr().String() + ip, _, _ := net.SplitHostPort(addr) + + if collector.IsCollectableIP(ip) { + collector.RegisterAndActivate(resource, resource, "publish:"+resource, addr) + } + + return ch +} + +func (ch *channel) Close() { + if ch.publisher == nil { + return + } + + ch.publisher.Close() + ch.publisher = nil +} + +func (ch *channel) AddSubscriber(conn srt.Conn, resource string) string { + addr := conn.RemoteAddr().String() + ip, _, _ := net.SplitHostPort(addr) + + client := newClient(conn, addr, ch.collector) + + if ch.collector.IsCollectableIP(ip) { + ch.collector.RegisterAndActivate(addr, resource, "play:"+resource, addr) + } + + ch.lock.Lock() + ch.subscriber[addr] = client + ch.lock.Unlock() + + return addr +} + +func (ch *channel) RemoveSubscriber(id string) { + ch.lock.Lock() + defer ch.lock.Unlock() + + client := ch.subscriber[id] + if client != nil { + delete(ch.subscriber, id) + client.Close() + } + + // If this is a proxied channel and the last subscriber leaves, + // close the channel. + if len(ch.subscriber) == 0 && ch.isProxy { + ch.Close() + } +} diff --git a/srt/srt.go b/srt/srt.go index e3620e5a..e1c2cbb3 100644 --- a/srt/srt.go +++ b/srt/srt.go @@ -5,7 +5,6 @@ import ( "context" "fmt" "net" - "regexp" "strings" "sync" "time" @@ -13,6 +12,7 @@ import ( "github.com/datarhei/core/v16/iam" "github.com/datarhei/core/v16/log" "github.com/datarhei/core/v16/session" + "github.com/datarhei/core/v16/srt/url" srt "github.com/datarhei/gosrt" ) @@ -20,132 +20,6 @@ import ( // has been closed regularly with the Close() function. var ErrServerClosed = srt.ErrServerClosed -type client struct { - conn srt.Conn - id string - createdAt time.Time - - txbytes uint64 - rxbytes uint64 - - collector session.Collector - - cancel context.CancelFunc -} - -func newClient(conn srt.Conn, id string, collector session.Collector) *client { - c := &client{ - conn: conn, - id: id, - createdAt: time.Now(), - - collector: collector, - } - - var ctx context.Context - ctx, c.cancel = context.WithCancel(context.Background()) - - go c.ticker(ctx) - - return c -} - -func (c *client) ticker(ctx context.Context) { - ticker := time.NewTicker(1 * time.Second) - defer ticker.Stop() - - stats := &srt.Statistics{} - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - c.conn.Stats(stats) - - rxbytes := stats.Accumulated.ByteRecv - txbytes := stats.Accumulated.ByteSent - - c.collector.Ingress(c.id, int64(rxbytes-c.rxbytes)) - c.collector.Egress(c.id, int64(txbytes-c.txbytes)) - - c.txbytes = txbytes - c.rxbytes = rxbytes - } - } -} - -func (c *client) Close() { - c.cancel() -} - -// channel represents a stream that is sent to the server -type channel struct { - pubsub srt.PubSub - collector session.Collector - path string - - publisher *client - subscriber map[string]*client - lock sync.RWMutex -} - -func newChannel(conn srt.Conn, resource string, collector session.Collector) *channel { - ch := &channel{ - pubsub: srt.NewPubSub(srt.PubSubConfig{}), - path: resource, - publisher: newClient(conn, resource, collector), - subscriber: make(map[string]*client), - collector: collector, - } - - addr := conn.RemoteAddr().String() - ip, _, _ := net.SplitHostPort(addr) - - if collector.IsCollectableIP(ip) { - collector.RegisterAndActivate(resource, resource, "publish:"+resource, addr) - } - - return ch -} - -func (ch *channel) Close() { - if ch.publisher == nil { - return - } - - ch.publisher.Close() - ch.publisher = nil -} - -func (ch *channel) AddSubscriber(conn srt.Conn, resource string) string { - addr := conn.RemoteAddr().String() - ip, _, _ := net.SplitHostPort(addr) - - client := newClient(conn, addr, ch.collector) - - if ch.collector.IsCollectableIP(ip) { - ch.collector.RegisterAndActivate(addr, resource, "play:"+resource, addr) - } - - ch.lock.Lock() - ch.subscriber[addr] = client - ch.lock.Unlock() - - return addr -} - -func (ch *channel) RemoveSubscriber(id string) { - ch.lock.Lock() - defer ch.lock.Unlock() - - client := ch.subscriber[id] - if client != nil { - delete(ch.subscriber, id) - client.Close() - } -} - // Config for a new SRT server type Config struct { // The address the SRT server should listen on, e.g. ":1935" @@ -369,165 +243,64 @@ func (s *server) log(handler, action, resource, message string, client net.Addr) }).Log(message) } -type streamInfo struct { - mode string - resource string - token string -} - -func parseStreamId(streamid string) (streamInfo, error) { - si := streamInfo{} - - if strings.HasPrefix(streamid, "#!:") { - return parseOldStreamId(streamid) - } - - re := regexp.MustCompile(`,(token|mode):(.+)`) - - results := map[string]string{} - - idEnd := -1 - value := streamid - key := "" - - for { - matches := re.FindStringSubmatchIndex(value) - if matches == nil { - break - } - - if idEnd < 0 { - idEnd = matches[2] - 1 - } - - if len(key) != 0 { - results[key] = value[:matches[2]-1] - } - - key = value[matches[2]:matches[3]] - value = value[matches[4]:matches[5]] - - results[key] = value - } - - if idEnd < 0 { - idEnd = len(streamid) - } - - si.resource = streamid[:idEnd] - if token, ok := results["token"]; ok { - si.token = token - } - - if mode, ok := results["mode"]; ok { - si.mode = mode - } else { - si.mode = "request" - } - - return si, nil -} - -func parseOldStreamId(streamid string) (streamInfo, error) { - si := streamInfo{} - - if !strings.HasPrefix(streamid, "#!:") { - return si, fmt.Errorf("unknown streamid format") - } - - streamid = strings.TrimPrefix(streamid, "#!:") - - kvs := strings.Split(streamid, ",") - - splitFn := func(s, sep string) (string, string, error) { - splitted := strings.SplitN(s, sep, 2) - - if len(splitted) != 2 { - return "", "", fmt.Errorf("invalid key/value pair") - } - - return splitted[0], splitted[1], nil - } - - for _, kv := range kvs { - key, value, err := splitFn(kv, "=") - if err != nil { - continue - } - - switch key { - case "m": - si.mode = value - case "r": - si.resource = value - case "token": - si.token = value - default: - } - } - - return si, nil -} - func (s *server) handleConnect(req srt.ConnRequest) srt.ConnType { mode := srt.REJECT client := req.RemoteAddr() streamId := req.StreamId() - si, err := parseStreamId(streamId) + si, err := url.ParseStreamId(streamId) if err != nil { s.log("CONNECT", "INVALID", "", err.Error(), client) return srt.REJECT } - if len(si.resource) == 0 { + if len(si.Resource) == 0 { s.log("CONNECT", "INVALID", "", "stream resource not provided", client) return srt.REJECT } - if si.mode == "publish" { + if si.Mode == "publish" { mode = srt.PUBLISH - } else if si.mode == "request" { + } else if si.Mode == "request" { mode = srt.SUBSCRIBE } else { - s.log("CONNECT", "INVALID", si.resource, "invalid connection mode", client) + s.log("CONNECT", "INVALID", si.Resource, "invalid connection mode", client) return srt.REJECT } if len(s.passphrase) != 0 { if !req.IsEncrypted() { - s.log("CONNECT", "FORBIDDEN", si.resource, "connection has to be encrypted", client) + s.log("CONNECT", "FORBIDDEN", si.Resource, "connection has to be encrypted", client) return srt.REJECT } if err := req.SetPassphrase(s.passphrase); err != nil { - s.log("CONNECT", "FORBIDDEN", si.resource, err.Error(), client) + s.log("CONNECT", "FORBIDDEN", si.Resource, err.Error(), client) return srt.REJECT } } else { if req.IsEncrypted() { - s.log("CONNECT", "INVALID", si.resource, "connection must not be encrypted", client) + s.log("CONNECT", "INVALID", si.Resource, "connection must not be encrypted", client) return srt.REJECT } } - // Check the token - if len(s.token) != 0 && s.token != si.token { - s.log("CONNECT", "FORBIDDEN", si.resource, "invalid token ("+si.token+")", client) + identity, err := s.findIdentityFromToken(si.Token) + if err != nil { + s.logger.Debug().WithError(err).Log("invalid token") + s.log("PUBLISH", "FORBIDDEN", si.Resource, "invalid token", client) return srt.REJECT } - s.lock.RLock() - ch := s.channels[si.resource] - s.lock.RUnlock() - - if mode == srt.PUBLISH && ch != nil { - s.log("CONNECT", "CONFLICT", si.resource, "already publishing", client) - return srt.REJECT + domain := s.findDomainFromPlaypath(si.Resource) + resource := "srt:" + si.Resource + action := "PLAY" + if mode == srt.PUBLISH { + action = "PUBLISH" } - if mode == srt.SUBSCRIBE && ch == nil { - s.log("CONNECT", "NOTFOUND", si.resource, "no publisher for this resource found", client) + if !s.iam.Enforce(identity, domain, resource, action) { + s.log("PUBLISH", "FORBIDDEN", si.Resource, "access denied", client) return srt.REJECT } @@ -538,61 +311,36 @@ func (s *server) handlePublish(conn srt.Conn) { streamId := conn.StreamId() client := conn.RemoteAddr() - si, _ := parseStreamId(streamId) - - identity, err := s.findIdentityFromToken(si.token) - if err != nil { - s.logger.Debug().WithError(err).Log("no valid identity found") - s.log("PUBLISH", "FORBIDDEN", si.resource, "invalid token", client) - return - } - - domain := s.findDomainFromPlaypath(si.resource) - resource := "srt:" + si.resource - - l := s.logger.Debug().WithFields(log.Fields{ - "name": identity.Name(), - "domain": domain, - "resource": resource, - "action": "PUBLISH", - }) - - if ok, rule := s.iam.Enforce(identity.Name(), domain, resource, "PUBLISH"); !ok { - l.Log("access denied") - s.log("PUBLISH", "FORBIDDEN", si.resource, "invalid token", client) - return - } else { - l.Log(rule) - } + si, _ := url.ParseStreamId(streamId) // Look for the stream s.lock.Lock() - ch := s.channels[si.resource] + ch := s.channels[si.Resource] if ch == nil { - ch = newChannel(conn, si.resource, s.collector) - s.channels[si.resource] = ch + ch = newChannel(conn, si.Resource, false, s.collector) + s.channels[si.Resource] = ch } else { ch = nil } s.lock.Unlock() if ch == nil { - s.log("PUBLISH", "CONFLICT", si.resource, "already publishing", client) + s.log("PUBLISH", "CONFLICT", si.Resource, "already publishing", client) conn.Close() return } - s.log("PUBLISH", "START", si.resource, "", client) + s.log("PUBLISH", "START", si.Resource, "", client) ch.pubsub.Publish(conn) s.lock.Lock() - delete(s.channels, si.resource) + delete(s.channels, si.Resource) s.lock.Unlock() ch.Close() - s.log("PUBLISH", "STOP", si.resource, "", client) + s.log("PUBLISH", "STOP", si.Resource, "", client) conn.Close() } @@ -601,83 +349,66 @@ func (s *server) handleSubscribe(conn srt.Conn) { streamId := conn.StreamId() client := conn.RemoteAddr() - si, _ := parseStreamId(streamId) - - identity, err := s.findIdentityFromToken(si.token) - if err != nil { - s.logger.Debug().WithError(err).Log("no valid identity found") - s.log("SUBSCRIBE", "FORBIDDEN", si.resource, "invalid token", client) - return - } - - domain := s.findDomainFromPlaypath(si.resource) - resource := "srt:" + si.resource - - l := s.logger.Debug().WithFields(log.Fields{ - "name": identity.Name(), - "domain": domain, - "resource": resource, - "action": "PLAY", - }) - - if ok, rule := s.iam.Enforce(identity.Name(), domain, resource, "PLAY"); !ok { - l.Log("access denied") - s.log("SUBSCRIBE", "FORBIDDEN", si.resource, "invalid token", client) - return - } else { - l.Log(rule) - } + si, _ := url.ParseStreamId(streamId) // Look for the stream s.lock.RLock() - ch := s.channels[si.resource] + ch := s.channels[si.Resource] s.lock.RUnlock() if ch == nil { - s.log("SUBSCRIBE", "NOTFOUND", si.resource, "no publisher for this resource found", client) + s.log("SUBSCRIBE", "NOTFOUND", si.Resource, "no publisher for this resource found", client) conn.Close() return } - s.log("SUBSCRIBE", "START", si.resource, "", client) + s.log("SUBSCRIBE", "START", si.Resource, "", client) - id := ch.AddSubscriber(conn, si.resource) + id := ch.AddSubscriber(conn, si.Resource) ch.pubsub.Subscribe(conn) - s.log("SUBSCRIBE", "STOP", si.resource, "", client) + s.log("SUBSCRIBE", "STOP", si.Resource, "", client) ch.RemoveSubscriber(id) conn.Close() } -func (s *server) findIdentityFromToken(key string) (iam.IdentityVerifier, error) { +func (s *server) findIdentityFromToken(key string) (string, error) { + if len(key) == 0 { + return "$anon", nil + } + var identity iam.IdentityVerifier var err error + var token string + elements := strings.Split(key, ":") if len(elements) == 1 { identity, err = s.iam.GetDefaultIdentity() + token = elements[0] } else { identity, err = s.iam.GetIdentity(elements[0]) + token = elements[1] } if err != nil { - return nil, fmt.Errorf("invalid token: %w", err) + return "$anon", nil } - if ok, err := identity.VerifyServiceToken(elements[1]); !ok { - return nil, fmt.Errorf("invalid token: %w", err) + if ok, err := identity.VerifyServiceToken(token); !ok { + return "$anon", fmt.Errorf("invalid token: %w", err) } - return identity, nil + return identity.Name(), nil } func (s *server) findDomainFromPlaypath(path string) string { elements := strings.Split(path, "/") if len(elements) == 1 { - return "" + return "$none" } domain := elements[0] @@ -686,5 +417,5 @@ func (s *server) findDomainFromPlaypath(path string) string { return domain } - return "" + return "$none" } diff --git a/srt/srt_test.go b/srt/srt_test.go deleted file mode 100644 index 91ae7ed1..00000000 --- a/srt/srt_test.go +++ /dev/null @@ -1,44 +0,0 @@ -package srt - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestParseStreamId(t *testing.T) { - streamids := map[string]streamInfo{ - "bla": {resource: "bla", mode: "request"}, - "bla,mode:publish": {resource: "bla", mode: "publish"}, - "123456789": {resource: "123456789", mode: "request"}, - "bla,token:foobar": {resource: "bla", token: "foobar", mode: "request"}, - "bla,token:foo,bar": {resource: "bla", token: "foo,bar", mode: "request"}, - "123456789,mode:publish,token:foobar": {resource: "123456789", token: "foobar", mode: "publish"}, - "mode:publish": {resource: "mode:publish", mode: "request"}, - } - - for streamid, wantsi := range streamids { - si, err := parseStreamId(streamid) - - require.NoError(t, err) - require.Equal(t, wantsi, si) - } -} - -func TestParseOldStreamId(t *testing.T) { - streamids := map[string]streamInfo{ - "#!:": {}, - "#!:key=value": {}, - "#!:m=publish": {mode: "publish"}, - "#!:r=123456789": {resource: "123456789"}, - "#!:token=foobar": {token: "foobar"}, - "#!:token=foo,bar": {token: "foo"}, - "#!:m=publish,r=123456789,token=foobar": {mode: "publish", resource: "123456789", token: "foobar"}, - } - - for streamid, wantsi := range streamids { - si, _ := parseOldStreamId(streamid) - - require.Equal(t, wantsi, si) - } -} diff --git a/srt/url/url.go b/srt/url/url.go new file mode 100644 index 00000000..0ec890b1 --- /dev/null +++ b/srt/url/url.go @@ -0,0 +1,206 @@ +package url + +import ( + "fmt" + neturl "net/url" + "regexp" + "strings" +) + +type URL struct { + Scheme string + Host string + StreamId string + Options neturl.Values +} + +func Parse(srturl string) (*URL, error) { + u, err := neturl.Parse(srturl) + if err != nil { + return nil, err + } + + if u.Scheme != "srt" { + return nil, fmt.Errorf("invalid SRT url") + } + + options := u.Query() + streamid := options.Get("streamid") + options.Del("streamid") + + su := &URL{ + Scheme: "srt", + Host: u.Host, + StreamId: streamid, + Options: options, + } + + return su, nil +} + +func (su *URL) String() string { + options, _ := neturl.ParseQuery(su.Options.Encode()) + options.Set("streamid", su.StreamId) + + u := neturl.URL{ + Scheme: su.Scheme, + Host: su.Host, + RawQuery: options.Encode(), + } + + return u.String() +} + +func (su *URL) StreamInfo() (*StreamInfo, error) { + s, err := ParseStreamId(su.StreamId) + if err != nil { + return nil, err + } + + return &s, nil +} + +func (su *URL) SetStreamInfo(si *StreamInfo) { + su.StreamId = si.String() +} + +func (su *URL) Hostname() string { + u := neturl.URL{ + Host: su.Host, + } + + return u.Hostname() +} + +func (su *URL) Port() string { + u := neturl.URL{ + Host: su.Host, + } + + return u.Port() +} + +type StreamInfo struct { + Mode string + Resource string + Token string +} + +func (si *StreamInfo) String() string { + streamid := si.Resource + + if si.Mode != "request" { + streamid += ",mode:" + si.Mode + } + + if len(si.Token) != 0 { + streamid += ",token:" + si.Token + } + + return streamid +} + +// ParseStreamId parses a streamid. If the streamid is in the old format +// it is detected and parsed accordingly. Otherwith the new simplified +// format will be assumed. +// +// resource[,token:{token}]?[,mode:(publish|*request)]? +// +// If the mode is not provided, "request" will be assumed. +func ParseStreamId(streamid string) (StreamInfo, error) { + si := StreamInfo{} + + if strings.HasPrefix(streamid, "#!:") { + return ParseDeprecatedStreamId(streamid) + } + + re := regexp.MustCompile(`,(token|mode):(.+)`) + + results := map[string]string{} + + idEnd := -1 + value := streamid + key := "" + + for { + matches := re.FindStringSubmatchIndex(value) + if matches == nil { + break + } + + if idEnd < 0 { + idEnd = matches[2] - 1 + } + + if len(key) != 0 { + results[key] = value[:matches[2]-1] + } + + key = value[matches[2]:matches[3]] + value = value[matches[4]:matches[5]] + + results[key] = value + } + + if idEnd < 0 { + idEnd = len(streamid) + } + + si.Resource = streamid[:idEnd] + if token, ok := results["token"]; ok { + si.Token = token + } + + if mode, ok := results["mode"]; ok { + si.Mode = mode + } else { + si.Mode = "request" + } + + return si, nil +} + +// ParseDeprecatedStreamId parses a streamid in the old format. The old format +// is based on the recommendation of the SRT specs, but with the special +// character it contains it can cause some trouble in clients (e.g. kiloview +// doesn't like the = character). +func ParseDeprecatedStreamId(streamid string) (StreamInfo, error) { + si := StreamInfo{Mode: "request"} + + if !strings.HasPrefix(streamid, "#!:") { + return si, fmt.Errorf("unknown streamid format") + } + + streamid = strings.TrimPrefix(streamid, "#!:") + + kvs := strings.Split(streamid, ",") + + split := func(s, sep string) (string, string, error) { + splitted := strings.SplitN(s, sep, 2) + + if len(splitted) != 2 { + return "", "", fmt.Errorf("invalid key/value pair") + } + + return splitted[0], splitted[1], nil + } + + for _, kv := range kvs { + key, value, err := split(kv, "=") + if err != nil { + continue + } + + switch key { + case "m": + si.Mode = value + case "r": + si.Resource = value + case "token": + si.Token = value + default: + } + } + + return si, nil +} diff --git a/srt/url/url_test.go b/srt/url/url_test.go new file mode 100644 index 00000000..4118b638 --- /dev/null +++ b/srt/url/url_test.go @@ -0,0 +1,67 @@ +package url + +import ( + "net/url" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParse(t *testing.T) { + srturl := "srt://127.0.0.1:6000?mode=caller&passphrase=foobar&streamid=" + url.QueryEscape("#!:m=publish,r=123456,token=bla") + + u, err := Parse(srturl) + require.NoError(t, err) + + require.Equal(t, "srt", u.Scheme) + require.Equal(t, "127.0.0.1:6000", u.Host) + require.Equal(t, "#!:m=publish,r=123456,token=bla", u.StreamId) + + si, err := u.StreamInfo() + require.NoError(t, err) + require.Equal(t, "publish", si.Mode) + require.Equal(t, "123456", si.Resource) + require.Equal(t, "bla", si.Token) + + require.Equal(t, srturl, u.String()) + + srturl = "srt://127.0.0.1:6000?mode=caller&passphrase=foobar&streamid=" + url.QueryEscape("123456,mode:publish,token:bla") + + u, err = Parse(srturl) + require.NoError(t, err) + + require.Equal(t, "srt", u.Scheme) + require.Equal(t, "127.0.0.1:6000", u.Host) + require.Equal(t, "123456,mode:publish,token:bla", u.StreamId) + + si, err = u.StreamInfo() + require.NoError(t, err) + require.Equal(t, "publish", si.Mode) + require.Equal(t, "123456", si.Resource) + require.Equal(t, "bla", si.Token) + + require.Equal(t, srturl, u.String()) +} + +func TestParseStreamId(t *testing.T) { + streamids := map[string]StreamInfo{ + "": {Mode: "request"}, + "bla": {Mode: "request", Resource: "bla"}, + "bla,token=foobar": {Mode: "request", Resource: "bla,token=foobar"}, + "bla,token:foobar": {Mode: "request", Resource: "bla", Token: "foobar"}, + "bla,token:foobar,mode:publish": {Mode: "publish", Resource: "bla", Token: "foobar"}, + "#!:": {Mode: "request"}, + "#!:key=value": {Mode: "request"}, + "#!:m=publish": {Mode: "publish"}, + "#!:r=123456789": {Mode: "request", Resource: "123456789"}, + "#!:token=foobar": {Mode: "request", Token: "foobar"}, + "#!:token=foo,bar": {Mode: "request", Token: "foo"}, + "#!:m=publish,r=123456789,token=foobar": {Mode: "publish", Resource: "123456789", Token: "foobar"}, + } + + for streamid, wantsi := range streamids { + si, err := ParseStreamId(streamid) + require.NoError(t, err) + require.Equal(t, wantsi, si, streamid) + } +}