diff --git a/app/api/api.go b/app/api/api.go index 8fc0cfa6..09d09f9d 100644 --- a/app/api/api.go +++ b/app/api/api.go @@ -25,8 +25,8 @@ import ( "github.com/datarhei/core/v16/http" "github.com/datarhei/core/v16/http/cache" httpfs "github.com/datarhei/core/v16/http/fs" - "github.com/datarhei/core/v16/http/jwt" "github.com/datarhei/core/v16/http/router" + "github.com/datarhei/core/v16/iam" "github.com/datarhei/core/v16/io/fs" "github.com/datarhei/core/v16/log" "github.com/datarhei/core/v16/math/rand" @@ -36,7 +36,9 @@ import ( "github.com/datarhei/core/v16/restream" restreamapp "github.com/datarhei/core/v16/restream/app" "github.com/datarhei/core/v16/restream/replace" + "github.com/datarhei/core/v16/restream/rewrite" restreamstore "github.com/datarhei/core/v16/restream/store" + restreamjsonstore "github.com/datarhei/core/v16/restream/store/json" "github.com/datarhei/core/v16/rtmp" "github.com/datarhei/core/v16/service" "github.com/datarhei/core/v16/session" @@ -83,10 +85,10 @@ type api struct { cache cache.Cacher mainserver *gohttp.Server sidecarserver *gohttp.Server - httpjwt jwt.JWT update update.Checker replacer replace.Replacer cluster cluster.Cluster + iam iam.IAM errorChan chan error @@ -387,6 +389,166 @@ func (a *api) start() error { a.sessions = sessions } + { + superuser := iam.User{ + Name: cfg.API.Auth.Username, + Superuser: true, + Auth: iam.UserAuth{ + API: iam.UserAuthAPI{ + Auth0: iam.UserAuthAPIAuth0{}, + }, + Services: iam.UserAuthServices{ + Token: []string{ + cfg.RTMP.Token, + cfg.SRT.Token, + }, + }, + }, + } + + if cfg.API.Auth.Enable { + superuser.Auth.API.Password = cfg.API.Auth.Password + } + + if cfg.API.Auth.Auth0.Enable { + superuser.Auth.API.Auth0.User = cfg.API.Auth.Auth0.Tenants[0].Users[0] + superuser.Auth.API.Auth0.Tenant = iam.Auth0Tenant{ + Domain: cfg.API.Auth.Auth0.Tenants[0].Domain, + Audience: cfg.API.Auth.Auth0.Tenants[0].Audience, + ClientID: cfg.API.Auth.Auth0.Tenants[0].ClientID, + } + } + + fs, err := fs.NewRootedDiskFilesystem(fs.RootedDiskConfig{ + Root: filepath.Join(cfg.DB.Dir, "iam"), + }) + if err != nil { + return err + } + + secret := rand.String(32) + if len(cfg.API.Auth.JWT.Secret) != 0 { + secret = cfg.API.Auth.Username + cfg.API.Auth.Password + cfg.API.Auth.JWT.Secret + } + + manager, err := iam.NewIAM(iam.Config{ + FS: fs, + Superuser: superuser, + JWTRealm: "datarhei-core", + JWTSecret: secret, + Logger: a.log.logger.core.WithComponent("IAM"), + }) + if err != nil { + return fmt.Errorf("iam: %w", err) + } + + // Check if there are already file created by IAM. If not, create policies + // and users based on the config in order to mimic the behaviour before IAM. + if len(fs.List("/", "/*.json")) == 0 { + policies := []iam.Policy{ + { + Name: "$anon", + Domain: "$none", + Resource: "fs:/**", + Actions: []string{"GET", "HEAD", "OPTIONS"}, + }, + { + Name: "$anon", + Domain: "$none", + Resource: "api:/api", + Actions: []string{"GET", "HEAD", "OPTIONS"}, + }, + { + Name: "$anon", + Domain: "$none", + Resource: "api:/api/v3/widget/process/**", + Actions: []string{"GET", "HEAD", "OPTIONS"}, + }, + } + + users := map[string]iam.User{} + + if cfg.Storage.Memory.Auth.Enable && cfg.Storage.Memory.Auth.Username != superuser.Name { + users[cfg.Storage.Memory.Auth.Username] = iam.User{ + Name: cfg.Storage.Memory.Auth.Username, + Auth: iam.UserAuth{ + Services: iam.UserAuthServices{ + Basic: []string{cfg.Storage.Memory.Auth.Password}, + }, + }, + } + + policies = append(policies, iam.Policy{ + Name: cfg.Storage.Memory.Auth.Username, + Domain: "$none", + Resource: "fs:/memfs/**", + Actions: []string{"ANY"}, + }) + } + + for _, s := range cfg.Storage.S3 { + if s.Auth.Enable && s.Auth.Username != superuser.Name { + user, ok := users[s.Auth.Username] + if !ok { + users[s.Auth.Username] = iam.User{ + Name: s.Auth.Username, + Auth: iam.UserAuth{ + Services: iam.UserAuthServices{ + Basic: []string{s.Auth.Password}, + }, + }, + } + } else { + user.Auth.Services.Basic = append(user.Auth.Services.Basic, s.Auth.Password) + users[s.Auth.Username] = user + } + + policies = append(policies, iam.Policy{ + Name: s.Auth.Username, + Domain: "$none", + Resource: "fs:" + s.Mountpoint + "/**", + Actions: []string{"ANY"}, + }) + } + } + + if cfg.RTMP.Enable && len(cfg.RTMP.Token) == 0 { + policies = append(policies, iam.Policy{ + Name: "$anon", + Domain: "$none", + Resource: "rtmp:/**", + Actions: []string{"ANY"}, + }) + } + + if cfg.SRT.Enable && len(cfg.SRT.Token) == 0 { + policies = append(policies, iam.Policy{ + Name: "$anon", + Domain: "$none", + Resource: "srt:**", + Actions: []string{"ANY"}, + }) + } + + for _, user := range users { + if _, err := manager.GetIdentity(user.Name); err == nil { + continue + } + + err := manager.CreateIdentity(user) + if err != nil { + return fmt.Errorf("iam: %w", err) + } + } + + for _, policy := range policies { + manager.AddPolicy(policy.Name, policy.Domain, policy.Resource, policy.Actions) + } + } + + a.iam = manager + } + diskfs, err := fs.NewRootedDiskFilesystem(fs.RootedDiskConfig{ Root: cfg.Storage.Disk.Dir, Logger: a.log.logger.core.WithComponent("DiskFS"), @@ -511,6 +673,35 @@ func (a *api) start() error { a.ffmpeg = ffmpeg + var rw rewrite.Rewriter + + { + baseAddress := func(address string) string { + var base string + host, port, _ := gonet.SplitHostPort(address) + if len(host) == 0 { + base = "localhost:" + port + } else { + base = address + } + + return base + } + + httpBase := baseAddress(cfg.Address) + rtmpBase := baseAddress(cfg.RTMP.Address) + cfg.RTMP.App + srtBase := baseAddress(cfg.SRT.Address) + + rw, err = rewrite.New(rewrite.Config{ + HTTPBase: "http://" + httpBase, + RTMPBase: "rtmp://" + rtmpBase, + SRTBase: "srt://" + srtBase, + }) + if err != nil { + return fmt.Errorf("unable to create url rewriter: %w", err) + } + } + a.replacer = replace.New() { @@ -546,8 +737,16 @@ func (a *api) start() error { } template += "/{name}" - if len(cfg.RTMP.Token) != 0 { - template += "?token=" + cfg.RTMP.Token + var identity iam.IdentityVerifier = nil + + if len(config.Owner) == 0 { + identity = a.iam.GetDefaultVerifier() + } else { + identity, _ = a.iam.GetVerifier(config.Owner) + } + + if identity != nil { + template += "/" + identity.GetServiceToken() } return template @@ -562,14 +761,22 @@ func (a *api) start() error { template := "srt://" + host + ":" + port + "?mode=caller&transtype=live&latency={latency}&streamid={name}" if section == "output" { template += ",mode:publish" + } + + var identity iam.IdentityVerifier = nil + + if len(config.Owner) == 0 { + identity = a.iam.GetDefaultVerifier() } else { - template += ",mode:request" + identity, _ = a.iam.GetVerifier(config.Owner) } - if len(cfg.SRT.Token) != 0 { - template += ",token:" + cfg.SRT.Token + + if identity != nil { + template += ",token:" + identity.GetServiceToken() } + if len(cfg.SRT.Passphrase) != 0 { - template += "&passphrase=" + cfg.SRT.Passphrase + template += "&passphrase=" + url.QueryEscape(cfg.SRT.Passphrase) } return template @@ -596,7 +803,7 @@ func (a *api) start() error { if err != nil { return err } - store, err = restreamstore.NewJSON(restreamstore.JSONConfig{ + store, err = restreamjsonstore.New(restreamjsonstore.Config{ Filesystem: fs, Filepath: "/db.json", Logger: a.log.logger.core.WithComponent("ProcessStore"), @@ -612,8 +819,10 @@ func (a *api) start() error { Store: store, Filesystems: filesystems, Replace: a.replacer, + Rewrite: rw, FFmpeg: a.ffmpeg, MaxProcesses: cfg.FFmpeg.MaxProcesses, + IAM: a.iam, Logger: a.log.logger.core.WithComponent("Process"), }) @@ -685,48 +894,6 @@ func (a *api) start() error { a.cluster = cluster } - var httpjwt jwt.JWT - - if cfg.API.Auth.Enable { - secret := rand.String(32) - if len(cfg.API.Auth.JWT.Secret) != 0 { - secret = cfg.API.Auth.Username + cfg.API.Auth.Password + cfg.API.Auth.JWT.Secret - } - - var err error - httpjwt, err = jwt.New(jwt.Config{ - Realm: app.Name, - Secret: secret, - SkipLocalhost: cfg.API.Auth.DisableLocalhost, - }) - - if err != nil { - return fmt.Errorf("unable to create JWT provider: %w", err) - } - - if validator, err := jwt.NewLocalValidator(cfg.API.Auth.Username, cfg.API.Auth.Password); err == nil { - if err := httpjwt.AddValidator(app.Name, validator); err != nil { - return fmt.Errorf("unable to add local JWT validator: %w", err) - } - } else { - return fmt.Errorf("unable to create local JWT validator: %w", err) - } - - if cfg.API.Auth.Auth0.Enable { - for _, t := range cfg.API.Auth.Auth0.Tenants { - if validator, err := jwt.NewAuth0Validator(t.Domain, t.Audience, t.ClientID, t.Users); err == nil { - if err := httpjwt.AddValidator("https://"+t.Domain+"/", validator); err != nil { - return fmt.Errorf("unable to add Auth0 JWT validator: %w", err) - } - } else { - return fmt.Errorf("unable to create Auth0 JWT validator: %w", err) - } - } - } - } - - a.httpjwt = httpjwt - metrics, err := monitor.NewHistory(monitor.HistoryConfig{ Enable: cfg.Metrics.Enable, Timerange: time.Duration(cfg.Metrics.Range) * time.Second, @@ -948,6 +1115,7 @@ func (a *api) start() error { Token: cfg.RTMP.Token, Logger: a.log.logger.rtmp, Collector: a.sessions.Collector("rtmp"), + IAM: a.iam, } if a.cluster != nil { @@ -980,6 +1148,7 @@ func (a *api) start() error { Token: cfg.SRT.Token, Logger: a.log.logger.core.WithComponent("SRT").WithField("address", cfg.SRT.Address), Collector: a.sessions.Collector("srt"), + IAM: a.iam, } if a.cluster != nil { @@ -1089,12 +1258,28 @@ func (a *api) start() error { }, RTMP: a.rtmpserver, SRT: a.srtserver, - JWT: a.httpjwt, Config: a.config.store, Sessions: a.sessions, Router: router, ReadOnly: cfg.API.ReadOnly, Cluster: a.cluster, + IAM: a.iam, + IAMSkipper: func(ip string) bool { + if !cfg.API.Auth.Enable { + return true + } else { + isLocalhost := false + if ip == "127.0.0.1" || ip == "::1" { + isLocalhost = true + } + + if isLocalhost && cfg.API.Auth.DisableLocalhost { + return true + } + } + + return false + }, } mainserverhandler, err := http.NewServer(serverConfig) @@ -1379,9 +1564,8 @@ func (a *api) stop() { a.cluster.Shutdown() } - // Stop JWT authentication - if a.httpjwt != nil { - a.httpjwt.ClearValidators() + if a.iam != nil { + a.iam.Close() } if a.update != nil { diff --git a/app/ffmigrate/main.go b/app/ffmigrate/main.go index 5f3b4996..4d959b6f 100644 --- a/app/ffmigrate/main.go +++ b/app/ffmigrate/main.go @@ -11,7 +11,7 @@ import ( "github.com/datarhei/core/v16/io/file" "github.com/datarhei/core/v16/io/fs" "github.com/datarhei/core/v16/log" - "github.com/datarhei/core/v16/restream/store" + jsonstore "github.com/datarhei/core/v16/restream/store/json" "github.com/Masterminds/semver/v3" _ "github.com/joho/godotenv/autoload" @@ -120,7 +120,7 @@ func doMigration(logger log.Logger, configstore cfgstore.Store) error { logger.Info().WithField("backup", backupFilepath).Log("Backup created") // Load the existing DB - datastore, err := store.NewJSON(store.JSONConfig{ + datastore, err := jsonstore.New(jsonstore.Config{ Filepath: cfg.DB.Dir + "/db.json", }) if err != nil { @@ -135,31 +135,33 @@ func doMigration(logger log.Logger, configstore cfgstore.Store) error { logger.Info().Log("Migrating processes ...") - // Migrate the processes to version 5 + // Migrate the processes to FFmpeg version 5 // Only this happens: // - for RTSP inputs, replace -stimeout with -timeout reRTSP := regexp.MustCompile(`^rtsps?://`) - for id, p := range data.Process { - logger.Info().WithField("processid", p.ID).Log("") + for name, domain := range data.Process { + for id, p := range domain { + logger.Info().WithField("processid", p.Process.ID).Log("") - for index, input := range p.Config.Input { - if !reRTSP.MatchString(input.Address) { - continue - } - - for i, o := range input.Options { - if o != "-stimeout" { + for index, input := range p.Process.Config.Input { + if !reRTSP.MatchString(input.Address) { continue } - input.Options[i] = "-timeout" - } + for i, o := range input.Options { + if o != "-stimeout" { + continue + } - p.Config.Input[index] = input + input.Options[i] = "-timeout" + } + + p.Process.Config.Input[index] = input + } + p.Process.Config.FFVersion = version.String() + data.Process[name][id] = p } - p.Config.FFVersion = version.String() - data.Process[id] = p } logger.Info().Log("Migrating processes done") diff --git a/app/import/import.go b/app/import/import.go index 5899c350..e72975ce 100644 --- a/app/import/import.go +++ b/app/import/import.go @@ -17,10 +17,12 @@ import ( "github.com/datarhei/core/v16/encoding/json" "github.com/datarhei/core/v16/ffmpeg" "github.com/datarhei/core/v16/ffmpeg/skills" + "github.com/datarhei/core/v16/iam" "github.com/datarhei/core/v16/io/fs" "github.com/datarhei/core/v16/restream" "github.com/datarhei/core/v16/restream/app" "github.com/datarhei/core/v16/restream/store" + jsonstore "github.com/datarhei/core/v16/restream/store/json" "github.com/google/uuid" ) @@ -496,12 +498,12 @@ type importConfigAudio struct { sampling string } -func importV1(fs fs.Filesystem, path string, cfg importConfig) (store.StoreData, error) { +func importV1(fs fs.Filesystem, path string, cfg importConfig) (store.Data, error) { if len(cfg.id) == 0 { cfg.id = uuid.New().String() } - r := store.NewStoreData() + r := store.NewData() jsondata, err := fs.ReadFile(path) if err != nil { @@ -1187,10 +1189,20 @@ func importV1(fs fs.Filesystem, path string, cfg importConfig) (store.StoreData, config.Output = append(config.Output, output) process.Config = config - r.Process[process.ID] = process - r.Metadata.Process["restreamer-ui:ingest:"+cfg.id] = make(map[string]interface{}) - r.Metadata.Process["restreamer-ui:ingest:"+cfg.id]["restreamer-ui"] = ui + p := store.Process{ + Process: process, + Metadata: map[string]interface{}{}, + } + + if metadata, err := gojson.Marshal(ui); err == nil { + m := map[string]interface{}{} + gojson.Unmarshal(metadata, &m) + p.Metadata["restreamer-ui"] = m + } + + r.Process[""] = map[string]store.Process{} + r.Process[""][process.ID] = p // Snapshot @@ -1240,9 +1252,13 @@ func importV1(fs fs.Filesystem, path string, cfg importConfig) (store.StoreData, snapshotConfig.Output = append(snapshotConfig.Output, snapshotOutput) snapshotProcess.Config = snapshotConfig - r.Process[snapshotProcess.ID] = snapshotProcess - r.Metadata.Process["restreamer-ui:ingest:"+cfg.id+"_snapshot"] = nil + p := store.Process{ + Process: snapshotProcess, + Metadata: nil, + } + + r.Process[""][snapshotProcess.ID] = p } // Optional publication @@ -1401,10 +1417,19 @@ func importV1(fs fs.Filesystem, path string, cfg importConfig) (store.StoreData, config.Output = append(config.Output, output) process.Config = config - r.Process[process.ID] = process - r.Metadata.Process[egressId] = make(map[string]interface{}) - r.Metadata.Process[egressId]["restreamer-ui"] = egress + p := store.Process{ + Process: process, + Metadata: map[string]interface{}{}, + } + + if metadata, err := gojson.Marshal(egress); err == nil { + m := map[string]interface{}{} + gojson.Unmarshal(metadata, &m) + p.Metadata["restreamer-ui"] = m + } + + r.Process[""][process.ID] = p } return r, nil @@ -1419,7 +1444,7 @@ func probeInput(binary string, config app.Config) app.Probe { } dummyfs, _ := fs.NewMemFilesystem(fs.MemConfig{}) - store, err := store.NewJSON(store.JSONConfig{ + store, err := jsonstore.New(jsonstore.Config{ Filesystem: dummyfs, Filepath: "/", Logger: nil, @@ -1428,17 +1453,32 @@ func probeInput(binary string, config app.Config) app.Probe { return app.Probe{} } + iam, _ := iam.NewIAM(iam.Config{ + FS: dummyfs, + Superuser: iam.User{ + Name: "foobar", + }, + JWTRealm: "", + JWTSecret: "", + Logger: nil, + }) + + iam.AddPolicy("$anon", "$none", "process:*", []string{"CREATE", "GET", "DELETE", "PROBE"}) + rs, err := restream.New(restream.Config{ FFmpeg: ffmpeg, Store: store, + IAM: iam, }) if err != nil { return app.Probe{} } rs.AddProcess(&config) - probe := rs.Probe(config.ID) - rs.DeleteProcess(config.ID) + + id := restream.TaskID{ID: config.ID} + probe := rs.Probe(id) + rs.DeleteProcess(id) return probe } diff --git a/app/import/import_test.go b/app/import/import_test.go index 8322c0eb..99957b85 100644 --- a/app/import/import_test.go +++ b/app/import/import_test.go @@ -1,13 +1,11 @@ package main import ( - gojson "encoding/json" "os" "testing" - "github.com/datarhei/core/v16/encoding/json" "github.com/datarhei/core/v16/io/fs" - "github.com/datarhei/core/v16/restream/store" + jsonstore "github.com/datarhei/core/v16/restream/store/json" "github.com/stretchr/testify/require" ) @@ -42,37 +40,28 @@ func testV1Import(t *testing.T, v1Fixture, v4Fixture string, config importConfig }) require.NoError(t, err) + store, err := jsonstore.New(jsonstore.Config{ + Filesystem: diskfs, + Filepath: v4Fixture, + }) + require.NoError(t, err) + // Import v1 database v4, err := importV1(diskfs, v1Fixture, config) - require.Equal(t, nil, err) + require.NoError(t, err) // Reset variants - for n := range v4.Process { - v4.Process[n].CreatedAt = 0 + for m, domain := range v4.Process { + for n := range domain { + v4.Process[m][n].Process.CreatedAt = 0 + } } - // Convert to JSON - datav4, err := gojson.MarshalIndent(&v4, "", " ") - require.Equal(t, nil, err) - // Read the wanted result - wantdatav4, err := diskfs.ReadFile(v4Fixture) - require.Equal(t, nil, err) + wantv4, err := store.Load() + require.NoError(t, err) - var wantv4 store.StoreData - - err = gojson.Unmarshal(wantdatav4, &wantv4) - require.Equal(t, nil, err, json.FormatError(wantdatav4, err)) - - // Convert to JSON - wantdatav4, err = gojson.MarshalIndent(&wantv4, "", " ") - require.Equal(t, nil, err) - - // Re-convert both to golang type - gojson.Unmarshal(wantdatav4, &wantv4) - gojson.Unmarshal(datav4, &v4) - - require.Equal(t, wantv4, v4) + require.Equal(t, wantv4, v4, v4Fixture) } func TestV1Import(t *testing.T) { diff --git a/app/import/main.go b/app/import/main.go index 2d641caf..71a99341 100644 --- a/app/import/main.go +++ b/app/import/main.go @@ -8,7 +8,7 @@ import ( cfgvars "github.com/datarhei/core/v16/config/vars" "github.com/datarhei/core/v16/io/fs" "github.com/datarhei/core/v16/log" - "github.com/datarhei/core/v16/restream/store" + jsonstore "github.com/datarhei/core/v16/restream/store/json" _ "github.com/joho/godotenv/autoload" ) @@ -87,7 +87,7 @@ func doImport(logger log.Logger, fs fs.Filesystem, configstore cfgstore.Store) e logger.Info().Log("Found database") // Load an existing DB - datastore, err := store.NewJSON(store.JSONConfig{ + datastore, err := jsonstore.New(jsonstore.Config{ Filesystem: fs, Filepath: cfg.DB.Dir + "/db.json", }) diff --git a/config/data.go b/config/data.go index 8899e418..a3770929 100644 --- a/config/data.go +++ b/config/data.go @@ -82,10 +82,10 @@ type Data struct { } `json:"disk"` Memory struct { Auth struct { - Enable bool `json:"enable"` - Username string `json:"username"` - Password string `json:"password"` - } `json:"auth"` + Enable bool `json:"enable"` // Deprecated, use IAM + Username string `json:"username"` // Deprecated, use IAM + Password string `json:"password"` // Deprecated, use IAM + } `json:"auth"` // Deprecated, use IAM Size int64 `json:"max_size_mbytes" format:"int64"` Purge bool `json:"purge"` } `json:"memory"` @@ -101,13 +101,13 @@ type Data struct { Address string `json:"address"` AddressTLS string `json:"address_tls"` App string `json:"app"` - Token string `json:"token"` + Token string `json:"token"` // Deprecated, use IAM } `json:"rtmp"` SRT struct { Enable bool `json:"enable"` Address string `json:"address"` Passphrase string `json:"passphrase"` - Token string `json:"token"` + Token string `json:"token"` // Deprecated, use IAM Log struct { Enable bool `json:"enable"` Topics []string `json:"topics"` diff --git a/config/value/s3.go b/config/value/s3.go index 7d80c193..52e5c934 100644 --- a/config/value/s3.go +++ b/config/value/s3.go @@ -13,7 +13,7 @@ import ( type S3Storage struct { Name string `json:"name"` Mountpoint string `json:"mountpoint"` - Auth S3StorageAuth `json:"auth"` + Auth S3StorageAuth `json:"auth"` // Deprecated, use IAM Endpoint string `json:"endpoint"` AccessKeyID string `json:"access_key_id"` SecretAccessKey string `json:"secret_access_key"` @@ -23,9 +23,9 @@ type S3Storage struct { } type S3StorageAuth struct { - Enable bool `json:"enable"` - Username string `json:"username"` - Password string `json:"password"` + Enable bool `json:"enable"` // Deprecated, use IAM + Username string `json:"username"` // Deprecated, use IAM + Password string `json:"password"` // Deprecated, use IAM } func (t *S3Storage) String() string { diff --git a/docs/docs.go b/docs/docs.go index 2015a527..7da578b0 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -109,87 +109,6 @@ const docTemplate = `{ } } }, - "/api/login": { - "post": { - "security": [ - { - "Auth0KeyAuth": [] - } - ], - "description": "Retrieve valid JWT access and refresh tokens to use for accessing the API. Login either by username/password or Auth0 token", - "produces": [ - "application/json" - ], - "summary": "Retrieve an access and a refresh token", - "operationId": "jwt-login", - "parameters": [ - { - "description": "Login data", - "name": "data", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/api.Login" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/api.JWT" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/api.Error" - } - }, - "403": { - "description": "Forbidden", - "schema": { - "$ref": "#/definitions/api.Error" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/api.Error" - } - } - } - } - }, - "/api/login/refresh": { - "get": { - "security": [ - { - "ApiRefreshKeyAuth": [] - } - ], - "description": "Retrieve a new access token by providing the refresh token", - "produces": [ - "application/json" - ], - "summary": "Retrieve a new access token", - "operationId": "jwt-refresh", - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/api.JWTRefresh" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/api.Error" - } - } - } - } - }, "/api/swagger": { "get": { "description": "Swagger UI for this API", @@ -913,6 +832,308 @@ const docTemplate = `{ } } }, + "/api/v3/iam/user": { + "post": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Add a new user", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "v16.?.?" + ], + "summary": "Add a new user", + "operationId": "iam-3-add-user", + "parameters": [ + { + "description": "User definition", + "name": "config", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/api.IAMUser" + } + }, + { + "type": "string", + "description": "Domain of the acting user", + "name": "domain", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/api.IAMUser" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/api.Error" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/api.Error" + } + } + } + } + }, + "/api/v3/iam/user/{name}": { + "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "List aa user by its name", + "produces": [ + "application/json" + ], + "tags": [ + "v16.?.?" + ], + "summary": "List an user by its name", + "operationId": "iam-3-get-user", + "parameters": [ + { + "type": "string", + "description": "Username", + "name": "name", + "in": "path", + "required": true + }, + { + "type": "string", + "description": "Domain of the acting user", + "name": "domain", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/api.IAMUser" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/api.Error" + } + } + } + }, + "put": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Replace an existing user.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "v16.?.?" + ], + "summary": "Replace an existing user", + "operationId": "iam-3-update-user", + "parameters": [ + { + "type": "string", + "description": "Username", + "name": "name", + "in": "path", + "required": true + }, + { + "type": "string", + "description": "Domain of the acting user", + "name": "domain", + "in": "query" + }, + { + "description": "User definition", + "name": "user", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/api.IAMUser" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/api.IAMUser" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/api.Error" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/api.Error" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/api.Error" + } + } + } + }, + "delete": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Delete an user by its name", + "produces": [ + "application/json" + ], + "tags": [ + "v16.?.?" + ], + "summary": "Delete an user by its name", + "operationId": "iam-3-delete-user", + "parameters": [ + { + "type": "string", + "description": "Username", + "name": "name", + "in": "path", + "required": true + }, + { + "type": "string", + "description": "Domain of the acting user", + "name": "domain", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "type": "string" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/api.Error" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/api.Error" + } + } + } + } + }, + "/api/v3/iam/user/{name}/policy": { + "put": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Replace policies of an user", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "v16.?.?" + ], + "summary": "Replace policies of an user", + "operationId": "iam-3-update-user", + "parameters": [ + { + "type": "string", + "description": "Username", + "name": "name", + "in": "path", + "required": true + }, + { + "type": "string", + "description": "Domain of the acting user", + "name": "domain", + "in": "query" + }, + { + "description": "Policy definitions", + "name": "user", + "in": "body", + "required": true, + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/api.IAMPolicy" + } + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/api.IAMPolicy" + } + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/api.Error" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/api.Error" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/api.Error" + } + } + } + } + }, "/api/v3/log": { "get": { "security": [ @@ -3246,22 +3467,115 @@ const docTemplate = `{ } } }, - "api.JWT": { + "api.IAMAuth0Tenant": { "type": "object", "properties": { - "access_token": { + "audience": { "type": "string" }, - "refresh_token": { + "client_id": { + "type": "string" + }, + "domain": { "type": "string" } } }, - "api.JWTRefresh": { + "api.IAMPolicy": { "type": "object", "properties": { - "access_token": { + "actions": { + "type": "array", + "items": { + "type": "string" + } + }, + "domain": { "type": "string" + }, + "resource": { + "type": "string" + } + } + }, + "api.IAMUser": { + "type": "object", + "properties": { + "auth": { + "$ref": "#/definitions/api.IAMUserAuth" + }, + "name": { + "type": "string" + }, + "policies": { + "type": "array", + "items": { + "$ref": "#/definitions/api.IAMPolicy" + } + }, + "superuser": { + "type": "boolean" + } + } + }, + "api.IAMUserAuth": { + "type": "object", + "properties": { + "api": { + "$ref": "#/definitions/api.IAMUserAuthAPI" + }, + "services": { + "$ref": "#/definitions/api.IAMUserAuthServices" + } + } + }, + "api.IAMUserAuthAPI": { + "type": "object", + "properties": { + "auth0": { + "$ref": "#/definitions/api.IAMUserAuthAPIAuth0" + }, + "userpass": { + "$ref": "#/definitions/api.IAMUserAuthPassword" + } + } + }, + "api.IAMUserAuthAPIAuth0": { + "type": "object", + "properties": { + "enable": { + "type": "boolean" + }, + "tenant": { + "$ref": "#/definitions/api.IAMAuth0Tenant" + }, + "user": { + "type": "string" + } + } + }, + "api.IAMUserAuthPassword": { + "type": "object", + "properties": { + "enable": { + "type": "boolean" + }, + "password": { + "type": "string" + } + } + }, + "api.IAMUserAuthServices": { + "type": "object", + "properties": { + "basic": { + "$ref": "#/definitions/api.IAMUserAuthPassword" + }, + "token": { + "type": "array", + "items": { + "type": "string" + } } } }, @@ -3269,21 +3583,6 @@ const docTemplate = `{ "type": "object", "additionalProperties": true }, - "api.Login": { - "type": "object", - "required": [ - "password", - "username" - ], - "properties": { - "password": { - "type": "string" - }, - "username": { - "type": "string" - } - } - }, "api.MetricsDescription": { "type": "object", "properties": { @@ -3603,6 +3902,9 @@ const docTemplate = `{ "autostart": { "type": "boolean" }, + "group": { + "type": "string" + }, "id": { "type": "string" }, diff --git a/docs/swagger.json b/docs/swagger.json index bfdf9149..77ca7754 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -102,87 +102,6 @@ } } }, - "/api/login": { - "post": { - "security": [ - { - "Auth0KeyAuth": [] - } - ], - "description": "Retrieve valid JWT access and refresh tokens to use for accessing the API. Login either by username/password or Auth0 token", - "produces": [ - "application/json" - ], - "summary": "Retrieve an access and a refresh token", - "operationId": "jwt-login", - "parameters": [ - { - "description": "Login data", - "name": "data", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/api.Login" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/api.JWT" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/api.Error" - } - }, - "403": { - "description": "Forbidden", - "schema": { - "$ref": "#/definitions/api.Error" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/api.Error" - } - } - } - } - }, - "/api/login/refresh": { - "get": { - "security": [ - { - "ApiRefreshKeyAuth": [] - } - ], - "description": "Retrieve a new access token by providing the refresh token", - "produces": [ - "application/json" - ], - "summary": "Retrieve a new access token", - "operationId": "jwt-refresh", - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/api.JWTRefresh" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/api.Error" - } - } - } - } - }, "/api/swagger": { "get": { "description": "Swagger UI for this API", @@ -906,6 +825,308 @@ } } }, + "/api/v3/iam/user": { + "post": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Add a new user", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "v16.?.?" + ], + "summary": "Add a new user", + "operationId": "iam-3-add-user", + "parameters": [ + { + "description": "User definition", + "name": "config", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/api.IAMUser" + } + }, + { + "type": "string", + "description": "Domain of the acting user", + "name": "domain", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/api.IAMUser" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/api.Error" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/api.Error" + } + } + } + } + }, + "/api/v3/iam/user/{name}": { + "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "List aa user by its name", + "produces": [ + "application/json" + ], + "tags": [ + "v16.?.?" + ], + "summary": "List an user by its name", + "operationId": "iam-3-get-user", + "parameters": [ + { + "type": "string", + "description": "Username", + "name": "name", + "in": "path", + "required": true + }, + { + "type": "string", + "description": "Domain of the acting user", + "name": "domain", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/api.IAMUser" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/api.Error" + } + } + } + }, + "put": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Replace an existing user.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "v16.?.?" + ], + "summary": "Replace an existing user", + "operationId": "iam-3-update-user", + "parameters": [ + { + "type": "string", + "description": "Username", + "name": "name", + "in": "path", + "required": true + }, + { + "type": "string", + "description": "Domain of the acting user", + "name": "domain", + "in": "query" + }, + { + "description": "User definition", + "name": "user", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/api.IAMUser" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/api.IAMUser" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/api.Error" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/api.Error" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/api.Error" + } + } + } + }, + "delete": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Delete an user by its name", + "produces": [ + "application/json" + ], + "tags": [ + "v16.?.?" + ], + "summary": "Delete an user by its name", + "operationId": "iam-3-delete-user", + "parameters": [ + { + "type": "string", + "description": "Username", + "name": "name", + "in": "path", + "required": true + }, + { + "type": "string", + "description": "Domain of the acting user", + "name": "domain", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "type": "string" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/api.Error" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/api.Error" + } + } + } + } + }, + "/api/v3/iam/user/{name}/policy": { + "put": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Replace policies of an user", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "v16.?.?" + ], + "summary": "Replace policies of an user", + "operationId": "iam-3-update-user", + "parameters": [ + { + "type": "string", + "description": "Username", + "name": "name", + "in": "path", + "required": true + }, + { + "type": "string", + "description": "Domain of the acting user", + "name": "domain", + "in": "query" + }, + { + "description": "Policy definitions", + "name": "user", + "in": "body", + "required": true, + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/api.IAMPolicy" + } + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/api.IAMPolicy" + } + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/api.Error" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/api.Error" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/api.Error" + } + } + } + } + }, "/api/v3/log": { "get": { "security": [ @@ -3239,22 +3460,115 @@ } } }, - "api.JWT": { + "api.IAMAuth0Tenant": { "type": "object", "properties": { - "access_token": { + "audience": { "type": "string" }, - "refresh_token": { + "client_id": { + "type": "string" + }, + "domain": { "type": "string" } } }, - "api.JWTRefresh": { + "api.IAMPolicy": { "type": "object", "properties": { - "access_token": { + "actions": { + "type": "array", + "items": { + "type": "string" + } + }, + "domain": { "type": "string" + }, + "resource": { + "type": "string" + } + } + }, + "api.IAMUser": { + "type": "object", + "properties": { + "auth": { + "$ref": "#/definitions/api.IAMUserAuth" + }, + "name": { + "type": "string" + }, + "policies": { + "type": "array", + "items": { + "$ref": "#/definitions/api.IAMPolicy" + } + }, + "superuser": { + "type": "boolean" + } + } + }, + "api.IAMUserAuth": { + "type": "object", + "properties": { + "api": { + "$ref": "#/definitions/api.IAMUserAuthAPI" + }, + "services": { + "$ref": "#/definitions/api.IAMUserAuthServices" + } + } + }, + "api.IAMUserAuthAPI": { + "type": "object", + "properties": { + "auth0": { + "$ref": "#/definitions/api.IAMUserAuthAPIAuth0" + }, + "userpass": { + "$ref": "#/definitions/api.IAMUserAuthPassword" + } + } + }, + "api.IAMUserAuthAPIAuth0": { + "type": "object", + "properties": { + "enable": { + "type": "boolean" + }, + "tenant": { + "$ref": "#/definitions/api.IAMAuth0Tenant" + }, + "user": { + "type": "string" + } + } + }, + "api.IAMUserAuthPassword": { + "type": "object", + "properties": { + "enable": { + "type": "boolean" + }, + "password": { + "type": "string" + } + } + }, + "api.IAMUserAuthServices": { + "type": "object", + "properties": { + "basic": { + "$ref": "#/definitions/api.IAMUserAuthPassword" + }, + "token": { + "type": "array", + "items": { + "type": "string" + } } } }, @@ -3262,21 +3576,6 @@ "type": "object", "additionalProperties": true }, - "api.Login": { - "type": "object", - "required": [ - "password", - "username" - ], - "properties": { - "password": { - "type": "string" - }, - "username": { - "type": "string" - } - } - }, "api.MetricsDescription": { "type": "object", "properties": { @@ -3596,6 +3895,9 @@ "autostart": { "type": "boolean" }, + "group": { + "type": "string" + }, "id": { "type": "string" }, diff --git a/docs/swagger.yaml b/docs/swagger.yaml index b05033eb..cd42f4ed 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -595,31 +595,81 @@ definitions: items: {} type: array type: object - api.JWT: + api.IAMAuth0Tenant: properties: - access_token: + audience: type: string - refresh_token: + client_id: + type: string + domain: type: string type: object - api.JWTRefresh: + api.IAMPolicy: properties: - access_token: + actions: + items: + type: string + type: array + domain: type: string + resource: + type: string + type: object + api.IAMUser: + properties: + auth: + $ref: '#/definitions/api.IAMUserAuth' + name: + type: string + policies: + items: + $ref: '#/definitions/api.IAMPolicy' + type: array + superuser: + type: boolean + type: object + api.IAMUserAuth: + properties: + api: + $ref: '#/definitions/api.IAMUserAuthAPI' + services: + $ref: '#/definitions/api.IAMUserAuthServices' + type: object + api.IAMUserAuthAPI: + properties: + auth0: + $ref: '#/definitions/api.IAMUserAuthAPIAuth0' + userpass: + $ref: '#/definitions/api.IAMUserAuthPassword' + type: object + api.IAMUserAuthAPIAuth0: + properties: + enable: + type: boolean + tenant: + $ref: '#/definitions/api.IAMAuth0Tenant' + user: + type: string + type: object + api.IAMUserAuthPassword: + properties: + enable: + type: boolean + password: + type: string + type: object + api.IAMUserAuthServices: + properties: + basic: + $ref: '#/definitions/api.IAMUserAuthPassword' + token: + items: + type: string + type: array type: object api.LogEvent: additionalProperties: true type: object - api.Login: - properties: - password: - type: string - username: - type: string - required: - - password - - username - type: object api.MetricsDescription: properties: description: @@ -835,6 +885,8 @@ definitions: properties: autostart: type: boolean + group: + type: string id: type: string input: @@ -2144,58 +2196,6 @@ paths: security: - ApiKeyAuth: [] summary: Query the GraphAPI - /api/login: - post: - description: Retrieve valid JWT access and refresh tokens to use for accessing - the API. Login either by username/password or Auth0 token - operationId: jwt-login - parameters: - - description: Login data - in: body - name: data - required: true - schema: - $ref: '#/definitions/api.Login' - produces: - - application/json - responses: - "200": - description: OK - schema: - $ref: '#/definitions/api.JWT' - "400": - description: Bad Request - schema: - $ref: '#/definitions/api.Error' - "403": - description: Forbidden - schema: - $ref: '#/definitions/api.Error' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/api.Error' - security: - - Auth0KeyAuth: [] - summary: Retrieve an access and a refresh token - /api/login/refresh: - get: - description: Retrieve a new access token by providing the refresh token - operationId: jwt-refresh - produces: - - application/json - responses: - "200": - description: OK - schema: - $ref: '#/definitions/api.JWTRefresh' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/api.Error' - security: - - ApiRefreshKeyAuth: [] - summary: Retrieve a new access token /api/swagger: get: description: Swagger UI for this API @@ -2661,6 +2661,201 @@ paths: summary: Add a file to a filesystem tags: - v16.7.2 + /api/v3/iam/user: + post: + consumes: + - application/json + description: Add a new user + operationId: iam-3-add-user + parameters: + - description: User definition + in: body + name: config + required: true + schema: + $ref: '#/definitions/api.IAMUser' + - description: Domain of the acting user + in: query + name: domain + type: string + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/api.IAMUser' + "400": + description: Bad Request + schema: + $ref: '#/definitions/api.Error' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/api.Error' + security: + - ApiKeyAuth: [] + summary: Add a new user + tags: + - v16.?.? + /api/v3/iam/user/{name}: + delete: + description: Delete an user by its name + operationId: iam-3-delete-user + parameters: + - description: Username + in: path + name: name + required: true + type: string + - description: Domain of the acting user + in: query + name: domain + type: string + produces: + - application/json + responses: + "200": + description: OK + schema: + type: string + "404": + description: Not Found + schema: + $ref: '#/definitions/api.Error' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/api.Error' + security: + - ApiKeyAuth: [] + summary: Delete an user by its name + tags: + - v16.?.? + get: + description: List aa user by its name + operationId: iam-3-get-user + parameters: + - description: Username + in: path + name: name + required: true + type: string + - description: Domain of the acting user + in: query + name: domain + type: string + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/api.IAMUser' + "404": + description: Not Found + schema: + $ref: '#/definitions/api.Error' + security: + - ApiKeyAuth: [] + summary: List an user by its name + tags: + - v16.?.? + put: + consumes: + - application/json + description: Replace an existing user. + operationId: iam-3-update-user + parameters: + - description: Username + in: path + name: name + required: true + type: string + - description: Domain of the acting user + in: query + name: domain + type: string + - description: User definition + in: body + name: user + required: true + schema: + $ref: '#/definitions/api.IAMUser' + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/api.IAMUser' + "400": + description: Bad Request + schema: + $ref: '#/definitions/api.Error' + "404": + description: Not Found + schema: + $ref: '#/definitions/api.Error' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/api.Error' + security: + - ApiKeyAuth: [] + summary: Replace an existing user + tags: + - v16.?.? + /api/v3/iam/user/{name}/policy: + put: + consumes: + - application/json + description: Replace policies of an user + operationId: iam-3-update-user + parameters: + - description: Username + in: path + name: name + required: true + type: string + - description: Domain of the acting user + in: query + name: domain + type: string + - description: Policy definitions + in: body + name: user + required: true + schema: + items: + $ref: '#/definitions/api.IAMPolicy' + type: array + produces: + - application/json + responses: + "200": + description: OK + schema: + items: + $ref: '#/definitions/api.IAMPolicy' + type: array + "400": + description: Bad Request + schema: + $ref: '#/definitions/api.Error' + "404": + description: Not Found + schema: + $ref: '#/definitions/api.Error' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/api.Error' + security: + - ApiKeyAuth: [] + summary: Replace policies of an user + tags: + - v16.?.? /api/v3/log: get: description: Get the last log lines of the Restreamer application diff --git a/go.mod b/go.mod index ec1c4e54..48dbf9ec 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/Masterminds/semver/v3 v3.2.1 github.com/atrox/haikunatorgo/v2 v2.0.1 github.com/caddyserver/certmagic v0.17.2 + github.com/casbin/casbin/v2 v2.60.0 github.com/datarhei/core-client-go/v16 v16.11.1-0.20230512155342-18a7ac72df3a github.com/datarhei/gosrt v0.3.1 github.com/datarhei/joy4 v0.0.0-20230505074825-fde05957445a @@ -39,6 +40,7 @@ require ( //replace github.com/datarhei/core-client-go/v16 => ../core-client-go require ( + github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible // indirect github.com/KyleBanks/depth v1.2.1 // indirect github.com/agnivade/levenshtein v1.1.1 // indirect github.com/armon/go-metrics v0.4.1 // indirect diff --git a/go.sum b/go.sum index 793433aa..77e59efd 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= github.com/DataDog/datadog-go v2.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible h1:1G1pk05UrOh0NlF1oeaaix1x8XzrfjIDK47TY0Zehcw= +github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc= github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE= github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0= @@ -40,6 +42,8 @@ github.com/boltdb/bolt v1.3.1/go.mod h1:clJnj/oiGkjum5o1McbSZDSLxVThjynRyGBgiAx2 github.com/caddyserver/certmagic v0.17.2 h1:o30seC1T/dBqBCNNGNHWwj2i5/I/FMjBbTAhjADP3nE= github.com/caddyserver/certmagic v0.17.2/go.mod h1:ouWUuC490GOLJzkyN35eXfV8bSbwMwSf4bdhkIxtdQE= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/casbin/casbin/v2 v2.60.0 h1:ZmC0/t4wolfEsDpDxTEsu2z6dfbMNpc11F52ceLs2Eo= +github.com/casbin/casbin/v2 v2.60.0/go.mod h1:vByNa/Fchek0KZUgG5wEsl7iFsiviAYKRtgrQfcJqHg= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/circonus-labs/circonus-gometrics v2.3.1+incompatible/go.mod h1:nmEj6Dob7S7YxXgwXpfOuvO54S+tGdZdw9fuRZt25Ag= @@ -100,6 +104,8 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU= github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc= +github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -459,6 +465,7 @@ golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.6-0.20210726203631-07bc1bf47fb2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= diff --git a/http/api/iam.go b/http/api/iam.go new file mode 100644 index 00000000..aea4f50b --- /dev/null +++ b/http/api/iam.go @@ -0,0 +1,109 @@ +package api + +import "github.com/datarhei/core/v16/iam" + +type IAMUser struct { + Name string `json:"name"` + Superuser bool `json:"superuser"` + Auth IAMUserAuth `json:"auth"` + Policies []IAMPolicy `json:"policies"` +} + +func (u *IAMUser) Marshal(user iam.User, policies []iam.Policy) { + u.Name = user.Name + u.Superuser = user.Superuser + u.Auth = IAMUserAuth{ + API: IAMUserAuthAPI{ + Password: user.Auth.API.Password, + Auth0: IAMUserAuthAPIAuth0{ + User: user.Auth.API.Auth0.User, + Tenant: IAMAuth0Tenant{ + Domain: user.Auth.API.Auth0.Tenant.Domain, + Audience: user.Auth.API.Auth0.Tenant.Audience, + ClientID: user.Auth.API.Auth0.Tenant.ClientID, + }, + }, + }, + Services: IAMUserAuthServices{ + Basic: user.Auth.Services.Basic, + Token: user.Auth.Services.Token, + }, + } + + for _, p := range policies { + u.Policies = append(u.Policies, IAMPolicy{ + Domain: p.Domain, + Resource: p.Resource, + Actions: p.Actions, + }) + } +} + +func (u *IAMUser) Unmarshal() (iam.User, []iam.Policy) { + iamuser := iam.User{ + Name: u.Name, + Superuser: u.Superuser, + Auth: iam.UserAuth{ + API: iam.UserAuthAPI{ + Password: u.Auth.API.Password, + Auth0: iam.UserAuthAPIAuth0{ + User: u.Auth.API.Auth0.User, + Tenant: iam.Auth0Tenant{ + Domain: u.Auth.API.Auth0.Tenant.Domain, + Audience: u.Auth.API.Auth0.Tenant.Audience, + ClientID: u.Auth.API.Auth0.Tenant.ClientID, + }, + }, + }, + Services: iam.UserAuthServices{ + Basic: u.Auth.Services.Basic, + Token: u.Auth.Services.Token, + }, + }, + } + + iampolicies := []iam.Policy{} + + for _, p := range u.Policies { + iampolicies = append(iampolicies, iam.Policy{ + Name: u.Name, + Domain: p.Domain, + Resource: p.Resource, + Actions: p.Actions, + }) + } + + return iamuser, iampolicies +} + +type IAMUserAuth struct { + API IAMUserAuthAPI `json:"api"` + Services IAMUserAuthServices `json:"services"` +} + +type IAMUserAuthAPI struct { + Password string `json:"userpass"` + Auth0 IAMUserAuthAPIAuth0 `json:"auth0"` +} + +type IAMUserAuthAPIAuth0 struct { + User string `json:"user"` + Tenant IAMAuth0Tenant `json:"tenant"` +} + +type IAMUserAuthServices struct { + Basic []string `json:"basic"` + Token []string `json:"token"` +} + +type IAMAuth0Tenant struct { + Domain string `json:"domain"` + Audience string `json:"audience"` + ClientID string `json:"client_id"` +} + +type IAMPolicy struct { + Domain string `json:"domain"` + Resource string `json:"resource"` + Actions []string `json:"actions"` +} diff --git a/http/api/process.go b/http/api/process.go index 98472459..4d7bcd2d 100644 --- a/http/api/process.go +++ b/http/api/process.go @@ -45,6 +45,7 @@ type ProcessConfigLimits struct { // ProcessConfig represents the configuration of an ffmpeg process type ProcessConfig struct { ID string `json:"id"` + Group string `json:"group"` Type string `json:"type" validate:"oneof='ffmpeg' ''" jsonschema:"enum=ffmpeg,enum="` Reference string `json:"reference"` Input []ProcessConfigIO `json:"input" validate:"required"` @@ -61,6 +62,7 @@ type ProcessConfig struct { func (cfg *ProcessConfig) Marshal() *app.Config { p := &app.Config{ ID: cfg.ID, + Domain: cfg.Group, Reference: cfg.Reference, Options: cfg.Options, Reconnect: cfg.Reconnect, @@ -140,6 +142,7 @@ func (cfg *ProcessConfig) Unmarshal(c *app.Config) { } cfg.ID = c.ID + cfg.Group = c.Domain cfg.Reference = c.Reference cfg.Type = "ffmpeg" cfg.Reconnect = c.Reconnect diff --git a/http/graph/graph/graph.go b/http/graph/graph/graph.go index 332d2fdf..57d0564e 100644 --- a/http/graph/graph/graph.go +++ b/http/graph/graph/graph.go @@ -127,8 +127,10 @@ type ComplexityRoot struct { Process struct { Config func(childComplexity int) int CreatedAt func(childComplexity int) int + Domain func(childComplexity int) int ID func(childComplexity int) int Metadata func(childComplexity int) int + Owner func(childComplexity int) int Reference func(childComplexity int) int Report func(childComplexity int) int State func(childComplexity int) int @@ -137,11 +139,13 @@ type ComplexityRoot struct { ProcessConfig struct { Autostart func(childComplexity int) int + Domain func(childComplexity int) int ID func(childComplexity int) int Input func(childComplexity int) int Limits func(childComplexity int) int Options func(childComplexity int) int Output func(childComplexity int) int + Owner func(childComplexity int) int Reconnect func(childComplexity int) int ReconnectDelaySeconds func(childComplexity int) int Reference func(childComplexity int) int @@ -236,10 +240,10 @@ type ComplexityRoot struct { Log func(childComplexity int) int Metrics func(childComplexity int, query models.MetricsInput) int Ping func(childComplexity int) int - PlayoutStatus func(childComplexity int, id string, input string) int - Probe func(childComplexity int, id string) int - Process func(childComplexity int, id string) int - Processes func(childComplexity int) int + PlayoutStatus func(childComplexity int, id string, domain string, input string) int + Probe func(childComplexity int, id string, domain string) int + Process func(childComplexity int, id string, domain string) int + Processes func(childComplexity int, idpattern *string, refpattern *string, domainpattern *string) int } RawAVstream struct { @@ -283,10 +287,10 @@ type QueryResolver interface { About(ctx context.Context) (*models.About, error) Log(ctx context.Context) ([]string, error) Metrics(ctx context.Context, query models.MetricsInput) (*models.Metrics, error) - PlayoutStatus(ctx context.Context, id string, input string) (*models.RawAVstream, error) - Processes(ctx context.Context) ([]*models.Process, error) - Process(ctx context.Context, id string) (*models.Process, error) - Probe(ctx context.Context, id string) (*models.Probe, error) + PlayoutStatus(ctx context.Context, id string, domain string, input string) (*models.RawAVstream, error) + Processes(ctx context.Context, idpattern *string, refpattern *string, domainpattern *string) ([]*models.Process, error) + Process(ctx context.Context, id string, domain string) (*models.Process, error) + Probe(ctx context.Context, id string, domain string) (*models.Probe, error) } type executableSchema struct { @@ -675,6 +679,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Process.CreatedAt(childComplexity), true + case "Process.domain": + if e.complexity.Process.Domain == nil { + break + } + + return e.complexity.Process.Domain(childComplexity), true + case "Process.id": if e.complexity.Process.ID == nil { break @@ -689,6 +700,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Process.Metadata(childComplexity), true + case "Process.owner": + if e.complexity.Process.Owner == nil { + break + } + + return e.complexity.Process.Owner(childComplexity), true + case "Process.reference": if e.complexity.Process.Reference == nil { break @@ -724,6 +742,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.ProcessConfig.Autostart(childComplexity), true + case "ProcessConfig.domain": + if e.complexity.ProcessConfig.Domain == nil { + break + } + + return e.complexity.ProcessConfig.Domain(childComplexity), true + case "ProcessConfig.id": if e.complexity.ProcessConfig.ID == nil { break @@ -759,6 +784,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.ProcessConfig.Output(childComplexity), true + case "ProcessConfig.owner": + if e.complexity.ProcessConfig.Owner == nil { + break + } + + return e.complexity.ProcessConfig.Owner(childComplexity), true + case "ProcessConfig.reconnect": if e.complexity.ProcessConfig.Reconnect == nil { break @@ -1243,7 +1275,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } - return e.complexity.Query.PlayoutStatus(childComplexity, args["id"].(string), args["input"].(string)), true + return e.complexity.Query.PlayoutStatus(childComplexity, args["id"].(string), args["domain"].(string), args["input"].(string)), true case "Query.probe": if e.complexity.Query.Probe == nil { @@ -1255,7 +1287,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } - return e.complexity.Query.Probe(childComplexity, args["id"].(string)), true + return e.complexity.Query.Probe(childComplexity, args["id"].(string), args["domain"].(string)), true case "Query.process": if e.complexity.Query.Process == nil { @@ -1267,14 +1299,19 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } - return e.complexity.Query.Process(childComplexity, args["id"].(string)), true + return e.complexity.Query.Process(childComplexity, args["id"].(string), args["domain"].(string)), true case "Query.processes": if e.complexity.Query.Processes == nil { break } - return e.complexity.Query.Processes(childComplexity), true + args, err := ec.field_Query_processes_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Query.Processes(childComplexity, args["idpattern"].(*string), args["refpattern"].(*string), args["domainpattern"].(*string)), true case "RawAVstream.aqueue": if e.complexity.RawAVstream.Aqueue == nil { @@ -1561,7 +1598,7 @@ type Metric { } `, BuiltIn: false}, {Name: "../playout.graphqls", Input: `extend type Query { - playoutStatus(id: ID!, input: ID!): RawAVstream + playoutStatus(id: ID!, domain: String!, input: ID!): RawAVstream } type RawAVstreamIO { @@ -1597,9 +1634,13 @@ type RawAVstream { } `, BuiltIn: false}, {Name: "../process.graphqls", Input: `extend type Query { - processes: [Process!]! - process(id: ID!): Process - probe(id: ID!): Probe! + processes( + idpattern: String + refpattern: String + domainpattern: String + ): [Process!]! + process(id: ID!, domain: String!): Process + probe(id: ID!, domain: String!): Probe! } type ProcessConfigIO { @@ -1616,6 +1657,8 @@ type ProcessConfigLimits { type ProcessConfig { id: String! + owner: String! + domain: String! type: String! reference: String! input: [ProcessConfigIO!]! @@ -1666,6 +1709,8 @@ type ProcessReport implements IProcessReportHistoryEntry { type Process { id: String! + owner: String! + domain: String! type: String! reference: String! created_at: Time! @@ -1841,14 +1886,23 @@ func (ec *executionContext) field_Query_playoutStatus_args(ctx context.Context, } args["id"] = arg0 var arg1 string - if tmp, ok := rawArgs["input"]; ok { - ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("input")) - arg1, err = ec.unmarshalNID2string(ctx, tmp) + if tmp, ok := rawArgs["domain"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("domain")) + arg1, err = ec.unmarshalNString2string(ctx, tmp) if err != nil { return nil, err } } - args["input"] = arg1 + args["domain"] = arg1 + var arg2 string + if tmp, ok := rawArgs["input"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("input")) + arg2, err = ec.unmarshalNID2string(ctx, tmp) + if err != nil { + return nil, err + } + } + args["input"] = arg2 return args, nil } @@ -1864,6 +1918,15 @@ func (ec *executionContext) field_Query_probe_args(ctx context.Context, rawArgs } } args["id"] = arg0 + var arg1 string + if tmp, ok := rawArgs["domain"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("domain")) + arg1, err = ec.unmarshalNString2string(ctx, tmp) + if err != nil { + return nil, err + } + } + args["domain"] = arg1 return args, nil } @@ -1879,6 +1942,48 @@ func (ec *executionContext) field_Query_process_args(ctx context.Context, rawArg } } args["id"] = arg0 + var arg1 string + if tmp, ok := rawArgs["domain"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("domain")) + arg1, err = ec.unmarshalNString2string(ctx, tmp) + if err != nil { + return nil, err + } + } + args["domain"] = arg1 + return args, nil +} + +func (ec *executionContext) field_Query_processes_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 *string + if tmp, ok := rawArgs["idpattern"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("idpattern")) + arg0, err = ec.unmarshalOString2ᚖstring(ctx, tmp) + if err != nil { + return nil, err + } + } + args["idpattern"] = arg0 + var arg1 *string + if tmp, ok := rawArgs["refpattern"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("refpattern")) + arg1, err = ec.unmarshalOString2ᚖstring(ctx, tmp) + if err != nil { + return nil, err + } + } + args["refpattern"] = arg1 + var arg2 *string + if tmp, ok := rawArgs["domainpattern"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("domainpattern")) + arg2, err = ec.unmarshalOString2ᚖstring(ctx, tmp) + if err != nil { + return nil, err + } + } + args["domainpattern"] = arg2 return args, nil } @@ -4275,6 +4380,94 @@ func (ec *executionContext) fieldContext_Process_id(ctx context.Context, field g return fc, nil } +func (ec *executionContext) _Process_owner(ctx context.Context, field graphql.CollectedField, obj *models.Process) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Process_owner(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Owner, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNString2string(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Process_owner(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Process", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _Process_domain(ctx context.Context, field graphql.CollectedField, obj *models.Process) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Process_domain(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Domain, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNString2string(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Process_domain(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Process", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _Process_type(ctx context.Context, field graphql.CollectedField, obj *models.Process) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Process_type(ctx, field) if err != nil { @@ -4448,6 +4641,10 @@ func (ec *executionContext) fieldContext_Process_config(ctx context.Context, fie switch field.Name { case "id": return ec.fieldContext_ProcessConfig_id(ctx, field) + case "owner": + return ec.fieldContext_ProcessConfig_owner(ctx, field) + case "domain": + return ec.fieldContext_ProcessConfig_domain(ctx, field) case "type": return ec.fieldContext_ProcessConfig_type(ctx, field) case "reference": @@ -4678,6 +4875,94 @@ func (ec *executionContext) fieldContext_ProcessConfig_id(ctx context.Context, f return fc, nil } +func (ec *executionContext) _ProcessConfig_owner(ctx context.Context, field graphql.CollectedField, obj *models.ProcessConfig) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_ProcessConfig_owner(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Owner, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNString2string(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_ProcessConfig_owner(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "ProcessConfig", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _ProcessConfig_domain(ctx context.Context, field graphql.CollectedField, obj *models.ProcessConfig) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_ProcessConfig_domain(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Domain, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNString2string(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_ProcessConfig_domain(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "ProcessConfig", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _ProcessConfig_type(ctx context.Context, field graphql.CollectedField, obj *models.ProcessConfig) (ret graphql.Marshaler) { fc, err := ec.fieldContext_ProcessConfig_type(ctx, field) if err != nil { @@ -8071,7 +8356,7 @@ func (ec *executionContext) _Query_playoutStatus(ctx context.Context, field grap }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Query().PlayoutStatus(rctx, fc.Args["id"].(string), fc.Args["input"].(string)) + return ec.resolvers.Query().PlayoutStatus(rctx, fc.Args["id"].(string), fc.Args["domain"].(string), fc.Args["input"].(string)) }) if err != nil { ec.Error(ctx, err) @@ -8155,7 +8440,7 @@ func (ec *executionContext) _Query_processes(ctx context.Context, field graphql. }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Query().Processes(rctx) + return ec.resolvers.Query().Processes(rctx, fc.Args["idpattern"].(*string), fc.Args["refpattern"].(*string), fc.Args["domainpattern"].(*string)) }) if err != nil { ec.Error(ctx, err) @@ -8182,6 +8467,10 @@ func (ec *executionContext) fieldContext_Query_processes(ctx context.Context, fi switch field.Name { case "id": return ec.fieldContext_Process_id(ctx, field) + case "owner": + return ec.fieldContext_Process_owner(ctx, field) + case "domain": + return ec.fieldContext_Process_domain(ctx, field) case "type": return ec.fieldContext_Process_type(ctx, field) case "reference": @@ -8200,6 +8489,17 @@ func (ec *executionContext) fieldContext_Query_processes(ctx context.Context, fi return nil, fmt.Errorf("no field named %q was found under type Process", field.Name) }, } + defer func() { + if r := recover(); r != nil { + err = ec.Recover(ctx, r) + ec.Error(ctx, err) + } + }() + ctx = graphql.WithFieldContext(ctx, fc) + if fc.Args, err = ec.field_Query_processes_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { + ec.Error(ctx, err) + return + } return fc, nil } @@ -8217,7 +8517,7 @@ func (ec *executionContext) _Query_process(ctx context.Context, field graphql.Co }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Query().Process(rctx, fc.Args["id"].(string)) + return ec.resolvers.Query().Process(rctx, fc.Args["id"].(string), fc.Args["domain"].(string)) }) if err != nil { ec.Error(ctx, err) @@ -8241,6 +8541,10 @@ func (ec *executionContext) fieldContext_Query_process(ctx context.Context, fiel switch field.Name { case "id": return ec.fieldContext_Process_id(ctx, field) + case "owner": + return ec.fieldContext_Process_owner(ctx, field) + case "domain": + return ec.fieldContext_Process_domain(ctx, field) case "type": return ec.fieldContext_Process_type(ctx, field) case "reference": @@ -8287,7 +8591,7 @@ func (ec *executionContext) _Query_probe(ctx context.Context, field graphql.Coll }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Query().Probe(rctx, fc.Args["id"].(string)) + return ec.resolvers.Query().Probe(rctx, fc.Args["id"].(string), fc.Args["domain"].(string)) }) if err != nil { ec.Error(ctx, err) @@ -11282,7 +11586,12 @@ func (ec *executionContext) unmarshalInputMetricInput(ctx context.Context, obj i asMap[k] = v } - for k, v := range asMap { + fieldsInOrder := [...]string{"name", "labels"} + for _, k := range fieldsInOrder { + v, ok := asMap[k] + if !ok { + continue + } switch k { case "name": var err error @@ -11313,7 +11622,12 @@ func (ec *executionContext) unmarshalInputMetricsInput(ctx context.Context, obj asMap[k] = v } - for k, v := range asMap { + fieldsInOrder := [...]string{"timerange_seconds", "interval_seconds", "metrics"} + for _, k := range fieldsInOrder { + v, ok := asMap[k] + if !ok { + continue + } switch k { case "timerange_seconds": var err error @@ -11938,6 +12252,20 @@ func (ec *executionContext) _Process(ctx context.Context, sel ast.SelectionSet, out.Values[i] = ec._Process_id(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } + case "owner": + + out.Values[i] = ec._Process_owner(ctx, field, obj) + + if out.Values[i] == graphql.Null { + invalids++ + } + case "domain": + + out.Values[i] = ec._Process_domain(ctx, field, obj) + if out.Values[i] == graphql.Null { invalids++ } @@ -12012,6 +12340,20 @@ func (ec *executionContext) _ProcessConfig(ctx context.Context, sel ast.Selectio out.Values[i] = ec._ProcessConfig_id(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } + case "owner": + + out.Values[i] = ec._ProcessConfig_owner(ctx, field, obj) + + if out.Values[i] == graphql.Null { + invalids++ + } + case "domain": + + out.Values[i] = ec._ProcessConfig_domain(ctx, field, obj) + if out.Values[i] == graphql.Null { invalids++ } diff --git a/http/graph/models/models_gen.go b/http/graph/models/models_gen.go index 7cda1413..56bb286b 100644 --- a/http/graph/models/models_gen.go +++ b/http/graph/models/models_gen.go @@ -13,6 +13,9 @@ import ( type IProcessReportHistoryEntry interface { IsIProcessReportHistoryEntry() + GetCreatedAt() time.Time + GetPrelude() []string + GetLog() []*ProcessReportLogEntry } type AVStream struct { @@ -102,6 +105,8 @@ type ProbeIo struct { type Process struct { ID string `json:"id"` + Owner string `json:"owner"` + Domain string `json:"domain"` Type string `json:"type"` Reference string `json:"reference"` CreatedAt time.Time `json:"created_at"` @@ -113,6 +118,8 @@ type Process struct { type ProcessConfig struct { ID string `json:"id"` + Owner string `json:"owner"` + Domain string `json:"domain"` Type string `json:"type"` Reference string `json:"reference"` Input []*ProcessConfigIo `json:"input"` @@ -145,6 +152,27 @@ type ProcessReport struct { } func (ProcessReport) IsIProcessReportHistoryEntry() {} +func (this ProcessReport) GetCreatedAt() time.Time { return this.CreatedAt } +func (this ProcessReport) GetPrelude() []string { + if this.Prelude == nil { + return nil + } + interfaceSlice := make([]string, 0, len(this.Prelude)) + for _, concrete := range this.Prelude { + interfaceSlice = append(interfaceSlice, concrete) + } + return interfaceSlice +} +func (this ProcessReport) GetLog() []*ProcessReportLogEntry { + if this.Log == nil { + return nil + } + interfaceSlice := make([]*ProcessReportLogEntry, 0, len(this.Log)) + for _, concrete := range this.Log { + interfaceSlice = append(interfaceSlice, concrete) + } + return interfaceSlice +} type ProcessReportHistoryEntry struct { CreatedAt time.Time `json:"created_at"` @@ -153,6 +181,27 @@ type ProcessReportHistoryEntry struct { } func (ProcessReportHistoryEntry) IsIProcessReportHistoryEntry() {} +func (this ProcessReportHistoryEntry) GetCreatedAt() time.Time { return this.CreatedAt } +func (this ProcessReportHistoryEntry) GetPrelude() []string { + if this.Prelude == nil { + return nil + } + interfaceSlice := make([]string, 0, len(this.Prelude)) + for _, concrete := range this.Prelude { + interfaceSlice = append(interfaceSlice, concrete) + } + return interfaceSlice +} +func (this ProcessReportHistoryEntry) GetLog() []*ProcessReportLogEntry { + if this.Log == nil { + return nil + } + interfaceSlice := make([]*ProcessReportLogEntry, 0, len(this.Log)) + for _, concrete := range this.Log { + interfaceSlice = append(interfaceSlice, concrete) + } + return interfaceSlice +} type ProcessReportLogEntry struct { Timestamp time.Time `json:"timestamp"` diff --git a/http/graph/playout.graphqls b/http/graph/playout.graphqls index 1ac70973..4fe1e721 100644 --- a/http/graph/playout.graphqls +++ b/http/graph/playout.graphqls @@ -1,5 +1,5 @@ extend type Query { - playoutStatus(id: ID!, input: ID!): RawAVstream + playoutStatus(id: ID!, domain: String!, input: ID!): RawAVstream } type RawAVstreamIO { diff --git a/http/graph/process.graphqls b/http/graph/process.graphqls index 1e4fff7b..a8b7a70e 100644 --- a/http/graph/process.graphqls +++ b/http/graph/process.graphqls @@ -1,7 +1,11 @@ extend type Query { - processes: [Process!]! - process(id: ID!): Process - probe(id: ID!): Probe! + processes( + idpattern: String + refpattern: String + domainpattern: String + ): [Process!]! + process(id: ID!, domain: String!): Process + probe(id: ID!, domain: String!): Probe! } type ProcessConfigIO { @@ -18,6 +22,8 @@ type ProcessConfigLimits { type ProcessConfig { id: String! + owner: String! + domain: String! type: String! reference: String! input: [ProcessConfigIO!]! @@ -68,6 +74,8 @@ type ProcessReport implements IProcessReportHistoryEntry { type Process { id: String! + owner: String! + domain: String! type: String! reference: String! created_at: Time! diff --git a/http/graph/resolver/about.resolvers.go b/http/graph/resolver/about.resolvers.go index 37453620..d62771b5 100644 --- a/http/graph/resolver/about.resolvers.go +++ b/http/graph/resolver/about.resolvers.go @@ -12,6 +12,7 @@ import ( "github.com/datarhei/core/v16/http/graph/scalars" ) +// About is the resolver for the about field. func (r *queryResolver) About(ctx context.Context) (*models.About, error) { createdAt := r.Restream.CreatedAt() diff --git a/http/graph/resolver/log.resolvers.go b/http/graph/resolver/log.resolvers.go index 50006348..18aa2281 100644 --- a/http/graph/resolver/log.resolvers.go +++ b/http/graph/resolver/log.resolvers.go @@ -10,6 +10,7 @@ import ( "github.com/datarhei/core/v16/log" ) +// Log is the resolver for the log field. func (r *queryResolver) Log(ctx context.Context) ([]string, error) { if r.LogBuffer == nil { r.LogBuffer = log.NewBufferWriter(log.Lsilent, 1) diff --git a/http/graph/resolver/metrics.resolvers.go b/http/graph/resolver/metrics.resolvers.go index 6a78567e..e9ae7d3e 100644 --- a/http/graph/resolver/metrics.resolvers.go +++ b/http/graph/resolver/metrics.resolvers.go @@ -12,6 +12,7 @@ import ( "github.com/datarhei/core/v16/monitor/metric" ) +// Metrics is the resolver for the metrics field. func (r *queryResolver) Metrics(ctx context.Context, query models.MetricsInput) (*models.Metrics, error) { patterns := []metric.Pattern{} diff --git a/http/graph/resolver/playout.resolvers.go b/http/graph/resolver/playout.resolvers.go index 8ec88ba1..6387aa20 100644 --- a/http/graph/resolver/playout.resolvers.go +++ b/http/graph/resolver/playout.resolvers.go @@ -11,10 +11,23 @@ import ( "github.com/datarhei/core/v16/http/graph/models" "github.com/datarhei/core/v16/playout" + "github.com/datarhei/core/v16/restream" ) -func (r *queryResolver) PlayoutStatus(ctx context.Context, id string, input string) (*models.RawAVstream, error) { - addr, err := r.Restream.GetPlayout(id, input) +// PlayoutStatus is the resolver for the playoutStatus field. +func (r *queryResolver) PlayoutStatus(ctx context.Context, id string, domain string, input string) (*models.RawAVstream, error) { + user, _ := ctx.Value("user").(string) + + if !r.IAM.Enforce(user, domain, "process:"+id, "read") { + return nil, fmt.Errorf("forbidden") + } + + tid := restream.TaskID{ + ID: id, + Domain: domain, + } + + addr, err := r.Restream.GetPlayout(tid, input) if err != nil { return nil, fmt.Errorf("unknown process or input: %w", err) } diff --git a/http/graph/resolver/process.resolvers.go b/http/graph/resolver/process.resolvers.go index d8f1ee33..f420bd99 100644 --- a/http/graph/resolver/process.resolvers.go +++ b/http/graph/resolver/process.resolvers.go @@ -5,16 +5,24 @@ package resolver import ( "context" + "fmt" "github.com/datarhei/core/v16/http/graph/models" + "github.com/datarhei/core/v16/restream" ) -func (r *queryResolver) Processes(ctx context.Context) ([]*models.Process, error) { - ids := r.Restream.GetProcessIDs("", "") +// Processes is the resolver for the processes field. +func (r *queryResolver) Processes(ctx context.Context, idpattern *string, refpattern *string, domainpattern *string) ([]*models.Process, error) { + user, _ := ctx.Value(GraphKey("user")).(string) + ids := r.Restream.GetProcessIDs(*idpattern, *refpattern, "", *domainpattern) procs := []*models.Process{} for _, id := range ids { + if !r.IAM.Enforce(user, id.Domain, "process:"+id.ID, "read") { + continue + } + p, err := r.getProcess(id) if err != nil { return nil, err @@ -26,12 +34,36 @@ func (r *queryResolver) Processes(ctx context.Context) ([]*models.Process, error return procs, nil } -func (r *queryResolver) Process(ctx context.Context, id string) (*models.Process, error) { - return r.getProcess(id) +// Process is the resolver for the process field. +func (r *queryResolver) Process(ctx context.Context, id string, domain string) (*models.Process, error) { + user, _ := ctx.Value(GraphKey("user")).(string) + + if !r.IAM.Enforce(user, domain, "process:"+id, "read") { + return nil, fmt.Errorf("forbidden") + } + + tid := restream.TaskID{ + ID: id, + Domain: domain, + } + + return r.getProcess(tid) } -func (r *queryResolver) Probe(ctx context.Context, id string) (*models.Probe, error) { - probe := r.Restream.Probe(id) +// Probe is the resolver for the probe field. +func (r *queryResolver) Probe(ctx context.Context, id string, domain string) (*models.Probe, error) { + user, _ := ctx.Value(GraphKey("user")).(string) + + if !r.IAM.Enforce(user, domain, "process:"+id, "write") { + return nil, fmt.Errorf("forbidden") + } + + tid := restream.TaskID{ + ID: id, + Domain: domain, + } + + probe := r.Restream.Probe(tid) p := &models.Probe{} p.UnmarshalRestream(probe) diff --git a/http/graph/resolver/resolver.go b/http/graph/resolver/resolver.go index 8705fe80..be658679 100644 --- a/http/graph/resolver/resolver.go +++ b/http/graph/resolver/resolver.go @@ -7,6 +7,7 @@ import ( "time" "github.com/datarhei/core/v16/http/graph/models" + "github.com/datarhei/core/v16/iam" "github.com/datarhei/core/v16/log" "github.com/datarhei/core/v16/monitor" "github.com/datarhei/core/v16/restream" @@ -20,9 +21,10 @@ type Resolver struct { Restream restream.Restreamer Monitor monitor.HistoryReader LogBuffer log.BufferWriter + IAM iam.IAM } -func (r *queryResolver) getProcess(id string) (*models.Process, error) { +func (r *queryResolver) getProcess(id restream.TaskID) (*models.Process, error) { process, err := r.Restream.GetProcess(id) if err != nil { return nil, err @@ -86,3 +88,5 @@ func (r *queryResolver) playoutRequest(method, addr, path, contentType string, d return data, nil } + +type GraphKey string diff --git a/http/graph/resolver/schema.resolvers.go b/http/graph/resolver/schema.resolvers.go index 60cebcea..89a0b922 100644 --- a/http/graph/resolver/schema.resolvers.go +++ b/http/graph/resolver/schema.resolvers.go @@ -9,10 +9,12 @@ import ( "github.com/datarhei/core/v16/http/graph/graph" ) +// Ping is the resolver for the ping field. func (r *mutationResolver) Ping(ctx context.Context) (string, error) { return "pong", nil } +// Ping is the resolver for the ping field. func (r *queryResolver) Ping(ctx context.Context) (string, error) { return "pong", nil } diff --git a/http/handler/api/about.go b/http/handler/api/about.go index 77c4d7db..a9add8ca 100644 --- a/http/handler/api/about.go +++ b/http/handler/api/about.go @@ -15,11 +15,11 @@ import ( // about the API version and build infos. type AboutHandler struct { restream restream.Restreamer - auths []string + auths func() []string } // NewAbout returns a new About type -func NewAbout(restream restream.Restreamer, auths []string) *AboutHandler { +func NewAbout(restream restream.Restreamer, auths func() []string) *AboutHandler { return &AboutHandler{ restream: restream, auths: auths, @@ -35,12 +35,24 @@ func NewAbout(restream restream.Restreamer, auths []string) *AboutHandler { // @Security ApiKeyAuth // @Router /api [get] func (p *AboutHandler) About(c echo.Context) error { + user, _ := c.Get("user").(string) + + if user == "$anon" { + return c.JSON(http.StatusOK, api.MinimalAbout{ + App: app.Name, + Auths: p.auths(), + Version: api.VersionMinimal{ + Number: app.Version.MajorString(), + }, + }) + } + createdAt := p.restream.CreatedAt() about := api.About{ App: app.Name, Name: p.restream.Name(), - Auths: p.auths, + Auths: p.auths(), ID: p.restream.ID(), CreatedAt: createdAt.Format(time.RFC3339), Uptime: uint64(time.Since(createdAt).Seconds()), diff --git a/http/handler/api/about_test.go b/http/handler/api/about_test.go index dfdb4673..a2f039b5 100644 --- a/http/handler/api/about_test.go +++ b/http/handler/api/about_test.go @@ -19,7 +19,7 @@ func getDummyAboutRouter() (*echo.Echo, error) { return nil, err } - handler := NewAbout(rs, []string{}) + handler := NewAbout(rs, func() []string { return []string{} }) router.Add("GET", "/", handler.About) diff --git a/http/handler/api/graph.go b/http/handler/api/graph.go index 9d9a0a13..35013e2d 100644 --- a/http/handler/api/graph.go +++ b/http/handler/api/graph.go @@ -1,6 +1,7 @@ package api import ( + "context" "net/http" "github.com/datarhei/core/v16/http/graph/graph" @@ -18,7 +19,7 @@ type GraphHandler struct { playgroundHandler http.HandlerFunc } -// NewRestream return a new Restream type. You have to provide a valid Restreamer instance. +// NewGraph return a new GraphHandler type. You have to provide a valid Restreamer instance. func NewGraph(resolver resolver.Resolver, path string) *GraphHandler { g := &GraphHandler{ resolver: resolver, @@ -43,7 +44,12 @@ func NewGraph(resolver resolver.Resolver, path string) *GraphHandler { // @Security ApiKeyAuth // @Router /api/graph/query [post] func (g *GraphHandler) Query(c echo.Context) error { - g.queryHandler.ServeHTTP(c.Response(), c.Request()) + user, _ := c.Get("user").(string) + + r := c.Request() + ctx := context.WithValue(r.Context(), resolver.GraphKey("user"), user) + + g.queryHandler.ServeHTTP(c.Response(), r.WithContext(ctx)) return nil } diff --git a/http/handler/api/iam.go b/http/handler/api/iam.go new file mode 100644 index 00000000..6382e73e --- /dev/null +++ b/http/handler/api/iam.go @@ -0,0 +1,331 @@ +package api + +import ( + "net/http" + + "github.com/datarhei/core/v16/http/api" + "github.com/datarhei/core/v16/http/handler/util" + "github.com/datarhei/core/v16/iam" + + "github.com/labstack/echo/v4" +) + +type IAMHandler struct { + iam iam.IAM +} + +func NewIAM(iam iam.IAM) *IAMHandler { + return &IAMHandler{ + iam: iam, + } +} + +// AddUser adds a new user +// @Summary Add a new user +// @Description Add a new user +// @Tags v16.?.? +// @ID iam-3-add-user +// @Accept json +// @Produce json +// @Param config body api.IAMUser true "User definition" +// @Param domain query string false "Domain of the acting user" +// @Success 200 {object} api.IAMUser +// @Failure 400 {object} api.Error +// @Failure 500 {object} api.Error +// @Security ApiKeyAuth +// @Router /api/v3/iam/user [post] +func (h *IAMHandler) AddUser(c echo.Context) error { + ctxuser := util.DefaultContext(c, "user", "") + superuser := util.DefaultContext(c, "superuser", false) + domain := util.DefaultQuery(c, "domain", "$none") + + user := api.IAMUser{} + + if err := util.ShouldBindJSON(c, &user); err != nil { + return api.Err(http.StatusBadRequest, "Invalid JSON", "%s", err) + } + + iamuser, iampolicies := user.Unmarshal() + + if !h.iam.Enforce(ctxuser, domain, "iam:"+iamuser.Name, "write") { + return api.Err(http.StatusForbidden, "Forbidden", "Not allowed to create user '%s'", iamuser.Name) + } + + for _, p := range iampolicies { + if !h.iam.Enforce(ctxuser, p.Domain, "iam:"+iamuser.Name, "write") { + return api.Err(http.StatusForbidden, "Forbidden", "Not allowed to write policy: %v", p) + } + } + + if !superuser && iamuser.Superuser { + return api.Err(http.StatusForbidden, "Forbidden", "Only superusers can add superusers") + } + + err := h.iam.CreateIdentity(iamuser) + if err != nil { + return api.Err(http.StatusBadRequest, "Bad request", "%s", err) + } + + for _, p := range iampolicies { + h.iam.AddPolicy(p.Name, p.Domain, p.Resource, p.Actions) + } + + err = h.iam.SaveIdentities() + if err != nil { + return api.Err(http.StatusInternalServerError, "Internal server error", "%s", err) + } + + return c.JSON(http.StatusOK, user) +} + +// RemoveUser deletes the user with the given name +// @Summary Delete an user by its name +// @Description Delete an user by its name +// @Tags v16.?.? +// @ID iam-3-delete-user +// @Produce json +// @Param name path string true "Username" +// @Param domain query string false "Domain of the acting user" +// @Success 200 {string} string +// @Failure 404 {object} api.Error +// @Failure 500 {object} api.Error +// @Security ApiKeyAuth +// @Router /api/v3/iam/user/{name} [delete] +func (h *IAMHandler) RemoveUser(c echo.Context) error { + ctxuser := util.DefaultContext(c, "user", "") + superuser := util.DefaultContext(c, "superuser", false) + domain := util.DefaultQuery(c, "domain", "$none") + name := util.PathParam(c, "name") + + if !h.iam.Enforce(ctxuser, domain, "iam:"+name, "write") { + return api.Err(http.StatusForbidden, "Forbidden", "Not allowed to delete this user") + } + + iamuser, err := h.iam.GetIdentity(name) + if err != nil { + return api.Err(http.StatusNotFound, "Not found", "%s", err) + } + + if !superuser && iamuser.Superuser { + return api.Err(http.StatusForbidden, "Forbidden", "Only superusers can remove superusers") + } + + // Remove the user + err = h.iam.DeleteIdentity(name) + if err != nil { + return api.Err(http.StatusBadRequest, "Bad request", "%s", err) + } + + err = h.iam.SaveIdentities() + if err != nil { + return api.Err(http.StatusInternalServerError, "Internal server error", "%s", err) + } + + // Remove all policies of that user + h.iam.RemovePolicy(name, "", "", nil) + + return c.JSON(http.StatusOK, "OK") +} + +// UpdateUser replaces an existing user +// @Summary Replace an existing user +// @Description Replace an existing user. +// @Tags v16.?.? +// @ID iam-3-update-user +// @Accept json +// @Produce json +// @Param name path string true "Username" +// @Param domain query string false "Domain of the acting user" +// @Param user body api.IAMUser true "User definition" +// @Success 200 {object} api.IAMUser +// @Failure 400 {object} api.Error +// @Failure 404 {object} api.Error +// @Failure 500 {object} api.Error +// @Security ApiKeyAuth +// @Router /api/v3/iam/user/{name} [put] +func (h *IAMHandler) UpdateUser(c echo.Context) error { + ctxuser := util.DefaultContext(c, "user", "") + superuser := util.DefaultContext(c, "superuser", false) + domain := util.DefaultQuery(c, "domain", "$none") + name := util.PathParam(c, "name") + + if !h.iam.Enforce(ctxuser, domain, "iam:"+name, "write") { + return api.Err(http.StatusForbidden, "Forbidden", "Not allowed to modify this user") + } + + var iamuser iam.User + var err error + + if name != "$anon" { + iamuser, err = h.iam.GetIdentity(name) + if err != nil { + return api.Err(http.StatusNotFound, "Not found", "%s", err) + } + } else { + iamuser = iam.User{ + Name: "$anon", + } + } + + iampolicies := h.iam.ListPolicies(name, "", "", nil) + + user := api.IAMUser{} + user.Marshal(iamuser, iampolicies) + + if err := util.ShouldBindJSON(c, &user); err != nil { + return api.Err(http.StatusBadRequest, "Invalid JSON", "%s", err) + } + + iamuser, iampolicies = user.Unmarshal() + + if !h.iam.Enforce(ctxuser, domain, "iam:"+iamuser.Name, "write") { + return api.Err(http.StatusForbidden, "Forbidden", "Not allowed to create user '%s'", iamuser.Name) + } + + for _, p := range iampolicies { + if !h.iam.Enforce(ctxuser, p.Domain, "iam:"+iamuser.Name, "write") { + return api.Err(http.StatusForbidden, "Forbidden", "Not allowed to write policy: %v", p) + } + } + + if !superuser && iamuser.Superuser { + return api.Err(http.StatusForbidden, "Forbidden", "Only superusers can modify superusers") + } + + if name != "$anon" { + err = h.iam.UpdateIdentity(name, iamuser) + if err != nil { + return api.Err(http.StatusBadRequest, "Bad request", "%s", err) + } + } + + h.iam.RemovePolicy(name, "", "", nil) + + for _, p := range iampolicies { + h.iam.AddPolicy(p.Name, p.Domain, p.Resource, p.Actions) + } + + err = h.iam.SaveIdentities() + if err != nil { + return api.Err(http.StatusInternalServerError, "Internal server error", "%s", err) + } + + return c.JSON(http.StatusOK, user) +} + +// UpdateUserPolicies replaces existing user policies +// @Summary Replace policies of an user +// @Description Replace policies of an user +// @Tags v16.?.? +// @ID iam-3-update-user +// @Accept json +// @Produce json +// @Param name path string true "Username" +// @Param domain query string false "Domain of the acting user" +// @Param user body []api.IAMPolicy true "Policy definitions" +// @Success 200 {array} api.IAMPolicy +// @Failure 400 {object} api.Error +// @Failure 404 {object} api.Error +// @Failure 500 {object} api.Error +// @Security ApiKeyAuth +// @Router /api/v3/iam/user/{name}/policy [put] +func (h *IAMHandler) UpdateUserPolicies(c echo.Context) error { + ctxuser := util.DefaultContext(c, "user", "") + superuser := util.DefaultContext(c, "superuser", false) + domain := util.DefaultQuery(c, "domain", "$none") + name := util.PathParam(c, "name") + + if !h.iam.Enforce(ctxuser, domain, "iam:"+name, "write") { + return api.Err(http.StatusForbidden, "Forbidden", "Not allowed to modify this user") + } + + var iamuser iam.User + var err error + + if name != "$anon" { + iamuser, err = h.iam.GetIdentity(name) + if err != nil { + return api.Err(http.StatusNotFound, "Not found", "%s", err) + } + } else { + iamuser = iam.User{ + Name: "$anon", + } + } + + policies := []api.IAMPolicy{} + + if err := util.ShouldBindJSON(c, &policies); err != nil { + return api.Err(http.StatusBadRequest, "Invalid JSON", "%s", err) + } + + for _, p := range policies { + if !h.iam.Enforce(ctxuser, p.Domain, "iam:"+iamuser.Name, "write") { + return api.Err(http.StatusForbidden, "Forbidden", "Not allowed to write policy: %v", p) + } + } + + if !superuser && iamuser.Superuser { + return api.Err(http.StatusForbidden, "Forbidden", "Only superusers can modify superusers") + } + + h.iam.RemovePolicy(name, "", "", nil) + + for _, p := range policies { + h.iam.AddPolicy(iamuser.Name, p.Domain, p.Resource, p.Actions) + } + + return c.JSON(http.StatusOK, policies) +} + +// GetUser returns the user with the given name +// @Summary List an user by its name +// @Description List aa user by its name +// @Tags v16.?.? +// @ID iam-3-get-user +// @Produce json +// @Param name path string true "Username" +// @Param domain query string false "Domain of the acting user" +// @Success 200 {object} api.IAMUser +// @Failure 404 {object} api.Error +// @Security ApiKeyAuth +// @Router /api/v3/iam/user/{name} [get] +func (h *IAMHandler) GetUser(c echo.Context) error { + ctxuser := util.DefaultContext(c, "user", "") + superuser := util.DefaultContext(c, "superuser", false) + domain := util.DefaultQuery(c, "domain", "$none") + name := util.PathParam(c, "name") + + if !h.iam.Enforce(ctxuser, domain, "iam:"+name, "read") { + return api.Err(http.StatusForbidden, "Forbidden", "Not allowed to access this user") + } + + var iamuser iam.User + var err error + + if name != "$anon" { + iamuser, err = h.iam.GetIdentity(name) + if err != nil { + return api.Err(http.StatusNotFound, "Not found", "%s", err) + } + + if !superuser && name != iamuser.Name { + if !h.iam.Enforce(ctxuser, domain, "iam:"+name, "write") { + iamuser = iam.User{ + Name: iamuser.Name, + } + } + } + } else { + iamuser = iam.User{ + Name: "$anon", + } + } + + iampolicies := h.iam.ListPolicies(name, "", "", nil) + + user := api.IAMUser{} + user.Marshal(iamuser, iampolicies) + + return c.JSON(http.StatusOK, user) +} diff --git a/http/handler/api/jwt.go b/http/handler/api/jwt.go new file mode 100644 index 00000000..e8b3a4a3 --- /dev/null +++ b/http/handler/api/jwt.go @@ -0,0 +1,53 @@ +package api + +import ( + "net/http" + + "github.com/datarhei/core/v16/http/api" + "github.com/datarhei/core/v16/iam" + + "github.com/labstack/echo/v4" +) + +type JWTHandler struct { + iam iam.IAM +} + +func NewJWT(iam iam.IAM) *JWTHandler { + return &JWTHandler{ + iam: iam, + } +} + +func (j *JWTHandler) Login(c echo.Context) error { + subject, ok := c.Get("user").(string) + if !ok { + return api.Err(http.StatusForbidden, "Invalid token") + } + + at, rt, err := j.iam.CreateJWT(subject) + if err != nil { + return api.Err(http.StatusInternalServerError, "Failed to create JWT", "%s", err) + } + + return c.JSON(http.StatusOK, api.JWT{ + AccessToken: at, + RefreshToken: rt, + }) +} + +func (j *JWTHandler) Refresh(c echo.Context) error { + subject, ok := c.Get("user").(string) + if !ok { + return api.Err(http.StatusForbidden, "Invalid token") + } + + at, _, err := j.iam.CreateJWT(subject) + if err != nil { + return api.Err(http.StatusInternalServerError, "Failed to create JWT", "%s", err) + } + + return c.JSON(http.StatusOK, api.JWTRefresh{ + AccessToken: at, + }) +} diff --git a/http/handler/api/playout.go b/http/handler/api/playout.go index cc073001..52c3cd9d 100644 --- a/http/handler/api/playout.go +++ b/http/handler/api/playout.go @@ -16,19 +16,7 @@ import ( "github.com/labstack/echo/v4" ) -// The PlayoutHandler type provides handlers for accessing the playout API of a process -type PlayoutHandler struct { - restream restream.Restreamer -} - -// NewPlayout returns a new Playout type. You have to provide a Restreamer instance. -func NewPlayout(restream restream.Restreamer) *PlayoutHandler { - return &PlayoutHandler{ - restream: restream, - } -} - -// Status return the current playout status +// PlayoutStatus return the current playout status // @Summary Get the current playout status // @Description Get the current playout status of an input of a process // @Tags v16.7.2 @@ -41,11 +29,22 @@ func NewPlayout(restream restream.Restreamer) *PlayoutHandler { // @Failure 500 {object} api.Error // @Security ApiKeyAuth // @Router /api/v3/process/{id}/playout/{inputid}/status [get] -func (h *PlayoutHandler) Status(c echo.Context) error { +func (h *RestreamHandler) PlayoutStatus(c echo.Context) error { id := util.PathParam(c, "id") inputid := util.PathParam(c, "inputid") + user := util.DefaultContext(c, "user", "") + domain := util.DefaultQuery(c, "domain", "") - addr, err := h.restream.GetPlayout(id, inputid) + if !h.iam.Enforce(user, domain, "process:"+id, "read") { + return api.Err(http.StatusForbidden, "Forbidden") + } + + tid := restream.TaskID{ + ID: id, + Domain: domain, + } + + addr, err := h.restream.GetPlayout(tid, inputid) if err != nil { return api.Err(http.StatusNotFound, "Unknown process or input", "%s", err) } @@ -82,7 +81,7 @@ func (h *PlayoutHandler) Status(c echo.Context) error { return c.Blob(response.StatusCode, response.Header.Get("content-type"), data) } -// Keyframe returns the last keyframe +// PlayoutKeyframe returns the last keyframe // @Summary Get the last keyframe // @Description Get the last keyframe of an input of a process. The extension of the name determines the return type. // @Tags v16.7.2 @@ -98,12 +97,23 @@ func (h *PlayoutHandler) Status(c echo.Context) error { // @Failure 500 {object} api.Error // @Security ApiKeyAuth // @Router /api/v3/process/{id}/playout/{inputid}/keyframe/{name} [get] -func (h *PlayoutHandler) Keyframe(c echo.Context) error { +func (h *RestreamHandler) PlayoutKeyframe(c echo.Context) error { id := util.PathParam(c, "id") inputid := util.PathParam(c, "inputid") name := util.PathWildcardParam(c) + user := util.DefaultContext(c, "user", "") + domain := util.DefaultQuery(c, "domain", "") - addr, err := h.restream.GetPlayout(id, inputid) + if !h.iam.Enforce(user, domain, "process:"+id, "read") { + return api.Err(http.StatusForbidden, "Forbidden") + } + + tid := restream.TaskID{ + ID: id, + Domain: domain, + } + + addr, err := h.restream.GetPlayout(tid, inputid) if err != nil { return api.Err(http.StatusNotFound, "Unknown process or input", "%s", err) } @@ -132,7 +142,7 @@ func (h *PlayoutHandler) Keyframe(c echo.Context) error { return c.Blob(response.StatusCode, response.Header.Get("content-type"), data) } -// EncodeErrorframe encodes the errorframe +// PlayoutEncodeErrorframe encodes the errorframe // @Summary Encode the errorframe // @Description Immediately encode the errorframe (if available and looping) // @Tags v16.7.2 @@ -146,11 +156,22 @@ func (h *PlayoutHandler) Keyframe(c echo.Context) error { // @Failure 500 {object} api.Error // @Security ApiKeyAuth // @Router /api/v3/process/{id}/playout/{inputid}/errorframe/encode [get] -func (h *PlayoutHandler) EncodeErrorframe(c echo.Context) error { +func (h *RestreamHandler) PlayoutEncodeErrorframe(c echo.Context) error { id := util.PathParam(c, "id") inputid := util.PathParam(c, "inputid") + user := util.DefaultContext(c, "user", "") + domain := util.DefaultQuery(c, "domain", "") - addr, err := h.restream.GetPlayout(id, inputid) + if !h.iam.Enforce(user, domain, "process:"+id, "write") { + return api.Err(http.StatusForbidden, "Forbidden") + } + + tid := restream.TaskID{ + ID: id, + Domain: domain, + } + + addr, err := h.restream.GetPlayout(tid, inputid) if err != nil { return api.Err(http.StatusNotFound, "Unknown process or input", "%s", err) } @@ -173,7 +194,7 @@ func (h *PlayoutHandler) EncodeErrorframe(c echo.Context) error { return c.Blob(response.StatusCode, response.Header.Get("content-type"), data) } -// SetErrorframe sets an errorframe +// PlayoutSetErrorframe sets an errorframe // @Summary Upload an error frame // @Description Upload an error frame which will be encoded immediately // @Tags v16.7.2 @@ -190,11 +211,22 @@ func (h *PlayoutHandler) EncodeErrorframe(c echo.Context) error { // @Failure 500 {object} api.Error // @Security ApiKeyAuth // @Router /api/v3/process/{id}/playout/{inputid}/errorframe/{name} [post] -func (h *PlayoutHandler) SetErrorframe(c echo.Context) error { +func (h *RestreamHandler) PlayoutSetErrorframe(c echo.Context) error { id := util.PathParam(c, "id") inputid := util.PathParam(c, "inputid") + user := util.DefaultContext(c, "user", "") + domain := util.DefaultQuery(c, "domain", "") - addr, err := h.restream.GetPlayout(id, inputid) + if !h.iam.Enforce(user, domain, "process:"+id, "write") { + return api.Err(http.StatusForbidden, "Forbidden") + } + + tid := restream.TaskID{ + ID: id, + Domain: domain, + } + + addr, err := h.restream.GetPlayout(tid, inputid) if err != nil { return api.Err(http.StatusNotFound, "Unknown process or input", "%s", err) } @@ -222,7 +254,7 @@ func (h *PlayoutHandler) SetErrorframe(c echo.Context) error { return c.Blob(response.StatusCode, response.Header.Get("content-type"), data) } -// ReopenInput closes the current input stream +// PlayoutReopenInput closes the current input stream // @Summary Close the current input stream // @Description Close the current input stream such that it will be automatically re-opened // @Tags v16.7.2 @@ -235,11 +267,22 @@ func (h *PlayoutHandler) SetErrorframe(c echo.Context) error { // @Failure 500 {object} api.Error // @Security ApiKeyAuth // @Router /api/v3/process/{id}/playout/{inputid}/reopen [get] -func (h *PlayoutHandler) ReopenInput(c echo.Context) error { +func (h *RestreamHandler) PlayoutReopenInput(c echo.Context) error { id := util.PathParam(c, "id") inputid := util.PathParam(c, "inputid") + user := util.DefaultContext(c, "user", "") + domain := util.DefaultQuery(c, "domain", "") - addr, err := h.restream.GetPlayout(id, inputid) + if !h.iam.Enforce(user, domain, "process:"+id, "write") { + return api.Err(http.StatusForbidden, "Forbidden") + } + + tid := restream.TaskID{ + ID: id, + Domain: domain, + } + + addr, err := h.restream.GetPlayout(tid, inputid) if err != nil { return api.Err(http.StatusNotFound, "Unknown process or input", "%s", err) } @@ -262,7 +305,7 @@ func (h *PlayoutHandler) ReopenInput(c echo.Context) error { return c.Blob(response.StatusCode, response.Header.Get("content-type"), data) } -// SetStream replaces the current stream +// PlayoutSetStream replaces the current stream // @Summary Switch to a new stream // @Description Replace the current stream with the one from the given URL. The switch will only happen if the stream parameters match. // @Tags v16.7.2 @@ -278,11 +321,22 @@ func (h *PlayoutHandler) ReopenInput(c echo.Context) error { // @Failure 500 {object} api.Error // @Security ApiKeyAuth // @Router /api/v3/process/{id}/playout/{inputid}/stream [put] -func (h *PlayoutHandler) SetStream(c echo.Context) error { +func (h *RestreamHandler) PlayoutSetStream(c echo.Context) error { id := util.PathParam(c, "id") inputid := util.PathParam(c, "inputid") + user := util.DefaultContext(c, "user", "") + domain := util.DefaultQuery(c, "domain", "") - addr, err := h.restream.GetPlayout(id, inputid) + if !h.iam.Enforce(user, domain, "process:"+id, "write") { + return api.Err(http.StatusForbidden, "Forbidden") + } + + tid := restream.TaskID{ + ID: id, + Domain: domain, + } + + addr, err := h.restream.GetPlayout(tid, inputid) if err != nil { return api.Err(http.StatusNotFound, "Unknown process or input", "%s", err) } @@ -310,7 +364,7 @@ func (h *PlayoutHandler) SetStream(c echo.Context) error { return c.Blob(response.StatusCode, response.Header.Get("content-type"), data) } -func (h *PlayoutHandler) request(method, addr, path, contentType string, data []byte) (*http.Response, error) { +func (h *RestreamHandler) request(method, addr, path, contentType string, data []byte) (*http.Response, error) { endpoint := "http://" + addr + path body := bytes.NewBuffer(data) diff --git a/http/handler/api/restream.go b/http/handler/api/restream.go index 96a4f75c..dbe9674e 100644 --- a/http/handler/api/restream.go +++ b/http/handler/api/restream.go @@ -6,6 +6,7 @@ import ( "github.com/datarhei/core/v16/http/api" "github.com/datarhei/core/v16/http/handler/util" + "github.com/datarhei/core/v16/iam" "github.com/datarhei/core/v16/restream" "github.com/labstack/echo/v4" @@ -15,12 +16,14 @@ import ( // The RestreamHandler type provides functions to interact with a Restreamer instance type RestreamHandler struct { restream restream.Restreamer + iam iam.Enforcer } // NewRestream return a new Restream type. You have to provide a valid Restreamer instance. -func NewRestream(restream restream.Restreamer) *RestreamHandler { +func NewRestream(restream restream.Restreamer, iam iam.IAM) *RestreamHandler { return &RestreamHandler{ restream: restream, + iam: iam, } } @@ -37,6 +40,8 @@ func NewRestream(restream restream.Restreamer) *RestreamHandler { // @Security ApiKeyAuth // @Router /api/v3/process [post] func (h *RestreamHandler) Add(c echo.Context) error { + user := util.DefaultContext(c, "user", "") + process := api.ProcessConfig{ ID: shortuuid.New(), Type: "ffmpeg", @@ -47,6 +52,10 @@ func (h *RestreamHandler) Add(c echo.Context) error { return api.Err(http.StatusBadRequest, "Invalid JSON", "%s", err) } + if !h.iam.Enforce(user, process.Group, "process:"+process.ID, "write") { + return api.Err(http.StatusForbidden, "Forbidden") + } + if process.Type != "ffmpeg" { return api.Err(http.StatusBadRequest, "Unsupported process type", "Supported process types are: ffmpeg") } @@ -56,12 +65,18 @@ func (h *RestreamHandler) Add(c echo.Context) error { } config := process.Marshal() + config.Owner = user if err := h.restream.AddProcess(config); err != nil { return api.Err(http.StatusBadRequest, "Invalid process config", "%s", err.Error()) } - p, _ := h.getProcess(config.ID, "config") + tid := restream.TaskID{ + ID: config.ID, + Domain: config.Domain, + } + + p, _ := h.getProcess(tid, config.Owner, "config") return c.JSON(http.StatusOK, p.Config) } @@ -88,14 +103,25 @@ func (h *RestreamHandler) GetAll(c echo.Context) error { }) idpattern := util.DefaultQuery(c, "idpattern", "") refpattern := util.DefaultQuery(c, "refpattern", "") + user := util.DefaultContext(c, "user", "") + domain := util.DefaultQuery(c, "domain", "") - ids := h.restream.GetProcessIDs(idpattern, refpattern) + preids := h.restream.GetProcessIDs(idpattern, refpattern, "", "") + ids := []restream.TaskID{} + + for _, id := range preids { + if !h.iam.Enforce(user, domain, "process:"+id.ID, "read") { + continue + } + + ids = append(ids, id) + } processes := []api.Process{} if len(wantids) == 0 || len(reference) != 0 { for _, id := range ids { - if p, err := h.getProcess(id, filter); err == nil { + if p, err := h.getProcess(id, user, filter); err == nil { if len(reference) != 0 && p.Reference != reference { continue } @@ -105,8 +131,8 @@ func (h *RestreamHandler) GetAll(c echo.Context) error { } else { for _, id := range ids { for _, wantid := range wantids { - if wantid == id { - if p, err := h.getProcess(id, filter); err == nil { + if wantid == id.ID { + if p, err := h.getProcess(id, user, filter); err == nil { processes = append(processes, p) } } @@ -132,8 +158,19 @@ func (h *RestreamHandler) GetAll(c echo.Context) error { func (h *RestreamHandler) Get(c echo.Context) error { id := util.PathParam(c, "id") filter := util.DefaultQuery(c, "filter", "") + user := util.DefaultContext(c, "user", "") + domain := util.DefaultQuery(c, "domain", "") - p, err := h.getProcess(id, filter) + if !h.iam.Enforce(user, domain, "process:"+id, "read") { + return api.Err(http.StatusForbidden, "Forbidden") + } + + tid := restream.TaskID{ + ID: id, + Domain: domain, + } + + p, err := h.getProcess(tid, user, filter) if err != nil { return api.Err(http.StatusNotFound, "Unknown process ID", "%s", err) } @@ -154,12 +191,23 @@ func (h *RestreamHandler) Get(c echo.Context) error { // @Router /api/v3/process/{id} [delete] func (h *RestreamHandler) Delete(c echo.Context) error { id := util.PathParam(c, "id") + user := util.DefaultContext(c, "user", "") + domain := util.DefaultQuery(c, "domain", "") - if err := h.restream.StopProcess(id); err != nil { + if !h.iam.Enforce(user, domain, "process:"+id, "write") { + return api.Err(http.StatusForbidden, "Forbidden") + } + + tid := restream.TaskID{ + ID: id, + Domain: domain, + } + + if err := h.restream.StopProcess(tid); err != nil { return api.Err(http.StatusNotFound, "Unknown process ID", "%s", err) } - if err := h.restream.DeleteProcess(id); err != nil { + if err := h.restream.DeleteProcess(tid); err != nil { return api.Err(http.StatusInternalServerError, "Process can't be deleted", "%s", err) } @@ -182,6 +230,8 @@ func (h *RestreamHandler) Delete(c echo.Context) error { // @Router /api/v3/process/{id} [put] func (h *RestreamHandler) Update(c echo.Context) error { id := util.PathParam(c, "id") + user := util.DefaultContext(c, "user", "") + domain := util.DefaultQuery(c, "domain", "") process := api.ProcessConfig{ ID: id, @@ -189,7 +239,16 @@ func (h *RestreamHandler) Update(c echo.Context) error { Autostart: true, } - current, err := h.restream.GetProcess(id) + if !h.iam.Enforce(user, domain, "process:"+id, "write") { + return api.Err(http.StatusForbidden, "Forbidden") + } + + tid := restream.TaskID{ + ID: id, + Domain: domain, + } + + current, err := h.restream.GetProcess(tid) if err != nil { return api.Err(http.StatusNotFound, "Process not found", "%s", id) } @@ -202,8 +261,18 @@ func (h *RestreamHandler) Update(c echo.Context) error { } config := process.Marshal() + config.Owner = user - if err := h.restream.UpdateProcess(id, config); err != nil { + if !h.iam.Enforce(user, config.Domain, "process:"+config.ID, "write") { + return api.Err(http.StatusForbidden, "Forbidden") + } + + tid = restream.TaskID{ + ID: id, + Domain: domain, + } + + if err := h.restream.UpdateProcess(tid, config); err != nil { if err == restream.ErrUnknownProcess { return api.Err(http.StatusNotFound, "Process not found", "%s", id) } @@ -211,7 +280,12 @@ func (h *RestreamHandler) Update(c echo.Context) error { return api.Err(http.StatusBadRequest, "Process can't be updated", "%s", err) } - p, _ := h.getProcess(config.ID, "config") + tid = restream.TaskID{ + ID: config.ID, + Domain: config.Domain, + } + + p, _ := h.getProcess(tid, config.Owner, "config") return c.JSON(http.StatusOK, p.Config) } @@ -232,6 +306,12 @@ func (h *RestreamHandler) Update(c echo.Context) error { // @Router /api/v3/process/{id}/command [put] func (h *RestreamHandler) Command(c echo.Context) error { id := util.PathParam(c, "id") + user := util.DefaultContext(c, "user", "") + domain := util.DefaultQuery(c, "domain", "") + + if !h.iam.Enforce(user, domain, "process:"+id, "write") { + return api.Err(http.StatusForbidden, "Forbidden") + } var command api.Command @@ -239,15 +319,20 @@ func (h *RestreamHandler) Command(c echo.Context) error { return api.Err(http.StatusBadRequest, "Invalid JSON", "%s", err) } + tid := restream.TaskID{ + ID: id, + Domain: domain, + } + var err error if command.Command == "start" { - err = h.restream.StartProcess(id) + err = h.restream.StartProcess(tid) } else if command.Command == "stop" { - err = h.restream.StopProcess(id) + err = h.restream.StopProcess(tid) } else if command.Command == "restart" { - err = h.restream.RestartProcess(id) + err = h.restream.RestartProcess(tid) } else if command.Command == "reload" { - err = h.restream.ReloadProcess(id) + err = h.restream.ReloadProcess(tid) } else { return api.Err(http.StatusBadRequest, "Unknown command provided", "Known commands are: start, stop, reload, restart") } @@ -273,8 +358,19 @@ func (h *RestreamHandler) Command(c echo.Context) error { // @Router /api/v3/process/{id}/config [get] func (h *RestreamHandler) GetConfig(c echo.Context) error { id := util.PathParam(c, "id") + user := util.DefaultContext(c, "user", "") + domain := util.DefaultQuery(c, "domain", "") - p, err := h.restream.GetProcess(id) + if !h.iam.Enforce(user, domain, "process:"+id, "read") { + return api.Err(http.StatusForbidden, "Forbidden") + } + + tid := restream.TaskID{ + ID: id, + Domain: domain, + } + + p, err := h.restream.GetProcess(tid) if err != nil { return api.Err(http.StatusNotFound, "Unknown process ID", "%s", err) } @@ -299,8 +395,19 @@ func (h *RestreamHandler) GetConfig(c echo.Context) error { // @Router /api/v3/process/{id}/state [get] func (h *RestreamHandler) GetState(c echo.Context) error { id := util.PathParam(c, "id") + user := util.DefaultContext(c, "user", "") + domain := util.DefaultQuery(c, "domain", "") - s, err := h.restream.GetProcessState(id) + if !h.iam.Enforce(user, domain, "process:"+id, "read") { + return api.Err(http.StatusForbidden, "Forbidden") + } + + tid := restream.TaskID{ + ID: id, + Domain: domain, + } + + s, err := h.restream.GetProcessState(tid) if err != nil { return api.Err(http.StatusNotFound, "Unknown process ID", "%s", err) } @@ -325,8 +432,19 @@ func (h *RestreamHandler) GetState(c echo.Context) error { // @Router /api/v3/process/{id}/report [get] func (h *RestreamHandler) GetReport(c echo.Context) error { id := util.PathParam(c, "id") + user := util.DefaultContext(c, "user", "") + domain := util.DefaultQuery(c, "domain", "") - l, err := h.restream.GetProcessLog(id) + if !h.iam.Enforce(user, domain, "process:"+id, "read") { + return api.Err(http.StatusForbidden, "Forbidden") + } + + tid := restream.TaskID{ + ID: id, + Domain: domain, + } + + l, err := h.restream.GetProcessLog(tid) if err != nil { return api.Err(http.StatusNotFound, "Unknown process ID", "%s", err) } @@ -349,8 +467,19 @@ func (h *RestreamHandler) GetReport(c echo.Context) error { // @Router /api/v3/process/{id}/probe [get] func (h *RestreamHandler) Probe(c echo.Context) error { id := util.PathParam(c, "id") + user := util.DefaultContext(c, "user", "") + domain := util.DefaultQuery(c, "domain", "") - probe := h.restream.Probe(id) + if !h.iam.Enforce(user, domain, "process:"+id, "write") { + return api.Err(http.StatusForbidden, "Forbidden") + } + + tid := restream.TaskID{ + ID: id, + Domain: domain, + } + + probe := h.restream.Probe(tid) apiprobe := api.Probe{} apiprobe.Unmarshal(&probe) @@ -411,8 +540,19 @@ func (h *RestreamHandler) ReloadSkills(c echo.Context) error { func (h *RestreamHandler) GetProcessMetadata(c echo.Context) error { id := util.PathParam(c, "id") key := util.PathParam(c, "key") + user := util.DefaultContext(c, "user", "") + domain := util.DefaultQuery(c, "domain", "") - data, err := h.restream.GetProcessMetadata(id, key) + if !h.iam.Enforce(user, domain, "process:"+id, "read") { + return api.Err(http.StatusForbidden, "Forbidden") + } + + tid := restream.TaskID{ + ID: id, + Domain: domain, + } + + data, err := h.restream.GetProcessMetadata(tid, key) if err != nil { return api.Err(http.StatusNotFound, "Unknown process ID", "%s", err) } @@ -437,6 +577,12 @@ func (h *RestreamHandler) GetProcessMetadata(c echo.Context) error { func (h *RestreamHandler) SetProcessMetadata(c echo.Context) error { id := util.PathParam(c, "id") key := util.PathParam(c, "key") + user := util.DefaultContext(c, "user", "") + domain := util.DefaultQuery(c, "domain", "") + + if !h.iam.Enforce(user, domain, "process:"+id, "write") { + return api.Err(http.StatusForbidden, "Forbidden") + } if len(key) == 0 { return api.Err(http.StatusBadRequest, "Invalid key", "The key must not be of length 0") @@ -448,7 +594,12 @@ func (h *RestreamHandler) SetProcessMetadata(c echo.Context) error { return api.Err(http.StatusBadRequest, "Invalid JSON", "%s", err) } - if err := h.restream.SetProcessMetadata(id, key, data); err != nil { + tid := restream.TaskID{ + ID: id, + Domain: domain, + } + + if err := h.restream.SetProcessMetadata(tid, key, data); err != nil { return api.Err(http.StatusNotFound, "Unknown process ID", "%s", err) } @@ -510,7 +661,7 @@ func (h *RestreamHandler) SetMetadata(c echo.Context) error { return c.JSON(http.StatusOK, data) } -func (h *RestreamHandler) getProcess(id, filterString string) (api.Process, error) { +func (h *RestreamHandler) getProcess(id restream.TaskID, user, filterString string) (api.Process, error) { filter := strings.FieldsFunc(filterString, func(r rune) bool { return r == rune(',') }) diff --git a/http/handler/api/restream_test.go b/http/handler/api/restream_test.go index 516db9ce..6cebaff0 100644 --- a/http/handler/api/restream_test.go +++ b/http/handler/api/restream_test.go @@ -3,14 +3,17 @@ package api import ( "bytes" "encoding/json" + "fmt" "net/http" "testing" "github.com/datarhei/core/v16/http/api" "github.com/datarhei/core/v16/http/mock" - "github.com/stretchr/testify/require" + "github.com/datarhei/core/v16/iam" + "github.com/datarhei/core/v16/io/fs" "github.com/labstack/echo/v4" + "github.com/stretchr/testify/require" ) type Response struct { @@ -25,7 +28,29 @@ func getDummyRestreamHandler() (*RestreamHandler, error) { return nil, err } - handler := NewRestream(rs) + memfs, err := fs.NewMemFilesystem(fs.MemConfig{}) + if err != nil { + return nil, fmt.Errorf("failed to create memory filesystem: %w", err) + } + + iam, err := iam.NewIAM(iam.Config{ + FS: memfs, + Superuser: iam.User{ + Name: "foobar", + }, + JWTRealm: "", + JWTSecret: "", + Logger: nil, + }) + if err != nil { + return nil, err + } + + iam.AddPolicy("$anon", "$none", "api:/**", []string{"ANY"}) + iam.AddPolicy("$anon", "$none", "fs:/**", []string{"ANY"}) + iam.AddPolicy("$anon", "$none", "process:**", []string{"ANY"}) + + handler := NewRestream(rs, iam) return handler, nil } diff --git a/http/handler/api/widget.go b/http/handler/api/widget.go index bb4688f2..af937135 100644 --- a/http/handler/api/widget.go +++ b/http/handler/api/widget.go @@ -43,17 +43,23 @@ func NewWidget(config WidgetConfig) *WidgetHandler { // @Router /api/v3/widget/process/{id} [get] func (w *WidgetHandler) Get(c echo.Context) error { id := util.PathParam(c, "id") + domain := util.DefaultQuery(c, "domain", "$none") if w.restream == nil { return api.Err(http.StatusNotFound, "Unknown process ID") } - process, err := w.restream.GetProcess(id) + tid := restream.TaskID{ + ID: id, + Domain: domain, + } + + process, err := w.restream.GetProcess(tid) if err != nil { return api.Err(http.StatusNotFound, "Unknown process ID", "%s", err) } - state, err := w.restream.GetProcessState(id) + state, err := w.restream.GetProcessState(tid) if err != nil { return api.Err(http.StatusNotFound, "Unknown process ID", "%s", err) } diff --git a/http/handler/util/util.go b/http/handler/util/util.go index fbbcd2fd..d9b99b4b 100644 --- a/http/handler/util/util.go +++ b/http/handler/util/util.go @@ -76,3 +76,12 @@ func DefaultQuery(c echo.Context, name, defValue string) string { return param } + +func DefaultContext[T any](c echo.Context, name string, defValue T) T { + value, ok := c.Get(name).(T) + if !ok { + return defValue + } + + return value +} diff --git a/http/jwt/jwt.go b/http/jwt/jwt.go deleted file mode 100644 index cbfad3a7..00000000 --- a/http/jwt/jwt.go +++ /dev/null @@ -1,333 +0,0 @@ -package jwt - -import ( - "errors" - "fmt" - "net/http" - "sync" - "time" - - "github.com/datarhei/core/v16/app" - "github.com/datarhei/core/v16/http/api" - - jwtgo "github.com/golang-jwt/jwt/v4" - "github.com/google/uuid" - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" -) - -// The Config type holds information that is required to create a new JWT provider -type Config struct { - Realm string - Secret string - SkipLocalhost bool -} - -// JWT provides access to a JWT provider -type JWT interface { - AddValidator(iss string, issuer Validator) error - - ClearValidators() - - Validators() []string - - // Middleware returns an echo middleware - AccessMiddleware() echo.MiddlewareFunc - RefreshMiddleware() echo.MiddlewareFunc - - // LoginHandler is an echo route handler for retrieving a JWT - LoginHandler(c echo.Context) error - - // RefreshHandle is an echo route handler for refreshing a JWT - RefreshHandler(c echo.Context) error -} - -type jwt struct { - realm string - skipLocalhost bool - secret []byte - accessValidFor time.Duration - accessConfig middleware.JWTConfig - accessMiddleware echo.MiddlewareFunc - refreshValidFor time.Duration - refreshConfig middleware.JWTConfig - refreshMiddleware echo.MiddlewareFunc - // Validators is a map of all recognized issuers to their specific validators. The key is the value of - // the "iss" field in the claims. Somewhat required because otherwise the token cannot be verified. - validators map[string]Validator - lock sync.RWMutex -} - -// New returns a new JWT provider -func New(config Config) (JWT, error) { - j := &jwt{ - realm: config.Realm, - skipLocalhost: config.SkipLocalhost, - secret: []byte(config.Secret), - accessValidFor: time.Minute * 10, - refreshValidFor: time.Hour * 24, - } - - if len(j.secret) == 0 { - return nil, fmt.Errorf("the JWT secret must not be empty") - } - - skipperFunc := func(c echo.Context) bool { - if j.skipLocalhost { - ip := c.RealIP() - - if ip == "127.0.0.1" || ip == "::1" { - return true - } - } - - return false - } - - j.accessConfig = middleware.JWTConfig{ - Skipper: skipperFunc, - SigningMethod: middleware.AlgorithmHS256, - ContextKey: "user", - TokenLookup: "header:" + echo.HeaderAuthorization, - AuthScheme: "Bearer", - Claims: jwtgo.MapClaims{}, - ErrorHandlerWithContext: j.ErrorHandler, - ParseTokenFunc: j.parseToken("access"), - } - - j.refreshConfig = middleware.JWTConfig{ - Skipper: skipperFunc, - SigningMethod: middleware.AlgorithmHS256, - ContextKey: "user", - TokenLookup: "header:" + echo.HeaderAuthorization, - AuthScheme: "Bearer", - Claims: jwtgo.MapClaims{}, - ErrorHandlerWithContext: j.ErrorHandler, - ParseTokenFunc: j.parseToken("refresh"), - } - - return j, nil -} - -func (j *jwt) parseToken(use string) func(auth string, c echo.Context) (interface{}, error) { - keyFunc := func(*jwtgo.Token) (interface{}, error) { return j.secret, nil } - - return func(auth string, c echo.Context) (interface{}, error) { - var token *jwtgo.Token - var err error - - token, err = jwtgo.Parse(auth, keyFunc) - if err != nil { - return nil, err - } - - if !token.Valid { - return nil, errors.New("invalid token") - } - - if _, ok := token.Claims.(jwtgo.MapClaims)["usefor"]; !ok { - return nil, fmt.Errorf("usefor claim is required") - } - - claimuse := token.Claims.(jwtgo.MapClaims)["usefor"].(string) - - if claimuse != use { - return nil, fmt.Errorf("invalid token claim") - } - - return token, nil - } -} - -func (j *jwt) Validators() []string { - j.lock.RLock() - defer j.lock.RUnlock() - - values := []string{} - - for _, v := range j.validators { - values = append(values, v.String()) - } - - return values -} - -func (j *jwt) AddValidator(iss string, issuer Validator) error { - j.lock.Lock() - defer j.lock.Unlock() - - if j.validators == nil { - j.validators = make(map[string]Validator) - } - - if _, ok := j.validators[iss]; ok { - return fmt.Errorf("a validator for %s is already registered", iss) - } - - j.validators[iss] = issuer - - return nil -} - -func (j *jwt) ClearValidators() { - j.lock.Lock() - defer j.lock.Unlock() - - if j.validators == nil { - return - } - - for _, v := range j.validators { - v.Cancel() - } - - j.validators = nil -} - -func (j *jwt) ErrorHandler(err error, c echo.Context) error { - if c.Request().URL.Path == "/api" { - return c.JSON(http.StatusOK, api.MinimalAbout{ - App: app.Name, - Auths: j.Validators(), - Version: api.VersionMinimal{ - Number: app.Version.MajorString(), - }, - }) - } - - return api.Err(http.StatusUnauthorized, "Missing or invalid JWT token") -} - -func (j *jwt) AccessMiddleware() echo.MiddlewareFunc { - if j.accessMiddleware == nil { - j.accessMiddleware = middleware.JWTWithConfig(j.accessConfig) - } - - return j.accessMiddleware -} - -func (j *jwt) RefreshMiddleware() echo.MiddlewareFunc { - if j.refreshMiddleware == nil { - j.refreshMiddleware = middleware.JWTWithConfig(j.refreshConfig) - } - - return j.refreshMiddleware -} - -// LoginHandler returns an access token and a refresh token -// @Summary Retrieve an access and a refresh token -// @Description Retrieve valid JWT access and refresh tokens to use for accessing the API. Login either by username/password or Auth0 token -// @ID jwt-login -// @Produce json -// @Param data body api.Login true "Login data" -// @Success 200 {object} api.JWT -// @Failure 400 {object} api.Error -// @Failure 403 {object} api.Error -// @Failure 500 {object} api.Error -// @Security Auth0KeyAuth -// @Router /api/login [post] -func (j *jwt) LoginHandler(c echo.Context) error { - var ok bool - var subject string - var err error - - j.lock.RLock() - for _, validator := range j.validators { - ok, subject, err = validator.Validate(c) - if ok { - break - } - } - j.lock.RUnlock() - - if ok { - if err != nil { - time.Sleep(5 * time.Second) - return api.Err(http.StatusUnauthorized, "Invalid authorization credentials", "%s", err) - } - } else { - time.Sleep(5 * time.Second) - return api.Err(http.StatusBadRequest, "Missing authorization credentials") - } - - at, rt, err := j.createToken(subject) - if err != nil { - return api.Err(http.StatusInternalServerError, "Failed to create JWT", "%s", err) - } - - return c.JSON(http.StatusOK, api.JWT{ - AccessToken: at, - RefreshToken: rt, - }) -} - -// RefreshHandler returns a new refresh token -// @Summary Retrieve a new access token -// @Description Retrieve a new access token by providing the refresh token -// @ID jwt-refresh -// @Produce json -// @Success 200 {object} api.JWTRefresh -// @Failure 500 {object} api.Error -// @Security ApiRefreshKeyAuth -// @Router /api/login/refresh [get] -func (j *jwt) RefreshHandler(c echo.Context) error { - token, ok := c.Get("user").(*jwtgo.Token) - if !ok { - return api.Err(http.StatusForbidden, "Invalid token") - } - - subject := token.Claims.(jwtgo.MapClaims)["sub"].(string) - - at, _, err := j.createToken(subject) - if err != nil { - return api.Err(http.StatusInternalServerError, "Failed to create JWT", "%s", err) - } - - return c.JSON(http.StatusOK, api.JWTRefresh{ - AccessToken: at, - }) -} - -// Already assigned claims: https://www.iana.org/assignments/jwt/jwt.xhtml - -func (j *jwt) createToken(username string) (string, string, error) { - now := time.Now() - accessExpires := now.Add(j.accessValidFor) - refreshExpires := now.Add(j.refreshValidFor) - - // Create access token - accessToken := jwtgo.NewWithClaims(jwtgo.SigningMethodHS256, jwtgo.MapClaims{ - "iss": j.realm, - "sub": username, - "usefor": "access", - "iat": now.Unix(), - "exp": accessExpires.Unix(), - "exi": uint64(accessExpires.Sub(now).Seconds()), - "jti": uuid.New().String(), - }) - - // Generate encoded access token - at, err := accessToken.SignedString(j.secret) - if err != nil { - return "", "", err - } - - // Create refresh token - refreshToken := jwtgo.NewWithClaims(jwtgo.SigningMethodHS256, jwtgo.MapClaims{ - "iss": j.realm, - "sub": username, - "usefor": "refresh", - "iat": now.Unix(), - "exp": refreshExpires.Unix(), - "exi": uint64(refreshExpires.Sub(now).Seconds()), - "jti": uuid.New().String(), - }) - - // Generate encoded refresh token - rt, err := refreshToken.SignedString(j.secret) - if err != nil { - return "", "", err - } - - return at, rt, nil -} diff --git a/http/jwt/validator.go b/http/jwt/validator.go deleted file mode 100644 index 309e1d7c..00000000 --- a/http/jwt/validator.go +++ /dev/null @@ -1,214 +0,0 @@ -package jwt - -import ( - "fmt" - "strings" - - "github.com/datarhei/core/v16/http/api" - "github.com/datarhei/core/v16/http/handler/util" - "github.com/datarhei/core/v16/http/jwt/jwks" - - jwtgo "github.com/golang-jwt/jwt/v4" - "github.com/labstack/echo/v4" -) - -type Validator interface { - String() string - - // Validate returns true if it identified itself as validator for - // that request. False if it doesn't handle this request. The string - // is the username. An error is only returned if it identified itself - // as validator but there was an error during validation. - Validate(c echo.Context) (bool, string, error) - Cancel() -} - -type localValidator struct { - username string - password string -} - -func NewLocalValidator(username, password string) (Validator, error) { - v := &localValidator{ - username: username, - password: password, - } - - return v, nil -} - -func (v *localValidator) String() string { - return "localjwt" -} - -func (v *localValidator) Validate(c echo.Context) (bool, string, error) { - var login api.Login - - if err := util.ShouldBindJSON(c, &login); err != nil { - return false, "", nil - } - - if login.Username != v.username || login.Password != v.password { - return true, "", fmt.Errorf("invalid username or password") - } - - return true, v.username, nil -} - -func (v *localValidator) Cancel() {} - -type auth0Validator struct { - domain string - issuer string - audience string - clientID string - users []string - certs jwks.JWKS -} - -func NewAuth0Validator(domain, audience, clientID string, users []string) (Validator, error) { - v := &auth0Validator{ - domain: domain, - issuer: "https://" + domain + "/", - audience: audience, - clientID: clientID, - users: users, - } - - url := v.issuer + ".well-known/jwks.json" - certs, err := jwks.NewFromURL(url, jwks.Config{}) - if err != nil { - return nil, err - } - - v.certs = certs - - return v, nil -} - -func (v auth0Validator) String() string { - return fmt.Sprintf("auth0 domain=%s audience=%s clientid=%s", v.domain, v.audience, v.clientID) -} - -func (v *auth0Validator) Validate(c echo.Context) (bool, string, error) { - // Look for an Auth header - values := c.Request().Header.Values("Authorization") - prefix := "Bearer " - - auth := "" - for _, value := range values { - if !strings.HasPrefix(value, prefix) { - continue - } - - auth = value[len(prefix):] - - break - } - - if len(auth) == 0 { - return false, "", nil - } - - p := &jwtgo.Parser{} - token, _, err := p.ParseUnverified(auth, jwtgo.MapClaims{}) - if err != nil { - return false, "", nil - } - - var issuer string - if claims, ok := token.Claims.(jwtgo.MapClaims); ok { - if iss, ok := claims["iss"]; ok { - issuer = iss.(string) - } - } - - if issuer != v.issuer { - return false, "", nil - } - - token, err = jwtgo.Parse(auth, v.keyFunc) - if err != nil { - return true, "", err - } - - if !token.Valid { - return true, "", fmt.Errorf("invalid token") - } - - var subject string - if claims, ok := token.Claims.(jwtgo.MapClaims); ok { - if sub, ok := claims["sub"]; ok { - subject = sub.(string) - } - } - - return true, subject, nil -} - -func (v *auth0Validator) keyFunc(token *jwtgo.Token) (interface{}, error) { - // Verify 'aud' claim - checkAud := token.Claims.(jwtgo.MapClaims).VerifyAudience(v.audience, false) - if !checkAud { - return nil, fmt.Errorf("invalid audience") - } - - // Verify 'iss' claim - checkIss := token.Claims.(jwtgo.MapClaims).VerifyIssuer(v.issuer, false) - if !checkIss { - return nil, fmt.Errorf("invalid issuer") - } - - // Verify 'sub' claim - if _, ok := token.Claims.(jwtgo.MapClaims)["sub"]; !ok { - return nil, fmt.Errorf("sub claim is required") - } - - sub := token.Claims.(jwtgo.MapClaims)["sub"].(string) - found := false - for _, u := range v.users { - if sub == u { - found = true - break - } - } - - if !found { - return nil, fmt.Errorf("user not allowed") - } - - // find the key - if _, ok := token.Header["kid"]; !ok { - return nil, fmt.Errorf("kid not found") - } - - kid := token.Header["kid"].(string) - - key, err := v.certs.Key(kid) - if err != nil { - return nil, fmt.Errorf("no cert for kid found: %w", err) - } - - // find algorithm - if _, ok := token.Header["alg"]; !ok { - return nil, fmt.Errorf("kid not found") - } - - alg := token.Header["alg"].(string) - - if key.Alg() != alg { - return nil, fmt.Errorf("signing method doesn't match") - } - - // get the public key - publicKey, err := key.PublicKey() - if err != nil { - return nil, fmt.Errorf("invalid public key: %w", err) - } - - return publicKey, nil -} - -func (v *auth0Validator) Cancel() { - v.certs.Cancel() -} diff --git a/http/middleware/iam/iam.go b/http/middleware/iam/iam.go new file mode 100644 index 00000000..8a1b8d75 --- /dev/null +++ b/http/middleware/iam/iam.go @@ -0,0 +1,466 @@ +// Package iam implements an identity and access management middleware +// +// Four information are required in order to decide to grant access. +// - identity +// - domain +// - resource +// - action +// +// The identity of the requester can be obtained by different means: +// - JWT +// - Username and password in the body as JSON +// - Auth0 access token +// - Basic auth +// +// The path prefix /api/login is treated specially in order to accommodate +// different ways of identification (UserPass, Auth0). All other /api paths +// only allow JWT as authentication method. +// +// If the identity can't be detected, the identity of "$anon" is given, representing +// an anonmyous user. If the Skipper function returns true for the request and the +// API is accessed, the username will be the one of the IAM superuser. +// +// The domain is provided as query parameter "group" for all API requests or the +// first path element after a mountpoint for filesystem requests. +// +// If the domain can't be detected, the domain "$none" will be used. +// +// The resource is the path of the request. For API requests it's prepended with +// the "api:" prefix. For all other requests it's prepended with the "fs:" prefix. +// +// The action is the requests HTTP method (e.g. GET, POST, ...). +package iam + +import ( + "encoding/base64" + "errors" + "fmt" + "net/http" + "path/filepath" + "sort" + "strings" + "time" + + "github.com/datarhei/core/v16/http/api" + "github.com/datarhei/core/v16/http/handler/util" + "github.com/datarhei/core/v16/iam" + "github.com/datarhei/core/v16/log" + + jwtgo "github.com/golang-jwt/jwt/v4" + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" +) + +type Config struct { + // Skipper defines a function to skip middleware. + Skipper middleware.Skipper + Mounts []string + IAM iam.IAM + WaitAfterFailedLogin bool + Logger log.Logger +} + +var DefaultConfig = Config{ + Skipper: middleware.DefaultSkipper, + Mounts: []string{}, + IAM: nil, + WaitAfterFailedLogin: false, + Logger: nil, +} + +var realm = "datarhei-core" + +type iammiddleware struct { + iam iam.IAM + mounts []string + logger log.Logger +} + +func New() echo.MiddlewareFunc { + return NewWithConfig(DefaultConfig) +} + +func NewWithConfig(config Config) echo.MiddlewareFunc { + if config.Skipper == nil { + config.Skipper = DefaultConfig.Skipper + } + + if len(config.Mounts) == 0 { + config.Mounts = append(config.Mounts, "/") + } + + if config.Logger == nil { + config.Logger = log.New("") + } + + mw := iammiddleware{ + iam: config.IAM, + mounts: config.Mounts, + logger: config.Logger, + } + + // Sort the mounts from longest to shortest + sort.Slice(mw.mounts, func(i, j int) bool { + return len(mw.mounts[i]) > len(mw.mounts[j]) + }) + + mw.logger.Debug().WithField("mounts", mw.mounts).Log("") + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if config.IAM == nil { + return api.Err(http.StatusForbidden, "Forbidden", "IAM is not provided") + } + + isAPISuperuser := false + if config.Skipper(c) { + isAPISuperuser = true + } + + var identity iam.IdentityVerifier = nil + var err error + + username := "$anon" + resource := c.Request().URL.Path + var domain string + + if resource == "/api" || strings.HasPrefix(resource, "/api/") { + if resource == "/api/login" { + identity, err = mw.findIdentityFromUserpass(c) + if err != nil { + if config.WaitAfterFailedLogin { + time.Sleep(5 * time.Second) + } + return api.Err(http.StatusForbidden, "Forbidden", "%s", err) + } + + if identity == nil { + identity, err = mw.findIdentityFromAuth0(c) + if err != nil { + if config.WaitAfterFailedLogin { + time.Sleep(5 * time.Second) + } + return api.Err(http.StatusForbidden, "Forbidden", "%s", err) + } + } + } else { + identity, err = mw.findIdentityFromJWT(c) + if err != nil { + return api.Err(http.StatusForbidden, "Forbidden", "%s", err) + } + + if identity != nil { + if resource == "/api/login/refresh" { + usefor, _ := c.Get("usefor").(string) + if usefor != "refresh" { + if config.WaitAfterFailedLogin { + time.Sleep(5 * time.Second) + } + return api.Err(http.StatusForbidden, "Forbidden", "invalid token") + } + } else { + usefor, _ := c.Get("usefor").(string) + if usefor != "access" { + if config.WaitAfterFailedLogin { + time.Sleep(5 * time.Second) + } + return api.Err(http.StatusForbidden, "Forbidden", "invalid token") + } + } + } + } + + if isAPISuperuser { + username = config.IAM.GetDefaultVerifier().Name() + } + + domain = c.QueryParam("group") + resource = "api:" + resource + } else { + identity, err = mw.findIdentityFromBasicAuth(c) + if err != nil { + if err == ErrAuthRequired { + c.Response().Header().Set(echo.HeaderWWWAuthenticate, "Basic realm="+realm) + return api.Err(http.StatusUnauthorized, "Unauthorized", "%s", err) + } else { + if config.WaitAfterFailedLogin { + time.Sleep(5 * time.Second) + } + + if err == ErrBadRequest { + return api.Err(http.StatusBadRequest, "Bad request", "%s", err) + } else if err == ErrUnauthorized { + c.Response().Header().Set(echo.HeaderWWWAuthenticate, "Basic realm="+realm) + return api.Err(http.StatusUnauthorized, "Unauthorized", "%s", err) + } else { + return api.Err(http.StatusForbidden, "Forbidden", "%s", err) + } + } + } + + domain = mw.findDomainFromFilesystem(resource) + resource = "fs:" + resource + } + + superuser := false + + if identity != nil { + username = identity.Name() + superuser = identity.IsSuperuser() + } + + c.Set("user", username) + c.Set("superuser", superuser) + + if len(domain) == 0 { + domain = "$none" + } + + action := c.Request().Method + + if !config.IAM.Enforce(username, domain, resource, action) { + return api.Err(http.StatusForbidden, "Forbidden", "access denied") + } + + return next(c) + } + } +} + +var ErrAuthRequired = errors.New("unauthorized") +var ErrUnauthorized = errors.New("unauthorized") +var ErrBadRequest = errors.New("bad request") + +func (m *iammiddleware) findIdentityFromBasicAuth(c echo.Context) (iam.IdentityVerifier, error) { + basic := "basic" + auth := c.Request().Header.Get(echo.HeaderAuthorization) + l := len(basic) + + if len(auth) == 0 { + path := c.Request().URL.Path + domain := m.findDomainFromFilesystem(path) + if len(domain) == 0 { + domain = "$none" + } + + if !m.iam.Enforce("$anon", domain, "fs:"+path, c.Request().Method) { + return nil, ErrAuthRequired + } + + return nil, nil + } + + var username string + var password string + + if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) { + // Invalid base64 shouldn't be treated as error + // instead should be treated as invalid client input + b, err := base64.StdEncoding.DecodeString(auth[l+1:]) + if err != nil { + return nil, ErrBadRequest + } + + cred := string(b) + for i := 0; i < len(cred); i++ { + if cred[i] == ':' { + username, password = cred[:i], cred[i+1:] + break + } + } + } + + identity, err := m.iam.GetVerifier(username) + if err != nil { + m.logger.Debug().WithFields(log.Fields{ + "path": c.Request().URL.Path, + "method": c.Request().Method, + }).WithError(err).Log("identity not found") + return nil, ErrUnauthorized + } + + if ok, err := identity.VerifyServiceBasicAuth(password); !ok { + m.logger.Debug().WithFields(log.Fields{ + "path": c.Request().URL.Path, + "method": c.Request().Method, + }).WithError(err).Log("wrong password") + return nil, ErrUnauthorized + } + + return identity, nil +} + +func (m *iammiddleware) findIdentityFromJWT(c echo.Context) (iam.IdentityVerifier, error) { + // Look for an Auth header + values := c.Request().Header.Values("Authorization") + prefix := "Bearer " + + auth := "" + for _, value := range values { + if !strings.HasPrefix(value, prefix) { + continue + } + + auth = value[len(prefix):] + + break + } + + if len(auth) == 0 { + return nil, nil + } + + p := &jwtgo.Parser{} + token, _, err := p.ParseUnverified(auth, jwtgo.MapClaims{}) + if err != nil { + m.logger.Debug().WithFields(log.Fields{ + "path": c.Request().URL.Path, + "method": c.Request().Method, + }).WithError(err).Log("identity not found") + return nil, err + } + + var subject string + if claims, ok := token.Claims.(jwtgo.MapClaims); ok { + if sub, ok := claims["sub"]; ok { + subject = sub.(string) + } + } + + var usefor string + if claims, ok := token.Claims.(jwtgo.MapClaims); ok { + if sub, ok := claims["usefor"]; ok { + usefor = sub.(string) + } + } + + identity, err := m.iam.GetVerifier(subject) + if err != nil { + m.logger.Debug().WithFields(log.Fields{ + "path": c.Request().URL.Path, + "method": c.Request().Method, + }).WithError(err).Log("identity not found") + return nil, fmt.Errorf("invalid token") + } + + if ok, err := identity.VerifyJWT(auth); !ok { + m.logger.Debug().WithFields(log.Fields{ + "path": c.Request().URL.Path, + "method": c.Request().Method, + }).WithError(err).Log("identity not found") + return nil, fmt.Errorf("invalid token") + } + + c.Set("usefor", usefor) + + return identity, nil +} + +func (m *iammiddleware) findIdentityFromUserpass(c echo.Context) (iam.IdentityVerifier, error) { + var login api.Login + + if err := util.ShouldBindJSON(c, &login); err != nil { + return nil, nil + } + + identity, err := m.iam.GetVerifier(login.Username) + if err != nil { + m.logger.Debug().WithFields(log.Fields{ + "path": c.Request().URL.Path, + "method": c.Request().Method, + }).WithError(err).Log("identity not found") + return nil, fmt.Errorf("invalid username or password") + } + + if ok, err := identity.VerifyAPIPassword(login.Password); !ok { + m.logger.Debug().WithFields(log.Fields{ + "path": c.Request().URL.Path, + "method": c.Request().Method, + }).WithError(err).Log("identity not found") + return nil, fmt.Errorf("invalid username or password") + } + + return identity, nil +} + +func (m *iammiddleware) findIdentityFromAuth0(c echo.Context) (iam.IdentityVerifier, error) { + // Look for an Auth header + values := c.Request().Header.Values("Authorization") + prefix := "Bearer " + + auth := "" + for _, value := range values { + if !strings.HasPrefix(value, prefix) { + continue + } + + auth = value[len(prefix):] + + break + } + + if len(auth) == 0 { + return nil, nil + } + + p := &jwtgo.Parser{} + token, _, err := p.ParseUnverified(auth, jwtgo.MapClaims{}) + if err != nil { + m.logger.Debug().WithFields(log.Fields{ + "path": c.Request().URL.Path, + "method": c.Request().Method, + }).WithError(err).Log("identity not found") + return nil, nil + } + + var subject string + if claims, ok := token.Claims.(jwtgo.MapClaims); ok { + if sub, ok := claims["sub"]; ok { + subject = sub.(string) + } + } + + identity, err := m.iam.GetVerfierFromAuth0(subject) + if err != nil { + m.logger.Debug().WithFields(log.Fields{ + "path": c.Request().URL.Path, + "method": c.Request().Method, + }).WithError(err).Log("identity not found") + return nil, fmt.Errorf("invalid token") + } + + if ok, err := identity.VerifyAPIAuth0(auth); !ok { + m.logger.Debug().WithFields(log.Fields{ + "path": c.Request().URL.Path, + "method": c.Request().Method, + }).WithError(err).Log("identity not found") + return nil, fmt.Errorf("invalid token") + } + + return identity, nil +} + +func (m *iammiddleware) findDomainFromFilesystem(path string) string { + path = filepath.Clean(path) + + // Longest prefix search. The slice is assumed to be sorted accordingly. + // Assume path is /memfs/foobar/file.txt + // The longest prefix that matches is /memfs/ + // Remove it from the path and split it into components: foobar file.txt + // Check if foobar a known domain. If yes, return it. If not, return empty domain. + for _, mount := range m.mounts { + prefix := filepath.Clean(mount) + if prefix != "/" { + prefix += "/" + } + + if strings.HasPrefix(path, prefix) { + elements := strings.Split(strings.TrimPrefix(path, prefix), "/") + if m.iam.HasDomain(elements[0]) { + return elements[0] + } + } + } + + return "" +} diff --git a/http/middleware/iam/iam_test.go b/http/middleware/iam/iam_test.go new file mode 100644 index 00000000..096c3803 --- /dev/null +++ b/http/middleware/iam/iam_test.go @@ -0,0 +1,366 @@ +package iam + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/datarhei/core/v16/http/api" + apihandler "github.com/datarhei/core/v16/http/handler/api" + "github.com/datarhei/core/v16/http/validator" + "github.com/datarhei/core/v16/iam" + "github.com/datarhei/core/v16/io/fs" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/require" +) + +var basic string = "Basic" + +func getIAM() (iam.IAM, error) { + dummyfs, err := fs.NewMemFilesystem(fs.MemConfig{}) + if err != nil { + return nil, err + } + + i, err := iam.NewIAM(iam.Config{ + FS: dummyfs, + Superuser: iam.User{ + Name: "admin", + }, + JWTRealm: "datarhei-core", + JWTSecret: "1234567890", + Logger: nil, + }) + if err != nil { + return nil, err + } + + i.CreateIdentity(iam.User{ + Name: "foobar", + Auth: iam.UserAuth{ + API: iam.UserAuthAPI{ + Password: "secret", + }, + Services: iam.UserAuthServices{ + Basic: []string{"secret"}, + }, + }, + }) + + return i, nil +} + +func TestNoIAM(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + c := e.NewContext(req, res) + h := New()(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + err := h(c) + require.Error(t, err) + + he := err.(api.Error) + require.Equal(t, http.StatusForbidden, he.Code) +} + +func TestBasicAuth(t *testing.T) { + iam, err := getIAM() + require.NoError(t, err) + + iam.AddPolicy("foobar", "$none", "fs:/**", []string{"ANY"}) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + c := e.NewContext(req, res) + h := NewWithConfig(Config{ + IAM: iam, + })(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + // No credentials + err = h(c) + require.Error(t, err) + + he := err.(api.Error) + require.Equal(t, http.StatusUnauthorized, he.Code) + require.Equal(t, basic+` realm=datarhei-core`, res.Header().Get(echo.HeaderWWWAuthenticate)) + + // Valid credentials + auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("foobar:secret")) + req.Header.Set(echo.HeaderAuthorization, auth) + require.NoError(t, h(c)) + + // Case-insensitive header scheme + auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("foobar:secret")) + req.Header.Set(echo.HeaderAuthorization, auth) + require.NoError(t, h(c)) + + // Invalid credentials + auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("foobar:invalid-password")) + req.Header.Set(echo.HeaderAuthorization, auth) + he = h(c).(api.Error) + require.Equal(t, http.StatusUnauthorized, he.Code) + require.Equal(t, basic+` realm=datarhei-core`, res.Header().Get(echo.HeaderWWWAuthenticate)) + + // Invalid base64 string + auth = basic + " invalidString" + req.Header.Set(echo.HeaderAuthorization, auth) + he = h(c).(api.Error) + require.Equal(t, http.StatusBadRequest, he.Code) + + // Missing Authorization header + req.Header.Del(echo.HeaderAuthorization) + he = h(c).(api.Error) + require.Equal(t, http.StatusUnauthorized, he.Code) + + // Invalid Authorization header + auth = base64.StdEncoding.EncodeToString([]byte("invalid")) + req.Header.Set(echo.HeaderAuthorization, auth) + he = h(c).(api.Error) + require.Equal(t, http.StatusUnauthorized, he.Code) +} + +func TestFindDomainFromFilesystem(t *testing.T) { + iam, err := getIAM() + require.NoError(t, err) + + iam.AddPolicy("$anon", "$none", "fs:/**", []string{"ANY"}) + iam.AddPolicy("foobar", "group", "fs:/group/**", []string{"ANY"}) + iam.AddPolicy("foobar", "anothergroup", "fs:/memfs/anothergroup/**", []string{"ANY"}) + + mw := &iammiddleware{ + iam: iam, + mounts: []string{"/", "/memfs"}, + } + + domain := mw.findDomainFromFilesystem("/") + require.Equal(t, "", domain) + + domain = mw.findDomainFromFilesystem("/group/bla") + require.Equal(t, "group", domain) + + domain = mw.findDomainFromFilesystem("/anothergroup/bla") + require.Equal(t, "anothergroup", domain) + + domain = mw.findDomainFromFilesystem("/memfs/anothergroup/bla") + require.Equal(t, "anothergroup", domain) +} + +func TestBasicAuthDomain(t *testing.T) { + iam, err := getIAM() + require.NoError(t, err) + + iam.AddPolicy("$anon", "$none", "fs:/**", []string{"ANY"}) + iam.AddPolicy("foobar", "group", "fs:/group/**", []string{"ANY"}) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + c := e.NewContext(req, res) + h := NewWithConfig(Config{ + IAM: iam, + Mounts: []string{"/"}, + })(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + // No credentials + require.NoError(t, h(c)) + + req = httptest.NewRequest(http.MethodGet, "/group/bla", nil) + c = e.NewContext(req, res) + + err = h(c) + require.Error(t, err) + + he := err.(api.Error) + require.Equal(t, http.StatusUnauthorized, he.Code) + require.Equal(t, basic+` realm=datarhei-core`, res.Header().Get(echo.HeaderWWWAuthenticate)) + + // Valid credentials + auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("foobar:secret")) + req.Header.Set(echo.HeaderAuthorization, auth) + require.NoError(t, h(c)) + + // Allow anonymous group read access + iam.AddPolicy("$anon", "group", "fs:/group/**", []string{"GET"}) + + req.Header.Del(echo.HeaderAuthorization) + require.NoError(t, h(c)) +} + +func TestAPILoginAndRefresh(t *testing.T) { + iam, err := getIAM() + require.NoError(t, err) + + iam.AddPolicy("foobar", "$none", "api:/**", []string{"ANY"}) + + jwthandler := apihandler.NewJWT(iam) + + e := echo.New() + e.Validator = validator.New() + res := httptest.NewRecorder() + h := NewWithConfig(Config{ + IAM: iam, + Mounts: []string{"/"}, + })(func(c echo.Context) error { + if c.Request().Method == http.MethodPost { + if c.Request().URL.Path == "/api/login" { + return jwthandler.Login(c) + } + } + + if c.Request().Method == http.MethodGet { + if c.Request().URL.Path == "/api/login/refresh" { + return jwthandler.Refresh(c) + } + } + + return c.String(http.StatusOK, "test") + }) + + req := httptest.NewRequest(http.MethodPost, "/api/login", nil) + c := e.NewContext(req, res) + + // No credentials + err = h(c) + require.Error(t, err) + + he := err.(api.Error) + require.Equal(t, http.StatusForbidden, he.Code) + + // Wrong password + login := api.Login{ + Username: "foobar", + Password: "nosecret", + } + + data, err := json.Marshal(login) + require.NoError(t, err) + + req = httptest.NewRequest(http.MethodPost, "/api/login", bytes.NewReader(data)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + c = e.NewContext(req, res) + + err = h(c) + require.Error(t, err) + + he = err.(api.Error) + require.Equal(t, http.StatusForbidden, he.Code) + + // Wrong username + login = api.Login{ + Username: "foobaz", + Password: "secret", + } + + data, err = json.Marshal(login) + require.NoError(t, err) + + req = httptest.NewRequest(http.MethodPost, "/api/login", bytes.NewReader(data)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + c = e.NewContext(req, res) + + err = h(c) + require.Error(t, err) + + he = err.(api.Error) + require.Equal(t, http.StatusForbidden, he.Code) + + // Correct credentials + login = api.Login{ + Username: "foobar", + Password: "secret", + } + + data, err = json.Marshal(login) + require.NoError(t, err) + + req = httptest.NewRequest(http.MethodPost, "/api/login", bytes.NewReader(data)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + c = e.NewContext(req, res) + res.Body.Reset() + + err = h(c) + require.NoError(t, err) + + data, err = io.ReadAll(res.Body) + require.NoError(t, err) + + jwt := api.JWT{} + err = json.Unmarshal(data, &jwt) + require.NoError(t, err) + + // No JWT + req = httptest.NewRequest(http.MethodGet, "/api/some/endpoint", nil) + c = e.NewContext(req, res) + + err = h(c) + require.Error(t, err) + + he = err.(api.Error) + require.Equal(t, http.StatusForbidden, he.Code) + + // With invalid JWT + req.Header.Set(echo.HeaderAuthorization, "Bearer invalid") + err = h(c) + require.Error(t, err) + + // With refresh JWT + req.Header.Set(echo.HeaderAuthorization, "Bearer "+jwt.RefreshToken) + err = h(c) + require.Error(t, err) + + // With access JWT + req.Header.Set(echo.HeaderAuthorization, "Bearer "+jwt.AccessToken) + err = h(c) + require.NoError(t, err) + + // Refresh JWT + req = httptest.NewRequest(http.MethodGet, "/api/login/refresh", nil) + c = e.NewContext(req, res) + + err = h(c) + require.Error(t, err) + + he = err.(api.Error) + require.Equal(t, http.StatusForbidden, he.Code) + + req.Header.Set(echo.HeaderAuthorization, "Bearer "+jwt.AccessToken) + err = h(c) + require.Error(t, err) + + he = err.(api.Error) + require.Equal(t, http.StatusForbidden, he.Code) + + req.Header.Set(echo.HeaderAuthorization, "Bearer "+jwt.RefreshToken) + res.Body.Reset() + err = h(c) + require.NoError(t, err) + + data, err = io.ReadAll(res.Body) + require.NoError(t, err) + + jwtrefresh := api.JWTRefresh{} + err = json.Unmarshal(data, &jwtrefresh) + require.NoError(t, err) + + req = httptest.NewRequest(http.MethodGet, "/api/some/endpoint", nil) + c = e.NewContext(req, res) + + // With new access JWT + req.Header.Set(echo.HeaderAuthorization, "Bearer "+jwtrefresh.AccessToken) + err = h(c) + require.NoError(t, err) +} diff --git a/http/mock/mock.go b/http/mock/mock.go index 621204a7..0ea2d1cb 100644 --- a/http/mock/mock.go +++ b/http/mock/mock.go @@ -16,10 +16,11 @@ import ( "github.com/datarhei/core/v16/http/api" "github.com/datarhei/core/v16/http/errorhandler" "github.com/datarhei/core/v16/http/validator" + "github.com/datarhei/core/v16/iam" "github.com/datarhei/core/v16/internal/testhelper" "github.com/datarhei/core/v16/io/fs" "github.com/datarhei/core/v16/restream" - "github.com/datarhei/core/v16/restream/store" + jsonstore "github.com/datarhei/core/v16/restream/store/json" "github.com/invopop/jsonschema" "github.com/labstack/echo/v4" @@ -38,7 +39,7 @@ func DummyRestreamer(pathPrefix string) (restream.Restreamer, error) { return nil, fmt.Errorf("failed to create memory filesystem: %w", err) } - store, err := store.NewJSON(store.JSONConfig{ + store, err := jsonstore.New(jsonstore.Config{ Filesystem: memfs, }) if err != nil { @@ -52,9 +53,23 @@ func DummyRestreamer(pathPrefix string) (restream.Restreamer, error) { return nil, err } + iam, err := iam.NewIAM(iam.Config{ + FS: memfs, + Superuser: iam.User{ + Name: "foobar", + }, + JWTRealm: "", + JWTSecret: "", + Logger: nil, + }) + if err != nil { + return nil, err + } + rs, err := restream.New(restream.Config{ Store: store, FFmpeg: ffmpeg, + IAM: iam, }) if err != nil { return nil, err diff --git a/http/server.go b/http/server.go index ae9c5ae3..943e553b 100644 --- a/http/server.go +++ b/http/server.go @@ -41,10 +41,10 @@ import ( "github.com/datarhei/core/v16/http/graph/resolver" "github.com/datarhei/core/v16/http/handler" api "github.com/datarhei/core/v16/http/handler/api" - "github.com/datarhei/core/v16/http/jwt" httplog "github.com/datarhei/core/v16/http/log" "github.com/datarhei/core/v16/http/router" "github.com/datarhei/core/v16/http/validator" + "github.com/datarhei/core/v16/iam" "github.com/datarhei/core/v16/log" "github.com/datarhei/core/v16/monitor" "github.com/datarhei/core/v16/net" @@ -58,6 +58,7 @@ import ( mwcors "github.com/datarhei/core/v16/http/middleware/cors" mwgzip "github.com/datarhei/core/v16/http/middleware/gzip" mwhlsrewrite "github.com/datarhei/core/v16/http/middleware/hlsrewrite" + mwiam "github.com/datarhei/core/v16/http/middleware/iam" mwiplimit "github.com/datarhei/core/v16/http/middleware/iplimit" mwlog "github.com/datarhei/core/v16/http/middleware/log" mwmime "github.com/datarhei/core/v16/http/middleware/mime" @@ -88,13 +89,14 @@ type Config struct { Cors CorsConfig RTMP rtmp.Server SRT srt.Server - JWT jwt.JWT Config cfgstore.Store Cache cache.Cacher Sessions session.RegistryReader Router router.Router ReadOnly bool Cluster cluster.Cluster + IAM iam.IAM + IAMSkipper func(ip string) bool } type CorsConfig struct { @@ -114,13 +116,12 @@ type server struct { profiling *handler.ProfilingHandler ping *handler.PingHandler graph *api.GraphHandler - jwt jwt.JWT + jwt *api.JWTHandler } v3handler struct { log *api.LogHandler restream *api.RestreamHandler - playout *api.PlayoutHandler rtmp *api.RTMPHandler srt *api.SRTHandler config *api.ConfigHandler @@ -128,17 +129,17 @@ type server struct { widget *api.WidgetHandler resources *api.MetricsHandler cluster *api.ClusterHandler + iam *api.IAMHandler } middleware struct { iplimit echo.MiddlewareFunc log echo.MiddlewareFunc - accessJWT echo.MiddlewareFunc - refreshJWT echo.MiddlewareFunc cors echo.MiddlewareFunc cache echo.MiddlewareFunc session echo.MiddlewareFunc hlsrewrite echo.MiddlewareFunc + iam echo.MiddlewareFunc } gzip struct { @@ -219,21 +220,34 @@ func NewServer(config Config) (Server, error) { } if config.Logger == nil { - s.logger = log.New("HTTP") + s.logger = log.New("") } - if config.JWT == nil { - s.handler.about = api.NewAbout( - config.Restream, - []string{}, - ) - } else { - s.handler.about = api.NewAbout( - config.Restream, - config.JWT.Validators(), - ) + mounts := []string{} + + for _, fs := range s.filesystems { + mounts = append(mounts, fs.FS.Mountpoint) } + s.middleware.iam = mwiam.NewWithConfig(mwiam.Config{ + Skipper: func(c echo.Context) bool { + return config.IAMSkipper(c.RealIP()) + }, + IAM: config.IAM, + Mounts: mounts, + WaitAfterFailedLogin: true, + Logger: s.logger.WithComponent("IAM"), + }) + + s.handler.about = api.NewAbout( + config.Restream, + func() []string { return config.IAM.Validators() }, + ) + + s.handler.jwt = api.NewJWT(config.IAM) + + s.v3handler.iam = api.NewIAM(config.IAM) + s.v3handler.log = api.NewLog( config.LogBuffer, ) @@ -241,10 +255,7 @@ func NewServer(config Config) (Server, error) { if config.Restream != nil { s.v3handler.restream = api.NewRestream( config.Restream, - ) - - s.v3handler.playout = api.NewPlayout( - config.Restream, + config.IAM, ) } @@ -284,12 +295,6 @@ func NewServer(config Config) (Server, error) { ) } - if config.JWT != nil { - s.handler.jwt = config.JWT - s.middleware.accessJWT = config.JWT.AccessMiddleware() - s.middleware.refreshJWT = config.JWT.RefreshMiddleware() - } - if config.Sessions == nil { config.Sessions, _ = session.New(session.Config{}) } @@ -331,6 +336,7 @@ func NewServer(config Config) (Server, error) { Restream: config.Restream, Monitor: config.Metrics, LogBuffer: config.LogBuffer, + IAM: config.IAM, }, "/api/graph/query") s.gzip.mimetypes = []string{ @@ -367,6 +373,8 @@ func NewServer(config Config) (Server, error) { s.router.Use(s.middleware.cors) } + s.router.Use(s.middleware.iam) + // Add static routes if path, target := config.Router.StaticRoute(); len(target) != 0 { group := s.router.Group(path) @@ -429,14 +437,9 @@ func (s *server) setRoutes() { api.Use(s.middleware.iplimit) } - if s.middleware.accessJWT != nil { - // Enable JWT auth - api.Use(s.middleware.accessJWT) - - // The login endpoint should not be blocked by auth - s.router.POST("/api/login", s.handler.jwt.LoginHandler) - s.router.GET("/api/login/refresh", s.handler.jwt.RefreshHandler, s.middleware.refreshJWT) - } + // The login endpoint should not be blocked by auth + s.router.POST("/api/login", s.handler.jwt.Login) + s.router.GET("/api/login/refresh", s.handler.jwt.Refresh) api.GET("", s.handler.about.About) @@ -488,23 +491,9 @@ func (s *server) setRoutes() { fs.HEAD("", filesystem.handler.GetFile) if filesystem.AllowWrite { - if filesystem.EnableAuth { - authmw := middleware.BasicAuth(func(username, password string, c echo.Context) (bool, error) { - if username == filesystem.Username && password == filesystem.Password { - return true, nil - } - - return false, nil - }) - - fs.POST("", filesystem.handler.PutFile, authmw) - fs.PUT("", filesystem.handler.PutFile, authmw) - fs.DELETE("", filesystem.handler.DeleteFile, authmw) - } else { - fs.POST("", filesystem.handler.PutFile) - fs.PUT("", filesystem.handler.PutFile) - fs.DELETE("", filesystem.handler.DeleteFile) - } + fs.POST("", filesystem.handler.PutFile) + fs.PUT("", filesystem.handler.PutFile) + fs.DELETE("", filesystem.handler.DeleteFile) } } @@ -543,10 +532,6 @@ func (s *server) setRoutes() { // APIv3 router group v3 := api.Group("/v3") - if s.handler.jwt != nil { - v3.Use(s.middleware.accessJWT) - } - v3.Use(gzipMiddleware) s.setRoutesV3(v3) @@ -558,6 +543,15 @@ func (s *server) setRoutesV3(v3 *echo.Group) { s.router.GET("/api/v3/widget/process/:id", s.v3handler.widget.Get) } + // v3 IAM + if s.v3handler.iam != nil { + v3.POST("/iam/user", s.v3handler.iam.AddUser) + v3.GET("/iam/user/:name", s.v3handler.iam.GetUser) + v3.PUT("/iam/user/:name", s.v3handler.iam.UpdateUser) + v3.PUT("/iam/user/:name/policy", s.v3handler.iam.UpdateUserPolicies) + v3.DELETE("/iam/user/:name", s.v3handler.iam.RemoveUser) + } + // v3 Restreamer if s.v3handler.restream != nil { v3.GET("/skills", s.v3handler.restream.Skills) @@ -587,18 +581,16 @@ func (s *server) setRoutesV3(v3 *echo.Group) { } // v3 Playout - if s.v3handler.playout != nil { - v3.GET("/process/:id/playout/:inputid/status", s.v3handler.playout.Status) - v3.GET("/process/:id/playout/:inputid/reopen", s.v3handler.playout.ReopenInput) - v3.GET("/process/:id/playout/:inputid/keyframe/*", s.v3handler.playout.Keyframe) - v3.GET("/process/:id/playout/:inputid/errorframe/encode", s.v3handler.playout.EncodeErrorframe) + v3.GET("/process/:id/playout/:inputid/status", s.v3handler.restream.PlayoutStatus) + v3.GET("/process/:id/playout/:inputid/reopen", s.v3handler.restream.PlayoutReopenInput) + v3.GET("/process/:id/playout/:inputid/keyframe/*", s.v3handler.restream.PlayoutKeyframe) + v3.GET("/process/:id/playout/:inputid/errorframe/encode", s.v3handler.restream.PlayoutEncodeErrorframe) - if !s.readOnly { - v3.PUT("/process/:id/playout/:inputid/errorframe/*", s.v3handler.playout.SetErrorframe) - v3.POST("/process/:id/playout/:inputid/errorframe/*", s.v3handler.playout.SetErrorframe) + if !s.readOnly { + v3.PUT("/process/:id/playout/:inputid/errorframe/*", s.v3handler.restream.PlayoutSetErrorframe) + v3.POST("/process/:id/playout/:inputid/errorframe/*", s.v3handler.restream.PlayoutSetErrorframe) - v3.PUT("/process/:id/playout/:inputid/stream", s.v3handler.playout.SetStream) - } + v3.PUT("/process/:id/playout/:inputid/stream", s.v3handler.restream.PlayoutSetStream) } } diff --git a/iam/access.go b/iam/access.go new file mode 100644 index 00000000..9f754504 --- /dev/null +++ b/iam/access.go @@ -0,0 +1,152 @@ +package iam + +import ( + "fmt" + "strings" + + "github.com/datarhei/core/v16/io/fs" + "github.com/datarhei/core/v16/log" + + "github.com/casbin/casbin/v2" + "github.com/casbin/casbin/v2/model" +) + +type Policy struct { + Name string + Domain string + Resource string + Actions []string +} + +type AccessEnforcer interface { + Enforce(name, domain, resource, action string) (bool, string) + + HasDomain(name string) bool + ListDomains() []string +} + +type AccessManager interface { + AccessEnforcer + + HasPolicy(name, domain, resource string, actions []string) bool + AddPolicy(name, domain, resource string, actions []string) bool + RemovePolicy(name, domain, resource string, actions []string) bool + ListPolicies(name, domain, resource string, actions []string) []Policy +} + +type access struct { + fs fs.Filesystem + logger log.Logger + + adapter *adapter + enforcer *casbin.Enforcer +} + +type AccessConfig struct { + FS fs.Filesystem + Logger log.Logger +} + +func NewAccessManager(config AccessConfig) (AccessManager, error) { + am := &access{ + fs: config.FS, + logger: config.Logger, + } + + if am.fs == nil { + return nil, fmt.Errorf("a filesystem has to be provided") + } + + if am.logger == nil { + am.logger = log.New("") + } + + m := model.NewModel() + m.AddDef("r", "r", "sub, dom, obj, act") + m.AddDef("p", "p", "sub, dom, obj, act") + m.AddDef("g", "g", "_, _, _") + m.AddDef("e", "e", "some(where (p.eft == allow))") + m.AddDef("m", "m", `g(r.sub, p.sub, r.dom) && r.dom == p.dom && ResourceMatch(r.obj, r.dom, p.obj) && ActionMatch(r.act, p.act) || r.sub == "$superuser"`) + + a, err := newAdapter(am.fs, "./policy.json", am.logger) + if err != nil { + return nil, err + } + + e, err := casbin.NewEnforcer(m, a) + if err != nil { + return nil, err + } + + e.AddFunction("ResourceMatch", resourceMatchFunc) + e.AddFunction("ActionMatch", actionMatchFunc) + + am.enforcer = e + am.adapter = a + + return am, nil +} + +func (am *access) HasPolicy(name, domain, resource string, actions []string) bool { + policy := []string{name, domain, resource, strings.Join(actions, "|")} + + return am.enforcer.HasPolicy(policy) +} + +func (am *access) AddPolicy(name, domain, resource string, actions []string) bool { + policy := []string{name, domain, resource, strings.Join(actions, "|")} + + if am.enforcer.HasPolicy(policy) { + return true + } + + ok, _ := am.enforcer.AddPolicy(policy) + + return ok +} + +func (am *access) RemovePolicy(name, domain, resource string, actions []string) bool { + policies := am.enforcer.GetFilteredPolicy(0, name, domain, resource, strings.Join(actions, "|")) + am.enforcer.RemovePolicies(policies) + + return true +} + +func (am *access) ListPolicies(name, domain, resource string, actions []string) []Policy { + policies := []Policy{} + + ps := am.enforcer.GetFilteredPolicy(0, name, domain, resource, strings.Join(actions, "|")) + + for _, p := range ps { + policies = append(policies, Policy{ + Name: p[0], + Domain: p[1], + Resource: p[2], + Actions: strings.Split(p[3], "|"), + }) + } + + return policies +} + +func (am *access) HasDomain(name string) bool { + groups := am.adapter.getAllDomains() + + for _, g := range groups { + if g == name { + return true + } + } + + return false +} + +func (am *access) ListDomains() []string { + return am.adapter.getAllDomains() +} + +func (am *access) Enforce(name, domain, resource, action string) (bool, string) { + ok, rule, _ := am.enforcer.EnforceEx(name, domain, resource, action) + + return ok, strings.Join(rule, ", ") +} diff --git a/iam/access_test.go b/iam/access_test.go new file mode 100644 index 00000000..01983091 --- /dev/null +++ b/iam/access_test.go @@ -0,0 +1,85 @@ +package iam + +import ( + "testing" + + "github.com/datarhei/core/v16/io/fs" + "github.com/stretchr/testify/require" +) + +func TestAccessManager(t *testing.T) { + memfs, err := fs.NewMemFilesystemFromDir("./fixtures", fs.MemConfig{}) + require.NoError(t, err) + + am, err := NewAccessManager(AccessConfig{ + FS: memfs, + Logger: nil, + }) + require.NoError(t, err) + + policies := am.ListPolicies("", "", "", nil) + require.ElementsMatch(t, []Policy{ + { + Name: "ingo", + Domain: "$none", + Resource: "rtmp:/bla-*", + Actions: []string{"play", "publish"}, + }, + { + Name: "ingo", + Domain: "igelcamp", + Resource: "rtmp:/igelcamp/**", + Actions: []string{"publish"}, + }, + }, policies) + + am.AddPolicy("foobar", "group", "bla:/", []string{"write"}) + + policies = am.ListPolicies("", "", "", nil) + require.ElementsMatch(t, []Policy{ + { + Name: "ingo", + Domain: "$none", + Resource: "rtmp:/bla-*", + Actions: []string{"play", "publish"}, + }, + { + Name: "ingo", + Domain: "igelcamp", + Resource: "rtmp:/igelcamp/**", + Actions: []string{"publish"}, + }, + { + Name: "foobar", + Domain: "group", + Resource: "bla:/", + Actions: []string{"write"}, + }, + }, policies) + + require.True(t, am.HasDomain("igelcamp")) + require.True(t, am.HasDomain("group")) + require.False(t, am.HasDomain("$none")) + + am.RemovePolicy("ingo", "", "", nil) + + policies = am.ListPolicies("", "", "", nil) + require.ElementsMatch(t, []Policy{ + { + Name: "foobar", + Domain: "group", + Resource: "bla:/", + Actions: []string{"write"}, + }, + }, policies) + + require.False(t, am.HasDomain("igelcamp")) + require.True(t, am.HasDomain("group")) + require.False(t, am.HasDomain("$none")) + + ok, _ := am.Enforce("foobar", "group", "bla:/", "read") + require.False(t, ok) + + ok, _ = am.Enforce("foobar", "group", "bla:/", "write") + require.True(t, ok) +} diff --git a/iam/adapter.go b/iam/adapter.go new file mode 100644 index 00000000..955ed4c2 --- /dev/null +++ b/iam/adapter.go @@ -0,0 +1,556 @@ +package iam + +import ( + "encoding/json" + "fmt" + "os" + "sort" + "strings" + "sync" + + "github.com/datarhei/core/v16/io/fs" + "github.com/datarhei/core/v16/log" + + "github.com/casbin/casbin/v2/model" +) + +// Adapter is the file adapter for Casbin. +// It can load policy from file or save policy to file. +type adapter struct { + fs fs.Filesystem + filePath string + logger log.Logger + domains []Domain + lock sync.Mutex +} + +func newAdapter(fs fs.Filesystem, filePath string, logger log.Logger) (*adapter, error) { + a := &adapter{ + fs: fs, + filePath: filePath, + logger: logger, + } + + if a.fs == nil { + return nil, fmt.Errorf("a filesystem has to be provided") + } + + if len(a.filePath) == 0 { + return nil, fmt.Errorf("invalid file path, file path cannot be empty") + } + + if a.logger == nil { + a.logger = log.New("") + } + + return a, nil +} + +// Adapter +func (a *adapter) LoadPolicy(model model.Model) error { + a.lock.Lock() + defer a.lock.Unlock() + + return a.loadPolicyFile(model) +} + +func (a *adapter) loadPolicyFile(model model.Model) error { + if _, err := a.fs.Stat(a.filePath); os.IsNotExist(err) { + a.domains = []Domain{} + return nil + } + + data, err := a.fs.ReadFile(a.filePath) + if err != nil { + return err + } + + domains := []Domain{} + + err = json.Unmarshal(data, &domains) + if err != nil { + return err + } + + rule := [5]string{} + for _, domain := range domains { + rule[0] = "p" + rule[2] = domain.Name + for name, roles := range domain.Roles { + rule[1] = "role:" + name + for _, role := range roles { + rule[3] = role.Resource + rule[4] = formatActions(role.Actions) + + if err := a.importPolicy(model, rule[0:5]); err != nil { + return err + } + } + } + + for _, policy := range domain.Policies { + rule[1] = policy.Username + rule[3] = policy.Resource + rule[4] = formatActions(policy.Actions) + + if err := a.importPolicy(model, rule[0:5]); err != nil { + return err + } + } + + rule[0] = "g" + rule[3] = domain.Name + + for _, ug := range domain.UserRoles { + rule[1] = ug.Username + rule[2] = "role:" + ug.Role + + if err := a.importPolicy(model, rule[0:4]); err != nil { + return err + } + } + } + + a.domains = domains + + return nil +} + +func (a *adapter) importPolicy(model model.Model, rule []string) error { + copiedRule := make([]string, len(rule)) + copy(copiedRule, rule) + + a.logger.Debug().WithFields(log.Fields{ + "subject": copiedRule[1], + "domain": copiedRule[2], + "resource": copiedRule[3], + "actions": copiedRule[4], + }).Log("Imported policy") + + ok, err := model.HasPolicyEx(copiedRule[0], copiedRule[0], copiedRule[1:]) + if err != nil { + return err + } + if ok { + return nil // skip duplicated policy + } + + model.AddPolicy(copiedRule[0], copiedRule[0], copiedRule[1:]) + + return nil +} + +// Adapter +func (a *adapter) SavePolicy(model model.Model) error { + a.lock.Lock() + defer a.lock.Unlock() + + return a.savePolicyFile() +} + +func (a *adapter) savePolicyFile() error { + jsondata, err := json.MarshalIndent(a.domains, "", " ") + if err != nil { + return err + } + + _, _, err = a.fs.WriteFileSafe(a.filePath, jsondata) + + return err +} + +// Adapter (auto-save) +func (a *adapter) AddPolicy(sec, ptype string, rule []string) error { + a.lock.Lock() + defer a.lock.Unlock() + + err := a.addPolicy(ptype, rule) + if err != nil { + return err + } + + return a.savePolicyFile() +} + +// BatchAdapter (auto-save) +func (a *adapter) AddPolicies(sec string, ptype string, rules [][]string) error { + a.lock.Lock() + defer a.lock.Unlock() + + for _, rule := range rules { + err := a.addPolicy(ptype, rule) + if err != nil { + return err + } + } + + return a.savePolicyFile() +} + +func (a *adapter) addPolicy(ptype string, rule []string) error { + ok, err := a.hasPolicy(ptype, rule) + if err != nil { + return err + } + + if ok { + // the policy is already there, nothing to add + return nil + } + + username := "" + role := "" + domain := "" + resource := "" + actions := "" + + if ptype == "p" { + username = rule[0] + domain = rule[1] + resource = rule[2] + actions = formatActions(rule[3]) + + a.logger.Debug().WithFields(log.Fields{ + "subject": username, + "domain": domain, + "resource": resource, + "actions": actions, + }).Log("Adding policy") + } else if ptype == "g" { + username = rule[0] + role = rule[1] + domain = rule[2] + + a.logger.Debug().WithFields(log.Fields{ + "subject": username, + "role": role, + "domain": domain, + }).Log("Adding role mapping") + } else { + return fmt.Errorf("unknown ptype: %s", ptype) + } + + var dom *Domain = nil + for i := range a.domains { + if a.domains[i].Name == domain { + dom = &a.domains[i] + break + } + } + + if dom == nil { + g := Domain{ + Name: domain, + Roles: map[string][]Role{}, + UserRoles: []MapUserRole{}, + Policies: []DomainPolicy{}, + } + + a.domains = append(a.domains, g) + dom = &a.domains[len(a.domains)-1] + } + + if ptype == "p" { + if strings.HasPrefix(username, "role:") { + if dom.Roles == nil { + dom.Roles = make(map[string][]Role) + } + + role := strings.TrimPrefix(username, "role:") + dom.Roles[role] = append(dom.Roles[role], Role{ + Resource: resource, + Actions: actions, + }) + } else { + dom.Policies = append(dom.Policies, DomainPolicy{ + Username: username, + Role: Role{ + Resource: resource, + Actions: actions, + }, + }) + } + } else { + dom.UserRoles = append(dom.UserRoles, MapUserRole{ + Username: username, + Role: strings.TrimPrefix(role, "role:"), + }) + } + + return nil +} + +func (a *adapter) hasPolicy(ptype string, rule []string) (bool, error) { + var username string + var role string + var domain string + var resource string + var actions string + + if ptype == "p" { + if len(rule) < 4 { + return false, fmt.Errorf("invalid rule length. must be 'user/role, domain, resource, actions'") + } + + username = rule[0] + domain = rule[1] + resource = rule[2] + actions = formatActions(rule[3]) + } else if ptype == "g" { + if len(rule) < 3 { + return false, fmt.Errorf("invalid rule length. must be 'user, role, domain'") + } + + username = rule[0] + role = rule[1] + domain = rule[2] + } else { + return false, fmt.Errorf("unknown ptype: %s", ptype) + } + + var dom *Domain = nil + for i := range a.domains { + if a.domains[i].Name == domain { + dom = &a.domains[i] + break + } + } + + if dom == nil { + // if we can't find any domain with that name, then the policy doesn't exist + return false, nil + } + + if ptype == "p" { + isRole := false + if strings.HasPrefix(username, "role:") { + isRole = true + username = strings.TrimPrefix(username, "role:") + } + + if isRole { + roles, ok := dom.Roles[username] + if !ok { + // unknown role, policy doesn't exist + return false, nil + } + + for _, role := range roles { + if role.Resource == resource && formatActions(role.Actions) == actions { + return true, nil + } + } + } else { + for _, p := range dom.Policies { + if p.Username == username && p.Resource == resource && formatActions(p.Actions) == actions { + return true, nil + } + } + } + } else { + role = strings.TrimPrefix(role, "role:") + for _, user := range dom.UserRoles { + if user.Username == username && user.Role == role { + return true, nil + } + } + } + + return false, nil +} + +// Adapter (auto-save) +func (a *adapter) RemovePolicy(sec string, ptype string, rule []string) error { + a.lock.Lock() + defer a.lock.Unlock() + + err := a.removePolicy(ptype, rule) + if err != nil { + return err + } + + return a.savePolicyFile() +} + +// BatchAdapter (auto-save) +func (a *adapter) RemovePolicies(sec string, ptype string, rules [][]string) error { + a.lock.Lock() + defer a.lock.Unlock() + + for _, rule := range rules { + err := a.removePolicy(ptype, rule) + if err != nil { + return err + } + } + + return a.savePolicyFile() +} + +func (a *adapter) removePolicy(ptype string, rule []string) error { + ok, err := a.hasPolicy(ptype, rule) + if err != nil { + return err + } + + if !ok { + // the policy is not there, nothing to remove + return nil + } + + username := "" + role := "" + domain := "" + resource := "" + actions := "" + + if ptype == "p" { + username = rule[0] + domain = rule[1] + resource = rule[2] + actions = formatActions(rule[3]) + + a.logger.Debug().WithFields(log.Fields{ + "subject": username, + "domain": domain, + "resource": resource, + "actions": actions, + }).Log("Removing policy") + } else if ptype == "g" { + username = rule[0] + role = rule[1] + domain = rule[2] + + a.logger.Debug().WithFields(log.Fields{ + "subject": username, + "role": role, + "domain": domain, + }).Log("Removing role mapping") + } else { + return fmt.Errorf("unknown ptype: %s", ptype) + } + + var dom *Domain = nil + for i := range a.domains { + if a.domains[i].Name == domain { + dom = &a.domains[i] + break + } + } + + if ptype == "p" { + isRole := false + if strings.HasPrefix(username, "role:") { + isRole = true + username = strings.TrimPrefix(username, "role:") + } + + if isRole { + roles := dom.Roles[username] + + newRoles := []Role{} + + for _, role := range roles { + if role.Resource == resource && formatActions(role.Actions) == actions { + continue + } + + newRoles = append(newRoles, role) + } + + dom.Roles[username] = newRoles + } else { + policies := []DomainPolicy{} + + for _, p := range dom.Policies { + if p.Username == username && p.Resource == resource && formatActions(p.Actions) == actions { + continue + } + + policies = append(policies, p) + } + + dom.Policies = policies + } + } else { + role = strings.TrimPrefix(role, "role:") + + users := []MapUserRole{} + + for _, user := range dom.UserRoles { + if user.Username == username && user.Role == role { + continue + } + + users = append(users, user) + } + + dom.UserRoles = users + } + + // Remove the group if there are no rules and policies + if len(dom.Roles) == 0 && len(dom.UserRoles) == 0 && len(dom.Policies) == 0 { + groups := []Domain{} + + for _, g := range a.domains { + if g.Name == dom.Name { + continue + } + + groups = append(groups, g) + } + + a.domains = groups + } + + return nil +} + +// Adapter +func (a *adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error { + return fmt.Errorf("not implemented") +} + +func (a *adapter) getAllDomains() []string { + names := []string{} + + for _, domain := range a.domains { + if domain.Name[0] == '$' { + continue + } + + names = append(names, domain.Name) + } + + return names +} + +type Domain struct { + Name string `json:"name"` + Roles map[string][]Role `json:"roles"` + UserRoles []MapUserRole `json:"userroles"` + Policies []DomainPolicy `json:"policies"` +} + +type Role struct { + Resource string `json:"resource"` + Actions string `json:"actions"` +} + +type MapUserRole struct { + Username string `json:"username"` + Role string `json:"role"` +} + +type DomainPolicy struct { + Username string `json:"username"` + Role +} + +func formatActions(actions string) string { + a := strings.Split(actions, "|") + + sort.Strings(a) + + return strings.Join(a, "|") +} diff --git a/iam/adapter_test.go b/iam/adapter_test.go new file mode 100644 index 00000000..41c613e7 --- /dev/null +++ b/iam/adapter_test.go @@ -0,0 +1,87 @@ +package iam + +import ( + "encoding/json" + "testing" + + "github.com/datarhei/core/v16/io/fs" + "github.com/stretchr/testify/require" +) + +func TestAddPolicy(t *testing.T) { + memfs, err := fs.NewMemFilesystem(fs.MemConfig{}) + require.NoError(t, err) + + a, err := newAdapter(memfs, "/policy.json", nil) + require.NoError(t, err) + + err = a.AddPolicy("p", "p", []string{"foobar", "group", "resource", "action"}) + require.NoError(t, err) + + require.Equal(t, 1, len(a.domains)) + + data, err := memfs.ReadFile("/policy.json") + require.NoError(t, err) + + g := []Domain{} + err = json.Unmarshal(data, &g) + require.NoError(t, err) + + require.Equal(t, "group", g[0].Name) + require.Equal(t, 1, len(g[0].Policies)) + require.Equal(t, DomainPolicy{ + Username: "foobar", + Role: Role{ + Resource: "resource", + Actions: "action", + }, + }, g[0].Policies[0]) +} + +func TestFormatActions(t *testing.T) { + data := [][]string{ + {"a|b|c", "a|b|c"}, + {"b|c|a", "a|b|c"}, + } + + for _, d := range data { + require.Equal(t, d[1], formatActions(d[0]), d[0]) + } +} + +func TestRemovePolicy(t *testing.T) { + memfs, err := fs.NewMemFilesystem(fs.MemConfig{}) + require.NoError(t, err) + + a, err := newAdapter(memfs, "/policy.json", nil) + require.NoError(t, err) + + err = a.AddPolicies("p", "p", [][]string{ + {"foobar1", "group", "resource1", "action1"}, + {"foobar2", "group", "resource2", "action2"}, + }) + require.NoError(t, err) + + require.Equal(t, 1, len(a.domains)) + require.Equal(t, 2, len(a.domains[0].Policies)) + + err = a.RemovePolicy("p", "p", []string{"foobar1", "group", "resource1", "action1"}) + require.NoError(t, err) + + require.Equal(t, 1, len(a.domains)) + require.Equal(t, 1, len(a.domains[0].Policies)) + + err = a.RemovePolicy("p", "p", []string{"foobar2", "group", "resource2", "action2"}) + require.NoError(t, err) + + require.Equal(t, 0, len(a.domains)) + + data, err := memfs.ReadFile("/policy.json") + require.NoError(t, err) + + g := []Domain{} + err = json.Unmarshal(data, &g) + require.NoError(t, err) + + require.Equal(t, 0, len(g)) +} diff --git a/iam/fixtures/policy.json b/iam/fixtures/policy.json new file mode 100644 index 00000000..306467ef --- /dev/null +++ b/iam/fixtures/policy.json @@ -0,0 +1,26 @@ +[ + { + "name": "$none", + "roles": {}, + "userroles": [], + "policies": [ + { + "username": "ingo", + "resource": "rtmp:/bla-*", + "actions": "play|publish" + } + ] + }, + { + "name": "igelcamp", + "roles": null, + "userroles": null, + "policies": [ + { + "username": "ingo", + "resource": "rtmp:/igelcamp/**", + "actions": "publish" + } + ] + } +] \ No newline at end of file diff --git a/iam/functions.go b/iam/functions.go new file mode 100644 index 00000000..41dce32d --- /dev/null +++ b/iam/functions.go @@ -0,0 +1,86 @@ +package iam + +import ( + "strings" + + "github.com/gobwas/glob" +) + +func resourceMatch(request, domain, policy string) bool { + reqPrefix, reqResource := getPrefix(request) + polPrefix, polResource := getPrefix(policy) + + if reqPrefix != polPrefix { + return false + } + + var match bool + var err error + + if reqPrefix == "api" || reqPrefix == "fs" || reqPrefix == "rtmp" || reqPrefix == "srt" { + match, err = globMatch(polResource, reqResource, rune('/')) + if err != nil { + return false + } + } else { + match, err = globMatch(polResource, reqResource) + if err != nil { + return false + } + } + + return match +} + +func resourceMatchFunc(args ...interface{}) (interface{}, error) { + request := args[0].(string) + domain := args[1].(string) + policy := args[2].(string) + + return (bool)(resourceMatch(request, domain, policy)), nil +} + +func actionMatch(request string, policy string) bool { + request = strings.ToUpper(request) + actions := strings.Split(strings.ToUpper(policy), "|") + if len(actions) == 0 { + return false + } + + if len(actions) == 1 && actions[0] == "ANY" { + return true + } + + for _, a := range actions { + if request == a { + return true + } + } + + return false +} + +func actionMatchFunc(args ...interface{}) (interface{}, error) { + request := args[0].(string) + policy := args[1].(string) + + return (bool)(actionMatch(request, policy)), nil +} + +func getPrefix(s string) (string, string) { + prefix, resource, found := strings.Cut(s, ":") + if !found { + return "", s + } + + return prefix, resource +} + +func globMatch(pattern, name string, separators ...rune) (bool, error) { + g, err := glob.Compile(pattern, separators...) + if err != nil { + return false, err + } + + return g.Match(name), nil +} diff --git a/iam/iam.go b/iam/iam.go new file mode 100644 index 00000000..e611f05b --- /dev/null +++ b/iam/iam.go @@ -0,0 +1,225 @@ +package iam + +import ( + "github.com/datarhei/core/v16/io/fs" + "github.com/datarhei/core/v16/log" +) + +type Enforcer interface { + Enforce(name, domain, resource, action string) bool +} + +type IAM interface { + Enforcer + + HasDomain(domain string) bool + ListDomains() []string + + HasPolicy(name, domain, resource string, actions []string) bool + AddPolicy(name, domain, resource string, actions []string) bool + RemovePolicy(name, domain, resource string, actions []string) bool + + ListPolicies(name, domain, resource string, actions []string) []Policy + + Validators() []string + + CreateIdentity(u User) error + GetIdentity(name string) (User, error) + UpdateIdentity(name string, u User) error + DeleteIdentity(name string) error + ListIdentities() []User + SaveIdentities() error + + GetVerifier(name string) (IdentityVerifier, error) + GetVerfierFromAuth0(name string) (IdentityVerifier, error) + GetDefaultVerifier() IdentityVerifier + + CreateJWT(name string) (string, string, error) + + Close() +} + +type iam struct { + im IdentityManager + am AccessManager + + logger log.Logger +} + +type Config struct { + FS fs.Filesystem + Superuser User + JWTRealm string + JWTSecret string + Logger log.Logger +} + +func NewIAM(config Config) (IAM, error) { + im, err := NewIdentityManager(IdentityConfig{ + FS: config.FS, + Superuser: config.Superuser, + JWTRealm: config.JWTRealm, + JWTSecret: config.JWTSecret, + Logger: config.Logger, + }) + if err != nil { + return nil, err + } + + am, err := NewAccessManager(AccessConfig{ + FS: config.FS, + Logger: config.Logger, + }) + if err != nil { + return nil, err + } + + iam := &iam{ + im: im, + am: am, + logger: config.Logger, + } + + if iam.logger == nil { + iam.logger = log.New("") + } + + return iam, nil +} + +func (i *iam) Close() { + i.im.Close() + i.im = nil + + i.am = nil +} + +func (i *iam) Enforce(name, domain, resource, action string) bool { + if len(name) == 0 { + name = "$anon" + } + + if len(domain) == 0 { + domain = "$none" + } + + superuser := false + + if identity, err := i.im.GetVerifier(name); err == nil { + if identity.IsSuperuser() { + superuser = true + } + } + + l := i.logger.Debug().WithFields(log.Fields{ + "subject": name, + "domain": domain, + "resource": resource, + "action": action, + "superuser": superuser, + }) + + if superuser { + name = "$superuser" + } + + ok, rule := i.am.Enforce(name, domain, resource, action) + + if !ok { + l.Log("no match") + } else { + if name == "$superuser" { + rule = "" + } + + l.WithField("rule", rule).Log("match") + } + + return ok +} + +func (i *iam) CreateIdentity(u User) error { + return i.im.Create(u) +} + +func (i *iam) GetIdentity(name string) (User, error) { + return i.im.Get(name) +} + +func (i *iam) UpdateIdentity(name string, u User) error { + return i.im.Update(name, u) +} + +func (i *iam) DeleteIdentity(name string) error { + return i.im.Delete(name) +} + +func (i *iam) ListIdentities() []User { + return nil +} + +func (i *iam) SaveIdentities() error { + return i.im.Save() +} + +func (i *iam) GetVerifier(name string) (IdentityVerifier, error) { + return i.im.GetVerifier(name) +} + +func (i *iam) GetVerfierFromAuth0(name string) (IdentityVerifier, error) { + return i.im.GetVerifierFromAuth0(name) +} + +func (i *iam) GetDefaultVerifier() IdentityVerifier { + v, _ := i.im.GetDefaultVerifier() + + return v +} + +func (i *iam) CreateJWT(name string) (string, string, error) { + return i.im.CreateJWT(name) +} + +func (i *iam) HasDomain(domain string) bool { + return i.am.HasDomain(domain) +} + +func (i *iam) ListDomains() []string { + return i.am.ListDomains() +} + +func (i *iam) Validators() []string { + return i.im.Validators() +} + +func (i *iam) HasPolicy(name, domain, resource string, actions []string) bool { + if len(name) == 0 { + name = "$anon" + } + + if len(domain) == 0 { + domain = "$none" + } + + return i.am.HasPolicy(name, domain, resource, actions) +} + +func (i *iam) AddPolicy(name, domain, resource string, actions []string) bool { + if len(name) == 0 { + name = "$anon" + } + + if len(domain) == 0 { + domain = "$none" + } + + return i.am.AddPolicy(name, domain, resource, actions) +} + +func (i *iam) RemovePolicy(name, domain, resource string, actions []string) bool { + return i.am.RemovePolicy(name, domain, resource, actions) +} + +func (i *iam) ListPolicies(name, domain, resource string, actions []string) []Policy { + return i.am.ListPolicies(name, domain, resource, actions) +} diff --git a/iam/identity.go b/iam/identity.go new file mode 100644 index 00000000..c732517a --- /dev/null +++ b/iam/identity.go @@ -0,0 +1,929 @@ +package iam + +import ( + "encoding/json" + "fmt" + "os" + "regexp" + "sync" + "time" + + "github.com/datarhei/core/v16/iam/jwks" + "github.com/datarhei/core/v16/io/fs" + "github.com/datarhei/core/v16/log" + "github.com/google/uuid" + + jwtgo "github.com/golang-jwt/jwt/v4" +) + +// Auth0 +// there needs to be a mapping from the Auth.User to Name +// the same Auth0.User can't have multiple identities +// the whole jwks will be part of this package + +type User struct { + Name string `json:"name"` + Superuser bool `json:"superuser"` + Auth UserAuth `json:"auth"` +} + +type UserAuth struct { + API UserAuthAPI `json:"api"` + Services UserAuthServices `json:"services"` +} + +type UserAuthAPI struct { + Password string `json:"password"` + Auth0 UserAuthAPIAuth0 `json:"auth0"` +} + +type UserAuthAPIAuth0 struct { + User string `json:"user"` + Tenant Auth0Tenant `json:"tenant"` +} + +type UserAuthServices struct { + Basic []string `json:"basic"` + Token []string `json:"token"` +} + +func (u *User) validate() error { + if len(u.Name) == 0 { + return fmt.Errorf("the name is required") + } + + chars := `A-Za-z0-9:_-` + + re := regexp.MustCompile(`[^` + chars + `]`) + if re.MatchString(u.Name) { + return fmt.Errorf("the name can only contain [%s]", chars) + } + + return nil +} + +func (u *User) marshalIdentity() *identity { + i := &identity{ + user: *u, + } + + return i +} + +func (u *User) clone() User { + user := *u + + user.Auth.Services.Token = make([]string, len(u.Auth.Services.Token)) + copy(user.Auth.Services.Token, u.Auth.Services.Token) + + return user +} + +type IdentityVerifier interface { + Name() string + + VerifyJWT(jwt string) (bool, error) + + VerifyAPIPassword(password string) (bool, error) + VerifyAPIAuth0(jwt string) (bool, error) + + VerifyServiceBasicAuth(password string) (bool, error) + VerifyServiceToken(token string) (bool, error) + + GetServiceBasicAuth() string + GetServiceToken() string + + IsSuperuser() bool +} + +type identity struct { + user User + + tenant *auth0Tenant + + jwtRealm string + jwtKeyFunc func(*jwtgo.Token) (interface{}, error) + + valid bool + + lock sync.RWMutex +} + +func (i *identity) Name() string { + return i.user.Name +} + +func (i *identity) VerifyAPIPassword(password string) (bool, error) { + i.lock.RLock() + defer i.lock.RUnlock() + + if !i.isValid() { + return false, fmt.Errorf("invalid identity") + } + + if len(i.user.Auth.API.Password) == 0 { + return false, fmt.Errorf("authentication method disabled") + } + + return i.user.Auth.API.Password == password, nil +} + +func (i *identity) VerifyAPIAuth0(jwt string) (bool, error) { + i.lock.RLock() + defer i.lock.RUnlock() + + if !i.isValid() { + return false, fmt.Errorf("invalid identity") + } + + if len(i.user.Auth.API.Auth0.User) == 0 { + return false, fmt.Errorf("authentication method disabled") + } + + p := &jwtgo.Parser{} + token, _, err := p.ParseUnverified(jwt, jwtgo.MapClaims{}) + if err != nil { + return false, err + } + + var subject string + if claims, ok := token.Claims.(jwtgo.MapClaims); ok { + if sub, ok := claims["sub"]; ok { + subject = sub.(string) + } + } + + if subject != i.user.Auth.API.Auth0.User { + return false, fmt.Errorf("wrong subject") + } + + var issuer string + if claims, ok := token.Claims.(jwtgo.MapClaims); ok { + if iss, ok := claims["iss"]; ok { + issuer = iss.(string) + } + } + + if issuer != i.tenant.issuer { + return false, fmt.Errorf("wrong issuer") + } + + token, err = jwtgo.Parse(jwt, i.auth0KeyFunc) + if err != nil { + return false, err + } + + if !token.Valid { + return false, fmt.Errorf("invalid token") + } + + return true, nil +} + +func (i *identity) auth0KeyFunc(token *jwtgo.Token) (interface{}, error) { + // Verify 'aud' claim + checkAud := token.Claims.(jwtgo.MapClaims).VerifyAudience(i.tenant.audience, false) + if !checkAud { + return nil, fmt.Errorf("invalid audience") + } + + // Verify 'iss' claim + checkIss := token.Claims.(jwtgo.MapClaims).VerifyIssuer(i.tenant.issuer, false) + if !checkIss { + return nil, fmt.Errorf("invalid issuer") + } + + // Verify 'sub' claim + if _, ok := token.Claims.(jwtgo.MapClaims)["sub"]; !ok { + return nil, fmt.Errorf("sub claim is required") + } + + // find the key + if _, ok := token.Header["kid"]; !ok { + return nil, fmt.Errorf("kid not found") + } + + kid := token.Header["kid"].(string) + + key, err := i.tenant.certs.Key(kid) + if err != nil { + return nil, fmt.Errorf("no cert for kid found: %w", err) + } + + // find algorithm + if _, ok := token.Header["alg"]; !ok { + return nil, fmt.Errorf("kid not found") + } + + alg := token.Header["alg"].(string) + + if key.Alg() != alg { + return nil, fmt.Errorf("signing method doesn't match") + } + + // get the public key + publicKey, err := key.PublicKey() + if err != nil { + return nil, fmt.Errorf("invalid public key: %w", err) + } + + return publicKey, nil +} + +func (i *identity) VerifyJWT(jwt string) (bool, error) { + i.lock.RLock() + defer i.lock.RUnlock() + + if !i.isValid() { + return false, fmt.Errorf("invalid identity") + } + + p := &jwtgo.Parser{} + token, _, err := p.ParseUnverified(jwt, jwtgo.MapClaims{}) + if err != nil { + return false, err + } + + var subject string + if claims, ok := token.Claims.(jwtgo.MapClaims); ok { + if sub, ok := claims["sub"]; ok { + subject = sub.(string) + } + } + + if subject != i.user.Name { + return false, fmt.Errorf("wrong subject") + } + + var issuer string + if claims, ok := token.Claims.(jwtgo.MapClaims); ok { + if sub, ok := claims["iss"]; ok { + issuer = sub.(string) + } + } + + if issuer != i.jwtRealm { + return false, fmt.Errorf("wrong issuer") + } + + if token.Method.Alg() != "HS256" { + return false, fmt.Errorf("invalid hashing algorithm") + } + + token, err = jwtgo.Parse(jwt, i.jwtKeyFunc) + if err != nil { + return false, err + } + + if !token.Valid { + return false, fmt.Errorf("invalid token") + } + + return true, nil +} + +func (i *identity) VerifyServiceBasicAuth(password string) (bool, error) { + i.lock.RLock() + defer i.lock.RUnlock() + + if !i.isValid() { + return false, fmt.Errorf("invalid identity") + } + + for _, pw := range i.user.Auth.Services.Basic { + if len(pw) == 0 { + continue + } + + if pw == password { + return true, nil + } + } + + return false, nil +} + +func (i *identity) GetServiceBasicAuth() string { + i.lock.RLock() + defer i.lock.RUnlock() + + if !i.isValid() { + return "" + } + + for _, password := range i.user.Auth.Services.Basic { + if len(password) == 0 { + continue + } + + return password + } + + return "" +} + +func (i *identity) VerifyServiceToken(token string) (bool, error) { + i.lock.RLock() + defer i.lock.RUnlock() + + if !i.isValid() { + return false, fmt.Errorf("invalid identity") + } + + for _, t := range i.user.Auth.Services.Token { + if t == token { + return true, nil + } + } + + return false, nil +} + +func (i *identity) GetServiceToken() string { + i.lock.RLock() + defer i.lock.RUnlock() + + if !i.isValid() { + return "" + } + + for _, token := range i.user.Auth.Services.Token { + if len(token) == 0 { + continue + } + + return i.Name() + ":" + token + } + + return "" +} + +func (i *identity) isValid() bool { + return i.valid +} + +func (i *identity) IsSuperuser() bool { + i.lock.RLock() + defer i.lock.RUnlock() + + return i.user.Superuser +} + +type IdentityManager interface { + Create(identity User) error + Update(name string, identity User) error + Delete(name string) error + + Get(name string) (User, error) + GetVerifier(name string) (IdentityVerifier, error) + GetVerifierFromAuth0(name string) (IdentityVerifier, error) + GetDefaultVerifier() (IdentityVerifier, error) + + Validators() []string + CreateJWT(name string) (string, string, error) + + Save() error + Autosave(bool) + Close() +} + +type identityManager struct { + root *identity + + identities map[string]*identity + tenants map[string]*auth0Tenant + + auth0UserIdentityMap map[string]string + + fs fs.Filesystem + filePath string + autosave bool + logger log.Logger + + jwtRealm string + jwtSecret []byte + + lock sync.RWMutex +} + +type IdentityConfig struct { + FS fs.Filesystem + Superuser User + JWTRealm string + JWTSecret string + Logger log.Logger +} + +func NewIdentityManager(config IdentityConfig) (IdentityManager, error) { + im := &identityManager{ + identities: map[string]*identity{}, + tenants: map[string]*auth0Tenant{}, + auth0UserIdentityMap: map[string]string{}, + fs: config.FS, + filePath: "./users.json", + jwtRealm: config.JWTRealm, + jwtSecret: []byte(config.JWTSecret), + logger: config.Logger, + } + + if im.logger == nil { + im.logger = log.New("") + } + + if im.fs == nil { + return nil, fmt.Errorf("no filesystem provided") + } + + err := im.load(im.filePath) + if err != nil { + return nil, err + } + + config.Superuser.Superuser = true + identity, err := im.create(config.Superuser) + if err != nil { + return nil, err + } + + im.root = identity + im.autosave = true + + return im, nil +} + +func (im *identityManager) Close() { + im.lock.Lock() + defer im.lock.Unlock() + + im.fs = nil + im.auth0UserIdentityMap = map[string]string{} + im.identities = map[string]*identity{} + im.root = nil + + for _, t := range im.tenants { + t.Cancel() + } + + im.tenants = map[string]*auth0Tenant{} +} + +func (im *identityManager) Create(u User) error { + if err := u.validate(); err != nil { + return err + } + + im.lock.Lock() + defer im.lock.Unlock() + + if im.root != nil && im.root.user.Name == u.Name { + return fmt.Errorf("identity already exists") + } + + _, ok := im.identities[u.Name] + if ok { + return fmt.Errorf("identity already exists") + } + + identity, err := im.create(u) + if err != nil { + return err + } + + im.identities[identity.user.Name] = identity + + if im.autosave { + im.save(im.filePath) + } + + return nil +} + +func (im *identityManager) create(u User) (*identity, error) { + u = u.clone() + identity := u.marshalIdentity() + + if len(identity.user.Auth.API.Auth0.User) != 0 { + if _, ok := im.auth0UserIdentityMap[identity.user.Auth.API.Auth0.User]; ok { + return nil, fmt.Errorf("the Auth0 user has already an identity") + } + + auth0Key := identity.user.Auth.API.Auth0.Tenant.key() + + if tenant, ok := im.tenants[auth0Key]; !ok { + tenant, err := newAuth0Tenant(identity.user.Auth.API.Auth0.Tenant) + if err != nil { + return nil, err + } + + im.tenants[auth0Key] = tenant + identity.tenant = tenant + } else { + tenant.AddClientID(identity.user.Auth.API.Auth0.Tenant.ClientID) + identity.tenant = tenant + } + + im.auth0UserIdentityMap[identity.user.Auth.API.Auth0.User] = u.Name + } + + identity.valid = true + + im.logger.Debug().WithField("name", identity.Name()).Log("Identity created") + + return identity, nil +} + +func (im *identityManager) Update(name string, u User) error { + if err := u.validate(); err != nil { + return err + } + + im.lock.Lock() + defer im.lock.Unlock() + + if im.root.user.Name == name { + return fmt.Errorf("this identity can't be updated") + } + + oldidentity, ok := im.identities[name] + if !ok { + return fmt.Errorf("not found") + } + + if name != u.Name { + _, err := im.getIdentity(u.Name) + if err == nil { + return fmt.Errorf("identity already exist") + } + } + + err := im.delete(name) + if err != nil { + return err + } + + identity, err := im.create(u) + if err != nil { + if identity, err := im.create(oldidentity.user); err != nil { + return err + } else { + im.identities[identity.user.Name] = identity + } + + return err + } + + im.identities[identity.user.Name] = identity + + im.logger.Debug().WithFields(log.Fields{ + "oldname": name, + "newname": identity.Name(), + }).Log("Identity updated") + + if im.autosave { + im.save(im.filePath) + } + + return nil +} + +func (im *identityManager) Delete(name string) error { + im.lock.Lock() + defer im.lock.Unlock() + + return im.delete(name) +} + +func (im *identityManager) delete(name string) error { + if im.root.user.Name == name { + return fmt.Errorf("this identity can't be removed") + } + + identity, ok := im.identities[name] + if !ok { + return fmt.Errorf("not found") + } + + delete(im.identities, name) + + identity.lock.Lock() + identity.valid = false + identity.lock.Unlock() + + if len(identity.user.Auth.API.Auth0.User) == 0 { + if im.autosave { + im.save(im.filePath) + } + + return nil + } + + delete(im.auth0UserIdentityMap, identity.user.Auth.API.Auth0.User) + + // find out if the tenant is still used somewhere else + found := false + for _, i := range im.identities { + if i.tenant == identity.tenant { + found = true + break + } + } + + if !found { + identity.tenant.Cancel() + delete(im.tenants, identity.user.Auth.API.Auth0.Tenant.key()) + + if im.autosave { + im.save(im.filePath) + } + + return nil + } + + // find out if the tenant's clientid is still used somewhere else + found = false + for _, i := range im.identities { + if len(i.user.Auth.API.Auth0.User) == 0 { + continue + } + + if i.user.Auth.API.Auth0.Tenant.ClientID == identity.user.Auth.API.Auth0.Tenant.ClientID { + found = true + break + } + } + + if !found { + identity.tenant.RemoveClientID(identity.user.Auth.API.Auth0.Tenant.ClientID) + } + + if im.autosave { + im.save(im.filePath) + } + + return nil +} + +func (im *identityManager) getIdentity(name string) (*identity, error) { + var identity *identity = nil + + if im.root.user.Name == name { + identity = im.root + } else { + identity = im.identities[name] + } + + if identity == nil { + return nil, fmt.Errorf("not found") + } + + identity.jwtRealm = im.jwtRealm + identity.jwtKeyFunc = func(*jwtgo.Token) (interface{}, error) { return im.jwtSecret, nil } + + return identity, nil +} + +func (im *identityManager) Get(name string) (User, error) { + im.lock.RLock() + defer im.lock.RUnlock() + + identity, err := im.getIdentity(name) + if err != nil { + return User{}, err + } + + user := identity.user.clone() + + return user, nil +} + +func (im *identityManager) GetVerifier(name string) (IdentityVerifier, error) { + im.lock.RLock() + defer im.lock.RUnlock() + + return im.getIdentity(name) +} + +func (im *identityManager) GetVerifierFromAuth0(name string) (IdentityVerifier, error) { + im.lock.RLock() + defer im.lock.RUnlock() + + name, ok := im.auth0UserIdentityMap[name] + if !ok { + return nil, fmt.Errorf("not found") + } + + return im.getIdentity(name) +} + +func (im *identityManager) GetDefaultVerifier() (IdentityVerifier, error) { + return im.root, nil +} + +func (im *identityManager) load(filePath string) error { + if _, err := im.fs.Stat(filePath); os.IsNotExist(err) { + return nil + } + + data, err := im.fs.ReadFile(filePath) + if err != nil { + return err + } + + users := []User{} + + err = json.Unmarshal(data, &users) + if err != nil { + return err + } + + for _, u := range users { + err = im.Create(u) + if err != nil { + return err + } + } + + return nil +} + +func (im *identityManager) Save() error { + im.lock.RLock() + defer im.lock.RUnlock() + + return im.save(im.filePath) +} + +func (im *identityManager) save(filePath string) error { + if filePath == "" { + return fmt.Errorf("invalid file path, file path cannot be empty") + } + + users := []User{} + + for _, u := range im.identities { + users = append(users, u.user) + } + + jsondata, err := json.MarshalIndent(users, "", " ") + if err != nil { + return err + } + + _, _, err = im.fs.WriteFileSafe(filePath, jsondata) + + im.logger.Debug().WithField("path", filePath).Log("Identity file save") + + return err +} + +func (im *identityManager) Autosave(auto bool) { + im.lock.Lock() + defer im.lock.Unlock() + + im.autosave = auto +} + +func (im *identityManager) Validators() []string { + validators := []string{"localjwt"} + + im.lock.RLock() + defer im.lock.RUnlock() + + for _, t := range im.tenants { + for _, clientid := range t.clientIDs { + validators = append(validators, fmt.Sprintf("auth0 domain=%s audience=%s clientid=%s", t.domain, t.audience, clientid)) + } + } + + return validators +} + +func (im *identityManager) CreateJWT(name string) (string, string, error) { + im.lock.RLock() + defer im.lock.RUnlock() + + identity, err := im.getIdentity(name) + if err != nil { + return "", "", err + } + + now := time.Now() + accessExpires := now.Add(time.Minute * 10) + refreshExpires := now.Add(time.Hour * 24) + + // Create access token + accessToken := jwtgo.NewWithClaims(jwtgo.SigningMethodHS256, jwtgo.MapClaims{ + "iss": im.jwtRealm, + "sub": identity.Name(), + "usefor": "access", + "iat": now.Unix(), + "exp": accessExpires.Unix(), + "exi": uint64(accessExpires.Sub(now).Seconds()), + "jti": uuid.New().String(), + }) + + // Generate encoded access token + at, err := accessToken.SignedString(im.jwtSecret) + if err != nil { + return "", "", err + } + + // Create refresh token + refreshToken := jwtgo.NewWithClaims(jwtgo.SigningMethodHS256, jwtgo.MapClaims{ + "iss": im.jwtRealm, + "sub": identity.Name(), + "usefor": "refresh", + "iat": now.Unix(), + "exp": refreshExpires.Unix(), + "exi": uint64(refreshExpires.Sub(now).Seconds()), + "jti": uuid.New().String(), + }) + + // Generate encoded refresh token + rt, err := refreshToken.SignedString(im.jwtSecret) + if err != nil { + return "", "", err + } + + return at, rt, nil +} + +type Auth0Tenant struct { + Domain string `json:"domain"` + Audience string `json:"audience"` + ClientID string `json:"client_id"` +} + +func (t *Auth0Tenant) key() string { + return t.Domain + t.Audience +} + +type auth0Tenant struct { + domain string + issuer string + audience string + clientIDs []string + certs jwks.JWKS + + lock sync.Mutex +} + +func newAuth0Tenant(tenant Auth0Tenant) (*auth0Tenant, error) { + t := &auth0Tenant{ + domain: tenant.Domain, + issuer: "https://" + tenant.Domain + "/", + audience: tenant.Audience, + clientIDs: []string{tenant.ClientID}, + certs: nil, + } + + url := t.issuer + ".well-known/jwks.json" + certs, err := jwks.NewFromURL(url, jwks.Config{}) + if err != nil { + return nil, err + } + + t.certs = certs + + return t, nil +} + +func (a *auth0Tenant) Cancel() { + a.certs.Cancel() +} + +func (a *auth0Tenant) AddClientID(clientid string) { + a.lock.Lock() + defer a.lock.Unlock() + + found := false + for _, id := range a.clientIDs { + if id == clientid { + found = true + break + } + } + + if found { + return + } + + a.clientIDs = append(a.clientIDs, clientid) +} + +func (a *auth0Tenant) RemoveClientID(clientid string) { + a.lock.Lock() + defer a.lock.Unlock() + + clientids := []string{} + + for _, id := range a.clientIDs { + if id == clientid { + continue + } + + clientids = append(clientids, id) + } + + a.clientIDs = clientids +} diff --git a/iam/identity_test.go b/iam/identity_test.go new file mode 100644 index 00000000..93c6d0cf --- /dev/null +++ b/iam/identity_test.go @@ -0,0 +1,769 @@ +package iam + +import ( + "testing" + + "github.com/datarhei/core/v16/io/fs" + "github.com/stretchr/testify/require" +) + +func TestUserName(t *testing.T) { + user := User{} + + err := user.validate() + require.Error(t, err) + + user.Name = "foobar_5" + err = user.validate() + require.NoError(t, err) + + user.Name = "$foob:ar" + err = user.validate() + require.Error(t, err) +} + +func TestIdentity(t *testing.T) { + user := User{ + Name: "foobar", + } + + identity := user.marshalIdentity() + + require.Equal(t, "foobar", identity.Name()) + + require.False(t, identity.isValid()) + identity.valid = true + require.True(t, identity.isValid()) + + require.False(t, identity.IsSuperuser()) + identity.user.Superuser = true + require.True(t, identity.IsSuperuser()) + + dummyfs, err := fs.NewMemFilesystem(fs.MemConfig{}) + require.NoError(t, err) + + im, err := NewIdentityManager(IdentityConfig{ + FS: dummyfs, + Superuser: User{Name: "foobar"}, + JWTRealm: "test-realm", + JWTSecret: "abc123", + Logger: nil, + }) + require.NoError(t, err) + require.NotNil(t, im) + + id, err := im.GetVerifier("unknown") + require.Error(t, err) + require.Nil(t, id) +} + +func TestDefaultIdentity(t *testing.T) { + dummyfs, err := fs.NewMemFilesystem(fs.MemConfig{}) + require.NoError(t, err) + + im, err := NewIdentityManager(IdentityConfig{ + FS: dummyfs, + Superuser: User{Name: "foobar"}, + JWTRealm: "test-realm", + JWTSecret: "abc123", + Logger: nil, + }) + require.NoError(t, err) + require.NotNil(t, im) + + identity, err := im.GetDefaultVerifier() + require.NoError(t, err) + require.NotNil(t, identity) + require.Equal(t, "foobar", identity.Name()) +} + +func TestIdentityAPIAuth(t *testing.T) { + user := User{ + Name: "foobar", + } + + identity := user.marshalIdentity() + + ok, err := identity.VerifyAPIPassword("secret") + require.False(t, ok) + require.Error(t, err) + + identity.user.Auth.API.Password = "secret" + + ok, err = identity.VerifyAPIPassword("secret") + require.False(t, ok) + require.Error(t, err) + + identity.valid = true + + ok, err = identity.VerifyAPIPassword("secret") + require.True(t, ok) + require.NoError(t, err) + + identity.user.Auth.API.Password = "" + + ok, err = identity.VerifyAPIPassword("secret") + require.False(t, ok) + require.Error(t, err) + + identity.user.Auth.API.Password = "terces" + + ok, err = identity.VerifyAPIPassword("secret") + require.False(t, ok) + require.NoError(t, err) +} + +func TestIdentityServiceBasicAuth(t *testing.T) { + user := User{ + Name: "foobar", + } + + identity := user.marshalIdentity() + + ok, err := identity.VerifyServiceBasicAuth("secret") + require.False(t, ok) + require.Error(t, err) + + identity.user.Auth.Services.Basic = append(identity.user.Auth.Services.Basic, "secret") + + ok, err = identity.VerifyServiceBasicAuth("secret") + require.False(t, ok) + require.Error(t, err) + + identity.valid = true + + ok, err = identity.VerifyServiceBasicAuth("secret") + require.True(t, ok) + require.NoError(t, err) + + identity.user.Auth.Services.Basic[0] = "" + + ok, err = identity.VerifyServiceBasicAuth("secret") + require.False(t, ok) + require.NoError(t, err) + + identity.user.Auth.Services.Basic[0] = "terces" + + ok, err = identity.VerifyServiceBasicAuth("secret") + require.False(t, ok) + require.NoError(t, err) + + password := identity.GetServiceBasicAuth() + require.Equal(t, "terces", password) +} + +func TestIdentityServiceTokenAuth(t *testing.T) { + user := User{ + Name: "foobar", + } + + identity := user.marshalIdentity() + + ok, err := identity.VerifyServiceToken("secret") + require.False(t, ok) + require.Error(t, err) + + identity.user.Auth.Services.Token = []string{"secret"} + + ok, err = identity.VerifyServiceToken("secret") + require.False(t, ok) + require.Error(t, err) + + identity.valid = true + + ok, err = identity.VerifyServiceToken("secret") + require.True(t, ok) + require.NoError(t, err) + + identity.user.Auth.Services.Token = []string{"terces"} + + ok, err = identity.VerifyServiceToken("secret") + require.False(t, ok) + require.NoError(t, err) + + token := identity.GetServiceToken() + require.Equal(t, "foobar:terces", token) +} + +func TestJWT(t *testing.T) { + dummyfs, err := fs.NewMemFilesystem(fs.MemConfig{}) + require.NoError(t, err) + + im, err := NewIdentityManager(IdentityConfig{ + FS: dummyfs, + Superuser: User{Name: "foobar"}, + JWTRealm: "test-realm", + JWTSecret: "abc123", + Logger: nil, + }) + require.NoError(t, err) + require.NotNil(t, im) + + access, refresh, err := im.CreateJWT("foobaz") + require.Error(t, err) + require.Equal(t, "", access) + require.Equal(t, "", refresh) + + access, refresh, err = im.CreateJWT("foobar") + require.NoError(t, err) + + identity, err := im.GetVerifier("foobar") + require.NoError(t, err) + require.NotNil(t, identity) + + ok, err := identity.VerifyJWT("something") + require.Error(t, err) + require.False(t, ok) + + ok, err = identity.VerifyJWT(access) + require.NoError(t, err) + require.True(t, ok) + + ok, err = identity.VerifyJWT(refresh) + require.NoError(t, err) + require.True(t, ok) + + err = im.Create(User{Name: "foobaz"}) + require.NoError(t, err) + + access, refresh, err = im.CreateJWT("foobaz") + require.NoError(t, err) + + ok, err = identity.VerifyJWT(access) + require.Error(t, err) + require.False(t, ok) + + ok, err = identity.VerifyJWT(refresh) + require.Error(t, err) + require.False(t, ok) +} + +func TestCreateUser(t *testing.T) { + dummyfs, err := fs.NewMemFilesystem(fs.MemConfig{}) + require.NoError(t, err) + + im, err := NewIdentityManager(IdentityConfig{ + FS: dummyfs, + Superuser: User{Name: "foobar"}, + JWTRealm: "test-realm", + JWTSecret: "abc123", + Logger: nil, + }) + require.NoError(t, err) + require.NotNil(t, im) + + err = im.Create(User{Name: "foobar"}) + require.Error(t, err) + + err = im.Create(User{Name: "foobaz"}) + require.NoError(t, err) + + err = im.Create(User{Name: "foobaz"}) + require.Error(t, err) +} + +func TestCreateUserAuth0(t *testing.T) { + dummyfs, err := fs.NewMemFilesystem(fs.MemConfig{}) + require.NoError(t, err) + + im, err := NewIdentityManager(IdentityConfig{ + FS: dummyfs, + Superuser: User{Name: "foobar"}, + JWTRealm: "test-realm", + JWTSecret: "abc123", + Logger: nil, + }) + require.NoError(t, err) + require.NotNil(t, im) + + require.ElementsMatch(t, []string{"localjwt"}, im.Validators()) + + err = im.Create(User{ + Name: "foobaz", + Superuser: false, + Auth: UserAuth{ + API: UserAuthAPI{ + Auth0: UserAuthAPIAuth0{ + User: "auth0|123456", + Tenant: Auth0Tenant{ + Domain: "example.com", + Audience: "https://api.example.com/", + ClientID: "123456", + }, + }, + }, + }, + }) + require.Error(t, err) + + err = im.Create(User{ + Name: "foobaz", + Superuser: false, + Auth: UserAuth{ + API: UserAuthAPI{ + Auth0: UserAuthAPIAuth0{ + User: "auth0|123456", + Tenant: Auth0Tenant{ + Domain: "datarhei-demo.eu.auth0.com", + Audience: "https://datarhei-demo.eu.auth0.com/api/v2/", + ClientID: "123456", + }, + }, + }, + }, + }) + require.NoError(t, err) + + identity, err := im.GetVerifierFromAuth0("foobaz") + require.Error(t, err) + require.Nil(t, identity) + + identity, err = im.GetVerifierFromAuth0("auth0|123456") + require.NoError(t, err) + require.NotNil(t, identity) + + manager, ok := im.(*identityManager) + require.True(t, ok) + require.NotNil(t, manager) + + require.Equal(t, 1, len(manager.tenants)) + require.Equal(t, map[string]string{"auth0|123456": "foobaz"}, manager.auth0UserIdentityMap) + + require.ElementsMatch(t, []string{ + "localjwt", + "auth0 domain=datarhei-demo.eu.auth0.com audience=https://datarhei-demo.eu.auth0.com/api/v2/ clientid=123456", + }, im.Validators()) + + err = im.Create(User{ + Name: "fooboz", + Superuser: false, + Auth: UserAuth{ + API: UserAuthAPI{ + Auth0: UserAuthAPIAuth0{ + User: "auth0|123456", + Tenant: Auth0Tenant{ + Domain: "datarhei-demo.eu.auth0.com", + Audience: "https://datarhei-demo.eu.auth0.com/api/v2/", + ClientID: "123456", + }, + }, + }, + }, + }) + require.Error(t, err) + + err = im.Create(User{ + Name: "fooboz", + Superuser: false, + Auth: UserAuth{ + API: UserAuthAPI{ + Auth0: UserAuthAPIAuth0{ + User: "auth0|987654", + Tenant: Auth0Tenant{ + Domain: "datarhei-demo.eu.auth0.com", + Audience: "https://datarhei-demo.eu.auth0.com/api/v2/", + ClientID: "987654", + }, + }, + }, + }, + }) + require.NoError(t, err) + + require.Equal(t, 1, len(manager.tenants)) + require.Equal(t, map[string]string{"auth0|123456": "foobaz", "auth0|987654": "fooboz"}, manager.auth0UserIdentityMap) + + require.ElementsMatch(t, []string{ + "localjwt", + "auth0 domain=datarhei-demo.eu.auth0.com audience=https://datarhei-demo.eu.auth0.com/api/v2/ clientid=123456", + "auth0 domain=datarhei-demo.eu.auth0.com audience=https://datarhei-demo.eu.auth0.com/api/v2/ clientid=987654", + }, im.Validators()) + + im.Close() +} + +func TestLoadAndSave(t *testing.T) { + dummyfs, err := fs.NewMemFilesystem(fs.MemConfig{}) + require.NoError(t, err) + + im, err := NewIdentityManager(IdentityConfig{ + FS: dummyfs, + Superuser: User{Name: "foobar"}, + JWTRealm: "test-realm", + JWTSecret: "abc123", + Logger: nil, + }) + require.NoError(t, err) + require.NotNil(t, im) + + err = im.Save() + require.NoError(t, err) + + _, err = dummyfs.Stat("./users.json") + require.NoError(t, err) + + data, err := dummyfs.ReadFile("./users.json") + require.NoError(t, err) + require.Equal(t, []byte("[]"), data) + + err = im.Create(User{Name: "foobaz"}) + require.NoError(t, err) + + identity, err := im.GetVerifier("foobaz") + require.NoError(t, err) + require.NotNil(t, identity) + + err = im.Save() + require.NoError(t, err) + + im, err = NewIdentityManager(IdentityConfig{ + FS: dummyfs, + Superuser: User{Name: "foobar"}, + JWTRealm: "test-realm", + JWTSecret: "abc123", + Logger: nil, + }) + require.NoError(t, err) + require.NotNil(t, im) + + identity, err = im.GetVerifier("foobaz") + require.NoError(t, err) + require.NotNil(t, identity) +} + +func TestUpdateUser(t *testing.T) { + dummyfs, err := fs.NewMemFilesystem(fs.MemConfig{}) + require.NoError(t, err) + + im, err := NewIdentityManager(IdentityConfig{ + FS: dummyfs, + Superuser: User{Name: "foobar"}, + JWTRealm: "test-realm", + JWTSecret: "abc123", + Logger: nil, + }) + require.NoError(t, err) + require.NotNil(t, im) + + err = im.Create(User{Name: "fooboz"}) + require.NoError(t, err) + + err = im.Update("unknown", User{Name: "fooboz"}) + require.Error(t, err) + + err = im.Update("foobar", User{Name: "foobar"}) + require.Error(t, err) + + err = im.Update("foobar", User{Name: "fooboz"}) + require.Error(t, err) + + identity, err := im.GetVerifier("foobar") + require.NoError(t, err) + require.NotNil(t, identity) + require.Equal(t, "foobar", identity.Name()) + + err = im.Update("foobar", User{Name: "foobaz"}) + require.Error(t, err) + require.Equal(t, "foobar", identity.Name()) + + identity, err = im.GetVerifier("foobaz") + require.Error(t, err) + require.Nil(t, identity) + + identity, err = im.GetVerifier("fooboz") + require.NoError(t, err) + require.NotNil(t, identity) + require.Equal(t, "fooboz", identity.Name()) + + err = im.Update("fooboz", User{Name: "foobaz"}) + require.NoError(t, err) +} + +func TestUpdateUserAuth0(t *testing.T) { + dummyfs, err := fs.NewMemFilesystem(fs.MemConfig{}) + require.NoError(t, err) + + im, err := NewIdentityManager(IdentityConfig{ + FS: dummyfs, + Superuser: User{Name: "foobar"}, + JWTRealm: "test-realm", + JWTSecret: "abc123", + Logger: nil, + }) + require.NoError(t, err) + require.NotNil(t, im) + + err = im.Create(User{ + Name: "foobaz", + Superuser: false, + Auth: UserAuth{ + API: UserAuthAPI{ + Auth0: UserAuthAPIAuth0{ + User: "auth0|123456", + Tenant: Auth0Tenant{ + Domain: "datarhei-demo.eu.auth0.com", + Audience: "https://datarhei-demo.eu.auth0.com/api/v2/", + ClientID: "123456", + }, + }, + }, + }, + }) + require.NoError(t, err) + + identity, err := im.GetVerifierFromAuth0("auth0|123456") + require.NoError(t, err) + require.NotNil(t, identity) + + identity, err = im.GetVerifier("foobaz") + require.NoError(t, err) + require.NotNil(t, identity) + + user, err := im.Get("foobaz") + require.NoError(t, err) + + user.Name = "fooboz" + + err = im.Update("foobaz", user) + require.NoError(t, err) + + identity, err = im.GetVerifierFromAuth0("auth0|123456") + require.NoError(t, err) + require.NotNil(t, identity) + + identity, err = im.GetVerifier("fooboz") + require.NoError(t, err) + require.NotNil(t, identity) +} + +func TestRemoveUser(t *testing.T) { + dummyfs, err := fs.NewMemFilesystem(fs.MemConfig{}) + require.NoError(t, err) + + im, err := NewIdentityManager(IdentityConfig{ + FS: dummyfs, + Superuser: User{Name: "foobar"}, + JWTRealm: "test-realm", + JWTSecret: "abc123", + Logger: nil, + }) + require.NoError(t, err) + require.NotNil(t, im) + + err = im.Delete("fooboz") + require.Error(t, err) + + err = im.Delete("foobar") + require.Error(t, err) + + err = im.Create(User{ + Name: "foobaz", + Superuser: false, + Auth: UserAuth{ + API: UserAuthAPI{ + Password: "apisecret", + Auth0: UserAuthAPIAuth0{}, + }, + Services: UserAuthServices{ + Basic: []string{"secret"}, + Token: []string{"tokensecret"}, + }, + }, + }) + require.NoError(t, err) + + identity, err := im.GetVerifier("foobaz") + require.NoError(t, err) + require.NotNil(t, identity) + + ok, err := identity.VerifyAPIPassword("apisecret") + require.True(t, ok) + require.NoError(t, err) + + ok, err = identity.VerifyServiceBasicAuth("secret") + require.True(t, ok) + require.NoError(t, err) + + ok, err = identity.VerifyServiceToken("tokensecret") + require.True(t, ok) + require.NoError(t, err) + + access, refresh, err := im.CreateJWT("foobaz") + require.NoError(t, err) + + ok, err = identity.VerifyJWT(access) + require.True(t, ok) + require.NoError(t, err) + + ok, err = identity.VerifyJWT(refresh) + require.True(t, ok) + require.NoError(t, err) + + err = im.Delete("foobaz") + require.NoError(t, err) + + ok, err = identity.VerifyAPIPassword("apisecret") + require.False(t, ok) + require.Error(t, err) + + ok, err = identity.VerifyServiceBasicAuth("secret") + require.False(t, ok) + require.Error(t, err) + + ok, err = identity.VerifyServiceToken("tokensecret") + require.False(t, ok) + require.Error(t, err) + + ok, err = identity.VerifyJWT(access) + require.False(t, ok) + require.Error(t, err) + + ok, err = identity.VerifyJWT(refresh) + require.False(t, ok) + require.Error(t, err) + + identity, err = im.GetVerifier("foobaz") + require.Error(t, err) + require.Nil(t, identity) + + access, refresh, err = im.CreateJWT("foobaz") + require.Error(t, err) + require.Empty(t, access) + require.Empty(t, refresh) +} + +func TestRemoveUserAuth0(t *testing.T) { + dummyfs, err := fs.NewMemFilesystem(fs.MemConfig{}) + require.NoError(t, err) + + im, err := NewIdentityManager(IdentityConfig{ + FS: dummyfs, + Superuser: User{Name: "foobar"}, + JWTRealm: "test-realm", + JWTSecret: "abc123", + Logger: nil, + }) + require.NoError(t, err) + require.NotNil(t, im) + + err = im.Create(User{ + Name: "foobaz", + Superuser: false, + Auth: UserAuth{ + API: UserAuthAPI{ + Auth0: UserAuthAPIAuth0{ + User: "auth0|123456", + Tenant: Auth0Tenant{ + Domain: "datarhei-demo.eu.auth0.com", + Audience: "https://datarhei-demo.eu.auth0.com/api/v2/", + ClientID: "123456", + }, + }, + }, + }, + }) + require.NoError(t, err) + + err = im.Create(User{ + Name: "fooboz", + Superuser: false, + Auth: UserAuth{ + API: UserAuthAPI{ + Auth0: UserAuthAPIAuth0{ + User: "auth0|987654", + Tenant: Auth0Tenant{ + Domain: "datarhei-demo.eu.auth0.com", + Audience: "https://datarhei-demo.eu.auth0.com/api/v2/", + ClientID: "987654", + }, + }, + }, + }, + }) + require.NoError(t, err) + + manager, ok := im.(*identityManager) + require.True(t, ok) + require.NotNil(t, manager) + + require.Equal(t, 1, len(manager.tenants)) + require.Equal(t, map[string]string{"auth0|123456": "foobaz", "auth0|987654": "fooboz"}, manager.auth0UserIdentityMap) + + require.ElementsMatch(t, []string{ + "localjwt", + "auth0 domain=datarhei-demo.eu.auth0.com audience=https://datarhei-demo.eu.auth0.com/api/v2/ clientid=123456", + "auth0 domain=datarhei-demo.eu.auth0.com audience=https://datarhei-demo.eu.auth0.com/api/v2/ clientid=987654", + }, im.Validators()) + + err = im.Delete("foobaz") + require.NoError(t, err) + + require.Equal(t, 1, len(manager.tenants)) + require.Equal(t, map[string]string{"auth0|987654": "fooboz"}, manager.auth0UserIdentityMap) + + require.ElementsMatch(t, []string{ + "localjwt", + "auth0 domain=datarhei-demo.eu.auth0.com audience=https://datarhei-demo.eu.auth0.com/api/v2/ clientid=987654", + }, im.Validators()) + + err = im.Delete("fooboz") + require.NoError(t, err) + + require.Equal(t, 0, len(manager.tenants)) + require.ElementsMatch(t, []string{ + "localjwt", + }, im.Validators()) +} + +func TestAutosave(t *testing.T) { + dummyfs, err := fs.NewMemFilesystem(fs.MemConfig{}) + require.NoError(t, err) + + im, err := NewIdentityManager(IdentityConfig{ + FS: dummyfs, + Superuser: User{Name: "foobar"}, + JWTRealm: "test-realm", + JWTSecret: "abc123", + Logger: nil, + }) + require.NoError(t, err) + require.NotNil(t, im) + + err = im.Save() + require.NoError(t, err) + + _, err = dummyfs.Stat("./users.json") + require.NoError(t, err) + + data, err := dummyfs.ReadFile("./users.json") + require.NoError(t, err) + require.Equal(t, []byte("[]"), data) + + im.Autosave(true) + + err = im.Create(User{Name: "foobaz"}) + require.NoError(t, err) + + data, err = dummyfs.ReadFile("./users.json") + require.NoError(t, err) + require.NotEqual(t, []byte("[]"), data) + + user, err := im.Get("foobaz") + require.NoError(t, err) + + user.Name = "fooboz" + + err = im.Update("foobaz", user) + require.NoError(t, err) + + data, err = dummyfs.ReadFile("./users.json") + require.NoError(t, err) + require.NotEqual(t, []byte("[]"), data) + + err = im.Delete("fooboz") + require.NoError(t, err) + + data, err = dummyfs.ReadFile("./users.json") + require.NoError(t, err) + require.Equal(t, []byte("[]"), data) +} diff --git a/http/jwt/jwks/config.go b/iam/jwks/config.go similarity index 100% rename from http/jwt/jwks/config.go rename to iam/jwks/config.go diff --git a/http/jwt/jwks/doc.go b/iam/jwks/doc.go similarity index 100% rename from http/jwt/jwks/doc.go rename to iam/jwks/doc.go diff --git a/http/jwt/jwks/ecdsa.go b/iam/jwks/ecdsa.go similarity index 100% rename from http/jwt/jwks/ecdsa.go rename to iam/jwks/ecdsa.go diff --git a/http/jwt/jwks/jwks.go b/iam/jwks/jwks.go similarity index 100% rename from http/jwt/jwks/jwks.go rename to iam/jwks/jwks.go diff --git a/http/jwt/jwks/rsa.go b/iam/jwks/rsa.go similarity index 100% rename from http/jwt/jwks/rsa.go rename to iam/jwks/rsa.go diff --git a/io/fs/fixtures/a.txt b/io/fs/fixtures/a.txt new file mode 100644 index 00000000..ea3377dc --- /dev/null +++ b/io/fs/fixtures/a.txt @@ -0,0 +1 @@ +qwertz \ No newline at end of file diff --git a/io/fs/fixtures/b.txt b/io/fs/fixtures/b.txt new file mode 100644 index 00000000..ea3377dc --- /dev/null +++ b/io/fs/fixtures/b.txt @@ -0,0 +1 @@ +qwertz \ No newline at end of file diff --git a/io/fs/mem.go b/io/fs/mem.go index a75eb932..3f9681e7 100644 --- a/io/fs/mem.go +++ b/io/fs/mem.go @@ -156,6 +156,8 @@ func NewMemFilesystemFromDir(dir string, config MemConfig) (Filesystem, error) { return nil, err } + dir = filepath.Clean(dir) + err = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { if err != nil { return nil @@ -181,7 +183,7 @@ func NewMemFilesystemFromDir(dir string, config MemConfig) (Filesystem, error) { defer file.Close() - _, _, err = mem.WriteFileReader(path, file) + _, _, err = mem.WriteFileReader(strings.TrimPrefix(path, dir), file) if err != nil { return fmt.Errorf("can't copy %s", path) } diff --git a/io/fs/mem_test.go b/io/fs/mem_test.go index d28a0d92..d8b3fa1f 100644 --- a/io/fs/mem_test.go +++ b/io/fs/mem_test.go @@ -7,24 +7,16 @@ import ( ) func TestMemFromDir(t *testing.T) { - mem, err := NewMemFilesystemFromDir(".", MemConfig{}) + mem, err := NewMemFilesystemFromDir("./fixtures", MemConfig{}) require.NoError(t, err) names := []string{} - for _, f := range mem.List("/", "/*.go") { + for _, f := range mem.List("/", "") { names = append(names, f.Name()) } require.ElementsMatch(t, []string{ - "/disk.go", - "/fs_test.go", - "/fs.go", - "/mem_test.go", - "/mem.go", - "/readonly_test.go", - "/readonly.go", - "/s3.go", - "/sized_test.go", - "/sized.go", + "/a.txt", + "/b.txt", }, names) } diff --git a/monitor/restream.go b/monitor/restream.go index c83e17ff..41d84745 100644 --- a/monitor/restream.go +++ b/monitor/restream.go @@ -57,7 +57,7 @@ func (c *restreamCollector) Collect() metric.Metrics { "starting": 0, } - ids := c.r.GetProcessIDs("", "") + ids := c.r.GetProcessIDs("", "", "", "") for _, id := range ids { state, _ := c.r.GetProcessState(id) @@ -72,31 +72,31 @@ func (c *restreamCollector) Collect() metric.Metrics { states[state.State]++ - metrics.Add(metric.NewValue(c.restreamProcessDescr, float64(state.Progress.Frame), id, state.State, state.Order, "frame")) - metrics.Add(metric.NewValue(c.restreamProcessDescr, float64(state.Progress.FPS), id, state.State, state.Order, "fps")) - metrics.Add(metric.NewValue(c.restreamProcessDescr, float64(state.Progress.Speed), id, state.State, state.Order, "speed")) - metrics.Add(metric.NewValue(c.restreamProcessDescr, state.Progress.Quantizer, id, state.State, state.Order, "q")) - metrics.Add(metric.NewValue(c.restreamProcessDescr, float64(state.Progress.Size), id, state.State, state.Order, "size")) - metrics.Add(metric.NewValue(c.restreamProcessDescr, state.Progress.Time, id, state.State, state.Order, "time")) - metrics.Add(metric.NewValue(c.restreamProcessDescr, float64(state.Progress.Drop), id, state.State, state.Order, "drop")) - metrics.Add(metric.NewValue(c.restreamProcessDescr, float64(state.Progress.Dup), id, state.State, state.Order, "dup")) - metrics.Add(metric.NewValue(c.restreamProcessDescr, float64(state.Progress.Packet), id, state.State, state.Order, "packet")) - metrics.Add(metric.NewValue(c.restreamProcessDescr, state.Progress.Bitrate, id, state.State, state.Order, "bitrate")) - metrics.Add(metric.NewValue(c.restreamProcessDescr, state.CPU, id, state.State, state.Order, "cpu")) - metrics.Add(metric.NewValue(c.restreamProcessDescr, float64(state.Memory), id, state.State, state.Order, "memory")) - metrics.Add(metric.NewValue(c.restreamProcessDescr, state.Duration, id, state.State, state.Order, "uptime")) + metrics.Add(metric.NewValue(c.restreamProcessDescr, float64(state.Progress.Frame), id.String(), state.State, state.Order, "frame")) + metrics.Add(metric.NewValue(c.restreamProcessDescr, float64(state.Progress.FPS), id.String(), state.State, state.Order, "fps")) + metrics.Add(metric.NewValue(c.restreamProcessDescr, float64(state.Progress.Speed), id.String(), state.State, state.Order, "speed")) + metrics.Add(metric.NewValue(c.restreamProcessDescr, state.Progress.Quantizer, id.String(), state.State, state.Order, "q")) + metrics.Add(metric.NewValue(c.restreamProcessDescr, float64(state.Progress.Size), id.String(), state.State, state.Order, "size")) + metrics.Add(metric.NewValue(c.restreamProcessDescr, state.Progress.Time, id.String(), state.State, state.Order, "time")) + metrics.Add(metric.NewValue(c.restreamProcessDescr, float64(state.Progress.Drop), id.String(), state.State, state.Order, "drop")) + metrics.Add(metric.NewValue(c.restreamProcessDescr, float64(state.Progress.Dup), id.String(), state.State, state.Order, "dup")) + metrics.Add(metric.NewValue(c.restreamProcessDescr, float64(state.Progress.Packet), id.String(), state.State, state.Order, "packet")) + metrics.Add(metric.NewValue(c.restreamProcessDescr, state.Progress.Bitrate, id.String(), state.State, state.Order, "bitrate")) + metrics.Add(metric.NewValue(c.restreamProcessDescr, state.CPU, id.String(), state.State, state.Order, "cpu")) + metrics.Add(metric.NewValue(c.restreamProcessDescr, float64(state.Memory), id.String(), state.State, state.Order, "memory")) + metrics.Add(metric.NewValue(c.restreamProcessDescr, state.Duration, id.String(), state.State, state.Order, "uptime")) if proc.Config != nil { - metrics.Add(metric.NewValue(c.restreamProcessDescr, proc.Config.LimitCPU, id, state.State, state.Order, "cpu_limit")) - metrics.Add(metric.NewValue(c.restreamProcessDescr, float64(proc.Config.LimitMemory), id, state.State, state.Order, "memory_limit")) + metrics.Add(metric.NewValue(c.restreamProcessDescr, proc.Config.LimitCPU, id.String(), state.State, state.Order, "cpu_limit")) + metrics.Add(metric.NewValue(c.restreamProcessDescr, float64(proc.Config.LimitMemory), id.String(), state.State, state.Order, "memory_limit")) } - metrics.Add(metric.NewValue(c.restreamProcessStatesDescr, float64(state.States.Failed), id, "failed")) - metrics.Add(metric.NewValue(c.restreamProcessStatesDescr, float64(state.States.Finished), id, "finished")) - metrics.Add(metric.NewValue(c.restreamProcessStatesDescr, float64(state.States.Finishing), id, "finishing")) - metrics.Add(metric.NewValue(c.restreamProcessStatesDescr, float64(state.States.Killed), id, "killed")) - metrics.Add(metric.NewValue(c.restreamProcessStatesDescr, float64(state.States.Running), id, "running")) - metrics.Add(metric.NewValue(c.restreamProcessStatesDescr, float64(state.States.Starting), id, "starting")) + metrics.Add(metric.NewValue(c.restreamProcessStatesDescr, float64(state.States.Failed), id.String(), "failed")) + metrics.Add(metric.NewValue(c.restreamProcessStatesDescr, float64(state.States.Finished), id.String(), "finished")) + metrics.Add(metric.NewValue(c.restreamProcessStatesDescr, float64(state.States.Finishing), id.String(), "finishing")) + metrics.Add(metric.NewValue(c.restreamProcessStatesDescr, float64(state.States.Killed), id.String(), "killed")) + metrics.Add(metric.NewValue(c.restreamProcessStatesDescr, float64(state.States.Running), id.String(), "running")) + metrics.Add(metric.NewValue(c.restreamProcessStatesDescr, float64(state.States.Starting), id.String(), "starting")) for i := range state.Progress.Input { io := &state.Progress.Input[i] @@ -104,32 +104,32 @@ func (c *restreamCollector) Collect() metric.Metrics { index := strconv.FormatUint(io.Index, 10) stream := strconv.FormatUint(io.Stream, 10) - metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.Frame), id, "input", io.ID, io.Address, index, stream, io.Type, "frame")) - metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.FPS), id, "input", io.ID, io.Address, index, stream, io.Type, "fps")) - metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.Packet), id, "input", io.ID, io.Address, index, stream, io.Type, "packet")) - metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.PPS), id, "input", io.ID, io.Address, index, stream, io.Type, "pps")) - metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.Size), id, "input", io.ID, io.Address, index, stream, io.Type, "size")) - metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.Bitrate), id, "input", io.ID, io.Address, index, stream, io.Type, "bitrate")) + metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.Frame), id.String(), "input", io.ID, io.Address, index, stream, io.Type, "frame")) + metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.FPS), id.String(), "input", io.ID, io.Address, index, stream, io.Type, "fps")) + metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.Packet), id.String(), "input", io.ID, io.Address, index, stream, io.Type, "packet")) + metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.PPS), id.String(), "input", io.ID, io.Address, index, stream, io.Type, "pps")) + metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.Size), id.String(), "input", io.ID, io.Address, index, stream, io.Type, "size")) + metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.Bitrate), id.String(), "input", io.ID, io.Address, index, stream, io.Type, "bitrate")) if io.AVstream != nil { a := io.AVstream - metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(a.Queue), id, "input", io.ID, io.Address, index, stream, io.Type, "avstream_queue")) - metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(a.Dup), id, "input", io.ID, io.Address, index, stream, io.Type, "avstream_dup")) - metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(a.Drop), id, "input", io.ID, io.Address, index, stream, io.Type, "avstream_drop")) - metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(a.Enc), id, "input", io.ID, io.Address, index, stream, io.Type, "avstream_enc")) + metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(a.Queue), id.String(), "input", io.ID, io.Address, index, stream, io.Type, "avstream_queue")) + metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(a.Dup), id.String(), "input", io.ID, io.Address, index, stream, io.Type, "avstream_dup")) + metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(a.Drop), id.String(), "input", io.ID, io.Address, index, stream, io.Type, "avstream_drop")) + metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(a.Enc), id.String(), "input", io.ID, io.Address, index, stream, io.Type, "avstream_enc")) value = 0 if a.Looping { value = 1 } - metrics.Add(metric.NewValue(c.restreamProcessIODescr, value, id, "input", io.ID, io.Address, index, stream, io.Type, "avstream_looping")) + metrics.Add(metric.NewValue(c.restreamProcessIODescr, value, id.String(), "input", io.ID, io.Address, index, stream, io.Type, "avstream_looping")) value = 0 if a.Duplicating { value = 1 } - metrics.Add(metric.NewValue(c.restreamProcessIODescr, value, id, "input", io.ID, io.Address, index, stream, io.Type, "avstream_duplicating")) + metrics.Add(metric.NewValue(c.restreamProcessIODescr, value, id.String(), "input", io.ID, io.Address, index, stream, io.Type, "avstream_duplicating")) } } @@ -139,13 +139,13 @@ func (c *restreamCollector) Collect() metric.Metrics { index := strconv.FormatUint(io.Index, 10) stream := strconv.FormatUint(io.Stream, 10) - metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.Frame), id, "output", io.ID, io.Address, index, stream, io.Type, "frame")) - metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.FPS), id, "output", io.ID, io.Address, index, stream, io.Type, "fps")) - metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.Packet), id, "output", io.ID, io.Address, index, stream, io.Type, "packet")) - metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.PPS), id, "output", io.ID, io.Address, index, stream, io.Type, "pps")) - metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.Size), id, "output", io.ID, io.Address, index, stream, io.Type, "size")) - metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.Bitrate), id, "output", io.ID, io.Address, index, stream, io.Type, "bitrate")) - metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.Quantizer), id, "output", io.ID, io.Address, index, stream, io.Type, "q")) + metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.Frame), id.String(), "output", io.ID, io.Address, index, stream, io.Type, "frame")) + metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.FPS), id.String(), "output", io.ID, io.Address, index, stream, io.Type, "fps")) + metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.Packet), id.String(), "output", io.ID, io.Address, index, stream, io.Type, "packet")) + metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.PPS), id.String(), "output", io.ID, io.Address, index, stream, io.Type, "pps")) + metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.Size), id.String(), "output", io.ID, io.Address, index, stream, io.Type, "size")) + metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.Bitrate), id.String(), "output", io.ID, io.Address, index, stream, io.Type, "bitrate")) + metrics.Add(metric.NewValue(c.restreamProcessIODescr, float64(io.Quantizer), id.String(), "output", io.ID, io.Address, index, stream, io.Type, "q")) } } diff --git a/restream/app/process.go b/restream/app/process.go index ab78d0f5..04336d46 100644 --- a/restream/app/process.go +++ b/restream/app/process.go @@ -5,17 +5,17 @@ import ( ) type ConfigIOCleanup struct { - Pattern string `json:"pattern"` - MaxFiles uint `json:"max_files"` - MaxFileAge uint `json:"max_file_age_seconds"` - PurgeOnDelete bool `json:"purge_on_delete"` + Pattern string + MaxFiles uint + MaxFileAge uint + PurgeOnDelete bool } type ConfigIO struct { - ID string `json:"id"` - Address string `json:"address"` - Options []string `json:"options"` - Cleanup []ConfigIOCleanup `json:"cleanup"` + ID string + Address string + Options []string + Cleanup []ConfigIOCleanup } func (io ConfigIO) Clone() ConfigIO { @@ -34,25 +34,29 @@ func (io ConfigIO) Clone() ConfigIO { } type Config struct { - ID string `json:"id"` - Reference string `json:"reference"` - FFVersion string `json:"ffversion"` - Input []ConfigIO `json:"input"` - Output []ConfigIO `json:"output"` - Options []string `json:"options"` - Reconnect bool `json:"reconnect"` - ReconnectDelay uint64 `json:"reconnect_delay_seconds"` // seconds - Autostart bool `json:"autostart"` - StaleTimeout uint64 `json:"stale_timeout_seconds"` // seconds - LimitCPU float64 `json:"limit_cpu_usage"` // percent 0-100 - LimitMemory uint64 `json:"limit_memory_bytes"` // bytes - LimitWaitFor uint64 `json:"limit_waitfor_seconds"` // seconds + ID string + Reference string + Owner string + Domain string + FFVersion string + Input []ConfigIO + Output []ConfigIO + Options []string + Reconnect bool + ReconnectDelay uint64 // seconds + Autostart bool + StaleTimeout uint64 // seconds + LimitCPU float64 // percent + LimitMemory uint64 // bytes + LimitWaitFor uint64 // seconds } func (config *Config) Clone() *Config { clone := &Config{ ID: config.ID, Reference: config.Reference, + Owner: config.Owner, + Domain: config.Domain, FFVersion: config.FFVersion, Reconnect: config.Reconnect, ReconnectDelay: config.ReconnectDelay, @@ -102,17 +106,21 @@ func (config *Config) CreateCommand() []string { } type Process struct { - ID string `json:"id"` - Reference string `json:"reference"` - Config *Config `json:"config"` - CreatedAt int64 `json:"created_at"` - UpdatedAt int64 `json:"updated_at"` - Order string `json:"order"` + ID string + Owner string + Domain string + Reference string + Config *Config + CreatedAt int64 + UpdatedAt int64 + Order string } func (process *Process) Clone() *Process { clone := &Process{ ID: process.ID, + Owner: process.Owner, + Domain: process.Domain, Reference: process.Reference, Config: process.Config.Clone(), CreatedAt: process.CreatedAt, diff --git a/restream/restream.go b/restream/restream.go index 7b252d2a..dca3274e 100644 --- a/restream/restream.go +++ b/restream/restream.go @@ -15,6 +15,7 @@ import ( "github.com/datarhei/core/v16/ffmpeg/parse" "github.com/datarhei/core/v16/ffmpeg/skills" "github.com/datarhei/core/v16/glob" + "github.com/datarhei/core/v16/iam" "github.com/datarhei/core/v16/io/fs" "github.com/datarhei/core/v16/log" "github.com/datarhei/core/v16/net" @@ -23,38 +24,42 @@ import ( "github.com/datarhei/core/v16/restream/app" rfs "github.com/datarhei/core/v16/restream/fs" "github.com/datarhei/core/v16/restream/replace" + "github.com/datarhei/core/v16/restream/rewrite" "github.com/datarhei/core/v16/restream/store" + jsonstore "github.com/datarhei/core/v16/restream/store/json" "github.com/Masterminds/semver/v3" ) // The Restreamer interface type Restreamer interface { - ID() string // ID of this instance - Name() string // Arbitrary name of this instance - CreatedAt() time.Time // Time of when this instance has been created - Start() // Start all processes that have a "start" order - Stop() // Stop all running process but keep their "start" order - AddProcess(config *app.Config) error // Add a new process - GetProcessIDs(idpattern, refpattern string) []string // Get a list of process IDs based on patterns for ID and reference - DeleteProcess(id string) error // Delete a process - UpdateProcess(id string, config *app.Config) error // Update a process - StartProcess(id string) error // Start a process - StopProcess(id string) error // Stop a process - RestartProcess(id string) error // Restart a process - ReloadProcess(id string) error // Reload a process - GetProcess(id string) (*app.Process, error) // Get a process - GetProcessState(id string) (*app.State, error) // Get the state of a process - GetProcessLog(id string) (*app.Log, error) // Get the logs of a process - GetPlayout(id, inputid string) (string, error) // Get the URL of the playout API for a process - Probe(id string) app.Probe // Probe a process - ProbeWithTimeout(id string, timeout time.Duration) app.Probe // Probe a process with specific timeout - Skills() skills.Skills // Get the ffmpeg skills - ReloadSkills() error // Reload the ffmpeg skills - SetProcessMetadata(id, key string, data interface{}) error // Set metatdata to a process - GetProcessMetadata(id, key string) (interface{}, error) // Get previously set metadata from a process - SetMetadata(key string, data interface{}) error // Set general metadata - GetMetadata(key string) (interface{}, error) // Get previously set general metadata + ID() string // ID of this instance + Name() string // Arbitrary name of this instance + CreatedAt() time.Time // Time of when this instance has been created + Start() // Start all processes that have a "start" order + Stop() // Stop all running process but keep their "start" order + + Skills() skills.Skills // Get the ffmpeg skills + ReloadSkills() error // Reload the ffmpeg skills + SetMetadata(key string, data interface{}) error // Set general metadata + GetMetadata(key string) (interface{}, error) // Get previously set general metadata + + AddProcess(config *app.Config) error // Add a new process + GetProcessIDs(idpattern, refpattern, ownerpattern, domainpattern string) []TaskID // Get a list of process IDs based on patterns for ID and reference + DeleteProcess(id TaskID) error // Delete a process + UpdateProcess(id TaskID, config *app.Config) error // Update a process + StartProcess(id TaskID) error // Start a process + StopProcess(id TaskID) error // Stop a process + RestartProcess(id TaskID) error // Restart a process + ReloadProcess(id TaskID) error // Reload a process + GetProcess(id TaskID) (*app.Process, error) // Get a process + GetProcessState(id TaskID) (*app.State, error) // Get the state of a process + GetProcessLog(id TaskID) (*app.Log, error) // Get the logs of a process + GetPlayout(id TaskID, inputid string) (string, error) // Get the URL of the playout API for a process + Probe(id TaskID) app.Probe // Probe a process + ProbeWithTimeout(id TaskID, timeout time.Duration) app.Probe // Probe a process with specific timeout + SetProcessMetadata(id TaskID, key string, data interface{}) error // Set metatdata to a process + GetProcessMetadata(id TaskID, key string) (interface{}, error) // Get previously set metadata from a process } // Config is the required configuration for a new restreamer instance. @@ -64,14 +69,18 @@ type Config struct { Store store.Store Filesystems []fs.Filesystem Replace replace.Replacer + Rewrite rewrite.Rewriter FFmpeg ffmpeg.FFmpeg MaxProcesses int64 Logger log.Logger + IAM iam.IAM } type task struct { valid bool id string // ID of the task/process + owner string + domain string reference string process *app.Process config *app.Config @@ -84,6 +93,34 @@ type task struct { metadata map[string]interface{} } +func (t *task) ID() TaskID { + return TaskID{ + ID: t.id, + Domain: t.domain, + } +} + +func (t *task) String() string { + return t.ID().String() +} + +type TaskID struct { + ID string + Domain string +} + +func (t TaskID) String() string { + return t.ID + "@" + t.Domain +} + +func (t TaskID) Equals(b TaskID) bool { + if t.ID == b.ID && t.Domain == b.Domain { + return true + } + + return false +} + type restream struct { id string name string @@ -98,14 +135,17 @@ type restream struct { stopObserver context.CancelFunc } replace replace.Replacer - tasks map[string]*task + rewrite rewrite.Rewriter + tasks map[TaskID]*task // domain:processid + metadata map[string]interface{} // global metadata logger log.Logger - metadata map[string]interface{} lock sync.RWMutex startOnce sync.Once stopOnce sync.Once + + iam iam.IAM } // New returns a new instance that implements the Restreamer interface @@ -116,16 +156,22 @@ func New(config Config) (Restreamer, error) { createdAt: time.Now(), store: config.Store, replace: config.Replace, + rewrite: config.Rewrite, logger: config.Logger, + iam: config.IAM, } if r.logger == nil { r.logger = log.New("") } + if r.iam == nil { + return nil, fmt.Errorf("missing IAM") + } + if r.store == nil { dummyfs, _ := fs.NewMemFilesystem(fs.MemConfig{}) - s, err := store.NewJSON(store.JSONConfig{ + s, err := jsonstore.New(jsonstore.Config{ Filesystem: dummyfs, }) if err != nil { @@ -204,8 +250,7 @@ func (r *restream) Stop() { r.lock.Lock() defer r.lock.Unlock() - // Stop the currently running processes without - // altering their order such that on a subsequent + // Stop the currently running processes without altering their order such that on a subsequent // Start() they will get restarted. for id, t := range r.tasks { if t.ffmpeg != nil { @@ -272,7 +317,7 @@ func (r *restream) load() error { return err } - tasks := make(map[string]*task) + tasks := make(map[TaskID]*task) skills := r.ffmpeg.Skills() ffversion := skills.FFmpeg.Version @@ -281,32 +326,33 @@ func (r *restream) load() error { ffversion = fmt.Sprintf("%d.%d.0", v.Major(), v.Minor()) } - for id, process := range data.Process { - if len(process.Config.FFVersion) == 0 { - process.Config.FFVersion = "^" + ffversion + for _, domain := range data.Process { + for _, p := range domain { + if len(p.Process.Config.FFVersion) == 0 { + p.Process.Config.FFVersion = "^" + ffversion + } + + t := &task{ + id: p.Process.ID, + owner: p.Process.Owner, + domain: p.Process.Domain, + reference: p.Process.Reference, + process: p.Process, + config: p.Process.Config.Clone(), + logger: r.logger.WithFields(log.Fields{ + "id": p.Process.ID, + "owner": p.Process.Owner, + "domain": p.Process.Domain, + }), + } + + t.metadata = p.Metadata + + // Replace all placeholders in the config + resolvePlaceholders(t.config, r.replace) + + tasks[t.ID()] = t } - - t := &task{ - id: id, - reference: process.Reference, - process: process, - config: process.Config.Clone(), - logger: r.logger.WithField("id", id), - } - - // Replace all placeholders in the config - resolvePlaceholders(t.config, r.replace) - - tasks[id] = t - } - - for id, userdata := range data.Metadata.Process { - t, ok := tasks[id] - if !ok { - continue - } - - t.metadata = userdata } // Now that all tasks are defined and all placeholders are @@ -317,39 +363,38 @@ func (r *restream) load() error { if c, err := semver.NewConstraint(t.config.FFVersion); err == nil { if v, err := semver.NewVersion(skills.FFmpeg.Version); err == nil { if !c.Check(v) { - r.logger.Warn().WithFields(log.Fields{ - "id": t.id, + t.logger.Warn().WithFields(log.Fields{ "constraint": t.config.FFVersion, "version": skills.FFmpeg.Version, }).WithError(fmt.Errorf("available FFmpeg version doesn't fit constraint; you have to update this process to adjust the constraint")).Log("") } } else { - r.logger.Warn().WithField("id", t.id).WithError(err).Log("") + t.logger.Warn().WithError(err).Log("") } } else { - r.logger.Warn().WithField("id", t.id).WithError(err).Log("") + t.logger.Warn().WithError(err).Log("") } err := r.resolveAddresses(tasks, t.config) if err != nil { - r.logger.Warn().WithField("id", t.id).WithError(err).Log("Ignoring") + t.logger.Warn().WithError(err).Log("Ignoring") continue } t.usesDisk, err = r.validateConfig(t.config) if err != nil { - r.logger.Warn().WithField("id", t.id).WithError(err).Log("Ignoring") + t.logger.Warn().WithError(err).Log("Ignoring") continue } err = r.setPlayoutPorts(t) if err != nil { - r.logger.Warn().WithField("id", t.id).WithError(err).Log("Ignoring") + t.logger.Warn().WithError(err).Log("Ignoring") continue } t.command = t.config.CreateCommand() - t.parser = r.ffmpeg.NewProcessParser(t.logger, t.id, t.reference) + t.parser = r.ffmpeg.NewProcessParser(t.logger, t.String(), t.reference) ffmpeg, err := r.ffmpeg.New(ffmpeg.ProcessConfig{ Reconnect: t.config.Reconnect, @@ -371,20 +416,30 @@ func (r *restream) load() error { } r.tasks = tasks - r.metadata = data.Metadata.System + r.metadata = data.Metadata return nil } func (r *restream) save() { - data := store.NewStoreData() + data := store.NewData() - for id, t := range r.tasks { - data.Process[id] = t.process - data.Metadata.System = r.metadata - data.Metadata.Process[id] = t.metadata + for tid, t := range r.tasks { + domain := data.Process[tid.Domain] + if domain == nil { + domain = map[string]store.Process{} + } + + domain[tid.ID] = store.Process{ + Process: t.process.Clone(), + Metadata: t.metadata, + } + + data.Process[tid.Domain] = domain } + data.Metadata = r.metadata + r.store.Store(data) } @@ -401,7 +456,9 @@ func (r *restream) CreatedAt() time.Time { } var ErrUnknownProcess = errors.New("unknown process") +var ErrUnknownProcessGroup = errors.New("unknown process group") var ErrProcessExists = errors.New("process already exists") +var ErrForbidden = errors.New("forbidden") func (r *restream) AddProcess(config *app.Config) error { r.lock.RLock() @@ -415,20 +472,22 @@ func (r *restream) AddProcess(config *app.Config) error { r.lock.Lock() defer r.lock.Unlock() - _, ok := r.tasks[t.id] + tid := t.ID() + + _, ok := r.tasks[tid] if ok { return ErrProcessExists } - r.tasks[t.id] = t + r.tasks[tid] = t // set filesystem cleanup rules - r.setCleanup(t.id, t.config) + r.setCleanup(tid, t.config) if t.process.Order == "start" { - err := r.startProcess(t.id) + err := r.startProcess(tid) if err != nil { - delete(r.tasks, t.id) + delete(r.tasks, tid) return err } } @@ -453,6 +512,7 @@ func (r *restream) createTask(config *app.Config) (*task, error) { process := &app.Process{ ID: config.ID, + Domain: config.Domain, Reference: config.Reference, Config: config.Clone(), Order: "stop", @@ -467,10 +527,14 @@ func (r *restream) createTask(config *app.Config) (*task, error) { t := &task{ id: config.ID, + domain: config.Domain, reference: process.Reference, process: process, config: process.Config.Clone(), - logger: r.logger.WithField("id", process.ID), + logger: r.logger.WithFields(log.Fields{ + "id": process.ID, + "group": process.Domain, + }), } resolvePlaceholders(t.config, r.replace) @@ -491,7 +555,7 @@ func (r *restream) createTask(config *app.Config) (*task, error) { } t.command = t.config.CreateCommand() - t.parser = r.ffmpeg.NewProcessParser(t.logger, t.id, t.reference) + t.parser = r.ffmpeg.NewProcessParser(t.logger, t.String(), t.reference) ffmpeg, err := r.ffmpeg.New(ffmpeg.ProcessConfig{ Reconnect: t.config.Reconnect, @@ -514,7 +578,7 @@ func (r *restream) createTask(config *app.Config) (*task, error) { return t, nil } -func (r *restream) setCleanup(id string, config *app.Config) { +func (r *restream) setCleanup(id TaskID, config *app.Config) { rePrefix := regexp.MustCompile(`^([a-z]+):`) for _, output := range config.Output { @@ -545,7 +609,7 @@ func (r *restream) setCleanup(id string, config *app.Config) { PurgeOnDelete: c.PurgeOnDelete, } - fs.SetCleanup(id, []rfs.Pattern{ + fs.SetCleanup(id.String(), []rfs.Pattern{ pattern, }) @@ -555,9 +619,9 @@ func (r *restream) setCleanup(id string, config *app.Config) { } } -func (r *restream) unsetCleanup(id string) { +func (r *restream) unsetCleanup(id TaskID) { for _, fs := range r.fs.list { - fs.UnsetCleanup(id) + fs.UnsetCleanup(id.String()) } } @@ -811,7 +875,7 @@ func (r *restream) validateOutputAddress(address, basedir string) (string, bool, return "file:" + address, true, nil } -func (r *restream) resolveAddresses(tasks map[string]*task, config *app.Config) error { +func (r *restream) resolveAddresses(tasks map[TaskID]*task, config *app.Config) error { for i, input := range config.Input { // Resolve any references address, err := r.resolveAddress(tasks, config.ID, input.Address) @@ -827,56 +891,159 @@ func (r *restream) resolveAddresses(tasks map[string]*task, config *app.Config) return nil } -func (r *restream) resolveAddress(tasks map[string]*task, id, address string) (string, error) { - re := regexp.MustCompile(`^#(.+):output=(.+)`) - - if len(address) == 0 { - return address, fmt.Errorf("empty address") +func (r *restream) resolveAddress(tasks map[TaskID]*task, id, address string) (string, error) { + matches, err := parseAddressReference(address) + if err != nil { + return address, err } - if address[0] != '#' { + // Address is not a reference + if _, ok := matches["address"]; ok { return address, nil } - matches := re.FindStringSubmatch(address) - if matches == nil { - return address, fmt.Errorf("invalid format (%s)", address) + if matches["id"] == id { + return address, fmt.Errorf("self-reference is not allowed (%s)", address) } - if matches[1] == id { - return address, fmt.Errorf("self-reference not possible (%s)", address) - } + var t *task = nil - task, ok := tasks[matches[1]] - if !ok { - return address, fmt.Errorf("unknown process '%s' (%s)", matches[1], address) - } - - for _, x := range task.config.Output { - if x.ID == matches[2] { - return x.Address, nil + for _, tsk := range tasks { + if tsk.id == matches["id"] && tsk.domain == matches["group"] { + t = tsk + break } } - return address, fmt.Errorf("the process '%s' has no outputs with the ID '%s' (%s)", matches[1], matches[2], address) + if t == nil { + return address, fmt.Errorf("unknown process '%s' in group '%s' (%s)", matches["id"], matches["group"], address) + } + + identity, _ := r.iam.GetVerifier(t.config.Owner) + + teeOptions := regexp.MustCompile(`^\[[^\]]*\]`) + + for _, x := range t.config.Output { + if x.ID != matches["output"] { + continue + } + + // Check for non-tee output + if !strings.Contains(x.Address, "|") && !strings.HasPrefix(x.Address, "[") { + return r.rewrite.RewriteAddress(x.Address, identity, rewrite.READ), nil + } + + // Split tee output in its individual addresses + + addresses := strings.Split(x.Address, "|") + if len(addresses) == 0 { + return x.Address, nil + } + + // Remove tee options + for i, a := range addresses { + addresses[i] = teeOptions.ReplaceAllString(a, "") + } + + if len(matches["source"]) == 0 { + return r.rewrite.RewriteAddress(addresses[0], identity, rewrite.READ), nil + } + + for _, a := range addresses { + u, err := url.Parse(a) + if err != nil { + // Ignore invalid addresses + continue + } + + if matches["source"] == "hls" { + if (u.Scheme == "http" || u.Scheme == "https") && strings.HasSuffix(u.RawPath, ".m3u8") { + return r.rewrite.RewriteAddress(a, identity, rewrite.READ), nil + } + } else if matches["source"] == "rtmp" { + if u.Scheme == "rtmp" { + return r.rewrite.RewriteAddress(a, identity, rewrite.READ), nil + } + } else if matches["source"] == "srt" { + if u.Scheme == "srt" { + return r.rewrite.RewriteAddress(a, identity, rewrite.READ), nil + } + } + } + + // If none of the sources matched, return the first address + return r.rewrite.RewriteAddress(addresses[0], identity, rewrite.READ), nil + } + + return address, fmt.Errorf("the process '%s' in group '%s' has no outputs with the ID '%s' (%s)", matches["id"], matches["group"], matches["output"], address) } -func (r *restream) UpdateProcess(id string, config *app.Config) error { +func parseAddressReference(address string) (map[string]string, error) { + if len(address) == 0 { + return nil, fmt.Errorf("empty address") + } + + if address[0] != '#' { + return map[string]string{ + "address": address, + }, nil + } + + re := regexp.MustCompile(`:(output|group|source)=(.+)`) + + results := map[string]string{} + + idEnd := -1 + value := address + key := "" + + for { + matches := re.FindStringSubmatchIndex(value) + if matches == nil { + break + } + + if idEnd < 0 { + idEnd = matches[2] - 1 + } + + if len(key) != 0 { + results[key] = value[:matches[2]-1] + } + + key = value[matches[2]:matches[3]] + value = value[matches[4]:matches[5]] + + results[key] = value + } + + if idEnd < 0 { + return nil, fmt.Errorf("invalid format (%s)", address) + } + + results["id"] = address[1:idEnd] + + return results, nil +} + +func (r *restream) UpdateProcess(id TaskID, config *app.Config) error { r.lock.Lock() defer r.lock.Unlock() - t, err := r.createTask(config) - if err != nil { - return err - } - task, ok := r.tasks[id] if !ok { return ErrUnknownProcess } - if id != t.id { - _, ok := r.tasks[t.id] + t, err := r.createTask(config) + if err != nil { + return err + } + + tid := t.ID() + + if !tid.Equals(id) { + _, ok := r.tasks[tid] if ok { return ErrProcessExists } @@ -885,11 +1052,11 @@ func (r *restream) UpdateProcess(id string, config *app.Config) error { t.process.Order = task.process.Order if err := r.stopProcess(id); err != nil { - return err + return fmt.Errorf("stop process: %w", err) } if err := r.deleteProcess(id); err != nil { - return err + return fmt.Errorf("delete process: %w", err) } // This would require a major version jump @@ -897,13 +1064,13 @@ func (r *restream) UpdateProcess(id string, config *app.Config) error { t.process.UpdatedAt = time.Now().Unix() task.parser.TransferReportHistory(t.parser) - r.tasks[t.id] = t + r.tasks[tid] = t // set filesystem cleanup rules - r.setCleanup(t.id, t.config) + r.setCleanup(tid, t.config) if t.process.Order == "start" { - r.startProcess(t.id) + r.startProcess(tid) } r.save() @@ -911,73 +1078,79 @@ func (r *restream) UpdateProcess(id string, config *app.Config) error { return nil } -func (r *restream) GetProcessIDs(idpattern, refpattern string) []string { +func (r *restream) GetProcessIDs(idpattern, refpattern, ownerpattern, domainpattern string) []TaskID { r.lock.RLock() defer r.lock.RUnlock() - if len(idpattern) == 0 && len(refpattern) == 0 { - ids := make([]string, len(r.tasks)) - i := 0 + ids := []TaskID{} - for id := range r.tasks { - ids[i] = id - i++ - } - - return ids - } - - idmap := map[string]int{} - count := 0 - - if len(idpattern) != 0 { - for id := range r.tasks { - match, err := glob.Match(idpattern, id) + for _, t := range r.tasks { + count := 0 + matches := 0 + if len(idpattern) != 0 { + count++ + match, err := glob.Match(idpattern, t.id) if err != nil { return nil } - if !match { - continue + if match { + matches++ } - - idmap[id]++ } - count++ - } - - if len(refpattern) != 0 { - for _, t := range r.tasks { + if len(refpattern) != 0 { + count++ match, err := glob.Match(refpattern, t.reference) if err != nil { return nil } - if !match { - continue + if match { + matches++ } - - idmap[t.id]++ } - count++ - } + if len(ownerpattern) != 0 { + count++ + match, err := glob.Match(ownerpattern, t.owner) + if err != nil { + return nil + } - ids := []string{} + if match { + matches++ + } + } - for id, n := range idmap { - if n != count { + if len(domainpattern) != 0 { + count++ + match, err := glob.Match(domainpattern, t.domain) + if err != nil { + return nil + } + + if match { + matches++ + } + } + + if count != matches { continue } - ids = append(ids, id) + tid := TaskID{ + ID: t.id, + Domain: t.domain, + } + + ids = append(ids, tid) } return ids } -func (r *restream) GetProcess(id string) (*app.Process, error) { +func (r *restream) GetProcess(id TaskID) (*app.Process, error) { r.lock.RLock() defer r.lock.RUnlock() @@ -991,7 +1164,7 @@ func (r *restream) GetProcess(id string) (*app.Process, error) { return process, nil } -func (r *restream) DeleteProcess(id string) error { +func (r *restream) DeleteProcess(id TaskID) error { r.lock.Lock() defer r.lock.Unlock() @@ -1005,25 +1178,25 @@ func (r *restream) DeleteProcess(id string) error { return nil } -func (r *restream) deleteProcess(id string) error { - task, ok := r.tasks[id] +func (r *restream) deleteProcess(tid TaskID) error { + task, ok := r.tasks[tid] if !ok { return ErrUnknownProcess } if task.process.Order != "stop" { - return fmt.Errorf("the process with the ID '%s' is still running", id) + return fmt.Errorf("the process with the ID '%s' is still running", tid) } r.unsetPlayoutPorts(task) - r.unsetCleanup(id) + r.unsetCleanup(tid) - delete(r.tasks, id) + delete(r.tasks, tid) return nil } -func (r *restream) StartProcess(id string) error { +func (r *restream) StartProcess(id TaskID) error { r.lock.Lock() defer r.lock.Unlock() @@ -1037,8 +1210,8 @@ func (r *restream) StartProcess(id string) error { return nil } -func (r *restream) startProcess(id string) error { - task, ok := r.tasks[id] +func (r *restream) startProcess(tid TaskID) error { + task, ok := r.tasks[tid] if !ok { return ErrUnknownProcess } @@ -1066,7 +1239,7 @@ func (r *restream) startProcess(id string) error { return nil } -func (r *restream) StopProcess(id string) error { +func (r *restream) StopProcess(id TaskID) error { r.lock.Lock() defer r.lock.Unlock() @@ -1080,8 +1253,8 @@ func (r *restream) StopProcess(id string) error { return nil } -func (r *restream) stopProcess(id string) error { - task, ok := r.tasks[id] +func (r *restream) stopProcess(tid TaskID) error { + task, ok := r.tasks[tid] if !ok { return ErrUnknownProcess } @@ -1105,15 +1278,15 @@ func (r *restream) stopProcess(id string) error { return nil } -func (r *restream) RestartProcess(id string) error { +func (r *restream) RestartProcess(id TaskID) error { r.lock.RLock() defer r.lock.RUnlock() return r.restartProcess(id) } -func (r *restream) restartProcess(id string) error { - task, ok := r.tasks[id] +func (r *restream) restartProcess(tid TaskID) error { + task, ok := r.tasks[tid] if !ok { return ErrUnknownProcess } @@ -1131,7 +1304,7 @@ func (r *restream) restartProcess(id string) error { return nil } -func (r *restream) ReloadProcess(id string) error { +func (r *restream) ReloadProcess(id TaskID) error { r.lock.Lock() defer r.lock.Unlock() @@ -1145,8 +1318,8 @@ func (r *restream) ReloadProcess(id string) error { return nil } -func (r *restream) reloadProcess(id string) error { - t, ok := r.tasks[id] +func (r *restream) reloadProcess(tid TaskID) error { + t, ok := r.tasks[tid] if !ok { return ErrUnknownProcess } @@ -1177,10 +1350,10 @@ func (r *restream) reloadProcess(id string) error { order := "stop" if t.process.Order == "start" { order = "start" - r.stopProcess(id) + r.stopProcess(tid) } - t.parser = r.ffmpeg.NewProcessParser(t.logger, t.id, t.reference) + t.parser = r.ffmpeg.NewProcessParser(t.logger, t.String(), t.reference) ffmpeg, err := r.ffmpeg.New(ffmpeg.ProcessConfig{ Reconnect: t.config.Reconnect, @@ -1201,13 +1374,13 @@ func (r *restream) reloadProcess(id string) error { t.valid = true if order == "start" { - r.startProcess(id) + r.startProcess(tid) } return nil } -func (r *restream) GetProcessState(id string) (*app.State, error) { +func (r *restream) GetProcessState(id TaskID) (*app.State, error) { state := &app.State{} r.lock.RLock() @@ -1270,21 +1443,21 @@ func (r *restream) GetProcessState(id string) (*app.State, error) { return state, nil } -func (r *restream) GetProcessLog(id string) (*app.Log, error) { +func (r *restream) GetProcessLog(id TaskID) (*app.Log, error) { + log := &app.Log{} + r.lock.RLock() defer r.lock.RUnlock() task, ok := r.tasks[id] if !ok { - return &app.Log{}, ErrUnknownProcess + return log, ErrUnknownProcess } if !task.valid { - return &app.Log{}, nil + return log, nil } - log := &app.Log{} - current := task.parser.Report() log.CreatedAt = current.CreatedAt @@ -1319,15 +1492,15 @@ func (r *restream) GetProcessLog(id string) (*app.Log, error) { return log, nil } -func (r *restream) Probe(id string) app.Probe { +func (r *restream) Probe(id TaskID) app.Probe { return r.ProbeWithTimeout(id, 20*time.Second) } -func (r *restream) ProbeWithTimeout(id string, timeout time.Duration) app.Probe { - r.lock.RLock() - +func (r *restream) ProbeWithTimeout(id TaskID, timeout time.Duration) app.Probe { appprobe := app.Probe{} + r.lock.RLock() + task, ok := r.tasks[id] if !ok { appprobe.Log = append(appprobe.Log, fmt.Sprintf("Unknown process ID (%s)", id)) @@ -1392,7 +1565,7 @@ func (r *restream) ReloadSkills() error { return r.ffmpeg.ReloadSkills() } -func (r *restream) GetPlayout(id, inputid string) (string, error) { +func (r *restream) GetPlayout(id TaskID, inputid string) (string, error) { r.lock.RLock() defer r.lock.RUnlock() @@ -1415,14 +1588,14 @@ func (r *restream) GetPlayout(id, inputid string) (string, error) { var ErrMetadataKeyNotFound = errors.New("unknown key") -func (r *restream) SetProcessMetadata(id, key string, data interface{}) error { - r.lock.Lock() - defer r.lock.Unlock() - +func (r *restream) SetProcessMetadata(id TaskID, key string, data interface{}) error { if len(key) == 0 { return fmt.Errorf("a key for storing the data has to be provided") } + r.lock.Lock() + defer r.lock.Unlock() + task, ok := r.tasks[id] if !ok { return ErrUnknownProcess @@ -1447,7 +1620,7 @@ func (r *restream) SetProcessMetadata(id, key string, data interface{}) error { return nil } -func (r *restream) GetProcessMetadata(id, key string) (interface{}, error) { +func (r *restream) GetProcessMetadata(id TaskID, key string) (interface{}, error) { r.lock.RLock() defer r.lock.RUnlock() @@ -1516,7 +1689,9 @@ func (r *restream) GetMetadata(key string) (interface{}, error) { func resolvePlaceholders(config *app.Config, r replace.Replacer) { vars := map[string]string{ "processid": config.ID, + "owner": config.Owner, "reference": config.Reference, + "group": config.Domain, } for i, option := range config.Options { diff --git a/restream/restream_test.go b/restream/restream_test.go index 64a536a5..6344ca81 100644 --- a/restream/restream_test.go +++ b/restream/restream_test.go @@ -6,10 +6,13 @@ import ( "time" "github.com/datarhei/core/v16/ffmpeg" + "github.com/datarhei/core/v16/iam" "github.com/datarhei/core/v16/internal/testhelper" + "github.com/datarhei/core/v16/io/fs" "github.com/datarhei/core/v16/net" "github.com/datarhei/core/v16/restream/app" "github.com/datarhei/core/v16/restream/replace" + "github.com/datarhei/core/v16/restream/rewrite" "github.com/stretchr/testify/require" ) @@ -31,9 +34,36 @@ func getDummyRestreamer(portrange net.Portranger, validatorIn, validatorOut ffmp return nil, err } + dummyfs, err := fs.NewMemFilesystem(fs.MemConfig{}) + if err != nil { + return nil, err + } + + iam, err := iam.NewIAM(iam.Config{ + FS: dummyfs, + Superuser: iam.User{ + Name: "foobar", + }, + JWTRealm: "", + JWTSecret: "", + Logger: nil, + }) + if err != nil { + return nil, err + } + + iam.AddPolicy("$anon", "$none", "process:*", []string{"CREATE", "GET", "DELETE", "UPDATE", "COMMAND", "PROBE", "METADATA", "PLAYOUT"}) + + rewriter, err := rewrite.New(rewrite.Config{}) + if err != nil { + return nil, err + } + rs, err := New(Config{ FFmpeg: ffmpeg, Replace: replacer, + Rewrite: rewriter, + IAM: iam, }) if err != nil { return nil, err @@ -86,16 +116,18 @@ func TestAddProcess(t *testing.T) { process := getDummyProcess() require.NotNil(t, process) - _, err = rs.GetProcess(process.ID) - require.NotEqual(t, nil, err, "Unset process found (%s)", process.ID) + tid := TaskID{ID: process.ID} + + _, err = rs.GetProcess(tid) + require.Equal(t, ErrUnknownProcess, err) err = rs.AddProcess(process) require.Equal(t, nil, err, "Failed to add process (%s)", err) - _, err = rs.GetProcess(process.ID) + _, err = rs.GetProcess(tid) require.Equal(t, nil, err, "Set process not found (%s)", process.ID) - state, _ := rs.GetProcessState(process.ID) + state, _ := rs.GetProcessState(tid) require.Equal(t, "stop", state.Order, "Process should be stopped") } @@ -106,12 +138,14 @@ func TestAutostartProcess(t *testing.T) { process := getDummyProcess() process.Autostart = true + tid := TaskID{ID: process.ID} + rs.AddProcess(process) - state, _ := rs.GetProcessState(process.ID) + state, _ := rs.GetProcessState(tid) require.Equal(t, "start", state.Order, "Process should be started") - rs.StopProcess(process.ID) + rs.StopProcess(tid) } func TestAddInvalidProcess(t *testing.T) { @@ -187,14 +221,15 @@ func TestRemoveProcess(t *testing.T) { require.NoError(t, err) process := getDummyProcess() + tid := TaskID{ID: process.ID} err = rs.AddProcess(process) require.Equal(t, nil, err, "Failed to add process (%s)", err) - err = rs.DeleteProcess(process.ID) + err = rs.DeleteProcess(tid) require.Equal(t, nil, err, "Set process not found (%s)", process.ID) - _, err = rs.GetProcess(process.ID) + _, err = rs.GetProcess(tid) require.NotEqual(t, nil, err, "Unset process found (%s)", process.ID) } @@ -205,10 +240,12 @@ func TestUpdateProcess(t *testing.T) { process1 := getDummyProcess() require.NotNil(t, process1) process1.ID = "process1" + tid1 := TaskID{ID: process1.ID} process2 := getDummyProcess() require.NotNil(t, process2) process2.ID = "process2" + tid2 := TaskID{ID: process2.ID} err = rs.AddProcess(process1) require.Equal(t, nil, err) @@ -216,7 +253,7 @@ func TestUpdateProcess(t *testing.T) { err = rs.AddProcess(process2) require.Equal(t, nil, err) - process, err := rs.GetProcess(process2.ID) + process, err := rs.GetProcess(tid2) require.NoError(t, err) createdAt := process.CreatedAt @@ -227,18 +264,20 @@ func TestUpdateProcess(t *testing.T) { process3 := getDummyProcess() require.NotNil(t, process3) process3.ID = "process2" + tid3 := TaskID{ID: process3.ID} - err = rs.UpdateProcess("process1", process3) + err = rs.UpdateProcess(tid1, process3) require.Error(t, err) process3.ID = "process3" - err = rs.UpdateProcess("process1", process3) + tid3.ID = process3.ID + err = rs.UpdateProcess(tid1, process3) require.NoError(t, err) - _, err = rs.GetProcess(process1.ID) + _, err = rs.GetProcess(tid1) require.Error(t, err) - process, err = rs.GetProcess(process3.ID) + process, err = rs.GetProcess(tid3) require.NoError(t, err) require.NotEqual(t, createdAt, process.CreatedAt) // this should be equal, but will require a major version jump @@ -252,51 +291,64 @@ func TestGetProcess(t *testing.T) { process1 := getDummyProcess() process1.ID = "foo_aaa_1" process1.Reference = "foo_aaa_1" + tid1 := TaskID{ID: process1.ID} process2 := getDummyProcess() process2.ID = "bar_bbb_2" process2.Reference = "bar_bbb_2" + tid2 := TaskID{ID: process2.ID} process3 := getDummyProcess() process3.ID = "foo_ccc_3" process3.Reference = "foo_ccc_3" + tid3 := TaskID{ID: process3.ID} process4 := getDummyProcess() process4.ID = "bar_ddd_4" process4.Reference = "bar_ddd_4" + tid4 := TaskID{ID: process4.ID} rs.AddProcess(process1) rs.AddProcess(process2) rs.AddProcess(process3) rs.AddProcess(process4) - _, err = rs.GetProcess(process1.ID) + _, err = rs.GetProcess(tid1) require.Equal(t, nil, err) - list := rs.GetProcessIDs("", "") + _, err = rs.GetProcess(tid2) + require.Equal(t, nil, err) + + _, err = rs.GetProcess(tid3) + require.Equal(t, nil, err) + + _, err = rs.GetProcess(tid4) + require.Equal(t, nil, err) + + list := rs.GetProcessIDs("", "", "", "") require.Len(t, list, 4) - require.ElementsMatch(t, []string{"foo_aaa_1", "bar_bbb_2", "foo_ccc_3", "bar_ddd_4"}, list) + require.ElementsMatch(t, []TaskID{{ID: "foo_aaa_1"}, {ID: "bar_bbb_2"}, {ID: "foo_ccc_3"}, {ID: "bar_ddd_4"}}, list) - list = rs.GetProcessIDs("foo_*", "") + list = rs.GetProcessIDs("foo_*", "", "", "") require.Len(t, list, 2) - require.ElementsMatch(t, []string{"foo_aaa_1", "foo_ccc_3"}, list) + require.ElementsMatch(t, []TaskID{{ID: "foo_aaa_1"}, {ID: "foo_ccc_3"}}, list) - list = rs.GetProcessIDs("bar_*", "") + list = rs.GetProcessIDs("bar_*", "", "", "") require.Len(t, list, 2) - require.ElementsMatch(t, []string{"bar_bbb_2", "bar_ddd_4"}, list) + require.ElementsMatch(t, []TaskID{{ID: "bar_bbb_2"}, {ID: "bar_ddd_4"}}, list) - list = rs.GetProcessIDs("*_bbb_*", "") + list = rs.GetProcessIDs("*_bbb_*", "", "", "") require.Len(t, list, 1) - require.ElementsMatch(t, []string{"bar_bbb_2"}, list) + require.ElementsMatch(t, []TaskID{{ID: "bar_bbb_2"}}, list) - list = rs.GetProcessIDs("", "foo_*") + list = rs.GetProcessIDs("", "foo_*", "", "") require.Len(t, list, 2) - require.ElementsMatch(t, []string{"foo_aaa_1", "foo_ccc_3"}, list) + require.ElementsMatch(t, []TaskID{{ID: "foo_aaa_1"}, {ID: "foo_ccc_3"}}, list) - list = rs.GetProcessIDs("", "bar_*") + list = rs.GetProcessIDs("", "bar_*", "", "") require.Len(t, list, 2) - require.ElementsMatch(t, []string{"bar_bbb_2", "bar_ddd_4"}, list) + require.ElementsMatch(t, []TaskID{{ID: "bar_bbb_2"}, {ID: "bar_ddd_4"}}, list) - list = rs.GetProcessIDs("", "*_bbb_*") + list = rs.GetProcessIDs("", "*_bbb_*", "", "") require.Len(t, list, 1) - require.ElementsMatch(t, []string{"bar_bbb_2"}, list) + require.ElementsMatch(t, []TaskID{{ID: "bar_bbb_2"}}, list) } func TestStartProcess(t *testing.T) { @@ -304,25 +356,26 @@ func TestStartProcess(t *testing.T) { require.NoError(t, err) process := getDummyProcess() + tid := TaskID{ID: process.ID} rs.AddProcess(process) - err = rs.StartProcess("foobar") + err = rs.StartProcess(TaskID{ID: "foobar"}) require.NotEqual(t, nil, err, "shouldn't be able to start non-existing process") - err = rs.StartProcess(process.ID) + err = rs.StartProcess(tid) require.Equal(t, nil, err, "should be able to start existing process") - state, _ := rs.GetProcessState(process.ID) + state, _ := rs.GetProcessState(tid) require.Equal(t, "start", state.Order, "Process should be started") - err = rs.StartProcess(process.ID) + err = rs.StartProcess(tid) require.Equal(t, nil, err, "should be able to start already running process") - state, _ = rs.GetProcessState(process.ID) + state, _ = rs.GetProcessState(tid) require.Equal(t, "start", state.Order, "Process should be started") - rs.StopProcess(process.ID) + rs.StopProcess(tid) } func TestStopProcess(t *testing.T) { @@ -330,23 +383,24 @@ func TestStopProcess(t *testing.T) { require.NoError(t, err) process := getDummyProcess() + tid := TaskID{ID: process.ID} rs.AddProcess(process) - rs.StartProcess(process.ID) + rs.StartProcess(tid) - err = rs.StopProcess("foobar") + err = rs.StopProcess(TaskID{ID: "foobar"}) require.NotEqual(t, nil, err, "shouldn't be able to stop non-existing process") - err = rs.StopProcess(process.ID) + err = rs.StopProcess(tid) require.Equal(t, nil, err, "should be able to stop existing running process") - state, _ := rs.GetProcessState(process.ID) + state, _ := rs.GetProcessState(tid) require.Equal(t, "stop", state.Order, "Process should be stopped") - err = rs.StopProcess(process.ID) + err = rs.StopProcess(tid) require.Equal(t, nil, err, "should be able to stop already stopped process") - state, _ = rs.GetProcessState(process.ID) + state, _ = rs.GetProcessState(tid) require.Equal(t, "stop", state.Order, "Process should be stopped") } @@ -355,24 +409,25 @@ func TestRestartProcess(t *testing.T) { require.NoError(t, err) process := getDummyProcess() + tid := TaskID{ID: process.ID} rs.AddProcess(process) - err = rs.RestartProcess("foobar") + err = rs.RestartProcess(TaskID{ID: "foobar"}) require.NotEqual(t, nil, err, "shouldn't be able to restart non-existing process") - err = rs.RestartProcess(process.ID) + err = rs.RestartProcess(tid) require.Equal(t, nil, err, "should be able to restart existing stopped process") - state, _ := rs.GetProcessState(process.ID) + state, _ := rs.GetProcessState(tid) require.Equal(t, "stop", state.Order, "Process should be stopped") - rs.StartProcess(process.ID) + rs.StartProcess(tid) - state, _ = rs.GetProcessState(process.ID) + state, _ = rs.GetProcessState(tid) require.Equal(t, "start", state.Order, "Process should be started") - rs.StopProcess(process.ID) + rs.StopProcess(tid) } func TestReloadProcess(t *testing.T) { @@ -380,30 +435,31 @@ func TestReloadProcess(t *testing.T) { require.NoError(t, err) process := getDummyProcess() + tid := TaskID{ID: process.ID} rs.AddProcess(process) - err = rs.ReloadProcess("foobar") + err = rs.ReloadProcess(TaskID{ID: "foobar"}) require.NotEqual(t, nil, err, "shouldn't be able to reload non-existing process") - err = rs.ReloadProcess(process.ID) + err = rs.ReloadProcess(tid) require.Equal(t, nil, err, "should be able to reload existing stopped process") - state, _ := rs.GetProcessState(process.ID) + state, _ := rs.GetProcessState(tid) require.Equal(t, "stop", state.Order, "Process should be stopped") - rs.StartProcess(process.ID) + rs.StartProcess(tid) - state, _ = rs.GetProcessState(process.ID) + state, _ = rs.GetProcessState(tid) require.Equal(t, "start", state.Order, "Process should be started") - err = rs.ReloadProcess(process.ID) + err = rs.ReloadProcess(tid) require.Equal(t, nil, err, "should be able to reload existing process") - state, _ = rs.GetProcessState(process.ID) + state, _ = rs.GetProcessState(tid) require.Equal(t, "start", state.Order, "Process should be started") - rs.StopProcess(process.ID) + rs.StopProcess(tid) } func TestProbeProcess(t *testing.T) { @@ -411,10 +467,11 @@ func TestProbeProcess(t *testing.T) { require.NoError(t, err) process := getDummyProcess() + tid := TaskID{ID: process.ID} rs.AddProcess(process) - probe := rs.ProbeWithTimeout(process.ID, 5*time.Second) + probe := rs.ProbeWithTimeout(tid, 5*time.Second) require.Equal(t, 3, len(probe.Streams)) } @@ -424,15 +481,19 @@ func TestProcessMetadata(t *testing.T) { require.NoError(t, err) process := getDummyProcess() + tid := TaskID{ID: process.ID} rs.AddProcess(process) - data, _ := rs.GetProcessMetadata(process.ID, "foobar") + data, err := rs.GetProcessMetadata(tid, "foobar") + require.Equal(t, ErrMetadataKeyNotFound, err) require.Equal(t, nil, data, "nothing should be stored under the key") - rs.SetProcessMetadata(process.ID, "foobar", process) + err = rs.SetProcessMetadata(tid, "foobar", process) + require.NoError(t, err) - data, _ = rs.GetProcessMetadata(process.ID, "foobar") + data, err = rs.GetProcessMetadata(tid, "foobar") + require.NoError(t, err) require.NotEqual(t, nil, data, "there should be something stored under the key") p := data.(*app.Config) @@ -445,29 +506,30 @@ func TestLog(t *testing.T) { require.NoError(t, err) process := getDummyProcess() + tid := TaskID{ID: process.ID} rs.AddProcess(process) - _, err = rs.GetProcessLog("foobar") + _, err = rs.GetProcessLog(TaskID{ID: "foobar"}) require.Error(t, err) - log, err := rs.GetProcessLog(process.ID) + log, err := rs.GetProcessLog(tid) require.NoError(t, err) require.Equal(t, 0, len(log.Prelude)) require.Equal(t, 0, len(log.Log)) - rs.StartProcess(process.ID) + rs.StartProcess(tid) time.Sleep(3 * time.Second) - log, _ = rs.GetProcessLog(process.ID) + log, _ = rs.GetProcessLog(tid) require.NotEqual(t, 0, len(log.Prelude)) require.NotEqual(t, 0, len(log.Log)) - rs.StopProcess(process.ID) + rs.StopProcess(tid) - log, _ = rs.GetProcessLog(process.ID) + log, _ = rs.GetProcessLog(tid) require.NotEqual(t, 0, len(log.Prelude)) require.NotEqual(t, 0, len(log.Log)) @@ -478,25 +540,26 @@ func TestLogTransfer(t *testing.T) { require.NoError(t, err) process := getDummyProcess() + tid := TaskID{ID: process.ID} err = rs.AddProcess(process) require.NoError(t, err) - rs.StartProcess(process.ID) + rs.StartProcess(tid) time.Sleep(3 * time.Second) - rs.StopProcess(process.ID) + rs.StopProcess(tid) - rs.StartProcess(process.ID) - rs.StopProcess(process.ID) + rs.StartProcess(tid) + rs.StopProcess(tid) - log, _ := rs.GetProcessLog(process.ID) + log, _ := rs.GetProcessLog(tid) require.Equal(t, 1, len(log.History)) - err = rs.UpdateProcess(process.ID, process) + err = rs.UpdateProcess(tid, process) require.NoError(t, err) - log, _ = rs.GetProcessLog(process.ID) + log, _ = rs.GetProcessLog(tid) require.Equal(t, 1, len(log.History)) } @@ -506,18 +569,19 @@ func TestPlayoutNoRange(t *testing.T) { require.NoError(t, err) process := getDummyProcess() + tid := TaskID{ID: process.ID} process.Input[0].Address = "playout:" + process.Input[0].Address rs.AddProcess(process) - _, err = rs.GetPlayout("foobar", process.Input[0].ID) - require.NotEqual(t, nil, err, "playout of non-existing process should error") + _, err = rs.GetPlayout(TaskID{ID: "foobar"}, process.Input[0].ID) + require.Equal(t, ErrUnknownProcess, err) - _, err = rs.GetPlayout(process.ID, "foobar") + _, err = rs.GetPlayout(tid, "foobar") require.NotEqual(t, nil, err, "playout of non-existing input should error") - addr, _ := rs.GetPlayout(process.ID, process.Input[0].ID) + addr, _ := rs.GetPlayout(tid, process.Input[0].ID) require.Equal(t, 0, len(addr), "the playout address should be empty if no port range is given") } @@ -529,22 +593,56 @@ func TestPlayoutRange(t *testing.T) { require.NoError(t, err) process := getDummyProcess() + tid := TaskID{ID: process.ID} process.Input[0].Address = "playout:" + process.Input[0].Address rs.AddProcess(process) - _, err = rs.GetPlayout("foobar", process.Input[0].ID) - require.NotEqual(t, nil, err, "playout of non-existing process should error") + _, err = rs.GetPlayout(TaskID{ID: "foobar"}, process.Input[0].ID) + require.Equal(t, ErrUnknownProcess, err) - _, err = rs.GetPlayout(process.ID, "foobar") + _, err = rs.GetPlayout(tid, "foobar") require.NotEqual(t, nil, err, "playout of non-existing input should error") - addr, _ := rs.GetPlayout(process.ID, process.Input[0].ID) + addr, _ := rs.GetPlayout(tid, process.Input[0].ID) require.NotEqual(t, 0, len(addr), "the playout address should not be empty if a port range is given") require.Equal(t, "127.0.0.1:3000", addr, "the playout address should be 127.0.0.1:3000") } +func TestParseAddressReference(t *testing.T) { + matches, err := parseAddressReference("foobar") + require.NoError(t, err) + require.Equal(t, "foobar", matches["address"]) + + _, err = parseAddressReference("#foobar") + require.Error(t, err) + + _, err = parseAddressReference("#foobar:nothing=foo") + require.Error(t, err) + + matches, err = parseAddressReference("#foobar:output=foo") + require.NoError(t, err) + require.Equal(t, "foobar", matches["id"]) + require.Equal(t, "foo", matches["output"]) + + matches, err = parseAddressReference("#foobar:group=foo") + require.NoError(t, err) + require.Equal(t, "foobar", matches["id"]) + require.Equal(t, "foo", matches["group"]) + + matches, err = parseAddressReference("#foobar:nothing=foo:output=bar") + require.NoError(t, err) + require.Equal(t, "foobar:nothing=foo", matches["id"]) + require.Equal(t, "bar", matches["output"]) + + matches, err = parseAddressReference("#foobar:output=foo:group=bar") + require.NoError(t, err) + require.Equal(t, "foobar", matches["id"]) + require.Equal(t, "foo", matches["output"]) + require.Equal(t, "bar", matches["group"]) +} + func TestAddressReference(t *testing.T) { rs, err := getDummyRestreamer(nil, nil, nil, nil) require.NoError(t, err) @@ -552,10 +650,9 @@ func TestAddressReference(t *testing.T) { process1 := getDummyProcess() process2 := getDummyProcess() - process2.ID = "process2" - rs.AddProcess(process1) + process2.ID = "process2" process2.Input[0].Address = "#process:foobar=out" err = rs.AddProcess(process2) @@ -577,6 +674,44 @@ func TestAddressReference(t *testing.T) { require.Equal(t, nil, err, "should resolve reference") } +func TestTeeAddressReference(t *testing.T) { + rs, err := getDummyRestreamer(nil, nil, nil, nil) + require.NoError(t, err) + + process1 := getDummyProcess() + process2 := getDummyProcess() + process3 := getDummyProcess() + process4 := getDummyProcess() + + process1.Output[0].Address = "[f=hls]http://example.com/live.m3u8|[f=flv]rtmp://example.com/live.stream?token=123" + process2.ID = "process2" + process3.ID = "process3" + process4.ID = "process4" + + rs.AddProcess(process1) + + process2.Input[0].Address = "#process:output=out" + + err = rs.AddProcess(process2) + require.Equal(t, nil, err, "should resolve reference") + + process3.Input[0].Address = "#process:output=out:source=hls" + + err = rs.AddProcess(process3) + require.Equal(t, nil, err, "should resolve reference") + + process4.Input[0].Address = "#process:output=out:source=rtmp" + + err = rs.AddProcess(process4) + require.Equal(t, nil, err, "should resolve reference") + + r := rs.(*restream) + + require.Equal(t, "http://example.com/live.m3u8", r.tasks[TaskID{ID: "process2"}].config.Input[0].Address) + require.Equal(t, "http://example.com/live.m3u8", r.tasks[TaskID{ID: "process3"}].config.Input[0].Address) + require.Equal(t, "rtmp://example.com/live.stream?token=123", r.tasks[TaskID{ID: "process4"}].config.Input[0].Address) +} + func TestConfigValidation(t *testing.T) { rsi, err := getDummyRestreamer(nil, nil, nil, nil) require.NoError(t, err) @@ -881,7 +1016,10 @@ func TestReplacer(t *testing.T) { StaleTimeout: 0, } - require.Equal(t, process, rs.tasks["314159265359"].config) + task, ok := rs.tasks[TaskID{ID: "314159265359"}] + require.True(t, ok) + + require.Equal(t, process, task.config) } func TestProcessLimit(t *testing.T) { @@ -898,7 +1036,7 @@ func TestProcessLimit(t *testing.T) { rs := rsi.(*restream) - task, ok := rs.tasks[process.ID] + task, ok := rs.tasks[TaskID{ID: process.ID}] require.True(t, ok) status := task.ffmpeg.Status() diff --git a/restream/rewrite/rewrite.go b/restream/rewrite/rewrite.go new file mode 100644 index 00000000..1112ce23 --- /dev/null +++ b/restream/rewrite/rewrite.go @@ -0,0 +1,156 @@ +// Package rewrite provides facilities for rewriting a local HLS, RTMP, and SRT address. +package rewrite + +import ( + "fmt" + "net/url" + + "github.com/datarhei/core/v16/iam" + "github.com/datarhei/core/v16/rtmp" + srturl "github.com/datarhei/core/v16/srt/url" +) + +type Access string + +var ( + READ Access = "read" + WRITE Access = "write" +) + +type Config struct { + HTTPBase string + RTMPBase string + SRTBase string +} + +// to a new identity, i.e. adjusting the credentials to the given identity. +type Rewriter interface { + RewriteAddress(address string, identity iam.IdentityVerifier, mode Access) string +} + +type rewrite struct { + httpBase string + rtmpBase string + srtBase string +} + +func New(config Config) (Rewriter, error) { + r := &rewrite{ + httpBase: config.HTTPBase, + rtmpBase: config.RTMPBase, + srtBase: config.SRTBase, + } + + return r, nil +} + +func (g *rewrite) RewriteAddress(address string, identity iam.IdentityVerifier, mode Access) string { + u, err := url.Parse(address) + if err != nil { + return address + } + + // Decide whether this is our local server + if !g.isLocal(u) { + return address + } + + if identity == nil { + return address + } + + if u.Scheme == "http" || u.Scheme == "https" { + return g.httpURL(u, mode, identity) + } else if u.Scheme == "rtmp" { + return g.rtmpURL(u, mode, identity) + } else if u.Scheme == "srt" { + return g.srtURL(u, mode, identity) + } + + return address +} + +func (g *rewrite) isLocal(u *url.URL) bool { + var base *url.URL + var err error + + if u.Scheme == "http" || u.Scheme == "https" { + base, err = url.Parse(g.httpBase) + } else if u.Scheme == "rtmp" { + base, err = url.Parse(g.rtmpBase) + } else if u.Scheme == "srt" { + base, err = url.Parse(g.srtBase) + } else { + err = fmt.Errorf("unsupported scheme") + } + + if err != nil { + return false + } + + hostname := u.Hostname() + port := u.Port() + + if base.Hostname() == "localhost" { + if hostname != "localhost" && hostname != "127.0.0.1" && hostname != "::1" { + return false + } + + hostname = "localhost" + } + + host := hostname + ":" + port + + return host == base.Host +} + +func (g *rewrite) httpURL(u *url.URL, mode Access, identity iam.IdentityVerifier) string { + password := identity.GetServiceBasicAuth() + + if len(password) == 0 { + u.User = nil + } else { + u.User = url.UserPassword(identity.Name(), password) + } + + return u.String() +} + +func (g *rewrite) rtmpURL(u *url.URL, mode Access, identity iam.IdentityVerifier) string { + token := identity.GetServiceToken() + + // Remove the existing token from the path + path, _ := rtmp.GetToken(u) + u.Path = path + + q := u.Query() + q.Set("token", token) + + u.RawQuery = q.Encode() + + return u.String() +} + +func (g *rewrite) srtURL(u *url.URL, mode Access, identity iam.IdentityVerifier) string { + token := identity.GetServiceToken() + + q := u.Query() + + streamInfo, err := srturl.ParseStreamId(q.Get("streamid")) + if err != nil { + return u.String() + } + + streamInfo.Token = token + + if mode == WRITE { + streamInfo.Mode = "publish" + } else { + streamInfo.Mode = "request" + } + + q.Set("streamid", streamInfo.String()) + u.RawQuery = q.Encode() + + return u.String() +} diff --git a/restream/rewrite/rewrite_test.go b/restream/rewrite/rewrite_test.go new file mode 100644 index 00000000..f1a1ae1b --- /dev/null +++ b/restream/rewrite/rewrite_test.go @@ -0,0 +1,158 @@ +package rewrite + +import ( + "net/url" + "testing" + + "github.com/datarhei/core/v16/iam" + "github.com/datarhei/core/v16/io/fs" + + "github.com/stretchr/testify/require" +) + +func getIdentityManager(enableBasic bool) iam.IdentityManager { + dummyfs, _ := fs.NewMemFilesystem(fs.MemConfig{}) + + superuser := iam.User{ + Name: "foobar", + Superuser: false, + Auth: iam.UserAuth{ + API: iam.UserAuthAPI{}, + Services: iam.UserAuthServices{ + Token: []string{"servicetoken"}, + }, + }, + } + + if enableBasic { + superuser.Auth.Services.Basic = []string{"basicauthpassword"} + } + + im, _ := iam.NewIdentityManager(iam.IdentityConfig{ + FS: dummyfs, + Superuser: superuser, + JWTRealm: "", + JWTSecret: "", + Logger: nil, + }) + + return im +} + +func TestRewriteHTTP(t *testing.T) { + im := getIdentityManager(false) + + rewrite, err := New(Config{ + HTTPBase: "http://localhost:8080/", + }) + require.NoError(t, err) + require.NotNil(t, rewrite) + + identity, err := im.GetVerifier("foobar") + require.NoError(t, err) + require.NotNil(t, identity) + + samples := [][3]string{ + {"http://example.com/live/stream.m3u8", "read", "http://example.com/live/stream.m3u8"}, + {"http://example.com/live/stream.m3u8", "write", "http://example.com/live/stream.m3u8"}, + {"http://localhost:8181/live/stream.m3u8", "read", "http://localhost:8181/live/stream.m3u8"}, + {"http://localhost:8181/live/stream.m3u8", "write", "http://localhost:8181/live/stream.m3u8"}, + {"http://localhost:8080/live/stream.m3u8", "read", "http://localhost:8080/live/stream.m3u8"}, + {"http://localhost:8080/live/stream.m3u8", "write", "http://localhost:8080/live/stream.m3u8"}, + {"http://admin:pass@localhost:8080/live/stream.m3u8", "read", "http://localhost:8080/live/stream.m3u8"}, + {"http://admin:pass@localhost:8080/live/stream.m3u8", "write", "http://localhost:8080/live/stream.m3u8"}, + } + + for _, e := range samples { + rewritten := rewrite.RewriteAddress(e[0], identity, Access(e[1])) + require.Equal(t, e[2], rewritten, "%s %s", e[0], e[1]) + } +} + +func TestRewriteHTTPPassword(t *testing.T) { + im := getIdentityManager(true) + + rewrite, err := New(Config{ + HTTPBase: "http://localhost:8080/", + }) + require.NoError(t, err) + require.NotNil(t, rewrite) + + identity, err := im.GetVerifier("foobar") + require.NoError(t, err) + require.NotNil(t, identity) + + samples := [][3]string{ + {"http://example.com/live/stream.m3u8", "read", "http://example.com/live/stream.m3u8"}, + {"http://example.com/live/stream.m3u8", "write", "http://example.com/live/stream.m3u8"}, + {"http://localhost:8181/live/stream.m3u8", "read", "http://localhost:8181/live/stream.m3u8"}, + {"http://localhost:8181/live/stream.m3u8", "write", "http://localhost:8181/live/stream.m3u8"}, + {"http://localhost:8080/live/stream.m3u8", "read", "http://foobar:basicauthpassword@localhost:8080/live/stream.m3u8"}, + {"http://localhost:8080/live/stream.m3u8", "write", "http://foobar:basicauthpassword@localhost:8080/live/stream.m3u8"}, + {"http://admin:pass@localhost:8080/live/stream.m3u8", "read", "http://foobar:basicauthpassword@localhost:8080/live/stream.m3u8"}, + {"http://admin:pass@localhost:8080/live/stream.m3u8", "write", "http://foobar:basicauthpassword@localhost:8080/live/stream.m3u8"}, + } + + for _, e := range samples { + rewritten := rewrite.RewriteAddress(e[0], identity, Access(e[1])) + require.Equal(t, e[2], rewritten, "%s %s", e[0], e[1]) + } +} + +func TestRewriteRTMP(t *testing.T) { + im := getIdentityManager(false) + + rewrite, err := New(Config{ + RTMPBase: "rtmp://localhost:1935/live", + }) + require.NoError(t, err) + require.NotNil(t, rewrite) + + identity, err := im.GetVerifier("foobar") + require.NoError(t, err) + require.NotNil(t, identity) + + samples := [][3]string{ + {"rtmp://example.com/live/stream", "read", "rtmp://example.com/live/stream"}, + {"rtmp://example.com/live/stream", "write", "rtmp://example.com/live/stream"}, + {"rtmp://localhost:1936/live/stream/token", "read", "rtmp://localhost:1936/live/stream/token"}, + {"rtmp://localhost:1936/live/stream?token=token", "write", "rtmp://localhost:1936/live/stream?token=token"}, + {"rtmp://localhost:1935/live/stream?token=token", "read", "rtmp://localhost:1935/live/stream?token=" + url.QueryEscape("foobar:servicetoken")}, + {"rtmp://localhost:1935/live/stream/token", "write", "rtmp://localhost:1935/live/stream?token=" + url.QueryEscape("foobar:servicetoken")}, + } + + for _, e := range samples { + rewritten := rewrite.RewriteAddress(e[0], identity, Access(e[1])) + require.Equal(t, e[2], rewritten, "%s %s", e[0], e[1]) + } +} + +func TestRewriteSRT(t *testing.T) { + im := getIdentityManager(false) + + rewrite, err := New(Config{ + SRTBase: "srt://localhost:6000/", + }) + require.NoError(t, err) + require.NotNil(t, rewrite) + + identity, err := im.GetVerifier("foobar") + require.NoError(t, err) + require.NotNil(t, identity) + + samples := [][3]string{ + {"srt://example.com/?streamid=stream", "read", "srt://example.com/?streamid=stream"}, + {"srt://example.com/?streamid=stream", "write", "srt://example.com/?streamid=stream"}, + {"srt://localhost:1936/?streamid=live/stream", "read", "srt://localhost:1936/?streamid=live/stream"}, + {"srt://localhost:1936/?streamid=live/stream", "write", "srt://localhost:1936/?streamid=live/stream"}, + {"srt://localhost:6000/?streamid=live/stream,mode:publish,token:token", "read", "srt://localhost:6000/?streamid=" + url.QueryEscape("live/stream,token:foobar:servicetoken")}, + {"srt://localhost:6000/?streamid=live/stream,mode:publish,token:token", "write", "srt://localhost:6000/?streamid=" + url.QueryEscape("live/stream,mode:publish,token:foobar:servicetoken")}, + {"srt://localhost:6000/?streamid=" + url.QueryEscape("#!:r=live/stream,m=publish,token=token"), "read", "srt://localhost:6000/?streamid=" + url.QueryEscape("live/stream,token:foobar:servicetoken")}, + {"srt://localhost:6000/?streamid=" + url.QueryEscape("#!:r=live/stream,m=publish,token=token"), "write", "srt://localhost:6000/?streamid=" + url.QueryEscape("live/stream,mode:publish,token:foobar:servicetoken")}, + } + + for _, e := range samples { + rewritten := rewrite.RewriteAddress(e[0], identity, Access(e[1])) + require.Equal(t, e[2], rewritten, "%s %s", e[0], e[1]) + } +} diff --git a/restream/store/data.go b/restream/store/data.go deleted file mode 100644 index a93dbc7d..00000000 --- a/restream/store/data.go +++ /dev/null @@ -1,49 +0,0 @@ -package store - -import ( - "github.com/datarhei/core/v16/restream/app" -) - -type StoreData struct { - Version uint64 `json:"version"` - - Process map[string]*app.Process `json:"process"` - Metadata struct { - System map[string]interface{} `json:"system"` - Process map[string]map[string]interface{} `json:"process"` - } `json:"metadata"` -} - -func NewStoreData() StoreData { - c := StoreData{ - Version: 4, - } - - c.Process = make(map[string]*app.Process) - c.Metadata.System = make(map[string]interface{}) - c.Metadata.Process = make(map[string]map[string]interface{}) - - return c -} - -func (c *StoreData) IsEmpty() bool { - if len(c.Process) != 0 { - return false - } - - if len(c.Metadata.Process) != 0 { - return false - } - - if len(c.Metadata.System) != 0 { - return false - } - - return true -} - -func (c *StoreData) sanitize() { - if c.Process == nil { - c.Process = make(map[string]*app.Process) - } -} diff --git a/restream/store/json.go b/restream/store/json.go deleted file mode 100644 index 36e5720e..00000000 --- a/restream/store/json.go +++ /dev/null @@ -1,137 +0,0 @@ -package store - -import ( - gojson "encoding/json" - "fmt" - "os" - "sync" - - "github.com/datarhei/core/v16/encoding/json" - "github.com/datarhei/core/v16/io/fs" - "github.com/datarhei/core/v16/log" -) - -type JSONConfig struct { - Filesystem fs.Filesystem - Filepath string // Full path to the database file - Logger log.Logger -} - -type jsonStore struct { - fs fs.Filesystem - filepath string - logger log.Logger - - // Mutex to serialize access to the backend - lock sync.RWMutex -} - -var version uint64 = 4 - -func NewJSON(config JSONConfig) (Store, error) { - s := &jsonStore{ - fs: config.Filesystem, - filepath: config.Filepath, - logger: config.Logger, - } - - if len(s.filepath) == 0 { - s.filepath = "/db.json" - } - - if s.fs == nil { - return nil, fmt.Errorf("no valid filesystem provided") - } - - if s.logger == nil { - s.logger = log.New("") - } - - return s, nil -} - -func (s *jsonStore) Load() (StoreData, error) { - s.lock.Lock() - defer s.lock.Unlock() - - data, err := s.load(s.filepath, version) - if err != nil { - return NewStoreData(), err - } - - data.sanitize() - - return data, nil -} - -func (s *jsonStore) Store(data StoreData) error { - if data.Version != version { - return fmt.Errorf("invalid version (have: %d, want: %d)", data.Version, version) - } - - s.lock.RLock() - defer s.lock.RUnlock() - - err := s.store(s.filepath, data) - if err != nil { - return fmt.Errorf("failed to store data: %w", err) - } - - return nil -} - -func (s *jsonStore) store(filepath string, data StoreData) error { - jsondata, err := gojson.MarshalIndent(&data, "", " ") - if err != nil { - return err - } - - _, _, err = s.fs.WriteFileSafe(filepath, jsondata) - if err != nil { - return err - } - - s.logger.WithField("file", filepath).Debug().Log("Stored data") - - return nil -} - -type storeVersion struct { - Version uint64 `json:"version"` -} - -func (s *jsonStore) load(filepath string, version uint64) (StoreData, error) { - r := NewStoreData() - - _, err := s.fs.Stat(filepath) - if err != nil { - if os.IsNotExist(err) { - return r, nil - } - - return r, err - } - - jsondata, err := s.fs.ReadFile(filepath) - if err != nil { - return r, err - } - - var db storeVersion - - if err = gojson.Unmarshal(jsondata, &db); err != nil { - return r, json.FormatError(jsondata, err) - } - - if db.Version != version { - return r, fmt.Errorf("unsupported version of the DB file (want: %d, have: %d)", version, db.Version) - } - - if err = gojson.Unmarshal(jsondata, &r); err != nil { - return r, json.FormatError(jsondata, err) - } - - s.logger.WithField("file", filepath).Debug().Log("Read data") - - return r, nil -} diff --git a/restream/store/json/data.go b/restream/store/json/data.go new file mode 100644 index 00000000..f5588257 --- /dev/null +++ b/restream/store/json/data.go @@ -0,0 +1,248 @@ +package json + +import "github.com/datarhei/core/v16/restream/app" + +type ProcessConfigIOCleanup struct { + Pattern string `json:"pattern"` + MaxFiles uint `json:"max_files"` + MaxFileAge uint `json:"max_file_age_seconds"` + PurgeOnDelete bool `json:"purge_on_delete"` +} + +func (p *ProcessConfigIOCleanup) Marshal(a *app.ConfigIOCleanup) { + p.Pattern = a.Pattern + p.MaxFiles = a.MaxFiles + p.MaxFileAge = a.MaxFileAge + p.PurgeOnDelete = a.PurgeOnDelete +} + +func (p *ProcessConfigIOCleanup) Unmarshal() app.ConfigIOCleanup { + a := app.ConfigIOCleanup{ + Pattern: p.Pattern, + MaxFiles: p.MaxFiles, + MaxFileAge: p.MaxFileAge, + PurgeOnDelete: p.PurgeOnDelete, + } + + return a +} + +type ProcessConfigIO struct { + ID string `json:"id"` + Address string `json:"address"` + Options []string `json:"options"` + Cleanup []ProcessConfigIOCleanup `json:"cleanup"` +} + +func (p *ProcessConfigIO) Marshal(a *app.ConfigIO) { + p.ID = a.ID + p.Address = a.Address + + p.Options = make([]string, len(a.Options)) + copy(p.Options, a.Options) + + if len(a.Cleanup) != 0 { + p.Cleanup = make([]ProcessConfigIOCleanup, len(a.Cleanup)) + for x, cleanup := range a.Cleanup { + p.Cleanup[x].Marshal(&cleanup) + } + } else { + p.Cleanup = nil + } +} + +func (p *ProcessConfigIO) Unmarshal() app.ConfigIO { + a := app.ConfigIO{ + ID: p.ID, + Address: p.Address, + } + + a.Options = make([]string, len(p.Options)) + copy(a.Options, p.Options) + + if len(p.Cleanup) != 0 { + a.Cleanup = make([]app.ConfigIOCleanup, len(p.Cleanup)) + for x, cleanup := range p.Cleanup { + a.Cleanup[x] = cleanup.Unmarshal() + } + } + + return a +} + +type ProcessConfig struct { + ID string `json:"id"` + Reference string `json:"reference"` + Owner string `json:"owner"` + Domain string `json:"domain"` + FFVersion string `json:"ffversion"` + Input []ProcessConfigIO `json:"input"` + Output []ProcessConfigIO `json:"output"` + Options []string `json:"options"` + Reconnect bool `json:"reconnect"` + ReconnectDelay uint64 `json:"reconnect_delay_seconds"` // seconds + Autostart bool `json:"autostart"` + StaleTimeout uint64 `json:"stale_timeout_seconds"` // seconds + LimitCPU float64 `json:"limit_cpu_usage"` // percent + LimitMemory uint64 `json:"limit_memory_bytes"` // bytes + LimitWaitFor uint64 `json:"limit_waitfor_seconds"` // seconds +} + +func (p *ProcessConfig) Marshal(a *app.Config) { + p.ID = a.ID + p.Reference = a.Reference + p.Owner = a.Owner + p.Domain = a.Domain + p.FFVersion = a.FFVersion + p.Reconnect = a.Reconnect + p.ReconnectDelay = a.ReconnectDelay + p.Autostart = a.Autostart + p.StaleTimeout = a.StaleTimeout + p.LimitCPU = a.LimitCPU + p.LimitMemory = a.LimitMemory + p.LimitWaitFor = a.LimitWaitFor + + p.Options = make([]string, len(a.Options)) + copy(p.Options, a.Options) + + p.Input = make([]ProcessConfigIO, len(a.Input)) + for x, input := range a.Input { + p.Input[x].Marshal(&input) + } + + p.Output = make([]ProcessConfigIO, len(a.Output)) + for x, output := range a.Output { + p.Output[x].Marshal(&output) + } +} + +func (p *ProcessConfig) Unmarshal() *app.Config { + a := &app.Config{ + ID: p.ID, + Reference: p.Reference, + Owner: p.Owner, + Domain: p.Domain, + FFVersion: p.FFVersion, + Input: []app.ConfigIO{}, + Output: []app.ConfigIO{}, + Options: []string{}, + Reconnect: p.Reconnect, + ReconnectDelay: p.ReconnectDelay, + Autostart: p.Autostart, + StaleTimeout: p.StaleTimeout, + LimitCPU: p.LimitCPU, + LimitMemory: p.LimitMemory, + LimitWaitFor: p.LimitWaitFor, + } + + a.Options = make([]string, len(p.Options)) + copy(a.Options, p.Options) + + a.Input = make([]app.ConfigIO, len(p.Input)) + for x, input := range p.Input { + a.Input[x] = input.Unmarshal() + } + + a.Output = make([]app.ConfigIO, len(p.Output)) + for x, output := range p.Output { + a.Output[x] = output.Unmarshal() + } + + return a +} + +type Process struct { + ID string `json:"id"` + Owner string `json:"owner"` + Domain string `json:"domain"` + Reference string `json:"reference"` + Config ProcessConfig `json:"config"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + Order string `json:"order"` +} + +func MarshalProcess(a *app.Process) Process { + p := Process{ + ID: a.ID, + Owner: a.Owner, + Domain: a.Domain, + Reference: a.Reference, + Config: ProcessConfig{}, + CreatedAt: a.CreatedAt, + UpdatedAt: a.UpdatedAt, + Order: a.Order, + } + + p.Config.Marshal(a.Config) + + return p +} + +func UnmarshalProcess(p Process) *app.Process { + a := &app.Process{ + ID: p.ID, + Owner: p.Owner, + Domain: p.Domain, + Reference: p.Reference, + Config: &app.Config{}, + CreatedAt: p.CreatedAt, + UpdatedAt: p.UpdatedAt, + Order: p.Order, + } + + a.Config = p.Config.Unmarshal() + + return a +} + +type Domain struct { + Process map[string]Process `json:"process"` + Metadata map[string]map[string]interface{} `json:"metadata"` +} + +type Data struct { + Version uint64 `json:"version"` + + Process map[string]Process `json:"process"` + Domain map[string]Domain `json:"domain"` + Metadata struct { + System map[string]interface{} `json:"system"` + Process map[string]map[string]interface{} `json:"process"` + } `json:"metadata"` +} + +var version uint64 = 4 + +func NewData() Data { + c := Data{ + Version: version, + } + + c.Process = make(map[string]Process) + c.Domain = make(map[string]Domain) + c.Metadata.System = make(map[string]interface{}) + c.Metadata.Process = make(map[string]map[string]interface{}) + + return c +} + +func (c *Data) IsEmpty() bool { + if len(c.Process) != 0 { + return false + } + + if len(c.Domain) != 0 { + return false + } + + if len(c.Metadata.Process) != 0 { + return false + } + + if len(c.Metadata.System) != 0 { + return false + } + + return true +} diff --git a/restream/store/fixtures/v3_empty.json b/restream/store/json/fixtures/v3_empty.json similarity index 100% rename from restream/store/fixtures/v3_empty.json rename to restream/store/json/fixtures/v3_empty.json diff --git a/restream/store/fixtures/v4_empty.json b/restream/store/json/fixtures/v4_empty.json similarity index 100% rename from restream/store/fixtures/v4_empty.json rename to restream/store/json/fixtures/v4_empty.json diff --git a/restream/store/fixtures/v4_invalid.json b/restream/store/json/fixtures/v4_invalid.json similarity index 100% rename from restream/store/fixtures/v4_invalid.json rename to restream/store/json/fixtures/v4_invalid.json diff --git a/restream/store/json/json.go b/restream/store/json/json.go new file mode 100644 index 00000000..2ca51b65 --- /dev/null +++ b/restream/store/json/json.go @@ -0,0 +1,214 @@ +package json + +import ( + gojson "encoding/json" + "fmt" + "os" + "sync" + + "github.com/datarhei/core/v16/encoding/json" + "github.com/datarhei/core/v16/io/fs" + "github.com/datarhei/core/v16/log" + "github.com/datarhei/core/v16/restream/store" +) + +type Config struct { + Filesystem fs.Filesystem + Filepath string // Full path to the database file + Logger log.Logger +} + +type jsonStore struct { + fs fs.Filesystem + filepath string + logger log.Logger + + // Mutex to serialize access to the disk + lock sync.RWMutex +} + +func New(config Config) (store.Store, error) { + s := &jsonStore{ + fs: config.Filesystem, + filepath: config.Filepath, + logger: config.Logger, + } + + if len(s.filepath) == 0 { + s.filepath = "/db.json" + } + + if s.fs == nil { + return nil, fmt.Errorf("no valid filesystem provided") + } + + if s.logger == nil { + s.logger = log.New("") + } + + return s, nil +} + +func (s *jsonStore) Load() (store.Data, error) { + s.lock.Lock() + defer s.lock.Unlock() + + data := store.NewData() + + d, err := s.load(s.filepath, version) + if err != nil { + return data, err + } + + for _, process := range d.Process { + if data.Process[""] == nil { + data.Process[""] = map[string]store.Process{} + } + + p := data.Process[""][process.ID] + + p.Process = UnmarshalProcess(process) + p.Metadata = map[string]interface{}{} + + data.Process[""][process.ID] = p + } + + for pid, m := range d.Metadata.Process { + if data.Process[""] == nil { + data.Process[""] = map[string]store.Process{} + } + + p := data.Process[""][pid] + p.Metadata = m + data.Process[""][pid] = p + } + + for k, v := range d.Metadata.System { + data.Metadata[k] = v + } + + for name, domain := range d.Domain { + if data.Process[name] == nil { + data.Process[name] = map[string]store.Process{} + } + + for pid, process := range domain.Process { + p := data.Process[name][pid] + + p.Process = UnmarshalProcess(process) + p.Metadata = map[string]interface{}{} + + data.Process[name][pid] = p + } + + for pid, m := range domain.Metadata { + p := data.Process[name][pid] + p.Metadata = m + data.Process[name][pid] = p + } + } + + return data, nil +} + +func (s *jsonStore) Store(data store.Data) error { + r := NewData() + + for k, v := range data.Metadata { + r.Metadata.System[k] = v + } + + for domain, d := range data.Process { + for pid, p := range d { + if len(domain) == 0 { + r.Process[pid] = MarshalProcess(p.Process) + r.Metadata.Process[pid] = p.Metadata + } else { + x := r.Domain[domain] + if x.Process == nil { + x.Process = map[string]Process{} + } + + x.Process[pid] = MarshalProcess(p.Process) + + if x.Metadata == nil { + x.Metadata = map[string]map[string]interface{}{} + } + + x.Metadata[pid] = p.Metadata + + r.Domain[domain] = x + } + } + } + + s.lock.RLock() + defer s.lock.RUnlock() + + err := s.store(s.filepath, r) + if err != nil { + return fmt.Errorf("failed to store data: %w", err) + } + + return nil +} + +func (s *jsonStore) store(filepath string, data Data) error { + if data.Version != version { + return fmt.Errorf("invalid version (have: %d, want: %d)", data.Version, version) + } + + jsondata, err := gojson.MarshalIndent(&data, "", " ") + if err != nil { + return err + } + + _, _, err = s.fs.WriteFileSafe(filepath, jsondata) + if err != nil { + return err + } + + s.logger.WithField("file", filepath).Debug().Log("Stored data") + + return nil +} + +type storeVersion struct { + Version uint64 `json:"version"` +} + +func (s *jsonStore) load(filepath string, version uint64) (Data, error) { + r := NewData() + + _, err := s.fs.Stat(filepath) + if err != nil { + if os.IsNotExist(err) { + return r, nil + } + + return r, err + } + + jsondata, err := s.fs.ReadFile(filepath) + if err != nil { + return r, err + } + + var db storeVersion + + if err = gojson.Unmarshal(jsondata, &db); err != nil { + return r, json.FormatError(jsondata, err) + } + + if db.Version == version { + if err = gojson.Unmarshal(jsondata, &r); err != nil { + return r, json.FormatError(jsondata, err) + } + } else { + return r, fmt.Errorf("unsupported version of the DB file (want: %d, have: %d)", version, db.Version) + } + + s.logger.WithField("file", filepath).Debug().Log("Read data") + + return r, nil +} diff --git a/restream/store/json/json_test.go b/restream/store/json/json_test.go new file mode 100644 index 00000000..13d4ac08 --- /dev/null +++ b/restream/store/json/json_test.go @@ -0,0 +1,184 @@ +package json + +import ( + "testing" + + "github.com/datarhei/core/v16/io/fs" + "github.com/datarhei/core/v16/restream/app" + "github.com/datarhei/core/v16/restream/store" + "github.com/stretchr/testify/require" +) + +func getFS(t *testing.T) fs.Filesystem { + fs, err := fs.NewRootedDiskFilesystem(fs.RootedDiskConfig{ + Root: ".", + }) + require.NoError(t, err) + + info, err := fs.Stat("./fixtures/v4_empty.json") + require.NoError(t, err) + require.Equal(t, "/fixtures/v4_empty.json", info.Name()) + + return fs +} + +func TestNew(t *testing.T) { + store, err := New(Config{ + Filesystem: getFS(t), + }) + require.NoError(t, err) + require.NotEmpty(t, store) +} + +func TestStoreLoad(t *testing.T) { + memfs, err := fs.NewMemFilesystem(fs.MemConfig{}) + require.NoError(t, err) + + jsonstore, err := New(Config{ + Filesystem: memfs, + Filepath: "./db.json", + }) + require.NoError(t, err) + + data := store.NewData() + + data.Process[""] = make(map[string]store.Process) + p := store.Process{ + Process: &app.Process{ + ID: "foobar", + Owner: "me", + Domain: "", + Reference: "ref", + Config: &app.Config{ + ID: "foobar", + Reference: "ref", + Owner: "me", + Domain: "", + FFVersion: "5.1.3", + Input: []app.ConfigIO{}, + Output: []app.ConfigIO{}, + Options: []string{ + "42", + }, + Reconnect: true, + ReconnectDelay: 14, + Autostart: true, + StaleTimeout: 1, + LimitCPU: 2, + LimitMemory: 3, + LimitWaitFor: 4, + }, + CreatedAt: 0, + UpdatedAt: 0, + Order: "stop", + }, + Metadata: map[string]interface{}{ + "some": "data", + }, + } + data.Process[""]["foobar"] = p + + data.Process["domain"] = make(map[string]store.Process) + p = store.Process{ + Process: &app.Process{ + ID: "foobaz", + Owner: "you", + Domain: "domain", + Reference: "refref", + Config: &app.Config{ + ID: "foobaz", + Reference: "refref", + Owner: "you", + Domain: "domain", + FFVersion: "5.1.4", + Input: []app.ConfigIO{}, + Output: []app.ConfigIO{}, + Options: []string{ + "47", + }, + Reconnect: true, + ReconnectDelay: 24, + Autostart: true, + StaleTimeout: 21, + LimitCPU: 22, + LimitMemory: 23, + LimitWaitFor: 24, + }, + CreatedAt: 0, + UpdatedAt: 0, + Order: "stop", + }, + Metadata: map[string]interface{}{ + "some-more": "data", + }, + } + data.Process["domain"]["foobaz"] = p + + data.Metadata["foo"] = "bar" + + err = jsonstore.Store(data) + require.NoError(t, err) + + d, err := jsonstore.Load() + require.NoError(t, err) + + require.Equal(t, data, d) +} + +func TestLoad(t *testing.T) { + store, err := New(Config{ + Filesystem: getFS(t), + Filepath: "./fixtures/v4_empty.json", + }) + require.NoError(t, err) + + _, err = store.Load() + require.NoError(t, err) +} + +func TestLoadFailed(t *testing.T) { + store, err := New(Config{ + Filesystem: getFS(t), + Filepath: "./fixtures/v4_invalid.json", + }) + require.NoError(t, err) + + _, err = store.Load() + require.Error(t, err) +} + +func TestIsEmpty(t *testing.T) { + store, err := New(Config{ + Filesystem: getFS(t), + Filepath: "./fixtures/v4_empty.json", + }) + require.NoError(t, err) + + data, err := store.Load() + require.NoError(t, err) + require.Equal(t, true, len(data.Process) == 0) +} + +func TestNotExists(t *testing.T) { + store, err := New(Config{ + Filesystem: getFS(t), + Filepath: "./fixtures/v4_notexist.json", + }) + require.NoError(t, err) + + data, err := store.Load() + require.NoError(t, err) + require.Equal(t, true, len(data.Process) == 0) +} + +func TestInvalidVersion(t *testing.T) { + store, err := New(Config{ + Filesystem: getFS(t), + Filepath: "./fixtures/v3_empty.json", + }) + require.NoError(t, err) + + data, err := store.Load() + require.Error(t, err) + require.Equal(t, true, len(data.Process) == 0) +} diff --git a/restream/store/json_test.go b/restream/store/json_test.go deleted file mode 100644 index 8b2c4698..00000000 --- a/restream/store/json_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package store - -import ( - "testing" - - "github.com/datarhei/core/v16/io/fs" - "github.com/stretchr/testify/require" -) - -func getFS(t *testing.T) fs.Filesystem { - fs, err := fs.NewRootedDiskFilesystem(fs.RootedDiskConfig{ - Root: ".", - }) - require.NoError(t, err) - - info, err := fs.Stat("./fixtures/v4_empty.json") - require.NoError(t, err) - require.Equal(t, "/fixtures/v4_empty.json", info.Name()) - - return fs -} - -func TestNew(t *testing.T) { - store, err := NewJSON(JSONConfig{ - Filesystem: getFS(t), - }) - require.NoError(t, err) - require.NotEmpty(t, store) -} - -func TestLoad(t *testing.T) { - store, err := NewJSON(JSONConfig{ - Filesystem: getFS(t), - Filepath: "./fixtures/v4_empty.json", - }) - require.NoError(t, err) - - _, err = store.Load() - require.NoError(t, err) -} - -func TestLoadFailed(t *testing.T) { - store, err := NewJSON(JSONConfig{ - Filesystem: getFS(t), - Filepath: "./fixtures/v4_invalid.json", - }) - require.NoError(t, err) - - _, err = store.Load() - require.Error(t, err) -} - -func TestIsEmpty(t *testing.T) { - store, err := NewJSON(JSONConfig{ - Filesystem: getFS(t), - Filepath: "./fixtures/v4_empty.json", - }) - require.NoError(t, err) - - data, err := store.Load() - require.NoError(t, err) - require.Equal(t, true, data.IsEmpty()) -} - -func TestNotExists(t *testing.T) { - store, err := NewJSON(JSONConfig{ - Filesystem: getFS(t), - Filepath: "./fixtures/v4_notexist.json", - }) - require.NoError(t, err) - - data, err := store.Load() - require.NoError(t, err) - require.Equal(t, true, data.IsEmpty()) -} - -func TestStore(t *testing.T) { - fs := getFS(t) - fs.Remove("./fixtures/v4_store.json") - - store, err := NewJSON(JSONConfig{ - Filesystem: fs, - Filepath: "./fixtures/v4_store.json", - }) - require.NoError(t, err) - - data, err := store.Load() - require.NoError(t, err) - require.Equal(t, true, data.IsEmpty()) - - data.Metadata.System["somedata"] = "foobar" - - store.Store(data) - - data2, err := store.Load() - require.NoError(t, err) - require.Equal(t, data, data2) - - fs.Remove("./fixtures/v4_store.json") -} - -func TestInvalidVersion(t *testing.T) { - store, err := NewJSON(JSONConfig{ - Filesystem: getFS(t), - Filepath: "./fixtures/v3_empty.json", - }) - require.NoError(t, err) - - data, err := store.Load() - require.Error(t, err) - require.Equal(t, true, data.IsEmpty()) -} diff --git a/restream/store/store.go b/restream/store/store.go index 22fed6f4..7a33a903 100644 --- a/restream/store/store.go +++ b/restream/store/store.go @@ -1,9 +1,42 @@ package store +import "github.com/datarhei/core/v16/restream/app" + +type Process struct { + Process *app.Process + Metadata map[string]interface{} +} + +type Data struct { + Process map[string]map[string]Process + Metadata map[string]interface{} +} + +func (d *Data) IsEmpty() bool { + if len(d.Process) != 0 { + return false + } + + if len(d.Metadata) != 0 { + return false + } + + return true +} + type Store interface { // Load data from the store - Load() (StoreData, error) + Load() (Data, error) // Save data to the store - Store(data StoreData) error + Store(Data) error +} + +func NewData() Data { + c := Data{ + Process: make(map[string]map[string]Process), + Metadata: make(map[string]interface{}), + } + + return c } diff --git a/rtmp/channel.go b/rtmp/channel.go index 26206e5a..666dca83 100644 --- a/rtmp/channel.go +++ b/rtmp/channel.go @@ -93,7 +93,7 @@ type channel struct { isProxy bool } -func newChannel(conn connection, playPath string, reference string, remote net.Addr, streams []av.CodecData, isProxy bool, collector session.Collector) *channel { +func newChannel(conn connection, playPath, reference string, remote net.Addr, streams []av.CodecData, isProxy bool, collector session.Collector) *channel { ch := &channel{ path: playPath, reference: reference, diff --git a/rtmp/rtmp.go b/rtmp/rtmp.go index ea56e49c..38dd276d 100644 --- a/rtmp/rtmp.go +++ b/rtmp/rtmp.go @@ -12,6 +12,7 @@ import ( "time" "github.com/datarhei/core/v16/cluster/proxy" + "github.com/datarhei/core/v16/iam" "github.com/datarhei/core/v16/log" "github.com/datarhei/core/v16/session" @@ -59,6 +60,8 @@ type Config struct { TLSConfig *tls.Config Proxy proxy.ProxyReader + + IAM iam.IAM } // Server represents a RTMP server @@ -94,6 +97,8 @@ type server struct { lock sync.RWMutex proxy proxy.ProxyReader + + iam iam.IAM } // New creates a new RTMP server according to the given config @@ -107,11 +112,12 @@ func New(config Config) (Server, error) { } s := &server{ - app: config.App, + app: filepath.Join("/", config.App), token: config.Token, logger: config.Logger, collector: config.Collector, proxy: config.Proxy, + iam: config.IAM, } if s.collector == nil { @@ -189,81 +195,97 @@ func (s *server) Channels() []string { return channels } -func (s *server) log(who, action, path, message string, client net.Addr) { +func (s *server) log(who, what, action, path, message string, client net.Addr) { s.logger.Info().WithFields(log.Fields{ "who": who, + "what": what, "action": action, "path": path, "client": client.String(), }).Log(message) } -// getToken returns the path and the token found in the URL. If the token -// was part of the path, the token is removed from the path. The token in -// the query string takes precedence. The token in the path is assumed to -// be the last path element. -func getToken(u *url.URL) (string, string) { +// GetToken returns the path without the token and the token found in the URL. If the token +// was part of the path, the token is removed from the path. The token in the query string +// takes precedence. The token in the path is assumed to be the last path element. +func GetToken(u *url.URL) (string, string) { q := u.Query() - token := q.Get("token") - - if len(token) != 0 { + if q.Has("token") { // The token was in the query. Return the unmomdified path and the token - return u.Path, token + return u.Path, q.Get("token") } - pathElements := strings.Split(u.EscapedPath(), "/") + pathElements := splitPath(u.EscapedPath()) nPathElements := len(pathElements) - if nPathElements == 0 { + if nPathElements <= 1 { return u.Path, "" } // Return the path without the token - return strings.Join(pathElements[:nPathElements-1], "/"), pathElements[nPathElements-1] + return "/" + strings.Join(pathElements[:nPathElements-1], "/"), pathElements[nPathElements-1] +} + +func splitPath(path string) []string { + pathElements := strings.Split(filepath.Clean(path), "/") + + if len(pathElements) == 0 { + return pathElements + } + + if len(pathElements[0]) == 0 { + pathElements = pathElements[1:] + } + + return pathElements +} + +func removePathPrefix(path, prefix string) (string, string) { + prefix = filepath.Join("/", prefix) + return filepath.Join("/", strings.TrimPrefix(path, prefix+"/")), prefix } // handlePlay is called when a RTMP client wants to play a stream func (s *server) handlePlay(conn *rtmp.Conn) { - client := conn.NetConn().RemoteAddr() - defer conn.Close() - playPath := conn.URL.Path + remote := conn.NetConn().RemoteAddr() + playpath, token := GetToken(conn.URL) - // Check the token in the URL if one is required - if len(s.token) != 0 { - path, token := getToken(conn.URL) + playpath, _ = removePathPrefix(playpath, s.app) - if len(token) == 0 { - s.log("PLAY", "FORBIDDEN", path, "no streamkey provided", client) - return - } + identity, err := s.findIdentityFromStreamKey(token) + if err != nil { + s.logger.Debug().WithError(err).Log("invalid streamkey") + s.log(identity, "PLAY", "FORBIDDEN", playpath, "invalid streamkey ("+token+")", remote) + return + } - if s.token != token { - s.log("PLAY", "FORBIDDEN", path, "invalid streamkey ("+token+")", client) - return - } + domain := s.findDomainFromPlaypath(playpath) + resource := "rtmp:" + playpath - playPath = path + if !s.iam.Enforce(identity, domain, resource, "PLAY") { + s.log(identity, "PLAY", "FORBIDDEN", playpath, "access denied", remote) + return } // Look for the stream s.lock.RLock() - ch := s.channels[playPath] + ch := s.channels[playpath] s.lock.RUnlock() if ch == nil { // Check in the cluster for that stream url, err := s.proxy.GetURL("rtmp:" + conn.URL.Path) if err != nil { - s.log("PLAY", "NOTFOUND", conn.URL.Path, "", client) + s.log(identity, "PLAY", "NOTFOUND", conn.URL.Path, "", remote) return } src, err := avutil.Open(url) if err != nil { s.logger.Error().WithField("address", url).WithError(err).Log("Proxying address failed") - s.log("PLAY", "NOTFOUND", conn.URL.Path, "", client) + s.log(identity, "PLAY", "NOTFOUND", conn.URL.Path, "", remote) return } @@ -273,13 +295,13 @@ func (s *server) handlePlay(conn *rtmp.Conn) { wg.Add(1) go func() { - s.log("PLAY", "PROXYSTART", url, "", client) + s.log(identity, "PLAY", "PROXYSTART", url, "", remote) wg.Done() - err := s.publish(c, playPath, client, true) + err := s.publish(c, playpath, remote, identity, true) if err != nil { s.logger.Error().WithField("address", url).WithError(err).Log("Proxying address failed") } - s.log("PLAY", "PROXYSTOP", url, "", client) + s.log(identity, "PLAY", "PROXYSTOP", url, "", remote) }() // Wait for the goroutine to start @@ -310,7 +332,7 @@ func (s *server) handlePlay(conn *rtmp.Conn) { // Send the metadata to the client conn.WriteHeader(ch.streams) - s.log("PLAY", "START", playPath, "", client) + s.log(identity, "PLAY", "START", conn.URL.Path, "", remote) // Get a cursor and apply filters cursor := ch.queue.Oldest() @@ -337,64 +359,65 @@ func (s *server) handlePlay(conn *rtmp.Conn) { ch.RemoveSubscriber(id) - s.log("PLAY", "STOP", playPath, "", client) + s.log(identity, "PLAY", "STOP", playpath, "", remote) } else { - s.log("PLAY", "NOTFOUND", playPath, "", client) + s.log(identity, "PLAY", "NOTFOUND", playpath, "", remote) } } // handlePublish is called when a RTMP client wants to publish a stream func (s *server) handlePublish(conn *rtmp.Conn) { - client := conn.NetConn().RemoteAddr() - defer conn.Close() - playPath := conn.URL.Path + remote := conn.NetConn().RemoteAddr() + playpath, token := GetToken(conn.URL) - if len(s.token) != 0 { - path, token := getToken(conn.URL) + playpath, app := removePathPrefix(playpath, s.app) - if len(token) == 0 { - s.log("PLAY", "FORBIDDEN", path, "no streamkey provided", client) - return - } - - if s.token != token { - s.log("PLAY", "FORBIDDEN", path, "invalid streamkey ("+token+")", client) - return - } - - playPath = path - } - - // Check the app patch - if !strings.HasPrefix(playPath, s.app) { - s.log("PUBLISH", "FORBIDDEN", playPath, "invalid app", client) + identity, err := s.findIdentityFromStreamKey(token) + if err != nil { + s.logger.Debug().WithError(err).Log("invalid streamkey") + s.log(identity, "PUBLISH", "FORBIDDEN", playpath, "invalid streamkey ("+token+")", remote) return } - err := s.publish(conn, playPath, client, false) + // Check the app patch + if app != s.app { + s.log(identity, "PUBLISH", "FORBIDDEN", playpath, "invalid app", remote) + return + } + + domain := s.findDomainFromPlaypath(playpath) + resource := "rtmp:" + playpath + + if !s.iam.Enforce(identity, domain, resource, "PUBLISH") { + s.log(identity, "PUBLISH", "FORBIDDEN", playpath, "access denied", remote) + return + } + + err = s.publish(conn, playpath, remote, identity, false) if err != nil { - s.logger.WithField("path", playPath).WithError(err).Log("") + s.logger.WithField("path", conn.URL.Path).WithError(err).Log("") } } -func (s *server) publish(src connection, playPath string, client net.Addr, isProxy bool) error { +func (s *server) publish(src connection, playpath string, remote net.Addr, identity string, isProxy bool) error { // Check the streams if it contains any valid/known streams streams, _ := src.Streams() if len(streams) == 0 { - s.log("PUBLISH", "INVALID", playPath, "no streams available", client) + s.log(identity, "PUBLISH", "INVALID", playpath, "no streams available", remote) + return fmt.Errorf("no streams are available") } s.lock.Lock() - ch := s.channels[playPath] + ch := s.channels[playpath] if ch == nil { - reference := strings.TrimPrefix(strings.TrimSuffix(playPath, filepath.Ext(playPath)), s.app+"/") + reference := strings.TrimPrefix(strings.TrimSuffix(playpath, filepath.Ext(playpath)), s.app+"/") // Create a new channel - ch = newChannel(src, playPath, reference, client, streams, isProxy, s.collector) + ch = newChannel(src, playpath, reference, remote, streams, isProxy, s.collector) for _, stream := range streams { typ := stream.Type() @@ -407,7 +430,7 @@ func (s *server) publish(src connection, playPath string, client net.Addr, isPro } } - s.channels[playPath] = ch + s.channels[playpath] = ch } else { ch = nil } @@ -415,26 +438,75 @@ func (s *server) publish(src connection, playPath string, client net.Addr, isPro s.lock.Unlock() if ch == nil { - s.log("PUBLISH", "CONFLICT", playPath, "already publishing", client) + s.log(identity, "PUBLISH", "CONFLICT", playpath, "already publishing", remote) return fmt.Errorf("already publishing") } - s.log("PUBLISH", "START", playPath, "", client) + s.log(identity, "PUBLISH", "START", playpath, "", remote) for _, stream := range streams { - s.log("PUBLISH", "STREAM", playPath, stream.Type().String(), client) + s.log(identity, "PUBLISH", "STREAM", playpath, stream.Type().String(), remote) } // Ingest the data, blocks until done avutil.CopyPackets(ch.queue, src) s.lock.Lock() - delete(s.channels, playPath) + delete(s.channels, playpath) s.lock.Unlock() ch.Close() - s.log("PUBLISH", "STOP", playPath, "", client) + s.log(identity, "PUBLISH", "STOP", playpath, "", remote) return nil } + +func (s *server) findIdentityFromStreamKey(key string) (string, error) { + if len(key) == 0 { + return "$anon", nil + } + + var identity iam.IdentityVerifier + var err error + + var token string + + elements := strings.Split(key, ":") + if len(elements) == 1 { + identity = s.iam.GetDefaultVerifier() + token = elements[0] + } else { + identity, err = s.iam.GetVerifier(elements[0]) + token = elements[1] + } + + if err != nil { + return "$anon", nil + } + + if ok, err := identity.VerifyServiceToken(token); !ok { + return "$anon", fmt.Errorf("invalid token: %w", err) + } + + return identity.Name(), nil +} + +// findDomainFromPlaypath finds the domain in the path. The domain is +// the first path element. If there's only one path element, it is not +// considered the domain. It is assumed that the app is not part of +// the provided path. +func (s *server) findDomainFromPlaypath(path string) string { + elements := splitPath(path) + if len(elements) == 1 { + return "$none" + } + + domain := elements[0] + + if s.iam.HasDomain(domain) { + return domain + } + + return "$none" +} diff --git a/rtmp/rtmp_test.go b/rtmp/rtmp_test.go index 20bb5274..ec372690 100644 --- a/rtmp/rtmp_test.go +++ b/rtmp/rtmp_test.go @@ -18,9 +18,37 @@ func TestToken(t *testing.T) { u, err := url.Parse(d[0]) require.NoError(t, err) - path, token := getToken(u) + path, token := GetToken(u) require.Equal(t, d[1], path, "url=%s", u.String()) require.Equal(t, d[2], token, "url=%s", u.String()) } } + +func TestSplitPath(t *testing.T) { + data := map[string][]string{ + "/foo/bar": {"foo", "bar"}, + "foo/bar": {"foo", "bar"}, + "/foo/bar/": {"foo", "bar"}, + } + + for path, split := range data { + elms := splitPath(path) + + require.ElementsMatch(t, split, elms, "%s", path) + } +} + +func TestRemovePathPrefix(t *testing.T) { + data := [][]string{ + {"/foo/bar", "/foo", "/bar"}, + {"/foo/bar", "/fo", "/foo/bar"}, + {"/foo/bar/abc", "/foo/bar", "/abc"}, + } + + for _, d := range data { + x, _ := removePathPrefix(d[0], d[1]) + + require.Equal(t, d[2], x, "path=%s prefix=%s", d[0], d[1]) + } +} diff --git a/srt/channel.go b/srt/channel.go index 801d4f90..ee2fbad9 100644 --- a/srt/channel.go +++ b/srt/channel.go @@ -44,12 +44,13 @@ func (c *client) ticker(ctx context.Context) { ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() + stats := &srt.Statistics{} + for { select { case <-ctx.Done(): return case <-ticker.C: - stats := &srt.Statistics{} c.conn.Stats(stats) rxbytes := stats.Accumulated.ByteRecv diff --git a/srt/srt.go b/srt/srt.go index 208b9b29..77dbd5e9 100644 --- a/srt/srt.go +++ b/srt/srt.go @@ -5,14 +5,15 @@ import ( "context" "fmt" "net" - "regexp" "strings" "sync" "time" "github.com/datarhei/core/v16/cluster/proxy" + "github.com/datarhei/core/v16/iam" "github.com/datarhei/core/v16/log" "github.com/datarhei/core/v16/session" + "github.com/datarhei/core/v16/srt/url" srt "github.com/datarhei/gosrt" ) @@ -41,6 +42,8 @@ type Config struct { SRTLogTopics []string Proxy proxy.ProxyReader + + IAM iam.IAM } // Server represents a SRT server @@ -78,6 +81,8 @@ type server struct { srtlogLock sync.RWMutex proxy proxy.ProxyReader + + iam iam.IAM } func New(config Config) (Server, error) { @@ -86,6 +91,7 @@ func New(config Config) (Server, error) { token: config.Token, passphrase: config.Passphrase, collector: config.Collector, + iam: config.IAM, logger: config.Logger, proxy: config.Proxy, } @@ -270,177 +276,133 @@ func (s *server) log(handler, action, resource, message string, client net.Addr) }).Log(message) } -type streamInfo struct { - mode string - resource string - token string -} - -// parseStreamId parses a streamid of the form "#!:key=value,key=value,..." and -// returns a streamInfo. In case the stream couldn't be parsed, an error is returned. -func parseStreamId(streamid string) (streamInfo, error) { - si := streamInfo{} - - if strings.HasPrefix(streamid, "#!:") { - return parseOldStreamId(streamid) - } - - re := regexp.MustCompile(`,(token|mode):(.+)`) - - results := map[string]string{} - - idEnd := -1 - value := streamid - key := "" - - for { - matches := re.FindStringSubmatchIndex(value) - if matches == nil { - break - } - - if idEnd < 0 { - idEnd = matches[2] - 1 - } - - if len(key) != 0 { - results[key] = value[:matches[2]-1] - } - - key = value[matches[2]:matches[3]] - value = value[matches[4]:matches[5]] - - results[key] = value - } - - if idEnd < 0 { - idEnd = len(streamid) - } - - si.resource = streamid[:idEnd] - if token, ok := results["token"]; ok { - si.token = token - } - - if mode, ok := results["mode"]; ok { - si.mode = mode - } else { - si.mode = "request" - } - - return si, nil -} - -func parseOldStreamId(streamid string) (streamInfo, error) { - si := streamInfo{} - - if !strings.HasPrefix(streamid, "#!:") { - return si, fmt.Errorf("unknown streamid format") - } - - streamid = strings.TrimPrefix(streamid, "#!:") - - kvs := strings.Split(streamid, ",") - - splitFn := func(s, sep string) (string, string, error) { - splitted := strings.SplitN(s, sep, 2) - - if len(splitted) != 2 { - return "", "", fmt.Errorf("invalid key/value pair") - } - - return splitted[0], splitted[1], nil - } - - for _, kv := range kvs { - key, value, err := splitFn(kv, "=") - if err != nil { - continue - } - - switch key { - case "m": - si.mode = value - case "r": - si.resource = value - case "token": - si.token = value - default: - } - } - - return si, nil -} - func (s *server) handleConnect(req srt.ConnRequest) srt.ConnType { mode := srt.REJECT client := req.RemoteAddr() streamId := req.StreamId() - si, err := parseStreamId(streamId) + si, err := url.ParseStreamId(streamId) if err != nil { s.log("CONNECT", "INVALID", "", err.Error(), client) return srt.REJECT } - if len(si.resource) == 0 { + if len(si.Resource) == 0 { s.log("CONNECT", "INVALID", "", "stream resource not provided", client) return srt.REJECT } - if si.mode == "publish" { + if si.Mode == "publish" { mode = srt.PUBLISH - } else if si.mode == "request" { + } else if si.Mode == "request" { mode = srt.SUBSCRIBE } else { - s.log("CONNECT", "INVALID", si.resource, "invalid connection mode", client) + s.log("CONNECT", "INVALID", si.Resource, "invalid connection mode", client) return srt.REJECT } if len(s.passphrase) != 0 { if !req.IsEncrypted() { - s.log("CONNECT", "FORBIDDEN", si.resource, "connection has to be encrypted", client) + s.log("CONNECT", "FORBIDDEN", si.Resource, "connection has to be encrypted", client) return srt.REJECT } if err := req.SetPassphrase(s.passphrase); err != nil { - s.log("CONNECT", "FORBIDDEN", si.resource, err.Error(), client) + s.log("CONNECT", "FORBIDDEN", si.Resource, err.Error(), client) return srt.REJECT } } else { if req.IsEncrypted() { - s.log("CONNECT", "INVALID", si.resource, "connection must not be encrypted", client) + s.log("CONNECT", "INVALID", si.Resource, "connection must not be encrypted", client) return srt.REJECT } } - // Check the token - if len(s.token) != 0 && s.token != si.token { - s.log("CONNECT", "FORBIDDEN", si.resource, "invalid token ("+si.token+")", client) + identity, err := s.findIdentityFromToken(si.Token) + if err != nil { + s.logger.Debug().WithError(err).Log("invalid token") + s.log("PUBLISH", "FORBIDDEN", si.Resource, "invalid token", client) + return srt.REJECT + } + + domain := s.findDomainFromPlaypath(si.Resource) + resource := "srt:" + si.Resource + action := "PLAY" + if mode == srt.PUBLISH { + action = "PUBLISH" + } + + if !s.iam.Enforce(identity, domain, resource, action) { + s.log("PUBLISH", "FORBIDDEN", si.Resource, "access denied", client) return srt.REJECT } return mode } +func (s *server) handlePublish(conn srt.Conn) { + s.publish(conn, false) +} + +func (s *server) publish(conn srt.Conn, isProxy bool) error { + streamId := conn.StreamId() + client := conn.RemoteAddr() + + si, _ := url.ParseStreamId(streamId) + + // Look for the stream + s.lock.Lock() + ch := s.channels[si.Resource] + if ch == nil { + ch = newChannel(conn, si.Resource, isProxy, s.collector) + s.channels[si.Resource] = ch + } else { + ch = nil + } + s.lock.Unlock() + + if ch == nil { + s.log("PUBLISH", "CONFLICT", si.Resource, "already publishing", client) + conn.Close() + return fmt.Errorf("already publishing this resource") + } + + s.log("PUBLISH", "START", si.Resource, "", client) + + // Blocks until connection closes + ch.pubsub.Publish(conn) + + s.lock.Lock() + delete(s.channels, si.Resource) + s.lock.Unlock() + + ch.Close() + + s.log("PUBLISH", "STOP", si.Resource, "", client) + + conn.Close() + + return nil +} + func (s *server) handleSubscribe(conn srt.Conn) { defer conn.Close() streamId := conn.StreamId() client := conn.RemoteAddr() - si, _ := parseStreamId(streamId) + si, _ := url.ParseStreamId(streamId) // Look for the stream locally s.lock.RLock() - ch := s.channels[si.resource] + ch := s.channels[si.Resource] s.lock.RUnlock() if ch == nil { // Check in the cluster for the stream and proxy it - srturl, err := s.proxy.GetURL("srt:" + si.resource) + srturl, err := s.proxy.GetURL("srt:" + si.Resource) if err != nil { - s.log("SUBSCRIBE", "NOTFOUND", si.resource, "no publisher for this resource found", client) + s.log("SUBSCRIBE", "NOTFOUND", si.Resource, "no publisher for this resource found", client) return } @@ -449,13 +411,13 @@ func (s *server) handleSubscribe(conn srt.Conn) { host, err := config.UnmarshalURL(srturl) if err != nil { s.logger.Error().WithField("address", srturl).WithError(err).Log("Parsing proxy address failed") - s.log("SUBSCRIBE", "NOTFOUND", si.resource, "no publisher for this resource found", client) + s.log("SUBSCRIBE", "NOTFOUND", si.Resource, "no publisher for this resource found", client) return } src, err := srt.Dial("srt", host, config) if err != nil { s.logger.Error().WithField("address", srturl).WithError(err).Log("Proxying address failed") - s.log("SUBSCRIBE", "NOTFOUND", si.resource, "no publisher for this resource found", client) + s.log("SUBSCRIBE", "NOTFOUND", si.Resource, "no publisher for this resource found", client) return } @@ -481,7 +443,7 @@ func (s *server) handleSubscribe(conn srt.Conn) { for range ticker.C { s.lock.RLock() - ch = s.channels[si.resource] + ch = s.channels[si.Resource] s.lock.RUnlock() if ch != nil { @@ -497,62 +459,60 @@ func (s *server) handleSubscribe(conn srt.Conn) { } if ch != nil { - s.log("SUBSCRIBE", "START", si.resource, "", client) + s.log("SUBSCRIBE", "START", si.Resource, "", client) - id := ch.AddSubscriber(conn, si.resource) + id := ch.AddSubscriber(conn, si.Resource) // Blocks until connection closes ch.pubsub.Subscribe(conn) - s.log("SUBSCRIBE", "STOP", si.resource, "", client) + s.log("SUBSCRIBE", "STOP", si.Resource, "", client) ch.RemoveSubscriber(id) - - return } } -func (s *server) handlePublish(conn srt.Conn) { - s.publish(conn, false) -} +func (s *server) findIdentityFromToken(key string) (string, error) { + if len(key) == 0 { + return "$anon", nil + } -func (s *server) publish(conn srt.Conn, isProxy bool) error { - streamId := conn.StreamId() - client := conn.RemoteAddr() + var identity iam.IdentityVerifier + var err error - si, _ := parseStreamId(streamId) + var token string - // Look for the stream - s.lock.Lock() - ch := s.channels[si.resource] - if ch == nil { - ch = newChannel(conn, si.resource, isProxy, s.collector) - s.channels[si.resource] = ch + elements := strings.Split(key, ":") + if len(elements) == 1 { + identity = s.iam.GetDefaultVerifier() + token = elements[0] } else { - ch = nil - } - s.lock.Unlock() - - if ch == nil { - s.log("PUBLISH", "CONFLICT", si.resource, "already publishing", client) - conn.Close() - return fmt.Errorf("already publishing this resource") + identity, err = s.iam.GetVerifier(elements[0]) + token = elements[1] } - s.log("PUBLISH", "START", si.resource, "", client) + if err != nil { + return "$anon", nil + } - // Blocks until connection closes - ch.pubsub.Publish(conn) + if ok, err := identity.VerifyServiceToken(token); !ok { + return "$anon", fmt.Errorf("invalid token: %w", err) + } - s.lock.Lock() - delete(s.channels, si.resource) - s.lock.Unlock() - - ch.Close() - - s.log("PUBLISH", "STOP", si.resource, "", client) - - conn.Close() - - return nil + return identity.Name(), nil +} + +func (s *server) findDomainFromPlaypath(path string) string { + elements := strings.Split(path, "/") + if len(elements) == 1 { + return "$none" + } + + domain := elements[0] + + if s.iam.HasDomain(domain) { + return domain + } + + return "$none" } diff --git a/srt/srt_test.go b/srt/srt_test.go deleted file mode 100644 index 91ae7ed1..00000000 --- a/srt/srt_test.go +++ /dev/null @@ -1,44 +0,0 @@ -package srt - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestParseStreamId(t *testing.T) { - streamids := map[string]streamInfo{ - "bla": {resource: "bla", mode: "request"}, - "bla,mode:publish": {resource: "bla", mode: "publish"}, - "123456789": {resource: "123456789", mode: "request"}, - "bla,token:foobar": {resource: "bla", token: "foobar", mode: "request"}, - "bla,token:foo,bar": {resource: "bla", token: "foo,bar", mode: "request"}, - "123456789,mode:publish,token:foobar": {resource: "123456789", token: "foobar", mode: "publish"}, - "mode:publish": {resource: "mode:publish", mode: "request"}, - } - - for streamid, wantsi := range streamids { - si, err := parseStreamId(streamid) - - require.NoError(t, err) - require.Equal(t, wantsi, si) - } -} - -func TestParseOldStreamId(t *testing.T) { - streamids := map[string]streamInfo{ - "#!:": {}, - "#!:key=value": {}, - "#!:m=publish": {mode: "publish"}, - "#!:r=123456789": {resource: "123456789"}, - "#!:token=foobar": {token: "foobar"}, - "#!:token=foo,bar": {token: "foo"}, - "#!:m=publish,r=123456789,token=foobar": {mode: "publish", resource: "123456789", token: "foobar"}, - } - - for streamid, wantsi := range streamids { - si, _ := parseOldStreamId(streamid) - - require.Equal(t, wantsi, si) - } -} diff --git a/srt/url/url.go b/srt/url/url.go new file mode 100644 index 00000000..0ec890b1 --- /dev/null +++ b/srt/url/url.go @@ -0,0 +1,206 @@ +package url + +import ( + "fmt" + neturl "net/url" + "regexp" + "strings" +) + +type URL struct { + Scheme string + Host string + StreamId string + Options neturl.Values +} + +func Parse(srturl string) (*URL, error) { + u, err := neturl.Parse(srturl) + if err != nil { + return nil, err + } + + if u.Scheme != "srt" { + return nil, fmt.Errorf("invalid SRT url") + } + + options := u.Query() + streamid := options.Get("streamid") + options.Del("streamid") + + su := &URL{ + Scheme: "srt", + Host: u.Host, + StreamId: streamid, + Options: options, + } + + return su, nil +} + +func (su *URL) String() string { + options, _ := neturl.ParseQuery(su.Options.Encode()) + options.Set("streamid", su.StreamId) + + u := neturl.URL{ + Scheme: su.Scheme, + Host: su.Host, + RawQuery: options.Encode(), + } + + return u.String() +} + +func (su *URL) StreamInfo() (*StreamInfo, error) { + s, err := ParseStreamId(su.StreamId) + if err != nil { + return nil, err + } + + return &s, nil +} + +func (su *URL) SetStreamInfo(si *StreamInfo) { + su.StreamId = si.String() +} + +func (su *URL) Hostname() string { + u := neturl.URL{ + Host: su.Host, + } + + return u.Hostname() +} + +func (su *URL) Port() string { + u := neturl.URL{ + Host: su.Host, + } + + return u.Port() +} + +type StreamInfo struct { + Mode string + Resource string + Token string +} + +func (si *StreamInfo) String() string { + streamid := si.Resource + + if si.Mode != "request" { + streamid += ",mode:" + si.Mode + } + + if len(si.Token) != 0 { + streamid += ",token:" + si.Token + } + + return streamid +} + +// ParseStreamId parses a streamid. If the streamid is in the old format +// it is detected and parsed accordingly. Otherwith the new simplified +// format will be assumed. +// +// resource[,token:{token}]?[,mode:(publish|*request)]? +// +// If the mode is not provided, "request" will be assumed. +func ParseStreamId(streamid string) (StreamInfo, error) { + si := StreamInfo{} + + if strings.HasPrefix(streamid, "#!:") { + return ParseDeprecatedStreamId(streamid) + } + + re := regexp.MustCompile(`,(token|mode):(.+)`) + + results := map[string]string{} + + idEnd := -1 + value := streamid + key := "" + + for { + matches := re.FindStringSubmatchIndex(value) + if matches == nil { + break + } + + if idEnd < 0 { + idEnd = matches[2] - 1 + } + + if len(key) != 0 { + results[key] = value[:matches[2]-1] + } + + key = value[matches[2]:matches[3]] + value = value[matches[4]:matches[5]] + + results[key] = value + } + + if idEnd < 0 { + idEnd = len(streamid) + } + + si.Resource = streamid[:idEnd] + if token, ok := results["token"]; ok { + si.Token = token + } + + if mode, ok := results["mode"]; ok { + si.Mode = mode + } else { + si.Mode = "request" + } + + return si, nil +} + +// ParseDeprecatedStreamId parses a streamid in the old format. The old format +// is based on the recommendation of the SRT specs, but with the special +// character it contains it can cause some trouble in clients (e.g. kiloview +// doesn't like the = character). +func ParseDeprecatedStreamId(streamid string) (StreamInfo, error) { + si := StreamInfo{Mode: "request"} + + if !strings.HasPrefix(streamid, "#!:") { + return si, fmt.Errorf("unknown streamid format") + } + + streamid = strings.TrimPrefix(streamid, "#!:") + + kvs := strings.Split(streamid, ",") + + split := func(s, sep string) (string, string, error) { + splitted := strings.SplitN(s, sep, 2) + + if len(splitted) != 2 { + return "", "", fmt.Errorf("invalid key/value pair") + } + + return splitted[0], splitted[1], nil + } + + for _, kv := range kvs { + key, value, err := split(kv, "=") + if err != nil { + continue + } + + switch key { + case "m": + si.Mode = value + case "r": + si.Resource = value + case "token": + si.Token = value + default: + } + } + + return si, nil +} diff --git a/srt/url/url_test.go b/srt/url/url_test.go new file mode 100644 index 00000000..4118b638 --- /dev/null +++ b/srt/url/url_test.go @@ -0,0 +1,67 @@ +package url + +import ( + "net/url" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParse(t *testing.T) { + srturl := "srt://127.0.0.1:6000?mode=caller&passphrase=foobar&streamid=" + url.QueryEscape("#!:m=publish,r=123456,token=bla") + + u, err := Parse(srturl) + require.NoError(t, err) + + require.Equal(t, "srt", u.Scheme) + require.Equal(t, "127.0.0.1:6000", u.Host) + require.Equal(t, "#!:m=publish,r=123456,token=bla", u.StreamId) + + si, err := u.StreamInfo() + require.NoError(t, err) + require.Equal(t, "publish", si.Mode) + require.Equal(t, "123456", si.Resource) + require.Equal(t, "bla", si.Token) + + require.Equal(t, srturl, u.String()) + + srturl = "srt://127.0.0.1:6000?mode=caller&passphrase=foobar&streamid=" + url.QueryEscape("123456,mode:publish,token:bla") + + u, err = Parse(srturl) + require.NoError(t, err) + + require.Equal(t, "srt", u.Scheme) + require.Equal(t, "127.0.0.1:6000", u.Host) + require.Equal(t, "123456,mode:publish,token:bla", u.StreamId) + + si, err = u.StreamInfo() + require.NoError(t, err) + require.Equal(t, "publish", si.Mode) + require.Equal(t, "123456", si.Resource) + require.Equal(t, "bla", si.Token) + + require.Equal(t, srturl, u.String()) +} + +func TestParseStreamId(t *testing.T) { + streamids := map[string]StreamInfo{ + "": {Mode: "request"}, + "bla": {Mode: "request", Resource: "bla"}, + "bla,token=foobar": {Mode: "request", Resource: "bla,token=foobar"}, + "bla,token:foobar": {Mode: "request", Resource: "bla", Token: "foobar"}, + "bla,token:foobar,mode:publish": {Mode: "publish", Resource: "bla", Token: "foobar"}, + "#!:": {Mode: "request"}, + "#!:key=value": {Mode: "request"}, + "#!:m=publish": {Mode: "publish"}, + "#!:r=123456789": {Mode: "request", Resource: "123456789"}, + "#!:token=foobar": {Mode: "request", Token: "foobar"}, + "#!:token=foo,bar": {Mode: "request", Token: "foo"}, + "#!:m=publish,r=123456789,token=foobar": {Mode: "publish", Resource: "123456789", Token: "foobar"}, + } + + for streamid, wantsi := range streamids { + si, err := ParseStreamId(streamid) + require.NoError(t, err) + require.Equal(t, wantsi, si, streamid) + } +} diff --git a/vendor/github.com/Knetic/govaluate/.gitignore b/vendor/github.com/Knetic/govaluate/.gitignore new file mode 100644 index 00000000..5ac0c3fc --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/.gitignore @@ -0,0 +1,28 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test +coverage.out + +manual_test.go +*.out +*.err diff --git a/vendor/github.com/Knetic/govaluate/.travis.yml b/vendor/github.com/Knetic/govaluate/.travis.yml new file mode 100644 index 00000000..f6c430f1 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/.travis.yml @@ -0,0 +1,10 @@ +language: go + +script: ./test.sh + +go: + - 1.2 + - 1.3 + - 1.4 + - 1.5 + - 1.6 diff --git a/vendor/github.com/Knetic/govaluate/CONTRIBUTORS b/vendor/github.com/Knetic/govaluate/CONTRIBUTORS new file mode 100644 index 00000000..c1a7fe42 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/CONTRIBUTORS @@ -0,0 +1,15 @@ +This library was authored by George Lester, and contains contributions from: + +vjeantet (regex support) +iasci (ternary operator) +oxtoacart (parameter structures, deferred parameter retrieval) +wmiller848 (bitwise operators) +prashantv (optimization of bools) +dpaolella (exposure of variables used in an expression) +benpaxton (fix for missing type checks during literal elide process) +abrander (panic-finding testing tool, float32 conversions) +xfennec (fix for dates being parsed in the current Location) +bgaifullin (lifting restriction on complex/struct types) +gautambt (hexadecimal literals) +felixonmars (fix multiple typos in test names) +sambonfire (automatic type conversion for accessor function calls) \ No newline at end of file diff --git a/vendor/github.com/Knetic/govaluate/EvaluableExpression.go b/vendor/github.com/Knetic/govaluate/EvaluableExpression.go new file mode 100644 index 00000000..a5fe50d4 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/EvaluableExpression.go @@ -0,0 +1,276 @@ +package govaluate + +import ( + "errors" + "fmt" +) + +const isoDateFormat string = "2006-01-02T15:04:05.999999999Z0700" +const shortCircuitHolder int = -1 + +var DUMMY_PARAMETERS = MapParameters(map[string]interface{}{}) + +/* + EvaluableExpression represents a set of ExpressionTokens which, taken together, + are an expression that can be evaluated down into a single value. +*/ +type EvaluableExpression struct { + + /* + Represents the query format used to output dates. Typically only used when creating SQL or Mongo queries from an expression. + Defaults to the complete ISO8601 format, including nanoseconds. + */ + QueryDateFormat string + + /* + Whether or not to safely check types when evaluating. + If true, this library will return error messages when invalid types are used. + If false, the library will panic when operators encounter types they can't use. + + This is exclusively for users who need to squeeze every ounce of speed out of the library as they can, + and you should only set this to false if you know exactly what you're doing. + */ + ChecksTypes bool + + tokens []ExpressionToken + evaluationStages *evaluationStage + inputExpression string +} + +/* + Parses a new EvaluableExpression from the given [expression] string. + Returns an error if the given expression has invalid syntax. +*/ +func NewEvaluableExpression(expression string) (*EvaluableExpression, error) { + + functions := make(map[string]ExpressionFunction) + return NewEvaluableExpressionWithFunctions(expression, functions) +} + +/* + Similar to [NewEvaluableExpression], except that instead of a string, an already-tokenized expression is given. + This is useful in cases where you may be generating an expression automatically, or using some other parser (e.g., to parse from a query language) +*/ +func NewEvaluableExpressionFromTokens(tokens []ExpressionToken) (*EvaluableExpression, error) { + + var ret *EvaluableExpression + var err error + + ret = new(EvaluableExpression) + ret.QueryDateFormat = isoDateFormat + + err = checkBalance(tokens) + if err != nil { + return nil, err + } + + err = checkExpressionSyntax(tokens) + if err != nil { + return nil, err + } + + ret.tokens, err = optimizeTokens(tokens) + if err != nil { + return nil, err + } + + ret.evaluationStages, err = planStages(ret.tokens) + if err != nil { + return nil, err + } + + ret.ChecksTypes = true + return ret, nil +} + +/* + Similar to [NewEvaluableExpression], except enables the use of user-defined functions. + Functions passed into this will be available to the expression. +*/ +func NewEvaluableExpressionWithFunctions(expression string, functions map[string]ExpressionFunction) (*EvaluableExpression, error) { + + var ret *EvaluableExpression + var err error + + ret = new(EvaluableExpression) + ret.QueryDateFormat = isoDateFormat + ret.inputExpression = expression + + ret.tokens, err = parseTokens(expression, functions) + if err != nil { + return nil, err + } + + err = checkBalance(ret.tokens) + if err != nil { + return nil, err + } + + err = checkExpressionSyntax(ret.tokens) + if err != nil { + return nil, err + } + + ret.tokens, err = optimizeTokens(ret.tokens) + if err != nil { + return nil, err + } + + ret.evaluationStages, err = planStages(ret.tokens) + if err != nil { + return nil, err + } + + ret.ChecksTypes = true + return ret, nil +} + +/* + Same as `Eval`, but automatically wraps a map of parameters into a `govalute.Parameters` structure. +*/ +func (this EvaluableExpression) Evaluate(parameters map[string]interface{}) (interface{}, error) { + + if parameters == nil { + return this.Eval(nil) + } + + return this.Eval(MapParameters(parameters)) +} + +/* + Runs the entire expression using the given [parameters]. + e.g., If the expression contains a reference to the variable "foo", it will be taken from `parameters.Get("foo")`. + + This function returns errors if the combination of expression and parameters cannot be run, + such as if a variable in the expression is not present in [parameters]. + + In all non-error circumstances, this returns the single value result of the expression and parameters given. + e.g., if the expression is "1 + 1", this will return 2.0. + e.g., if the expression is "foo + 1" and parameters contains "foo" = 2, this will return 3.0 +*/ +func (this EvaluableExpression) Eval(parameters Parameters) (interface{}, error) { + + if this.evaluationStages == nil { + return nil, nil + } + + if parameters != nil { + parameters = &sanitizedParameters{parameters} + } else { + parameters = DUMMY_PARAMETERS + } + + return this.evaluateStage(this.evaluationStages, parameters) +} + +func (this EvaluableExpression) evaluateStage(stage *evaluationStage, parameters Parameters) (interface{}, error) { + + var left, right interface{} + var err error + + if stage.leftStage != nil { + left, err = this.evaluateStage(stage.leftStage, parameters) + if err != nil { + return nil, err + } + } + + if stage.isShortCircuitable() { + switch stage.symbol { + case AND: + if left == false { + return false, nil + } + case OR: + if left == true { + return true, nil + } + case COALESCE: + if left != nil { + return left, nil + } + + case TERNARY_TRUE: + if left == false { + right = shortCircuitHolder + } + case TERNARY_FALSE: + if left != nil { + right = shortCircuitHolder + } + } + } + + if right != shortCircuitHolder && stage.rightStage != nil { + right, err = this.evaluateStage(stage.rightStage, parameters) + if err != nil { + return nil, err + } + } + + if this.ChecksTypes { + if stage.typeCheck == nil { + + err = typeCheck(stage.leftTypeCheck, left, stage.symbol, stage.typeErrorFormat) + if err != nil { + return nil, err + } + + err = typeCheck(stage.rightTypeCheck, right, stage.symbol, stage.typeErrorFormat) + if err != nil { + return nil, err + } + } else { + // special case where the type check needs to know both sides to determine if the operator can handle it + if !stage.typeCheck(left, right) { + errorMsg := fmt.Sprintf(stage.typeErrorFormat, left, stage.symbol.String()) + return nil, errors.New(errorMsg) + } + } + } + + return stage.operator(left, right, parameters) +} + +func typeCheck(check stageTypeCheck, value interface{}, symbol OperatorSymbol, format string) error { + + if check == nil { + return nil + } + + if check(value) { + return nil + } + + errorMsg := fmt.Sprintf(format, value, symbol.String()) + return errors.New(errorMsg) +} + +/* + Returns an array representing the ExpressionTokens that make up this expression. +*/ +func (this EvaluableExpression) Tokens() []ExpressionToken { + + return this.tokens +} + +/* + Returns the original expression used to create this EvaluableExpression. +*/ +func (this EvaluableExpression) String() string { + + return this.inputExpression +} + +/* + Returns an array representing the variables contained in this EvaluableExpression. +*/ +func (this EvaluableExpression) Vars() []string { + var varlist []string + for _, val := range this.Tokens() { + if val.Kind == VARIABLE { + varlist = append(varlist, val.Value.(string)) + } + } + return varlist +} diff --git a/vendor/github.com/Knetic/govaluate/EvaluableExpression_sql.go b/vendor/github.com/Knetic/govaluate/EvaluableExpression_sql.go new file mode 100644 index 00000000..7e0ad1c8 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/EvaluableExpression_sql.go @@ -0,0 +1,167 @@ +package govaluate + +import ( + "errors" + "fmt" + "regexp" + "time" +) + +/* + Returns a string representing this expression as if it were written in SQL. + This function assumes that all parameters exist within the same table, and that the table essentially represents + a serialized object of some sort (e.g., hibernate). + If your data model is more normalized, you may need to consider iterating through each actual token given by `Tokens()` + to create your query. + + Boolean values are considered to be "1" for true, "0" for false. + + Times are formatted according to this.QueryDateFormat. +*/ +func (this EvaluableExpression) ToSQLQuery() (string, error) { + + var stream *tokenStream + var transactions *expressionOutputStream + var transaction string + var err error + + stream = newTokenStream(this.tokens) + transactions = new(expressionOutputStream) + + for stream.hasNext() { + + transaction, err = this.findNextSQLString(stream, transactions) + if err != nil { + return "", err + } + + transactions.add(transaction) + } + + return transactions.createString(" "), nil +} + +func (this EvaluableExpression) findNextSQLString(stream *tokenStream, transactions *expressionOutputStream) (string, error) { + + var token ExpressionToken + var ret string + + token = stream.next() + + switch token.Kind { + + case STRING: + ret = fmt.Sprintf("'%v'", token.Value) + case PATTERN: + ret = fmt.Sprintf("'%s'", token.Value.(*regexp.Regexp).String()) + case TIME: + ret = fmt.Sprintf("'%s'", token.Value.(time.Time).Format(this.QueryDateFormat)) + + case LOGICALOP: + switch logicalSymbols[token.Value.(string)] { + + case AND: + ret = "AND" + case OR: + ret = "OR" + } + + case BOOLEAN: + if token.Value.(bool) { + ret = "1" + } else { + ret = "0" + } + + case VARIABLE: + ret = fmt.Sprintf("[%s]", token.Value.(string)) + + case NUMERIC: + ret = fmt.Sprintf("%g", token.Value.(float64)) + + case COMPARATOR: + switch comparatorSymbols[token.Value.(string)] { + + case EQ: + ret = "=" + case NEQ: + ret = "<>" + case REQ: + ret = "RLIKE" + case NREQ: + ret = "NOT RLIKE" + default: + ret = fmt.Sprintf("%s", token.Value.(string)) + } + + case TERNARY: + + switch ternarySymbols[token.Value.(string)] { + + case COALESCE: + + left := transactions.rollback() + right, err := this.findNextSQLString(stream, transactions) + if err != nil { + return "", err + } + + ret = fmt.Sprintf("COALESCE(%v, %v)", left, right) + case TERNARY_TRUE: + fallthrough + case TERNARY_FALSE: + return "", errors.New("Ternary operators are unsupported in SQL output") + } + case PREFIX: + switch prefixSymbols[token.Value.(string)] { + + case INVERT: + ret = fmt.Sprintf("NOT") + default: + + right, err := this.findNextSQLString(stream, transactions) + if err != nil { + return "", err + } + + ret = fmt.Sprintf("%s%s", token.Value.(string), right) + } + case MODIFIER: + + switch modifierSymbols[token.Value.(string)] { + + case EXPONENT: + + left := transactions.rollback() + right, err := this.findNextSQLString(stream, transactions) + if err != nil { + return "", err + } + + ret = fmt.Sprintf("POW(%s, %s)", left, right) + case MODULUS: + + left := transactions.rollback() + right, err := this.findNextSQLString(stream, transactions) + if err != nil { + return "", err + } + + ret = fmt.Sprintf("MOD(%s, %s)", left, right) + default: + ret = fmt.Sprintf("%s", token.Value.(string)) + } + case CLAUSE: + ret = "(" + case CLAUSE_CLOSE: + ret = ")" + case SEPARATOR: + ret = "," + + default: + errorMsg := fmt.Sprintf("Unrecognized query token '%s' of kind '%s'", token.Value, token.Kind) + return "", errors.New(errorMsg) + } + + return ret, nil +} diff --git a/vendor/github.com/Knetic/govaluate/ExpressionToken.go b/vendor/github.com/Knetic/govaluate/ExpressionToken.go new file mode 100644 index 00000000..f849f381 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/ExpressionToken.go @@ -0,0 +1,9 @@ +package govaluate + +/* + Represents a single parsed token. +*/ +type ExpressionToken struct { + Kind TokenKind + Value interface{} +} diff --git a/vendor/github.com/Knetic/govaluate/LICENSE b/vendor/github.com/Knetic/govaluate/LICENSE new file mode 100644 index 00000000..0ef0f41e --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2014-2016 George Lester + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/Knetic/govaluate/MANUAL.md b/vendor/github.com/Knetic/govaluate/MANUAL.md new file mode 100644 index 00000000..e0658285 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/MANUAL.md @@ -0,0 +1,176 @@ +govaluate +==== + +This library contains quite a lot of functionality, this document is meant to be formal documentation on the operators and features of it. +Some of this documentation may duplicate what's in README.md, but should never conflict. + +# Types + +This library only officially deals with four types; `float64`, `bool`, `string`, and arrays. + +All numeric literals, with or without a radix, will be converted to `float64` for evaluation. For instance; in practice, there is no difference between the literals "1.0" and "1", they both end up as `float64`. This matters to users because if you intend to return numeric values from your expressions, then the returned value will be `float64`, not any other numeric type. + +Any string _literal_ (not parameter) which is interpretable as a date will be converted to a `float64` representation of that date's unix time. Any `time.Time` parameters will not be operable with these date literals; such parameters will need to use the `time.Time.Unix()` method to get a numeric representation. + +Arrays are untyped, and can be mixed-type. Internally they're all just `interface{}`. Only two operators can interact with arrays, `IN` and `,`. All other operators will refuse to operate on arrays. + +# Operators + +## Modifiers + +### Addition, concatenation `+` + +If either left or right sides of the `+` operator are a `string`, then this operator will perform string concatenation and return that result. If neither are string, then both must be numeric, and this will return a numeric result. + +Any other case is invalid. + +### Arithmetic `-` `*` `/` `**` `%` + +`**` refers to "take to the power of". For instance, `3 ** 4` == 81. + +* _Left side_: numeric +* _Right side_: numeric +* _Returns_: numeric + +### Bitwise shifts, masks `>>` `<<` `|` `&` `^` + +All of these operators convert their `float64` left and right sides to `int64`, perform their operation, and then convert back. +Given how this library assumes numeric are represented (as `float64`), it is unlikely that this behavior will change, even though it may cause havoc with extremely large or small numbers. + +* _Left side_: numeric +* _Right side_: numeric +* _Returns_: numeric + +### Negation `-` + +Prefix only. This can never have a left-hand value. + +* _Right side_: numeric +* _Returns_: numeric + +### Inversion `!` + +Prefix only. This can never have a left-hand value. + +* _Right side_: bool +* _Returns_: bool + +### Bitwise NOT `~` + +Prefix only. This can never have a left-hand value. + +* _Right side_: numeric +* _Returns_: numeric + +## Logical Operators + +For all logical operators, this library will short-circuit the operation if the left-hand side is sufficient to determine what to do. For instance, `true || expensiveOperation()` will not actually call `expensiveOperation()`, since it knows the left-hand side is `true`. + +### Logical AND/OR `&&` `||` + +* _Left side_: bool +* _Right side_: bool +* _Returns_: bool + +### Ternary true `?` + +Checks if the left side is `true`. If so, returns the right side. If the left side is `false`, returns `nil`. +In practice, this is commonly used with the other ternary operator. + +* _Left side_: bool +* _Right side_: Any type. +* _Returns_: Right side or `nil` + +### Ternary false `:` + +Checks if the left side is `nil`. If so, returns the right side. If the left side is non-nil, returns the left side. +In practice, this is commonly used with the other ternary operator. + +* _Left side_: Any type. +* _Right side_: Any type. +* _Returns_: Right side or `nil` + +### Null coalescence `??` + +Similar to the C# operator. If the left value is non-nil, it returns that. If not, then the right-value is returned. + +* _Left side_: Any type. +* _Right side_: Any type. +* _Returns_: No specific type - whichever is passed to it. + +## Comparators + +### Numeric/lexicographic comparators `>` `<` `>=` `<=` + +If both sides are numeric, this returns the usual greater/lesser behavior that would be expected. +If both sides are string, this returns the lexicographic comparison of the strings. This uses Go's standard lexicographic compare. + +* _Accepts_: Left and right side must either be both string, or both numeric. +* _Returns_: bool + +### Regex comparators `=~` `!~` + +These use go's standard `regexp` flavor of regex. The left side is expected to be the candidate string, the right side is the pattern. `=~` returns whether or not the candidate string matches the regex pattern given on the right. `!~` is the inverted version of the same logic. + +* _Left side_: string +* _Right side_: string +* _Returns_: bool + +## Arrays + +### Separator `,` + +The separator, always paired with parenthesis, creates arrays. It must always have both a left and right-hand value, so for instance `(, 0)` and `(0,)` are invalid uses of it. + +Again, this should always be used with parenthesis; like `(1, 2, 3, 4)`. + +### Membership `IN` + +The only operator with a text name, this operator checks the right-hand side array to see if it contains a value that is equal to the left-side value. +Equality is determined by the use of the `==` operator, and this library doesn't check types between the values. Any two values, when cast to `interface{}`, and can still be checked for equality with `==` will act as expected. + +Note that you can use a parameter for the array, but it must be an `[]interface{}`. + +* _Left side_: Any type. +* _Right side_: array +* _Returns_: bool + +# Parameters + +Parameters must be passed in every time the expression is evaluated. Parameters can be of any type, but will not cause errors unless actually used in an erroneous way. There is no difference in behavior for any of the above operators for parameters - they are type checked when used. + +All `int` and `float` values of any width will be converted to `float64` before use. + +At no point is the parameter structure, or any value thereof, modified by this library. + +## Alternates to maps + +The default form of parameters as a map may not serve your use case. You may have parameters in some other structure, you may want to change the no-parameter-found behavior, or maybe even just have some debugging print statements invoked when a parameter is accessed. + +To do this, define a type that implements the `govaluate.Parameters` interface. When you want to evaluate, instead call `EvaluableExpression.Eval` and pass your parameter structure. + +# Functions + +During expression parsing (_not_ evaluation), a map of functions can be given to `govaluate.NewEvaluableExpressionWithFunctions` (the lengthiest and finest of function names). The resultant expression will be able to invoke those functions during evaluation. Once parsed, an expression cannot have functions added or removed - a new expression will need to be created if you want to change the functions, or behavior of said functions. + +Functions always take the form `()`, including parens. Functions can have an empty list of parameters, like `()`, but still must have parens. + +If the expression contains something that looks like it ought to be a function (such as `foo()`), but no such function was given to it, it will error on parsing. + +Functions must be of type `map[string]govaluate.ExpressionFunction`. `ExpressionFunction`, for brevity, has the following signature: + +`func(args ...interface{}) (interface{}, error)` + +Where `args` is whatever is passed to the function when called. If a non-nil error is returned from a function during evaluation, the evaluation stops and ultimately returns that error to the caller of `Evaluate()` or `Eval()`. + +## Built-in functions + +There aren't any builtin functions. The author is opposed to maintaining a standard library of functions to be used. + +Every use case of this library is different, and even in simple use cases (such as parameters, see above) different users need different behavior, naming, or even functionality. The author prefers that users make their own decisions about what functions they need, and how they operate. + +# Equality + +The `==` and `!=` operators involve a moderately complex workflow. They use [`reflect.DeepEqual`](https://golang.org/pkg/reflect/#DeepEqual). This is for complicated reasons, but there are some types in Go that cannot be compared with the native `==` operator. Arrays, in particular, cannot be compared - Go will panic if you try. One might assume this could be handled with the type checking system in `govaluate`, but unfortunately without reflection there is no way to know if a variable is a slice/array. Worse, structs can be incomparable if they _contain incomparable types_. + +It's all very complicated. Fortunately, Go includes the `reflect.DeepEqual` function to handle all the edge cases. Currently, `govaluate` uses that for all equality/inequality. diff --git a/vendor/github.com/Knetic/govaluate/OperatorSymbol.go b/vendor/github.com/Knetic/govaluate/OperatorSymbol.go new file mode 100644 index 00000000..4b810658 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/OperatorSymbol.go @@ -0,0 +1,309 @@ +package govaluate + +/* + Represents the valid symbols for operators. + +*/ +type OperatorSymbol int + +const ( + VALUE OperatorSymbol = iota + LITERAL + NOOP + EQ + NEQ + GT + LT + GTE + LTE + REQ + NREQ + IN + + AND + OR + + PLUS + MINUS + BITWISE_AND + BITWISE_OR + BITWISE_XOR + BITWISE_LSHIFT + BITWISE_RSHIFT + MULTIPLY + DIVIDE + MODULUS + EXPONENT + + NEGATE + INVERT + BITWISE_NOT + + TERNARY_TRUE + TERNARY_FALSE + COALESCE + + FUNCTIONAL + ACCESS + SEPARATE +) + +type operatorPrecedence int + +const ( + noopPrecedence operatorPrecedence = iota + valuePrecedence + functionalPrecedence + prefixPrecedence + exponentialPrecedence + additivePrecedence + bitwisePrecedence + bitwiseShiftPrecedence + multiplicativePrecedence + comparatorPrecedence + ternaryPrecedence + logicalAndPrecedence + logicalOrPrecedence + separatePrecedence +) + +func findOperatorPrecedenceForSymbol(symbol OperatorSymbol) operatorPrecedence { + + switch symbol { + case NOOP: + return noopPrecedence + case VALUE: + return valuePrecedence + case EQ: + fallthrough + case NEQ: + fallthrough + case GT: + fallthrough + case LT: + fallthrough + case GTE: + fallthrough + case LTE: + fallthrough + case REQ: + fallthrough + case NREQ: + fallthrough + case IN: + return comparatorPrecedence + case AND: + return logicalAndPrecedence + case OR: + return logicalOrPrecedence + case BITWISE_AND: + fallthrough + case BITWISE_OR: + fallthrough + case BITWISE_XOR: + return bitwisePrecedence + case BITWISE_LSHIFT: + fallthrough + case BITWISE_RSHIFT: + return bitwiseShiftPrecedence + case PLUS: + fallthrough + case MINUS: + return additivePrecedence + case MULTIPLY: + fallthrough + case DIVIDE: + fallthrough + case MODULUS: + return multiplicativePrecedence + case EXPONENT: + return exponentialPrecedence + case BITWISE_NOT: + fallthrough + case NEGATE: + fallthrough + case INVERT: + return prefixPrecedence + case COALESCE: + fallthrough + case TERNARY_TRUE: + fallthrough + case TERNARY_FALSE: + return ternaryPrecedence + case ACCESS: + fallthrough + case FUNCTIONAL: + return functionalPrecedence + case SEPARATE: + return separatePrecedence + } + + return valuePrecedence +} + +/* + Map of all valid comparators, and their string equivalents. + Used during parsing of expressions to determine if a symbol is, in fact, a comparator. + Also used during evaluation to determine exactly which comparator is being used. +*/ +var comparatorSymbols = map[string]OperatorSymbol{ + "==": EQ, + "!=": NEQ, + ">": GT, + ">=": GTE, + "<": LT, + "<=": LTE, + "=~": REQ, + "!~": NREQ, + "in": IN, +} + +var logicalSymbols = map[string]OperatorSymbol{ + "&&": AND, + "||": OR, +} + +var bitwiseSymbols = map[string]OperatorSymbol{ + "^": BITWISE_XOR, + "&": BITWISE_AND, + "|": BITWISE_OR, +} + +var bitwiseShiftSymbols = map[string]OperatorSymbol{ + ">>": BITWISE_RSHIFT, + "<<": BITWISE_LSHIFT, +} + +var additiveSymbols = map[string]OperatorSymbol{ + "+": PLUS, + "-": MINUS, +} + +var multiplicativeSymbols = map[string]OperatorSymbol{ + "*": MULTIPLY, + "/": DIVIDE, + "%": MODULUS, +} + +var exponentialSymbolsS = map[string]OperatorSymbol{ + "**": EXPONENT, +} + +var prefixSymbols = map[string]OperatorSymbol{ + "-": NEGATE, + "!": INVERT, + "~": BITWISE_NOT, +} + +var ternarySymbols = map[string]OperatorSymbol{ + "?": TERNARY_TRUE, + ":": TERNARY_FALSE, + "??": COALESCE, +} + +// this is defined separately from additiveSymbols et al because it's needed for parsing, not stage planning. +var modifierSymbols = map[string]OperatorSymbol{ + "+": PLUS, + "-": MINUS, + "*": MULTIPLY, + "/": DIVIDE, + "%": MODULUS, + "**": EXPONENT, + "&": BITWISE_AND, + "|": BITWISE_OR, + "^": BITWISE_XOR, + ">>": BITWISE_RSHIFT, + "<<": BITWISE_LSHIFT, +} + +var separatorSymbols = map[string]OperatorSymbol{ + ",": SEPARATE, +} + +/* + Returns true if this operator is contained by the given array of candidate symbols. + False otherwise. +*/ +func (this OperatorSymbol) IsModifierType(candidate []OperatorSymbol) bool { + + for _, symbolType := range candidate { + if this == symbolType { + return true + } + } + + return false +} + +/* + Generally used when formatting type check errors. + We could store the stringified symbol somewhere else and not require a duplicated codeblock to translate + OperatorSymbol to string, but that would require more memory, and another field somewhere. + Adding operators is rare enough that we just stringify it here instead. +*/ +func (this OperatorSymbol) String() string { + + switch this { + case NOOP: + return "NOOP" + case VALUE: + return "VALUE" + case EQ: + return "=" + case NEQ: + return "!=" + case GT: + return ">" + case LT: + return "<" + case GTE: + return ">=" + case LTE: + return "<=" + case REQ: + return "=~" + case NREQ: + return "!~" + case AND: + return "&&" + case OR: + return "||" + case IN: + return "in" + case BITWISE_AND: + return "&" + case BITWISE_OR: + return "|" + case BITWISE_XOR: + return "^" + case BITWISE_LSHIFT: + return "<<" + case BITWISE_RSHIFT: + return ">>" + case PLUS: + return "+" + case MINUS: + return "-" + case MULTIPLY: + return "*" + case DIVIDE: + return "/" + case MODULUS: + return "%" + case EXPONENT: + return "**" + case NEGATE: + return "-" + case INVERT: + return "!" + case BITWISE_NOT: + return "~" + case TERNARY_TRUE: + return "?" + case TERNARY_FALSE: + return ":" + case COALESCE: + return "??" + } + return "" +} diff --git a/vendor/github.com/Knetic/govaluate/README.md b/vendor/github.com/Knetic/govaluate/README.md new file mode 100644 index 00000000..2e5716d4 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/README.md @@ -0,0 +1,233 @@ +govaluate +==== + +[![Build Status](https://travis-ci.org/Knetic/govaluate.svg?branch=master)](https://travis-ci.org/Knetic/govaluate) +[![Godoc](https://img.shields.io/badge/godoc-reference-5272B4.svg)](https://godoc.org/github.com/Knetic/govaluate) +[![Go Report Card](https://goreportcard.com/badge/github.com/Knetic/govaluate)](https://goreportcard.com/report/github.com/Knetic/govaluate) +[![Gocover](https://gocover.io/_badge/github.com/Knetic/govaluate)](https://gocover.io/github.com/Knetic/govaluate) + +Provides support for evaluating arbitrary C-like artithmetic/string expressions. + +Why can't you just write these expressions in code? +-- + +Sometimes, you can't know ahead-of-time what an expression will look like, or you want those expressions to be configurable. +Perhaps you've got a set of data running through your application, and you want to allow your users to specify some validations to run on it before committing it to a database. Or maybe you've written a monitoring framework which is capable of gathering a bunch of metrics, then evaluating a few expressions to see if any metrics should be alerted upon, but the conditions for alerting are different for each monitor. + +A lot of people wind up writing their own half-baked style of evaluation language that fits their needs, but isn't complete. Or they wind up baking the expression into the actual executable, even if they know it's subject to change. These strategies may work, but they take time to implement, time for users to learn, and induce technical debt as requirements change. This library is meant to cover all the normal C-like expressions, so that you don't have to reinvent one of the oldest wheels on a computer. + +How do I use it? +-- + +You create a new EvaluableExpression, then call "Evaluate" on it. + +```go + expression, err := govaluate.NewEvaluableExpression("10 > 0"); + result, err := expression.Evaluate(nil); + // result is now set to "true", the bool value. +``` + +Cool, but how about with parameters? + +```go + expression, err := govaluate.NewEvaluableExpression("foo > 0"); + + parameters := make(map[string]interface{}, 8) + parameters["foo"] = -1; + + result, err := expression.Evaluate(parameters); + // result is now set to "false", the bool value. +``` + +That's cool, but we can almost certainly have done all that in code. What about a complex use case that involves some math? + +```go + expression, err := govaluate.NewEvaluableExpression("(requests_made * requests_succeeded / 100) >= 90"); + + parameters := make(map[string]interface{}, 8) + parameters["requests_made"] = 100; + parameters["requests_succeeded"] = 80; + + result, err := expression.Evaluate(parameters); + // result is now set to "false", the bool value. +``` + +Or maybe you want to check the status of an alive check ("smoketest") page, which will be a string? + +```go + expression, err := govaluate.NewEvaluableExpression("http_response_body == 'service is ok'"); + + parameters := make(map[string]interface{}, 8) + parameters["http_response_body"] = "service is ok"; + + result, err := expression.Evaluate(parameters); + // result is now set to "true", the bool value. +``` + +These examples have all returned boolean values, but it's equally possible to return numeric ones. + +```go + expression, err := govaluate.NewEvaluableExpression("(mem_used / total_mem) * 100"); + + parameters := make(map[string]interface{}, 8) + parameters["total_mem"] = 1024; + parameters["mem_used"] = 512; + + result, err := expression.Evaluate(parameters); + // result is now set to "50.0", the float64 value. +``` + +You can also do date parsing, though the formats are somewhat limited. Stick to RF3339, ISO8061, unix date, or ruby date formats. If you're having trouble getting a date string to parse, check the list of formats actually used: [parsing.go:248](https://github.com/Knetic/govaluate/blob/0580e9b47a69125afa0e4ebd1cf93c49eb5a43ec/parsing.go#L258). + +```go + expression, err := govaluate.NewEvaluableExpression("'2014-01-02' > '2014-01-01 23:59:59'"); + result, err := expression.Evaluate(nil); + + // result is now set to true +``` + +Expressions are parsed once, and can be re-used multiple times. Parsing is the compute-intensive phase of the process, so if you intend to use the same expression with different parameters, just parse it once. Like so; + +```go + expression, err := govaluate.NewEvaluableExpression("response_time <= 100"); + parameters := make(map[string]interface{}, 8) + + for { + parameters["response_time"] = pingSomething(); + result, err := expression.Evaluate(parameters) + } +``` + +The normal C-standard order of operators is respected. When writing an expression, be sure that you either order the operators correctly, or use parenthesis to clarify which portions of an expression should be run first. + +Escaping characters +-- + +Sometimes you'll have parameters that have spaces, slashes, pluses, ampersands or some other character +that this library interprets as something special. For example, the following expression will not +act as one might expect: + + "response-time < 100" + +As written, the library will parse it as "[response] minus [time] is less than 100". In reality, +"response-time" is meant to be one variable that just happens to have a dash in it. + +There are two ways to work around this. First, you can escape the entire parameter name: + + "[response-time] < 100" + +Or you can use backslashes to escape only the minus sign. + + "response\\-time < 100" + +Backslashes can be used anywhere in an expression to escape the very next character. Square bracketed parameter names can be used instead of plain parameter names at any time. + +Functions +-- + +You may have cases where you want to call a function on a parameter during execution of the expression. Perhaps you want to aggregate some set of data, but don't know the exact aggregation you want to use until you're writing the expression itself. Or maybe you have a mathematical operation you want to perform, for which there is no operator; like `log` or `tan` or `sqrt`. For cases like this, you can provide a map of functions to `NewEvaluableExpressionWithFunctions`, which will then be able to use them during execution. For instance; + +```go + functions := map[string]govaluate.ExpressionFunction { + "strlen": func(args ...interface{}) (interface{}, error) { + length := len(args[0].(string)) + return (float64)(length), nil + }, + } + + expString := "strlen('someReallyLongInputString') <= 16" + expression, _ := govaluate.NewEvaluableExpressionWithFunctions(expString, functions) + + result, _ := expression.Evaluate(nil) + // result is now "false", the boolean value +``` + +Functions can accept any number of arguments, correctly handles nested functions, and arguments can be of any type (even if none of this library's operators support evaluation of that type). For instance, each of these usages of functions in an expression are valid (assuming that the appropriate functions and parameters are given): + +```go +"sqrt(x1 ** y1, x2 ** y2)" +"max(someValue, abs(anotherValue), 10 * lastValue)" +``` + +Functions cannot be passed as parameters, they must be known at the time when the expression is parsed, and are unchangeable after parsing. + +Accessors +-- + +If you have structs in your parameters, you can access their fields and methods in the usual way. For instance, given a struct that has a method "Echo", present in the parameters as `foo`, the following is valid: + + "foo.Echo('hello world')" + +Fields are accessed in a similar way. Assuming `foo` has a field called "Length": + + "foo.Length > 9000" + +Accessors can be nested to any depth, like the following + + "foo.Bar.Baz.SomeFunction()" + +However it is not _currently_ supported to access values in `map`s. So the following will not work + + "foo.SomeMap['key']" + +This may be convenient, but note that using accessors involves a _lot_ of reflection. This makes the expression about four times slower than just using a parameter (consult the benchmarks for more precise measurements on your system). +If at all reasonable, the author recommends extracting the values you care about into a parameter map beforehand, or defining a struct that implements the `Parameters` interface, and which grabs fields as required. If there are functions you want to use, it's better to pass them as expression functions (see the above section). These approaches use no reflection, and are designed to be fast and clean. + +What operators and types does this support? +-- + +* Modifiers: `+` `-` `/` `*` `&` `|` `^` `**` `%` `>>` `<<` +* Comparators: `>` `>=` `<` `<=` `==` `!=` `=~` `!~` +* Logical ops: `||` `&&` +* Numeric constants, as 64-bit floating point (`12345.678`) +* String constants (single quotes: `'foobar'`) +* Date constants (single quotes, using any permutation of RFC3339, ISO8601, ruby date, or unix date; date parsing is automatically tried with any string constant) +* Boolean constants: `true` `false` +* Parenthesis to control order of evaluation `(` `)` +* Arrays (anything separated by `,` within parenthesis: `(1, 2, 'foo')`) +* Prefixes: `!` `-` `~` +* Ternary conditional: `?` `:` +* Null coalescence: `??` + +See [MANUAL.md](https://github.com/Knetic/govaluate/blob/master/MANUAL.md) for exacting details on what types each operator supports. + +Types +-- + +Some operators don't make sense when used with some types. For instance, what does it mean to get the modulo of a string? What happens if you check to see if two numbers are logically AND'ed together? + +Everyone has a different intuition about the answers to these questions. To prevent confusion, this library will _refuse to operate_ upon types for which there is not an unambiguous meaning for the operation. See [MANUAL.md](https://github.com/Knetic/govaluate/blob/master/MANUAL.md) for details about what operators are valid for which types. + +Benchmarks +-- + +If you're concerned about the overhead of this library, a good range of benchmarks are built into this repo. You can run them with `go test -bench=.`. The library is built with an eye towards being quick, but has not been aggressively profiled and optimized. For most applications, though, it is completely fine. + +For a very rough idea of performance, here are the results output from a benchmark run on a 3rd-gen Macbook Pro (Linux Mint 17.1). + +``` +BenchmarkSingleParse-12 1000000 1382 ns/op +BenchmarkSimpleParse-12 200000 10771 ns/op +BenchmarkFullParse-12 30000 49383 ns/op +BenchmarkEvaluationSingle-12 50000000 30.1 ns/op +BenchmarkEvaluationNumericLiteral-12 10000000 119 ns/op +BenchmarkEvaluationLiteralModifiers-12 10000000 236 ns/op +BenchmarkEvaluationParameters-12 5000000 260 ns/op +BenchmarkEvaluationParametersModifiers-12 3000000 547 ns/op +BenchmarkComplexExpression-12 2000000 963 ns/op +BenchmarkRegexExpression-12 100000 20357 ns/op +BenchmarkConstantRegexExpression-12 1000000 1392 ns/op +ok +``` + +API Breaks +-- + +While this library has very few cases which will ever result in an API break, it can (and [has](https://github.com/Knetic/govaluate/releases/tag/v2.0.0)) happened. If you are using this in production, vendor the commit you've tested against, or use gopkg.in to redirect your import (e.g., `import "gopkg.in/Knetic/govaluate.v2"`). Master branch (while infrequent) _may_ at some point contain API breaking changes, and the author will have no way to communicate these to downstreams, other than creating a new major release. + +Releases will explicitly state when an API break happens, and if they do not specify an API break it should be safe to upgrade. + +License +-- + +This project is licensed under the MIT general use license. You're free to integrate, fork, and play with this code as you feel fit without consulting the author, as long as you provide proper credit to the author in your works. diff --git a/vendor/github.com/Knetic/govaluate/TokenKind.go b/vendor/github.com/Knetic/govaluate/TokenKind.go new file mode 100644 index 00000000..7c9516d2 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/TokenKind.go @@ -0,0 +1,75 @@ +package govaluate + +/* + Represents all valid types of tokens that a token can be. +*/ +type TokenKind int + +const ( + UNKNOWN TokenKind = iota + + PREFIX + NUMERIC + BOOLEAN + STRING + PATTERN + TIME + VARIABLE + FUNCTION + SEPARATOR + ACCESSOR + + COMPARATOR + LOGICALOP + MODIFIER + + CLAUSE + CLAUSE_CLOSE + + TERNARY +) + +/* + GetTokenKindString returns a string that describes the given TokenKind. + e.g., when passed the NUMERIC TokenKind, this returns the string "NUMERIC". +*/ +func (kind TokenKind) String() string { + + switch kind { + + case PREFIX: + return "PREFIX" + case NUMERIC: + return "NUMERIC" + case BOOLEAN: + return "BOOLEAN" + case STRING: + return "STRING" + case PATTERN: + return "PATTERN" + case TIME: + return "TIME" + case VARIABLE: + return "VARIABLE" + case FUNCTION: + return "FUNCTION" + case SEPARATOR: + return "SEPARATOR" + case COMPARATOR: + return "COMPARATOR" + case LOGICALOP: + return "LOGICALOP" + case MODIFIER: + return "MODIFIER" + case CLAUSE: + return "CLAUSE" + case CLAUSE_CLOSE: + return "CLAUSE_CLOSE" + case TERNARY: + return "TERNARY" + case ACCESSOR: + return "ACCESSOR" + } + + return "UNKNOWN" +} diff --git a/vendor/github.com/Knetic/govaluate/evaluationStage.go b/vendor/github.com/Knetic/govaluate/evaluationStage.go new file mode 100644 index 00000000..11ea5872 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/evaluationStage.go @@ -0,0 +1,516 @@ +package govaluate + +import ( + "errors" + "fmt" + "math" + "reflect" + "regexp" + "strings" +) + +const ( + logicalErrorFormat string = "Value '%v' cannot be used with the logical operator '%v', it is not a bool" + modifierErrorFormat string = "Value '%v' cannot be used with the modifier '%v', it is not a number" + comparatorErrorFormat string = "Value '%v' cannot be used with the comparator '%v', it is not a number" + ternaryErrorFormat string = "Value '%v' cannot be used with the ternary operator '%v', it is not a bool" + prefixErrorFormat string = "Value '%v' cannot be used with the prefix '%v'" +) + +type evaluationOperator func(left interface{}, right interface{}, parameters Parameters) (interface{}, error) +type stageTypeCheck func(value interface{}) bool +type stageCombinedTypeCheck func(left interface{}, right interface{}) bool + +type evaluationStage struct { + symbol OperatorSymbol + + leftStage, rightStage *evaluationStage + + // the operation that will be used to evaluate this stage (such as adding [left] to [right] and return the result) + operator evaluationOperator + + // ensures that both left and right values are appropriate for this stage. Returns an error if they aren't operable. + leftTypeCheck stageTypeCheck + rightTypeCheck stageTypeCheck + + // if specified, will override whatever is used in "leftTypeCheck" and "rightTypeCheck". + // primarily used for specific operators that don't care which side a given type is on, but still requires one side to be of a given type + // (like string concat) + typeCheck stageCombinedTypeCheck + + // regardless of which type check is used, this string format will be used as the error message for type errors + typeErrorFormat string +} + +var ( + _true = interface{}(true) + _false = interface{}(false) +) + +func (this *evaluationStage) swapWith(other *evaluationStage) { + + temp := *other + other.setToNonStage(*this) + this.setToNonStage(temp) +} + +func (this *evaluationStage) setToNonStage(other evaluationStage) { + + this.symbol = other.symbol + this.operator = other.operator + this.leftTypeCheck = other.leftTypeCheck + this.rightTypeCheck = other.rightTypeCheck + this.typeCheck = other.typeCheck + this.typeErrorFormat = other.typeErrorFormat +} + +func (this *evaluationStage) isShortCircuitable() bool { + + switch this.symbol { + case AND: + fallthrough + case OR: + fallthrough + case TERNARY_TRUE: + fallthrough + case TERNARY_FALSE: + fallthrough + case COALESCE: + return true + } + + return false +} + +func noopStageRight(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return right, nil +} + +func addStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + + // string concat if either are strings + if isString(left) || isString(right) { + return fmt.Sprintf("%v%v", left, right), nil + } + + return left.(float64) + right.(float64), nil +} +func subtractStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return left.(float64) - right.(float64), nil +} +func multiplyStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return left.(float64) * right.(float64), nil +} +func divideStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return left.(float64) / right.(float64), nil +} +func exponentStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return math.Pow(left.(float64), right.(float64)), nil +} +func modulusStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return math.Mod(left.(float64), right.(float64)), nil +} +func gteStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + if isString(left) && isString(right) { + return boolIface(left.(string) >= right.(string)), nil + } + return boolIface(left.(float64) >= right.(float64)), nil +} +func gtStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + if isString(left) && isString(right) { + return boolIface(left.(string) > right.(string)), nil + } + return boolIface(left.(float64) > right.(float64)), nil +} +func lteStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + if isString(left) && isString(right) { + return boolIface(left.(string) <= right.(string)), nil + } + return boolIface(left.(float64) <= right.(float64)), nil +} +func ltStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + if isString(left) && isString(right) { + return boolIface(left.(string) < right.(string)), nil + } + return boolIface(left.(float64) < right.(float64)), nil +} +func equalStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return boolIface(reflect.DeepEqual(left, right)), nil +} +func notEqualStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return boolIface(!reflect.DeepEqual(left, right)), nil +} +func andStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return boolIface(left.(bool) && right.(bool)), nil +} +func orStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return boolIface(left.(bool) || right.(bool)), nil +} +func negateStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return -right.(float64), nil +} +func invertStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return boolIface(!right.(bool)), nil +} +func bitwiseNotStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return float64(^int64(right.(float64))), nil +} +func ternaryIfStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + if left.(bool) { + return right, nil + } + return nil, nil +} +func ternaryElseStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + if left != nil { + return left, nil + } + return right, nil +} + +func regexStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + + var pattern *regexp.Regexp + var err error + + switch right.(type) { + case string: + pattern, err = regexp.Compile(right.(string)) + if err != nil { + return nil, errors.New(fmt.Sprintf("Unable to compile regexp pattern '%v': %v", right, err)) + } + case *regexp.Regexp: + pattern = right.(*regexp.Regexp) + } + + return pattern.Match([]byte(left.(string))), nil +} + +func notRegexStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + + ret, err := regexStage(left, right, parameters) + if err != nil { + return nil, err + } + + return !(ret.(bool)), nil +} + +func bitwiseOrStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return float64(int64(left.(float64)) | int64(right.(float64))), nil +} +func bitwiseAndStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return float64(int64(left.(float64)) & int64(right.(float64))), nil +} +func bitwiseXORStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return float64(int64(left.(float64)) ^ int64(right.(float64))), nil +} +func leftShiftStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return float64(uint64(left.(float64)) << uint64(right.(float64))), nil +} +func rightShiftStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return float64(uint64(left.(float64)) >> uint64(right.(float64))), nil +} + +func makeParameterStage(parameterName string) evaluationOperator { + + return func(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + value, err := parameters.Get(parameterName) + if err != nil { + return nil, err + } + + return value, nil + } +} + +func makeLiteralStage(literal interface{}) evaluationOperator { + return func(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return literal, nil + } +} + +func makeFunctionStage(function ExpressionFunction) evaluationOperator { + + return func(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + + if right == nil { + return function() + } + + switch right.(type) { + case []interface{}: + return function(right.([]interface{})...) + default: + return function(right) + } + } +} + +func typeConvertParam(p reflect.Value, t reflect.Type) (ret reflect.Value, err error) { + defer func() { + if r := recover(); r != nil { + errorMsg := fmt.Sprintf("Argument type conversion failed: failed to convert '%s' to '%s'", p.Kind().String(), t.Kind().String()) + err = errors.New(errorMsg) + ret = p + } + }() + + return p.Convert(t), nil +} + +func typeConvertParams(method reflect.Value, params []reflect.Value) ([]reflect.Value, error) { + + methodType := method.Type() + numIn := methodType.NumIn() + numParams := len(params) + + if numIn != numParams { + if numIn > numParams { + return nil, fmt.Errorf("Too few arguments to parameter call: got %d arguments, expected %d", len(params), numIn) + } + return nil, fmt.Errorf("Too many arguments to parameter call: got %d arguments, expected %d", len(params), numIn) + } + + for i := 0; i < numIn; i++ { + t := methodType.In(i) + p := params[i] + pt := p.Type() + + if t.Kind() != pt.Kind() { + np, err := typeConvertParam(p, t) + if err != nil { + return nil, err + } + params[i] = np + } + } + + return params, nil +} + +func makeAccessorStage(pair []string) evaluationOperator { + + reconstructed := strings.Join(pair, ".") + + return func(left interface{}, right interface{}, parameters Parameters) (ret interface{}, err error) { + + var params []reflect.Value + + value, err := parameters.Get(pair[0]) + if err != nil { + return nil, err + } + + // while this library generally tries to handle panic-inducing cases on its own, + // accessors are a sticky case which have a lot of possible ways to fail. + // therefore every call to an accessor sets up a defer that tries to recover from panics, converting them to errors. + defer func() { + if r := recover(); r != nil { + errorMsg := fmt.Sprintf("Failed to access '%s': %v", reconstructed, r.(string)) + err = errors.New(errorMsg) + ret = nil + } + }() + + for i := 1; i < len(pair); i++ { + + coreValue := reflect.ValueOf(value) + + var corePtrVal reflect.Value + + // if this is a pointer, resolve it. + if coreValue.Kind() == reflect.Ptr { + corePtrVal = coreValue + coreValue = coreValue.Elem() + } + + if coreValue.Kind() != reflect.Struct { + return nil, errors.New("Unable to access '" + pair[i] + "', '" + pair[i-1] + "' is not a struct") + } + + field := coreValue.FieldByName(pair[i]) + if field != (reflect.Value{}) { + value = field.Interface() + continue + } + + method := coreValue.MethodByName(pair[i]) + if method == (reflect.Value{}) { + if corePtrVal.IsValid() { + method = corePtrVal.MethodByName(pair[i]) + } + if method == (reflect.Value{}) { + return nil, errors.New("No method or field '" + pair[i] + "' present on parameter '" + pair[i-1] + "'") + } + } + + switch right.(type) { + case []interface{}: + + givenParams := right.([]interface{}) + params = make([]reflect.Value, len(givenParams)) + for idx, _ := range givenParams { + params[idx] = reflect.ValueOf(givenParams[idx]) + } + + default: + + if right == nil { + params = []reflect.Value{} + break + } + + params = []reflect.Value{reflect.ValueOf(right.(interface{}))} + } + + params, err = typeConvertParams(method, params) + + if err != nil { + return nil, errors.New("Method call failed - '" + pair[0] + "." + pair[1] + "': " + err.Error()) + } + + returned := method.Call(params) + retLength := len(returned) + + if retLength == 0 { + return nil, errors.New("Method call '" + pair[i-1] + "." + pair[i] + "' did not return any values.") + } + + if retLength == 1 { + + value = returned[0].Interface() + continue + } + + if retLength == 2 { + + errIface := returned[1].Interface() + err, validType := errIface.(error) + + if validType && errIface != nil { + return returned[0].Interface(), err + } + + value = returned[0].Interface() + continue + } + + return nil, errors.New("Method call '" + pair[0] + "." + pair[1] + "' did not return either one value, or a value and an error. Cannot interpret meaning.") + } + + value = castToFloat64(value) + return value, nil + } +} + +func separatorStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + + var ret []interface{} + + switch left.(type) { + case []interface{}: + ret = append(left.([]interface{}), right) + default: + ret = []interface{}{left, right} + } + + return ret, nil +} + +func inStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + + for _, value := range right.([]interface{}) { + if left == value { + return true, nil + } + } + return false, nil +} + +// + +func isString(value interface{}) bool { + + switch value.(type) { + case string: + return true + } + return false +} + +func isRegexOrString(value interface{}) bool { + + switch value.(type) { + case string: + return true + case *regexp.Regexp: + return true + } + return false +} + +func isBool(value interface{}) bool { + switch value.(type) { + case bool: + return true + } + return false +} + +func isFloat64(value interface{}) bool { + switch value.(type) { + case float64: + return true + } + return false +} + +/* + Addition usually means between numbers, but can also mean string concat. + String concat needs one (or both) of the sides to be a string. +*/ +func additionTypeCheck(left interface{}, right interface{}) bool { + + if isFloat64(left) && isFloat64(right) { + return true + } + if !isString(left) && !isString(right) { + return false + } + return true +} + +/* + Comparison can either be between numbers, or lexicographic between two strings, + but never between the two. +*/ +func comparatorTypeCheck(left interface{}, right interface{}) bool { + + if isFloat64(left) && isFloat64(right) { + return true + } + if isString(left) && isString(right) { + return true + } + return false +} + +func isArray(value interface{}) bool { + switch value.(type) { + case []interface{}: + return true + } + return false +} + +/* + Converting a boolean to an interface{} requires an allocation. + We can use interned bools to avoid this cost. +*/ +func boolIface(b bool) interface{} { + if b { + return _true + } + return _false +} diff --git a/vendor/github.com/Knetic/govaluate/expressionFunctions.go b/vendor/github.com/Knetic/govaluate/expressionFunctions.go new file mode 100644 index 00000000..ac6592b3 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/expressionFunctions.go @@ -0,0 +1,8 @@ +package govaluate + +/* + Represents a function that can be called from within an expression. + This method must return an error if, for any reason, it is unable to produce exactly one unambiguous result. + An error returned will halt execution of the expression. +*/ +type ExpressionFunction func(arguments ...interface{}) (interface{}, error) diff --git a/vendor/github.com/Knetic/govaluate/expressionOutputStream.go b/vendor/github.com/Knetic/govaluate/expressionOutputStream.go new file mode 100644 index 00000000..88a84163 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/expressionOutputStream.go @@ -0,0 +1,46 @@ +package govaluate + +import ( + "bytes" +) + +/* + Holds a series of "transactions" which represent each token as it is output by an outputter (such as ToSQLQuery()). + Some outputs (such as SQL) require a function call or non-c-like syntax to represent an expression. + To accomplish this, this struct keeps track of each translated token as it is output, and can return and rollback those transactions. +*/ +type expressionOutputStream struct { + transactions []string +} + +func (this *expressionOutputStream) add(transaction string) { + this.transactions = append(this.transactions, transaction) +} + +func (this *expressionOutputStream) rollback() string { + + index := len(this.transactions) - 1 + ret := this.transactions[index] + + this.transactions = this.transactions[:index] + return ret +} + +func (this *expressionOutputStream) createString(delimiter string) string { + + var retBuffer bytes.Buffer + var transaction string + + penultimate := len(this.transactions) - 1 + + for i := 0; i < penultimate; i++ { + + transaction = this.transactions[i] + + retBuffer.WriteString(transaction) + retBuffer.WriteString(delimiter) + } + retBuffer.WriteString(this.transactions[penultimate]) + + return retBuffer.String() +} diff --git a/vendor/github.com/Knetic/govaluate/lexerState.go b/vendor/github.com/Knetic/govaluate/lexerState.go new file mode 100644 index 00000000..6726e909 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/lexerState.go @@ -0,0 +1,373 @@ +package govaluate + +import ( + "errors" + "fmt" +) + +type lexerState struct { + isEOF bool + isNullable bool + kind TokenKind + validNextKinds []TokenKind +} + +// lexer states. +// Constant for all purposes except compiler. +var validLexerStates = []lexerState{ + + lexerState{ + kind: UNKNOWN, + isEOF: false, + isNullable: true, + validNextKinds: []TokenKind{ + + PREFIX, + NUMERIC, + BOOLEAN, + VARIABLE, + PATTERN, + FUNCTION, + ACCESSOR, + STRING, + TIME, + CLAUSE, + }, + }, + + lexerState{ + + kind: CLAUSE, + isEOF: false, + isNullable: true, + validNextKinds: []TokenKind{ + + PREFIX, + NUMERIC, + BOOLEAN, + VARIABLE, + PATTERN, + FUNCTION, + ACCESSOR, + STRING, + TIME, + CLAUSE, + CLAUSE_CLOSE, + }, + }, + + lexerState{ + + kind: CLAUSE_CLOSE, + isEOF: true, + isNullable: true, + validNextKinds: []TokenKind{ + + COMPARATOR, + MODIFIER, + NUMERIC, + BOOLEAN, + VARIABLE, + STRING, + PATTERN, + TIME, + CLAUSE, + CLAUSE_CLOSE, + LOGICALOP, + TERNARY, + SEPARATOR, + }, + }, + + lexerState{ + + kind: NUMERIC, + isEOF: true, + isNullable: false, + validNextKinds: []TokenKind{ + + MODIFIER, + COMPARATOR, + LOGICALOP, + CLAUSE_CLOSE, + TERNARY, + SEPARATOR, + }, + }, + lexerState{ + + kind: BOOLEAN, + isEOF: true, + isNullable: false, + validNextKinds: []TokenKind{ + + MODIFIER, + COMPARATOR, + LOGICALOP, + CLAUSE_CLOSE, + TERNARY, + SEPARATOR, + }, + }, + lexerState{ + + kind: STRING, + isEOF: true, + isNullable: false, + validNextKinds: []TokenKind{ + + MODIFIER, + COMPARATOR, + LOGICALOP, + CLAUSE_CLOSE, + TERNARY, + SEPARATOR, + }, + }, + lexerState{ + + kind: TIME, + isEOF: true, + isNullable: false, + validNextKinds: []TokenKind{ + + MODIFIER, + COMPARATOR, + LOGICALOP, + CLAUSE_CLOSE, + SEPARATOR, + }, + }, + lexerState{ + + kind: PATTERN, + isEOF: true, + isNullable: false, + validNextKinds: []TokenKind{ + + MODIFIER, + COMPARATOR, + LOGICALOP, + CLAUSE_CLOSE, + SEPARATOR, + }, + }, + lexerState{ + + kind: VARIABLE, + isEOF: true, + isNullable: false, + validNextKinds: []TokenKind{ + + MODIFIER, + COMPARATOR, + LOGICALOP, + CLAUSE_CLOSE, + TERNARY, + SEPARATOR, + }, + }, + lexerState{ + + kind: MODIFIER, + isEOF: false, + isNullable: false, + validNextKinds: []TokenKind{ + + PREFIX, + NUMERIC, + VARIABLE, + FUNCTION, + ACCESSOR, + STRING, + BOOLEAN, + CLAUSE, + CLAUSE_CLOSE, + }, + }, + lexerState{ + + kind: COMPARATOR, + isEOF: false, + isNullable: false, + validNextKinds: []TokenKind{ + + PREFIX, + NUMERIC, + BOOLEAN, + VARIABLE, + FUNCTION, + ACCESSOR, + STRING, + TIME, + CLAUSE, + CLAUSE_CLOSE, + PATTERN, + }, + }, + lexerState{ + + kind: LOGICALOP, + isEOF: false, + isNullable: false, + validNextKinds: []TokenKind{ + + PREFIX, + NUMERIC, + BOOLEAN, + VARIABLE, + FUNCTION, + ACCESSOR, + STRING, + TIME, + CLAUSE, + CLAUSE_CLOSE, + }, + }, + lexerState{ + + kind: PREFIX, + isEOF: false, + isNullable: false, + validNextKinds: []TokenKind{ + + NUMERIC, + BOOLEAN, + VARIABLE, + FUNCTION, + ACCESSOR, + CLAUSE, + CLAUSE_CLOSE, + }, + }, + + lexerState{ + + kind: TERNARY, + isEOF: false, + isNullable: false, + validNextKinds: []TokenKind{ + + PREFIX, + NUMERIC, + BOOLEAN, + STRING, + TIME, + VARIABLE, + FUNCTION, + ACCESSOR, + CLAUSE, + SEPARATOR, + }, + }, + lexerState{ + + kind: FUNCTION, + isEOF: false, + isNullable: false, + validNextKinds: []TokenKind{ + CLAUSE, + }, + }, + lexerState{ + + kind: ACCESSOR, + isEOF: true, + isNullable: false, + validNextKinds: []TokenKind{ + CLAUSE, + MODIFIER, + COMPARATOR, + LOGICALOP, + CLAUSE_CLOSE, + TERNARY, + SEPARATOR, + }, + }, + lexerState{ + + kind: SEPARATOR, + isEOF: false, + isNullable: true, + validNextKinds: []TokenKind{ + + PREFIX, + NUMERIC, + BOOLEAN, + STRING, + TIME, + VARIABLE, + FUNCTION, + ACCESSOR, + CLAUSE, + }, + }, +} + +func (this lexerState) canTransitionTo(kind TokenKind) bool { + + for _, validKind := range this.validNextKinds { + + if validKind == kind { + return true + } + } + + return false +} + +func checkExpressionSyntax(tokens []ExpressionToken) error { + + var state lexerState + var lastToken ExpressionToken + var err error + + state = validLexerStates[0] + + for _, token := range tokens { + + if !state.canTransitionTo(token.Kind) { + + // call out a specific error for tokens looking like they want to be functions. + if lastToken.Kind == VARIABLE && token.Kind == CLAUSE { + return errors.New("Undefined function " + lastToken.Value.(string)) + } + + firstStateName := fmt.Sprintf("%s [%v]", state.kind.String(), lastToken.Value) + nextStateName := fmt.Sprintf("%s [%v]", token.Kind.String(), token.Value) + + return errors.New("Cannot transition token types from " + firstStateName + " to " + nextStateName) + } + + state, err = getLexerStateForToken(token.Kind) + if err != nil { + return err + } + + if !state.isNullable && token.Value == nil { + + errorMsg := fmt.Sprintf("Token kind '%v' cannot have a nil value", token.Kind.String()) + return errors.New(errorMsg) + } + + lastToken = token + } + + if !state.isEOF { + return errors.New("Unexpected end of expression") + } + return nil +} + +func getLexerStateForToken(kind TokenKind) (lexerState, error) { + + for _, possibleState := range validLexerStates { + + if possibleState.kind == kind { + return possibleState, nil + } + } + + errorMsg := fmt.Sprintf("No lexer state found for token kind '%v'\n", kind.String()) + return validLexerStates[0], errors.New(errorMsg) +} diff --git a/vendor/github.com/Knetic/govaluate/lexerStream.go b/vendor/github.com/Knetic/govaluate/lexerStream.go new file mode 100644 index 00000000..b72e6bdb --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/lexerStream.go @@ -0,0 +1,39 @@ +package govaluate + +type lexerStream struct { + source []rune + position int + length int +} + +func newLexerStream(source string) *lexerStream { + + var ret *lexerStream + var runes []rune + + for _, character := range source { + runes = append(runes, character) + } + + ret = new(lexerStream) + ret.source = runes + ret.length = len(runes) + return ret +} + +func (this *lexerStream) readCharacter() rune { + + var character rune + + character = this.source[this.position] + this.position += 1 + return character +} + +func (this *lexerStream) rewind(amount int) { + this.position -= amount +} + +func (this lexerStream) canRead() bool { + return this.position < this.length +} diff --git a/vendor/github.com/Knetic/govaluate/parameters.go b/vendor/github.com/Knetic/govaluate/parameters.go new file mode 100644 index 00000000..6c5b9ecb --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/parameters.go @@ -0,0 +1,32 @@ +package govaluate + +import ( + "errors" +) + +/* + Parameters is a collection of named parameters that can be used by an EvaluableExpression to retrieve parameters + when an expression tries to use them. +*/ +type Parameters interface { + + /* + Get gets the parameter of the given name, or an error if the parameter is unavailable. + Failure to find the given parameter should be indicated by returning an error. + */ + Get(name string) (interface{}, error) +} + +type MapParameters map[string]interface{} + +func (p MapParameters) Get(name string) (interface{}, error) { + + value, found := p[name] + + if !found { + errorMessage := "No parameter '" + name + "' found." + return nil, errors.New(errorMessage) + } + + return value, nil +} diff --git a/vendor/github.com/Knetic/govaluate/parsing.go b/vendor/github.com/Knetic/govaluate/parsing.go new file mode 100644 index 00000000..40c7ed2c --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/parsing.go @@ -0,0 +1,526 @@ +package govaluate + +import ( + "bytes" + "errors" + "fmt" + "regexp" + "strconv" + "strings" + "time" + "unicode" +) + +func parseTokens(expression string, functions map[string]ExpressionFunction) ([]ExpressionToken, error) { + + var ret []ExpressionToken + var token ExpressionToken + var stream *lexerStream + var state lexerState + var err error + var found bool + + stream = newLexerStream(expression) + state = validLexerStates[0] + + for stream.canRead() { + + token, err, found = readToken(stream, state, functions) + + if err != nil { + return ret, err + } + + if !found { + break + } + + state, err = getLexerStateForToken(token.Kind) + if err != nil { + return ret, err + } + + // append this valid token + ret = append(ret, token) + } + + err = checkBalance(ret) + if err != nil { + return nil, err + } + + return ret, nil +} + +func readToken(stream *lexerStream, state lexerState, functions map[string]ExpressionFunction) (ExpressionToken, error, bool) { + + var function ExpressionFunction + var ret ExpressionToken + var tokenValue interface{} + var tokenTime time.Time + var tokenString string + var kind TokenKind + var character rune + var found bool + var completed bool + var err error + + // numeric is 0-9, or . or 0x followed by digits + // string starts with ' + // variable is alphanumeric, always starts with a letter + // bracket always means variable + // symbols are anything non-alphanumeric + // all others read into a buffer until they reach the end of the stream + for stream.canRead() { + + character = stream.readCharacter() + + if unicode.IsSpace(character) { + continue + } + + kind = UNKNOWN + + // numeric constant + if isNumeric(character) { + + if stream.canRead() && character == '0' { + character = stream.readCharacter() + + if stream.canRead() && character == 'x' { + tokenString, _ = readUntilFalse(stream, false, true, true, isHexDigit) + tokenValueInt, err := strconv.ParseUint(tokenString, 16, 64) + + if err != nil { + errorMsg := fmt.Sprintf("Unable to parse hex value '%v' to uint64\n", tokenString) + return ExpressionToken{}, errors.New(errorMsg), false + } + + kind = NUMERIC + tokenValue = float64(tokenValueInt) + break + } else { + stream.rewind(1) + } + } + + tokenString = readTokenUntilFalse(stream, isNumeric) + tokenValue, err = strconv.ParseFloat(tokenString, 64) + + if err != nil { + errorMsg := fmt.Sprintf("Unable to parse numeric value '%v' to float64\n", tokenString) + return ExpressionToken{}, errors.New(errorMsg), false + } + kind = NUMERIC + break + } + + // comma, separator + if character == ',' { + + tokenValue = "," + kind = SEPARATOR + break + } + + // escaped variable + if character == '[' { + + tokenValue, completed = readUntilFalse(stream, true, false, true, isNotClosingBracket) + kind = VARIABLE + + if !completed { + return ExpressionToken{}, errors.New("Unclosed parameter bracket"), false + } + + // above method normally rewinds us to the closing bracket, which we want to skip. + stream.rewind(-1) + break + } + + // regular variable - or function? + if unicode.IsLetter(character) { + + tokenString = readTokenUntilFalse(stream, isVariableName) + + tokenValue = tokenString + kind = VARIABLE + + // boolean? + if tokenValue == "true" { + + kind = BOOLEAN + tokenValue = true + } else { + + if tokenValue == "false" { + + kind = BOOLEAN + tokenValue = false + } + } + + // textual operator? + if tokenValue == "in" || tokenValue == "IN" { + + // force lower case for consistency + tokenValue = "in" + kind = COMPARATOR + } + + // function? + function, found = functions[tokenString] + if found { + kind = FUNCTION + tokenValue = function + } + + // accessor? + accessorIndex := strings.Index(tokenString, ".") + if accessorIndex > 0 { + + // check that it doesn't end with a hanging period + if tokenString[len(tokenString)-1] == '.' { + errorMsg := fmt.Sprintf("Hanging accessor on token '%s'", tokenString) + return ExpressionToken{}, errors.New(errorMsg), false + } + + kind = ACCESSOR + splits := strings.Split(tokenString, ".") + tokenValue = splits + + // check that none of them are unexported + for i := 1; i < len(splits); i++ { + + firstCharacter := getFirstRune(splits[i]) + + if unicode.ToUpper(firstCharacter) != firstCharacter { + errorMsg := fmt.Sprintf("Unable to access unexported field '%s' in token '%s'", splits[i], tokenString) + return ExpressionToken{}, errors.New(errorMsg), false + } + } + } + break + } + + if !isNotQuote(character) { + tokenValue, completed = readUntilFalse(stream, true, false, true, isNotQuote) + + if !completed { + return ExpressionToken{}, errors.New("Unclosed string literal"), false + } + + // advance the stream one position, since reading until false assumes the terminator is a real token + stream.rewind(-1) + + // check to see if this can be parsed as a time. + tokenTime, found = tryParseTime(tokenValue.(string)) + if found { + kind = TIME + tokenValue = tokenTime + } else { + kind = STRING + } + break + } + + if character == '(' { + tokenValue = character + kind = CLAUSE + break + } + + if character == ')' { + tokenValue = character + kind = CLAUSE_CLOSE + break + } + + // must be a known symbol + tokenString = readTokenUntilFalse(stream, isNotAlphanumeric) + tokenValue = tokenString + + // quick hack for the case where "-" can mean "prefixed negation" or "minus", which are used + // very differently. + if state.canTransitionTo(PREFIX) { + _, found = prefixSymbols[tokenString] + if found { + + kind = PREFIX + break + } + } + _, found = modifierSymbols[tokenString] + if found { + + kind = MODIFIER + break + } + + _, found = logicalSymbols[tokenString] + if found { + + kind = LOGICALOP + break + } + + _, found = comparatorSymbols[tokenString] + if found { + + kind = COMPARATOR + break + } + + _, found = ternarySymbols[tokenString] + if found { + + kind = TERNARY + break + } + + errorMessage := fmt.Sprintf("Invalid token: '%s'", tokenString) + return ret, errors.New(errorMessage), false + } + + ret.Kind = kind + ret.Value = tokenValue + + return ret, nil, (kind != UNKNOWN) +} + +func readTokenUntilFalse(stream *lexerStream, condition func(rune) bool) string { + + var ret string + + stream.rewind(1) + ret, _ = readUntilFalse(stream, false, true, true, condition) + return ret +} + +/* + Returns the string that was read until the given [condition] was false, or whitespace was broken. + Returns false if the stream ended before whitespace was broken or condition was met. +*/ +func readUntilFalse(stream *lexerStream, includeWhitespace bool, breakWhitespace bool, allowEscaping bool, condition func(rune) bool) (string, bool) { + + var tokenBuffer bytes.Buffer + var character rune + var conditioned bool + + conditioned = false + + for stream.canRead() { + + character = stream.readCharacter() + + // Use backslashes to escape anything + if allowEscaping && character == '\\' { + + character = stream.readCharacter() + tokenBuffer.WriteString(string(character)) + continue + } + + if unicode.IsSpace(character) { + + if breakWhitespace && tokenBuffer.Len() > 0 { + conditioned = true + break + } + if !includeWhitespace { + continue + } + } + + if condition(character) { + tokenBuffer.WriteString(string(character)) + } else { + conditioned = true + stream.rewind(1) + break + } + } + + return tokenBuffer.String(), conditioned +} + +/* + Checks to see if any optimizations can be performed on the given [tokens], which form a complete, valid expression. + The returns slice will represent the optimized (or unmodified) list of tokens to use. +*/ +func optimizeTokens(tokens []ExpressionToken) ([]ExpressionToken, error) { + + var token ExpressionToken + var symbol OperatorSymbol + var err error + var index int + + for index, token = range tokens { + + // if we find a regex operator, and the right-hand value is a constant, precompile and replace with a pattern. + if token.Kind != COMPARATOR { + continue + } + + symbol = comparatorSymbols[token.Value.(string)] + if symbol != REQ && symbol != NREQ { + continue + } + + index++ + token = tokens[index] + if token.Kind == STRING { + + token.Kind = PATTERN + token.Value, err = regexp.Compile(token.Value.(string)) + + if err != nil { + return tokens, err + } + + tokens[index] = token + } + } + return tokens, nil +} + +/* + Checks the balance of tokens which have multiple parts, such as parenthesis. +*/ +func checkBalance(tokens []ExpressionToken) error { + + var stream *tokenStream + var token ExpressionToken + var parens int + + stream = newTokenStream(tokens) + + for stream.hasNext() { + + token = stream.next() + if token.Kind == CLAUSE { + parens++ + continue + } + if token.Kind == CLAUSE_CLOSE { + parens-- + continue + } + } + + if parens != 0 { + return errors.New("Unbalanced parenthesis") + } + return nil +} + +func isDigit(character rune) bool { + return unicode.IsDigit(character) +} + +func isHexDigit(character rune) bool { + + character = unicode.ToLower(character) + + return unicode.IsDigit(character) || + character == 'a' || + character == 'b' || + character == 'c' || + character == 'd' || + character == 'e' || + character == 'f' +} + +func isNumeric(character rune) bool { + + return unicode.IsDigit(character) || character == '.' +} + +func isNotQuote(character rune) bool { + + return character != '\'' && character != '"' +} + +func isNotAlphanumeric(character rune) bool { + + return !(unicode.IsDigit(character) || + unicode.IsLetter(character) || + character == '(' || + character == ')' || + character == '[' || + character == ']' || // starting to feel like there needs to be an `isOperation` func (#59) + !isNotQuote(character)) +} + +func isVariableName(character rune) bool { + + return unicode.IsLetter(character) || + unicode.IsDigit(character) || + character == '_' || + character == '.' +} + +func isNotClosingBracket(character rune) bool { + + return character != ']' +} + +/* + Attempts to parse the [candidate] as a Time. + Tries a series of standardized date formats, returns the Time if one applies, + otherwise returns false through the second return. +*/ +func tryParseTime(candidate string) (time.Time, bool) { + + var ret time.Time + var found bool + + timeFormats := [...]string{ + time.ANSIC, + time.UnixDate, + time.RubyDate, + time.Kitchen, + time.RFC3339, + time.RFC3339Nano, + "2006-01-02", // RFC 3339 + "2006-01-02 15:04", // RFC 3339 with minutes + "2006-01-02 15:04:05", // RFC 3339 with seconds + "2006-01-02 15:04:05-07:00", // RFC 3339 with seconds and timezone + "2006-01-02T15Z0700", // ISO8601 with hour + "2006-01-02T15:04Z0700", // ISO8601 with minutes + "2006-01-02T15:04:05Z0700", // ISO8601 with seconds + "2006-01-02T15:04:05.999999999Z0700", // ISO8601 with nanoseconds + } + + for _, format := range timeFormats { + + ret, found = tryParseExactTime(candidate, format) + if found { + return ret, true + } + } + + return time.Now(), false +} + +func tryParseExactTime(candidate string, format string) (time.Time, bool) { + + var ret time.Time + var err error + + ret, err = time.ParseInLocation(format, candidate, time.Local) + if err != nil { + return time.Now(), false + } + + return ret, true +} + +func getFirstRune(candidate string) rune { + + for _, character := range candidate { + return character + } + + return 0 +} diff --git a/vendor/github.com/Knetic/govaluate/sanitizedParameters.go b/vendor/github.com/Knetic/govaluate/sanitizedParameters.go new file mode 100644 index 00000000..28bd795d --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/sanitizedParameters.go @@ -0,0 +1,43 @@ +package govaluate + +// sanitizedParameters is a wrapper for Parameters that does sanitization as +// parameters are accessed. +type sanitizedParameters struct { + orig Parameters +} + +func (p sanitizedParameters) Get(key string) (interface{}, error) { + value, err := p.orig.Get(key) + if err != nil { + return nil, err + } + + return castToFloat64(value), nil +} + +func castToFloat64(value interface{}) interface{} { + switch value.(type) { + case uint8: + return float64(value.(uint8)) + case uint16: + return float64(value.(uint16)) + case uint32: + return float64(value.(uint32)) + case uint64: + return float64(value.(uint64)) + case int8: + return float64(value.(int8)) + case int16: + return float64(value.(int16)) + case int32: + return float64(value.(int32)) + case int64: + return float64(value.(int64)) + case int: + return float64(value.(int)) + case float32: + return float64(value.(float32)) + } + + return value +} diff --git a/vendor/github.com/Knetic/govaluate/stagePlanner.go b/vendor/github.com/Knetic/govaluate/stagePlanner.go new file mode 100644 index 00000000..d71ed129 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/stagePlanner.go @@ -0,0 +1,724 @@ +package govaluate + +import ( + "errors" + "fmt" + "time" +) + +var stageSymbolMap = map[OperatorSymbol]evaluationOperator{ + EQ: equalStage, + NEQ: notEqualStage, + GT: gtStage, + LT: ltStage, + GTE: gteStage, + LTE: lteStage, + REQ: regexStage, + NREQ: notRegexStage, + AND: andStage, + OR: orStage, + IN: inStage, + BITWISE_OR: bitwiseOrStage, + BITWISE_AND: bitwiseAndStage, + BITWISE_XOR: bitwiseXORStage, + BITWISE_LSHIFT: leftShiftStage, + BITWISE_RSHIFT: rightShiftStage, + PLUS: addStage, + MINUS: subtractStage, + MULTIPLY: multiplyStage, + DIVIDE: divideStage, + MODULUS: modulusStage, + EXPONENT: exponentStage, + NEGATE: negateStage, + INVERT: invertStage, + BITWISE_NOT: bitwiseNotStage, + TERNARY_TRUE: ternaryIfStage, + TERNARY_FALSE: ternaryElseStage, + COALESCE: ternaryElseStage, + SEPARATE: separatorStage, +} + +/* + A "precedent" is a function which will recursively parse new evaluateionStages from a given stream of tokens. + It's called a `precedent` because it is expected to handle exactly what precedence of operator, + and defer to other `precedent`s for other operators. +*/ +type precedent func(stream *tokenStream) (*evaluationStage, error) + +/* + A convenience function for specifying the behavior of a `precedent`. + Most `precedent` functions can be described by the same function, just with different type checks, symbols, and error formats. + This struct is passed to `makePrecedentFromPlanner` to create a `precedent` function. +*/ +type precedencePlanner struct { + validSymbols map[string]OperatorSymbol + validKinds []TokenKind + + typeErrorFormat string + + next precedent + nextRight precedent +} + +var planPrefix precedent +var planExponential precedent +var planMultiplicative precedent +var planAdditive precedent +var planBitwise precedent +var planShift precedent +var planComparator precedent +var planLogicalAnd precedent +var planLogicalOr precedent +var planTernary precedent +var planSeparator precedent + +func init() { + + // all these stages can use the same code (in `planPrecedenceLevel`) to execute, + // they simply need different type checks, symbols, and recursive precedents. + // While not all precedent phases are listed here, most can be represented this way. + planPrefix = makePrecedentFromPlanner(&precedencePlanner{ + validSymbols: prefixSymbols, + validKinds: []TokenKind{PREFIX}, + typeErrorFormat: prefixErrorFormat, + nextRight: planFunction, + }) + planExponential = makePrecedentFromPlanner(&precedencePlanner{ + validSymbols: exponentialSymbolsS, + validKinds: []TokenKind{MODIFIER}, + typeErrorFormat: modifierErrorFormat, + next: planFunction, + }) + planMultiplicative = makePrecedentFromPlanner(&precedencePlanner{ + validSymbols: multiplicativeSymbols, + validKinds: []TokenKind{MODIFIER}, + typeErrorFormat: modifierErrorFormat, + next: planExponential, + }) + planAdditive = makePrecedentFromPlanner(&precedencePlanner{ + validSymbols: additiveSymbols, + validKinds: []TokenKind{MODIFIER}, + typeErrorFormat: modifierErrorFormat, + next: planMultiplicative, + }) + planShift = makePrecedentFromPlanner(&precedencePlanner{ + validSymbols: bitwiseShiftSymbols, + validKinds: []TokenKind{MODIFIER}, + typeErrorFormat: modifierErrorFormat, + next: planAdditive, + }) + planBitwise = makePrecedentFromPlanner(&precedencePlanner{ + validSymbols: bitwiseSymbols, + validKinds: []TokenKind{MODIFIER}, + typeErrorFormat: modifierErrorFormat, + next: planShift, + }) + planComparator = makePrecedentFromPlanner(&precedencePlanner{ + validSymbols: comparatorSymbols, + validKinds: []TokenKind{COMPARATOR}, + typeErrorFormat: comparatorErrorFormat, + next: planBitwise, + }) + planLogicalAnd = makePrecedentFromPlanner(&precedencePlanner{ + validSymbols: map[string]OperatorSymbol{"&&": AND}, + validKinds: []TokenKind{LOGICALOP}, + typeErrorFormat: logicalErrorFormat, + next: planComparator, + }) + planLogicalOr = makePrecedentFromPlanner(&precedencePlanner{ + validSymbols: map[string]OperatorSymbol{"||": OR}, + validKinds: []TokenKind{LOGICALOP}, + typeErrorFormat: logicalErrorFormat, + next: planLogicalAnd, + }) + planTernary = makePrecedentFromPlanner(&precedencePlanner{ + validSymbols: ternarySymbols, + validKinds: []TokenKind{TERNARY}, + typeErrorFormat: ternaryErrorFormat, + next: planLogicalOr, + }) + planSeparator = makePrecedentFromPlanner(&precedencePlanner{ + validSymbols: separatorSymbols, + validKinds: []TokenKind{SEPARATOR}, + next: planTernary, + }) +} + +/* + Given a planner, creates a function which will evaluate a specific precedence level of operators, + and link it to other `precedent`s which recurse to parse other precedence levels. +*/ +func makePrecedentFromPlanner(planner *precedencePlanner) precedent { + + var generated precedent + var nextRight precedent + + generated = func(stream *tokenStream) (*evaluationStage, error) { + return planPrecedenceLevel( + stream, + planner.typeErrorFormat, + planner.validSymbols, + planner.validKinds, + nextRight, + planner.next, + ) + } + + if planner.nextRight != nil { + nextRight = planner.nextRight + } else { + nextRight = generated + } + + return generated +} + +/* + Creates a `evaluationStageList` object which represents an execution plan (or tree) + which is used to completely evaluate a set of tokens at evaluation-time. + The three stages of evaluation can be thought of as parsing strings to tokens, then tokens to a stage list, then evaluation with parameters. +*/ +func planStages(tokens []ExpressionToken) (*evaluationStage, error) { + + stream := newTokenStream(tokens) + + stage, err := planTokens(stream) + if err != nil { + return nil, err + } + + // while we're now fully-planned, we now need to re-order same-precedence operators. + // this could probably be avoided with a different planning method + reorderStages(stage) + + stage = elideLiterals(stage) + return stage, nil +} + +func planTokens(stream *tokenStream) (*evaluationStage, error) { + + if !stream.hasNext() { + return nil, nil + } + + return planSeparator(stream) +} + +/* + The most usual method of parsing an evaluation stage for a given precedence. + Most stages use the same logic +*/ +func planPrecedenceLevel( + stream *tokenStream, + typeErrorFormat string, + validSymbols map[string]OperatorSymbol, + validKinds []TokenKind, + rightPrecedent precedent, + leftPrecedent precedent) (*evaluationStage, error) { + + var token ExpressionToken + var symbol OperatorSymbol + var leftStage, rightStage *evaluationStage + var checks typeChecks + var err error + var keyFound bool + + if leftPrecedent != nil { + + leftStage, err = leftPrecedent(stream) + if err != nil { + return nil, err + } + } + + for stream.hasNext() { + + token = stream.next() + + if len(validKinds) > 0 { + + keyFound = false + for _, kind := range validKinds { + if kind == token.Kind { + keyFound = true + break + } + } + + if !keyFound { + break + } + } + + if validSymbols != nil { + + if !isString(token.Value) { + break + } + + symbol, keyFound = validSymbols[token.Value.(string)] + if !keyFound { + break + } + } + + if rightPrecedent != nil { + rightStage, err = rightPrecedent(stream) + if err != nil { + return nil, err + } + } + + checks = findTypeChecks(symbol) + + return &evaluationStage{ + + symbol: symbol, + leftStage: leftStage, + rightStage: rightStage, + operator: stageSymbolMap[symbol], + + leftTypeCheck: checks.left, + rightTypeCheck: checks.right, + typeCheck: checks.combined, + typeErrorFormat: typeErrorFormat, + }, nil + } + + stream.rewind() + return leftStage, nil +} + +/* + A special case where functions need to be of higher precedence than values, and need a special wrapped execution stage operator. +*/ +func planFunction(stream *tokenStream) (*evaluationStage, error) { + + var token ExpressionToken + var rightStage *evaluationStage + var err error + + token = stream.next() + + if token.Kind != FUNCTION { + stream.rewind() + return planAccessor(stream) + } + + rightStage, err = planAccessor(stream) + if err != nil { + return nil, err + } + + return &evaluationStage{ + + symbol: FUNCTIONAL, + rightStage: rightStage, + operator: makeFunctionStage(token.Value.(ExpressionFunction)), + typeErrorFormat: "Unable to run function '%v': %v", + }, nil +} + +func planAccessor(stream *tokenStream) (*evaluationStage, error) { + + var token, otherToken ExpressionToken + var rightStage *evaluationStage + var err error + + if !stream.hasNext() { + return nil, nil + } + + token = stream.next() + + if token.Kind != ACCESSOR { + stream.rewind() + return planValue(stream) + } + + // check if this is meant to be a function or a field. + // fields have a clause next to them, functions do not. + // if it's a function, parse the arguments. Otherwise leave the right stage null. + if stream.hasNext() { + + otherToken = stream.next() + if otherToken.Kind == CLAUSE { + + stream.rewind() + + rightStage, err = planTokens(stream) + if err != nil { + return nil, err + } + } else { + stream.rewind() + } + } + + return &evaluationStage{ + + symbol: ACCESS, + rightStage: rightStage, + operator: makeAccessorStage(token.Value.([]string)), + typeErrorFormat: "Unable to access parameter field or method '%v': %v", + }, nil +} + +/* + A truly special precedence function, this handles all the "lowest-case" errata of the process, including literals, parmeters, + clauses, and prefixes. +*/ +func planValue(stream *tokenStream) (*evaluationStage, error) { + + var token ExpressionToken + var symbol OperatorSymbol + var ret *evaluationStage + var operator evaluationOperator + var err error + + if !stream.hasNext() { + return nil, nil + } + + token = stream.next() + + switch token.Kind { + + case CLAUSE: + + ret, err = planTokens(stream) + if err != nil { + return nil, err + } + + // advance past the CLAUSE_CLOSE token. We know that it's a CLAUSE_CLOSE, because at parse-time we check for unbalanced parens. + stream.next() + + // the stage we got represents all of the logic contained within the parens + // but for technical reasons, we need to wrap this stage in a "noop" stage which breaks long chains of precedence. + // see github #33. + ret = &evaluationStage{ + rightStage: ret, + operator: noopStageRight, + symbol: NOOP, + } + + return ret, nil + + case CLAUSE_CLOSE: + + // when functions have empty params, this will be hit. In this case, we don't have any evaluation stage to do, + // so we just return nil so that the stage planner continues on its way. + stream.rewind() + return nil, nil + + case VARIABLE: + operator = makeParameterStage(token.Value.(string)) + + case NUMERIC: + fallthrough + case STRING: + fallthrough + case PATTERN: + fallthrough + case BOOLEAN: + symbol = LITERAL + operator = makeLiteralStage(token.Value) + case TIME: + symbol = LITERAL + operator = makeLiteralStage(float64(token.Value.(time.Time).Unix())) + + case PREFIX: + stream.rewind() + return planPrefix(stream) + } + + if operator == nil { + errorMsg := fmt.Sprintf("Unable to plan token kind: '%s', value: '%v'", token.Kind.String(), token.Value) + return nil, errors.New(errorMsg) + } + + return &evaluationStage{ + symbol: symbol, + operator: operator, + }, nil +} + +/* + Convenience function to pass a triplet of typechecks between `findTypeChecks` and `planPrecedenceLevel`. + Each of these members may be nil, which indicates that type does not matter for that value. +*/ +type typeChecks struct { + left stageTypeCheck + right stageTypeCheck + combined stageCombinedTypeCheck +} + +/* + Maps a given [symbol] to a set of typechecks to be used during runtime. +*/ +func findTypeChecks(symbol OperatorSymbol) typeChecks { + + switch symbol { + case GT: + fallthrough + case LT: + fallthrough + case GTE: + fallthrough + case LTE: + return typeChecks{ + combined: comparatorTypeCheck, + } + case REQ: + fallthrough + case NREQ: + return typeChecks{ + left: isString, + right: isRegexOrString, + } + case AND: + fallthrough + case OR: + return typeChecks{ + left: isBool, + right: isBool, + } + case IN: + return typeChecks{ + right: isArray, + } + case BITWISE_LSHIFT: + fallthrough + case BITWISE_RSHIFT: + fallthrough + case BITWISE_OR: + fallthrough + case BITWISE_AND: + fallthrough + case BITWISE_XOR: + return typeChecks{ + left: isFloat64, + right: isFloat64, + } + case PLUS: + return typeChecks{ + combined: additionTypeCheck, + } + case MINUS: + fallthrough + case MULTIPLY: + fallthrough + case DIVIDE: + fallthrough + case MODULUS: + fallthrough + case EXPONENT: + return typeChecks{ + left: isFloat64, + right: isFloat64, + } + case NEGATE: + return typeChecks{ + right: isFloat64, + } + case INVERT: + return typeChecks{ + right: isBool, + } + case BITWISE_NOT: + return typeChecks{ + right: isFloat64, + } + case TERNARY_TRUE: + return typeChecks{ + left: isBool, + } + + // unchecked cases + case EQ: + fallthrough + case NEQ: + return typeChecks{} + case TERNARY_FALSE: + fallthrough + case COALESCE: + fallthrough + default: + return typeChecks{} + } +} + +/* + During stage planning, stages of equal precedence are parsed such that they'll be evaluated in reverse order. + For commutative operators like "+" or "-", it's no big deal. But for order-specific operators, it ruins the expected result. +*/ +func reorderStages(rootStage *evaluationStage) { + + // traverse every rightStage until we find multiples in a row of the same precedence. + var identicalPrecedences []*evaluationStage + var currentStage, nextStage *evaluationStage + var precedence, currentPrecedence operatorPrecedence + + nextStage = rootStage + precedence = findOperatorPrecedenceForSymbol(rootStage.symbol) + + for nextStage != nil { + + currentStage = nextStage + nextStage = currentStage.rightStage + + // left depth first, since this entire method only looks for precedences down the right side of the tree + if currentStage.leftStage != nil { + reorderStages(currentStage.leftStage) + } + + currentPrecedence = findOperatorPrecedenceForSymbol(currentStage.symbol) + + if currentPrecedence == precedence { + identicalPrecedences = append(identicalPrecedences, currentStage) + continue + } + + // precedence break. + // See how many in a row we had, and reorder if there's more than one. + if len(identicalPrecedences) > 1 { + mirrorStageSubtree(identicalPrecedences) + } + + identicalPrecedences = []*evaluationStage{currentStage} + precedence = currentPrecedence + } + + if len(identicalPrecedences) > 1 { + mirrorStageSubtree(identicalPrecedences) + } +} + +/* + Performs a "mirror" on a subtree of stages. + This mirror functionally inverts the order of execution for all members of the [stages] list. + That list is assumed to be a root-to-leaf (ordered) list of evaluation stages, where each is a right-hand stage of the last. +*/ +func mirrorStageSubtree(stages []*evaluationStage) { + + var rootStage, inverseStage, carryStage, frontStage *evaluationStage + + stagesLength := len(stages) + + // reverse all right/left + for _, frontStage = range stages { + + carryStage = frontStage.rightStage + frontStage.rightStage = frontStage.leftStage + frontStage.leftStage = carryStage + } + + // end left swaps with root right + rootStage = stages[0] + frontStage = stages[stagesLength-1] + + carryStage = frontStage.leftStage + frontStage.leftStage = rootStage.rightStage + rootStage.rightStage = carryStage + + // for all non-root non-end stages, right is swapped with inverse stage right in list + for i := 0; i < (stagesLength-2)/2+1; i++ { + + frontStage = stages[i+1] + inverseStage = stages[stagesLength-i-1] + + carryStage = frontStage.rightStage + frontStage.rightStage = inverseStage.rightStage + inverseStage.rightStage = carryStage + } + + // swap all other information with inverse stages + for i := 0; i < stagesLength/2; i++ { + + frontStage = stages[i] + inverseStage = stages[stagesLength-i-1] + frontStage.swapWith(inverseStage) + } +} + +/* + Recurses through all operators in the entire tree, eliding operators where both sides are literals. +*/ +func elideLiterals(root *evaluationStage) *evaluationStage { + + if root.leftStage != nil { + root.leftStage = elideLiterals(root.leftStage) + } + + if root.rightStage != nil { + root.rightStage = elideLiterals(root.rightStage) + } + + return elideStage(root) +} + +/* + Elides a specific stage, if possible. + Returns the unmodified [root] stage if it cannot or should not be elided. + Otherwise, returns a new stage representing the condensed value from the elided stages. +*/ +func elideStage(root *evaluationStage) *evaluationStage { + + var leftValue, rightValue, result interface{} + var err error + + // right side must be a non-nil value. Left side must be nil or a value. + if root.rightStage == nil || + root.rightStage.symbol != LITERAL || + root.leftStage == nil || + root.leftStage.symbol != LITERAL { + return root + } + + // don't elide some operators + switch root.symbol { + case SEPARATE: + fallthrough + case IN: + return root + } + + // both sides are values, get their actual values. + // errors should be near-impossible here. If we encounter them, just abort this optimization. + leftValue, err = root.leftStage.operator(nil, nil, nil) + if err != nil { + return root + } + + rightValue, err = root.rightStage.operator(nil, nil, nil) + if err != nil { + return root + } + + // typcheck, since the grammar checker is a bit loose with which operator symbols go together. + err = typeCheck(root.leftTypeCheck, leftValue, root.symbol, root.typeErrorFormat) + if err != nil { + return root + } + + err = typeCheck(root.rightTypeCheck, rightValue, root.symbol, root.typeErrorFormat) + if err != nil { + return root + } + + if root.typeCheck != nil && !root.typeCheck(leftValue, rightValue) { + return root + } + + // pre-calculate, and return a new stage representing the result. + result, err = root.operator(leftValue, rightValue, nil) + if err != nil { + return root + } + + return &evaluationStage{ + symbol: LITERAL, + operator: makeLiteralStage(result), + } +} diff --git a/vendor/github.com/Knetic/govaluate/test.sh b/vendor/github.com/Knetic/govaluate/test.sh new file mode 100644 index 00000000..11aa8b33 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/test.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +# Script that runs tests, code coverage, and benchmarks all at once. +# Builds a symlink in /tmp, mostly to avoid messing with GOPATH at the user's shell level. + +TEMPORARY_PATH="/tmp/govaluate_test" +SRC_PATH="${TEMPORARY_PATH}/src" +FULL_PATH="${TEMPORARY_PATH}/src/govaluate" + +# set up temporary directory +rm -rf "${FULL_PATH}" +mkdir -p "${SRC_PATH}" + +ln -s $(pwd) "${FULL_PATH}" +export GOPATH="${TEMPORARY_PATH}" + +pushd "${TEMPORARY_PATH}/src/govaluate" + +# run the actual tests. +export GOVALUATE_TORTURE_TEST="true" +go test -bench=. -benchmem #-coverprofile coverage.out +status=$? + +if [ "${status}" != 0 ]; +then + exit $status +fi + +# coverage +# disabled because travis go1.4 seems not to support it suddenly? +#go tool cover -func=coverage.out + +popd diff --git a/vendor/github.com/Knetic/govaluate/tokenStream.go b/vendor/github.com/Knetic/govaluate/tokenStream.go new file mode 100644 index 00000000..d0029209 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/tokenStream.go @@ -0,0 +1,36 @@ +package govaluate + +type tokenStream struct { + tokens []ExpressionToken + index int + tokenLength int +} + +func newTokenStream(tokens []ExpressionToken) *tokenStream { + + var ret *tokenStream + + ret = new(tokenStream) + ret.tokens = tokens + ret.tokenLength = len(tokens) + return ret +} + +func (this *tokenStream) rewind() { + this.index -= 1 +} + +func (this *tokenStream) next() ExpressionToken { + + var token ExpressionToken + + token = this.tokens[this.index] + + this.index += 1 + return token +} + +func (this tokenStream) hasNext() bool { + + return this.index < this.tokenLength +} diff --git a/vendor/github.com/casbin/casbin/v2/.gitignore b/vendor/github.com/casbin/casbin/v2/.gitignore new file mode 100644 index 00000000..da27805f --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/.gitignore @@ -0,0 +1,30 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test +*.prof + +.idea/ +*.iml + +# vendor files +vendor diff --git a/vendor/github.com/casbin/casbin/v2/.releaserc.json b/vendor/github.com/casbin/casbin/v2/.releaserc.json new file mode 100644 index 00000000..58cb0bb4 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/.releaserc.json @@ -0,0 +1,16 @@ +{ + "debug": true, + "branches": [ + "+([0-9])?(.{+([0-9]),x}).x", + "master", + { + "name": "beta", + "prerelease": true + } + ], + "plugins": [ + "@semantic-release/commit-analyzer", + "@semantic-release/release-notes-generator", + "@semantic-release/github" + ] +} diff --git a/vendor/github.com/casbin/casbin/v2/.travis.yml b/vendor/github.com/casbin/casbin/v2/.travis.yml new file mode 100644 index 00000000..a35e0622 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/.travis.yml @@ -0,0 +1,15 @@ +language: go + +sudo: false + +env: + - GO111MODULE=on + +go: + - "1.11.13" + - "1.12" + - "1.13" + - "1.14" + +script: + - make test diff --git a/vendor/github.com/casbin/casbin/v2/CONTRIBUTING.md b/vendor/github.com/casbin/casbin/v2/CONTRIBUTING.md new file mode 100644 index 00000000..c5f0ddb5 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/CONTRIBUTING.md @@ -0,0 +1,35 @@ +# How to contribute + +The following is a set of guidelines for contributing to casbin and its libraries, which are hosted at [casbin organization at Github](https://github.com/casbin). + +This project adheres to the [Contributor Covenant 1.2.](https://www.contributor-covenant.org/version/1/2/0/code-of-conduct.html) By participating, you are expected to uphold this code. Please report unacceptable behavior to info@casbin.com. + +## Questions + +- We do our best to have an [up-to-date documentation](https://casbin.org/docs/overview) +- [Stack Overflow](https://stackoverflow.com) is the best place to start if you have a question. Please use the [casbin tag](https://stackoverflow.com/tags/casbin/info) we are actively monitoring. We encourage you to use Stack Overflow specially for Modeling Access Control Problems, in order to build a shared knowledge base. +- You can also join our [Gitter community](https://gitter.im/casbin/Lobby). + +## Reporting issues + +Reporting issues are a great way to contribute to the project. We are perpetually grateful about a well-written, thorough bug report. + +Before raising a new issue, check our [issue list](https://github.com/casbin/casbin/issues) to determine if it already contains the problem that you are facing. + +A good bug report shouldn't leave others needing to chase you for more information. Please be as detailed as possible. The following questions might serve as a template for writing a detailed report: + +What were you trying to achieve? +What are the expected results? +What are the received results? +What are the steps to reproduce the issue? +In what environment did you encounter the issue? + +Feature requests can also be submitted as issues. + +## Pull requests + +Good pull requests (e.g. patches, improvements, new features) are a fantastic help. They should remain focused in scope and avoid unrelated commits. + +Please ask first before embarking on any significant pull request (e.g. implementing new features, refactoring code etc.), otherwise you risk spending a lot of time working on something that the maintainers might not want to merge into the project. + +First add an issue to the project to discuss the improvement. Please adhere to the coding conventions used throughout the project. If in doubt, consult the [Effective Go style guide](https://golang.org/doc/effective_go.html). diff --git a/vendor/github.com/casbin/casbin/v2/LICENSE b/vendor/github.com/casbin/casbin/v2/LICENSE new file mode 100644 index 00000000..8dada3ed --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/casbin/casbin/v2/Makefile b/vendor/github.com/casbin/casbin/v2/Makefile new file mode 100644 index 00000000..6db2b920 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/Makefile @@ -0,0 +1,18 @@ +SHELL = /bin/bash +export PATH := $(shell yarn global bin):$(PATH) + +default: lint test + +test: + go test -race -v ./... + +benchmark: + go test -bench=. + +lint: + golangci-lint run --verbose + +release: + yarn global add semantic-release@17.2.4 + semantic-release + diff --git a/vendor/github.com/casbin/casbin/v2/README.md b/vendor/github.com/casbin/casbin/v2/README.md new file mode 100644 index 00000000..98723bff --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/README.md @@ -0,0 +1,286 @@ +Casbin +==== + +[![Go Report Card](https://goreportcard.com/badge/github.com/casbin/casbin)](https://goreportcard.com/report/github.com/casbin/casbin) +[![Build Status](https://travis-ci.com/casbin/casbin.svg?branch=master)](https://travis-ci.com/casbin/casbin) +[![Coverage Status](https://coveralls.io/repos/github/casbin/casbin/badge.svg?branch=master)](https://coveralls.io/github/casbin/casbin?branch=master) +[![Godoc](https://godoc.org/github.com/casbin/casbin?status.svg)](https://pkg.go.dev/github.com/casbin/casbin/v2) +[![Release](https://img.shields.io/github/release/casbin/casbin.svg)](https://github.com/casbin/casbin/releases/latest) +[![Gitter](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/casbin/lobby) +[![Sourcegraph](https://sourcegraph.com/github.com/casbin/casbin/-/badge.svg)](https://sourcegraph.com/github.com/casbin/casbin?badge) + +💖 [**Looking for an open-source identity and access management solution like Okta, Auth0, Keycloak ? Learn more about: Casdoor**](https://casdoor.org/) + +casdoor + +**News**: still worry about how to write the correct Casbin policy? ``Casbin online editor`` is coming to help! Try it at: https://casbin.org/editor/ + +![casbin Logo](casbin-logo.png) + +Casbin is a powerful and efficient open-source access control library for Golang projects. It provides support for enforcing authorization based on various [access control models](https://en.wikipedia.org/wiki/Computer_security_model). + +## All the languages supported by Casbin: + +[![golang](https://casbin.org/img/langs/golang.png)](https://github.com/casbin/casbin) | [![java](https://casbin.org/img/langs/java.png)](https://github.com/casbin/jcasbin) | [![nodejs](https://casbin.org/img/langs/nodejs.png)](https://github.com/casbin/node-casbin) | [![php](https://casbin.org/img/langs/php.png)](https://github.com/php-casbin/php-casbin) +----|----|----|---- +[Casbin](https://github.com/casbin/casbin) | [jCasbin](https://github.com/casbin/jcasbin) | [node-Casbin](https://github.com/casbin/node-casbin) | [PHP-Casbin](https://github.com/php-casbin/php-casbin) +production-ready | production-ready | production-ready | production-ready + +[![python](https://casbin.org/img/langs/python.png)](https://github.com/casbin/pycasbin) | [![dotnet](https://casbin.org/img/langs/dotnet.png)](https://github.com/casbin-net/Casbin.NET) | [![c++](https://casbin.org/img/langs/cpp.png)](https://github.com/casbin/casbin-cpp) | [![rust](https://casbin.org/img/langs/rust.png)](https://github.com/casbin/casbin-rs) +----|----|----|---- +[PyCasbin](https://github.com/casbin/pycasbin) | [Casbin.NET](https://github.com/casbin-net/Casbin.NET) | [Casbin-CPP](https://github.com/casbin/casbin-cpp) | [Casbin-RS](https://github.com/casbin/casbin-rs) +production-ready | production-ready | beta-test | production-ready + +## Table of contents + +- [Supported models](#supported-models) +- [How it works?](#how-it-works) +- [Features](#features) +- [Installation](#installation) +- [Documentation](#documentation) +- [Online editor](#online-editor) +- [Tutorials](#tutorials) +- [Get started](#get-started) +- [Policy management](#policy-management) +- [Policy persistence](#policy-persistence) +- [Policy consistence between multiple nodes](#policy-consistence-between-multiple-nodes) +- [Role manager](#role-manager) +- [Benchmarks](#benchmarks) +- [Examples](#examples) +- [Middlewares](#middlewares) +- [Our adopters](#our-adopters) + +## Supported models + +1. [**ACL (Access Control List)**](https://en.wikipedia.org/wiki/Access_control_list) +2. **ACL with [superuser](https://en.wikipedia.org/wiki/Superuser)** +3. **ACL without users**: especially useful for systems that don't have authentication or user log-ins. +3. **ACL without resources**: some scenarios may target for a type of resources instead of an individual resource by using permissions like ``write-article``, ``read-log``. It doesn't control the access to a specific article or log. +4. **[RBAC (Role-Based Access Control)](https://en.wikipedia.org/wiki/Role-based_access_control)** +5. **RBAC with resource roles**: both users and resources can have roles (or groups) at the same time. +6. **RBAC with domains/tenants**: users can have different role sets for different domains/tenants. +7. **[ABAC (Attribute-Based Access Control)](https://en.wikipedia.org/wiki/Attribute-Based_Access_Control)**: syntax sugar like ``resource.Owner`` can be used to get the attribute for a resource. +8. **[RESTful](https://en.wikipedia.org/wiki/Representational_state_transfer)**: supports paths like ``/res/*``, ``/res/:id`` and HTTP methods like ``GET``, ``POST``, ``PUT``, ``DELETE``. +9. **Deny-override**: both allow and deny authorizations are supported, deny overrides the allow. +10. **Priority**: the policy rules can be prioritized like firewall rules. + +## How it works? + +In Casbin, an access control model is abstracted into a CONF file based on the **PERM metamodel (Policy, Effect, Request, Matchers)**. So switching or upgrading the authorization mechanism for a project is just as simple as modifying a configuration. You can customize your own access control model by combining the available models. For example, you can get RBAC roles and ABAC attributes together inside one model and share one set of policy rules. + +The most basic and simplest model in Casbin is ACL. ACL's model CONF is: + +```ini +# Request definition +[request_definition] +r = sub, obj, act + +# Policy definition +[policy_definition] +p = sub, obj, act + +# Policy effect +[policy_effect] +e = some(where (p.eft == allow)) + +# Matchers +[matchers] +m = r.sub == p.sub && r.obj == p.obj && r.act == p.act + +``` + +An example policy for ACL model is like: + +``` +p, alice, data1, read +p, bob, data2, write +``` + +It means: + +- alice can read data1 +- bob can write data2 + +We also support multi-line mode by appending '\\' in the end: + +```ini +# Matchers +[matchers] +m = r.sub == p.sub && r.obj == p.obj \ + && r.act == p.act +``` + +Further more, if you are using ABAC, you can try operator `in` like following in Casbin **golang** edition (jCasbin and Node-Casbin are not supported yet): + +```ini +# Matchers +[matchers] +m = r.obj == p.obj && r.act == p.act || r.obj in ('data2', 'data3') +``` + +But you **SHOULD** make sure that the length of the array is **MORE** than **1**, otherwise there will cause it to panic. + +For more operators, you may take a look at [govaluate](https://github.com/Knetic/govaluate) + +## Features + +What Casbin does: + +1. enforce the policy in the classic ``{subject, object, action}`` form or a customized form as you defined, both allow and deny authorizations are supported. +2. handle the storage of the access control model and its policy. +3. manage the role-user mappings and role-role mappings (aka role hierarchy in RBAC). +4. support built-in superuser like ``root`` or ``administrator``. A superuser can do anything without explicit permissions. +5. multiple built-in operators to support the rule matching. For example, ``keyMatch`` can map a resource key ``/foo/bar`` to the pattern ``/foo*``. + +What Casbin does NOT do: + +1. authentication (aka verify ``username`` and ``password`` when a user logs in) +2. manage the list of users or roles. I believe it's more convenient for the project itself to manage these entities. Users usually have their passwords, and Casbin is not designed as a password container. However, Casbin stores the user-role mapping for the RBAC scenario. + +## Installation + +``` +go get github.com/casbin/casbin/v2 +``` + +## Documentation + +https://casbin.org/docs/overview + +## Online editor + +You can also use the online editor (https://casbin.org/editor/) to write your Casbin model and policy in your web browser. It provides functionality such as ``syntax highlighting`` and ``code completion``, just like an IDE for a programming language. + +## Tutorials + +https://casbin.org/docs/tutorials + +## Get started + +1. New a Casbin enforcer with a model file and a policy file: + + ```go + e, _ := casbin.NewEnforcer("path/to/model.conf", "path/to/policy.csv") + ``` + +Note: you can also initialize an enforcer with policy in DB instead of file, see [Policy-persistence](#policy-persistence) section for details. + +2. Add an enforcement hook into your code right before the access happens: + + ```go + sub := "alice" // the user that wants to access a resource. + obj := "data1" // the resource that is going to be accessed. + act := "read" // the operation that the user performs on the resource. + + if res, _ := e.Enforce(sub, obj, act); res { + // permit alice to read data1 + } else { + // deny the request, show an error + } + ``` + +3. Besides the static policy file, Casbin also provides API for permission management at run-time. For example, You can get all the roles assigned to a user as below: + + ```go + roles, _ := e.GetImplicitRolesForUser(sub) + ``` + +See [Policy management APIs](#policy-management) for more usage. + +## Policy management + +Casbin provides two sets of APIs to manage permissions: + +- [Management API](https://casbin.org/docs/management-api): the primitive API that provides full support for Casbin policy management. +- [RBAC API](https://casbin.org/docs/rbac-api): a more friendly API for RBAC. This API is a subset of Management API. The RBAC users could use this API to simplify the code. + +We also provide a [web-based UI](https://casbin.org/docs/admin-portal) for model management and policy management: + +![model editor](https://hsluoyz.github.io/casbin/ui_model_editor.png) + +![policy editor](https://hsluoyz.github.io/casbin/ui_policy_editor.png) + +## Policy persistence + +https://casbin.org/docs/adapters + +## Policy consistence between multiple nodes + +https://casbin.org/docs/watchers + +## Role manager + +https://casbin.org/docs/role-managers + +## Benchmarks + +https://casbin.org/docs/benchmark + +## Examples + +Model | Model file | Policy file +----|------|---- +ACL | [basic_model.conf](https://github.com/casbin/casbin/blob/master/examples/basic_model.conf) | [basic_policy.csv](https://github.com/casbin/casbin/blob/master/examples/basic_policy.csv) +ACL with superuser | [basic_model_with_root.conf](https://github.com/casbin/casbin/blob/master/examples/basic_with_root_model.conf) | [basic_policy.csv](https://github.com/casbin/casbin/blob/master/examples/basic_policy.csv) +ACL without users | [basic_model_without_users.conf](https://github.com/casbin/casbin/blob/master/examples/basic_without_users_model.conf) | [basic_policy_without_users.csv](https://github.com/casbin/casbin/blob/master/examples/basic_without_users_policy.csv) +ACL without resources | [basic_model_without_resources.conf](https://github.com/casbin/casbin/blob/master/examples/basic_without_resources_model.conf) | [basic_policy_without_resources.csv](https://github.com/casbin/casbin/blob/master/examples/basic_without_resources_policy.csv) +RBAC | [rbac_model.conf](https://github.com/casbin/casbin/blob/master/examples/rbac_model.conf) | [rbac_policy.csv](https://github.com/casbin/casbin/blob/master/examples/rbac_policy.csv) +RBAC with resource roles | [rbac_model_with_resource_roles.conf](https://github.com/casbin/casbin/blob/master/examples/rbac_with_resource_roles_model.conf) | [rbac_policy_with_resource_roles.csv](https://github.com/casbin/casbin/blob/master/examples/rbac_with_resource_roles_policy.csv) +RBAC with domains/tenants | [rbac_model_with_domains.conf](https://github.com/casbin/casbin/blob/master/examples/rbac_with_domains_model.conf) | [rbac_policy_with_domains.csv](https://github.com/casbin/casbin/blob/master/examples/rbac_with_domains_policy.csv) +ABAC | [abac_model.conf](https://github.com/casbin/casbin/blob/master/examples/abac_model.conf) | N/A +RESTful | [keymatch_model.conf](https://github.com/casbin/casbin/blob/master/examples/keymatch_model.conf) | [keymatch_policy.csv](https://github.com/casbin/casbin/blob/master/examples/keymatch_policy.csv) +Deny-override | [rbac_model_with_deny.conf](https://github.com/casbin/casbin/blob/master/examples/rbac_with_deny_model.conf) | [rbac_policy_with_deny.csv](https://github.com/casbin/casbin/blob/master/examples/rbac_with_deny_policy.csv) +Priority | [priority_model.conf](https://github.com/casbin/casbin/blob/master/examples/priority_model.conf) | [priority_policy.csv](https://github.com/casbin/casbin/blob/master/examples/priority_policy.csv) + +## Middlewares + +Authz middlewares for web frameworks: https://casbin.org/docs/middlewares + +## Our adopters + +https://casbin.org/docs/adopters + +## How to Contribute + +Please read the [contributing guide](CONTRIBUTING.md). + +## Contributors + +This project exists thanks to all the people who contribute. + + +## Backers + +Thank you to all our backers! 🙏 [[Become a backer](https://opencollective.com/casbin#backer)] + + + +## Sponsors + +Support this project by becoming a sponsor. Your logo will show up here with a link to your website. [[Become a sponsor](https://opencollective.com/casbin#sponsor)] + + + + + + + + + + + + +## Star History + +[![Star History Chart](https://api.star-history.com/svg?repos=casbin/casbin&type=Date)](https://star-history.com/#casbin/casbin&Date) + +## License + +This project is licensed under the [Apache 2.0 license](LICENSE). + +## Contact + +If you have any issues or feature requests, please contact us. PR is welcomed. +- https://github.com/casbin/casbin/issues +- hsluoyz@gmail.com +- Tencent QQ group: [546057381](//shang.qq.com/wpa/qunwpa?idkey=8ac8b91fc97ace3d383d0035f7aa06f7d670fd8e8d4837347354a31c18fac885) diff --git a/vendor/github.com/casbin/casbin/v2/casbin-logo.png b/vendor/github.com/casbin/casbin/v2/casbin-logo.png new file mode 100644 index 00000000..7e5d1ecf Binary files /dev/null and b/vendor/github.com/casbin/casbin/v2/casbin-logo.png differ diff --git a/vendor/github.com/casbin/casbin/v2/config/config.go b/vendor/github.com/casbin/casbin/v2/config/config.go new file mode 100644 index 00000000..e9f83030 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/config/config.go @@ -0,0 +1,267 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "os" + "strconv" + "strings" +) + +var ( + // DEFAULT_SECTION specifies the name of a section if no name provided + DEFAULT_SECTION = "default" + // DEFAULT_COMMENT defines what character(s) indicate a comment `#` + DEFAULT_COMMENT = []byte{'#'} + // DEFAULT_COMMENT_SEM defines what alternate character(s) indicate a comment `;` + DEFAULT_COMMENT_SEM = []byte{';'} + // DEFAULT_MULTI_LINE_SEPARATOR defines what character indicates a multi-line content + DEFAULT_MULTI_LINE_SEPARATOR = []byte{'\\'} +) + +// ConfigInterface defines the behavior of a Config implementation +type ConfigInterface interface { + String(key string) string + Strings(key string) []string + Bool(key string) (bool, error) + Int(key string) (int, error) + Int64(key string) (int64, error) + Float64(key string) (float64, error) + Set(key string, value string) error +} + +// Config represents an implementation of the ConfigInterface +type Config struct { + // Section:key=value + data map[string]map[string]string +} + +// NewConfig create an empty configuration representation from file. +func NewConfig(confName string) (ConfigInterface, error) { + c := &Config{ + data: make(map[string]map[string]string), + } + err := c.parse(confName) + return c, err +} + +// NewConfigFromText create an empty configuration representation from text. +func NewConfigFromText(text string) (ConfigInterface, error) { + c := &Config{ + data: make(map[string]map[string]string), + } + err := c.parseBuffer(bufio.NewReader(strings.NewReader(text))) + return c, err +} + +// AddConfig adds a new section->key:value to the configuration. +func (c *Config) AddConfig(section string, option string, value string) bool { + if section == "" { + section = DEFAULT_SECTION + } + + if _, ok := c.data[section]; !ok { + c.data[section] = make(map[string]string) + } + + _, ok := c.data[section][option] + c.data[section][option] = value + + return !ok +} + +func (c *Config) parse(fname string) (err error) { + f, err := os.Open(fname) + if err != nil { + return err + } + defer f.Close() + + buf := bufio.NewReader(f) + return c.parseBuffer(buf) +} + +func (c *Config) parseBuffer(buf *bufio.Reader) error { + var section string + var lineNum int + var buffer bytes.Buffer + var canWrite bool + for { + if canWrite { + if err := c.write(section, lineNum, &buffer); err != nil { + return err + } else { + canWrite = false + } + } + lineNum++ + line, _, err := buf.ReadLine() + if err == io.EOF { + // force write when buffer is not flushed yet + if buffer.Len() > 0 { + if err := c.write(section, lineNum, &buffer); err != nil { + return err + } + } + break + } else if err != nil { + return err + } + + line = bytes.TrimSpace(line) + switch { + case bytes.Equal(line, []byte{}), bytes.HasPrefix(line, DEFAULT_COMMENT_SEM), + bytes.HasPrefix(line, DEFAULT_COMMENT): + canWrite = true + continue + case bytes.HasPrefix(line, []byte{'['}) && bytes.HasSuffix(line, []byte{']'}): + // force write when buffer is not flushed yet + if buffer.Len() > 0 { + if err := c.write(section, lineNum, &buffer); err != nil { + return err + } + canWrite = false + } + section = string(line[1 : len(line)-1]) + default: + var p []byte + if bytes.HasSuffix(line, DEFAULT_MULTI_LINE_SEPARATOR) { + p = bytes.TrimSpace(line[:len(line)-1]) + p = append(p, " "...) + } else { + p = line + canWrite = true + } + + end := len(p) + for i, value := range p { + if value == DEFAULT_COMMENT[0] || value == DEFAULT_COMMENT_SEM[0] { + end = i + break + } + } + if _, err := buffer.Write(p[:end]); err != nil { + return err + } + } + } + + return nil +} + +func (c *Config) write(section string, lineNum int, b *bytes.Buffer) error { + if b.Len() <= 0 { + return nil + } + + optionVal := bytes.SplitN(b.Bytes(), []byte{'='}, 2) + if len(optionVal) != 2 { + return fmt.Errorf("parse the content error : line %d , %s = ? ", lineNum, optionVal[0]) + } + option := bytes.TrimSpace(optionVal[0]) + value := bytes.TrimSpace(optionVal[1]) + c.AddConfig(section, string(option), string(value)) + + // flush buffer after adding + b.Reset() + + return nil +} + +// Bool lookups up the value using the provided key and converts the value to a bool +func (c *Config) Bool(key string) (bool, error) { + return strconv.ParseBool(c.get(key)) +} + +// Int lookups up the value using the provided key and converts the value to a int +func (c *Config) Int(key string) (int, error) { + return strconv.Atoi(c.get(key)) +} + +// Int64 lookups up the value using the provided key and converts the value to a int64 +func (c *Config) Int64(key string) (int64, error) { + return strconv.ParseInt(c.get(key), 10, 64) +} + +// Float64 lookups up the value using the provided key and converts the value to a float64 +func (c *Config) Float64(key string) (float64, error) { + return strconv.ParseFloat(c.get(key), 64) +} + +// String lookups up the value using the provided key and converts the value to a string +func (c *Config) String(key string) string { + return c.get(key) +} + +// Strings lookups up the value using the provided key and converts the value to an array of string +// by splitting the string by comma +func (c *Config) Strings(key string) []string { + v := c.get(key) + if v == "" { + return nil + } + return strings.Split(v, ",") +} + +// Set sets the value for the specific key in the Config +func (c *Config) Set(key string, value string) error { + if len(key) == 0 { + return errors.New("key is empty") + } + + var ( + section string + option string + ) + + keys := strings.Split(strings.ToLower(key), "::") + if len(keys) >= 2 { + section = keys[0] + option = keys[1] + } else { + option = keys[0] + } + + c.AddConfig(section, option, value) + return nil +} + +// section.key or key +func (c *Config) get(key string) string { + var ( + section string + option string + ) + + keys := strings.Split(strings.ToLower(key), "::") + if len(keys) >= 2 { + section = keys[0] + option = keys[1] + } else { + section = DEFAULT_SECTION + option = keys[0] + } + + if value, ok := c.data[section][option]; ok { + return value + } + + return "" +} diff --git a/vendor/github.com/casbin/casbin/v2/constant/constants.go b/vendor/github.com/casbin/casbin/v2/constant/constants.go new file mode 100644 index 00000000..7a454aec --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/constant/constants.go @@ -0,0 +1,30 @@ +// Copyright 2022 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package constant + +const ( + DomainIndex = "dom" + SubjectIndex = "sub" + ObjectIndex = "obj" + PriorityIndex = "priority" +) + +const ( + AllowOverrideEffect = "some(where (p_eft == allow))" + DenyOverrideEffect = "!some(where (p_eft == deny))" + AllowAndDenyEffect = "some(where (p_eft == allow)) && !some(where (p_eft == deny))" + PriorityEffect = "priority(p_eft) || deny" + SubjectPriorityEffect = "subjectPriority(p_eft) || deny" +) diff --git a/vendor/github.com/casbin/casbin/v2/effector/default_effector.go b/vendor/github.com/casbin/casbin/v2/effector/default_effector.go new file mode 100644 index 00000000..feb083a6 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/effector/default_effector.go @@ -0,0 +1,109 @@ +// Copyright 2018 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package effector + +import ( + "errors" + + "github.com/casbin/casbin/v2/constant" +) + +// DefaultEffector is default effector for Casbin. +type DefaultEffector struct { +} + +// NewDefaultEffector is the constructor for DefaultEffector. +func NewDefaultEffector() *DefaultEffector { + e := DefaultEffector{} + return &e +} + +// MergeEffects merges all matching results collected by the enforcer into a single decision. +func (e *DefaultEffector) MergeEffects(expr string, effects []Effect, matches []float64, policyIndex int, policyLength int) (Effect, int, error) { + result := Indeterminate + explainIndex := -1 + + switch expr { + case constant.AllowOverrideEffect: + if matches[policyIndex] == 0 { + break + } + // only check the current policyIndex + if effects[policyIndex] == Allow { + result = Allow + explainIndex = policyIndex + break + } + case constant.DenyOverrideEffect: + // only check the current policyIndex + if matches[policyIndex] != 0 && effects[policyIndex] == Deny { + result = Deny + explainIndex = policyIndex + break + } + // if no deny rules are matched at last, then allow + if policyIndex == policyLength-1 { + result = Allow + } + case constant.AllowAndDenyEffect: + // short-circuit if matched deny rule + if matches[policyIndex] != 0 && effects[policyIndex] == Deny { + result = Deny + // set hit rule to the (first) matched deny rule + explainIndex = policyIndex + break + } + + // short-circuit some effects in the middle + if policyIndex < policyLength-1 { + // choose not to short-circuit + return result, explainIndex, nil + } + // merge all effects at last + for i, eft := range effects { + if matches[i] == 0 { + continue + } + + if eft == Allow { + result = Allow + // set hit rule to first matched allow rule + explainIndex = i + break + } + } + case constant.PriorityEffect, constant.SubjectPriorityEffect: + // reverse merge, short-circuit may be earlier + for i := len(effects) - 1; i >= 0; i-- { + if matches[i] == 0 { + continue + } + + if effects[i] != Indeterminate { + if effects[i] == Allow { + result = Allow + } else { + result = Deny + } + explainIndex = i + break + } + } + default: + return Deny, -1, errors.New("unsupported effect") + } + + return result, explainIndex, nil +} diff --git a/vendor/github.com/casbin/casbin/v2/effector/effector.go b/vendor/github.com/casbin/casbin/v2/effector/effector.go new file mode 100644 index 00000000..665848b5 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/effector/effector.go @@ -0,0 +1,31 @@ +// Copyright 2018 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package effector + +// Effect is the result for a policy rule. +type Effect int + +// Values for policy effect. +const ( + Allow Effect = iota + Indeterminate + Deny +) + +// Effector is the interface for Casbin effectors. +type Effector interface { + // MergeEffects merges all matching results collected by the enforcer into a single decision. + MergeEffects(expr string, effects []Effect, matches []float64, policyIndex int, policyLength int) (Effect, int, error) +} diff --git a/vendor/github.com/casbin/casbin/v2/enforcer.go b/vendor/github.com/casbin/casbin/v2/enforcer.go new file mode 100644 index 00000000..d3c0c727 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/enforcer.go @@ -0,0 +1,834 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +import ( + "errors" + "fmt" + "runtime/debug" + "strings" + "sync" + + "github.com/Knetic/govaluate" + "github.com/casbin/casbin/v2/effector" + "github.com/casbin/casbin/v2/log" + "github.com/casbin/casbin/v2/model" + "github.com/casbin/casbin/v2/persist" + fileadapter "github.com/casbin/casbin/v2/persist/file-adapter" + "github.com/casbin/casbin/v2/rbac" + defaultrolemanager "github.com/casbin/casbin/v2/rbac/default-role-manager" + "github.com/casbin/casbin/v2/util" +) + +// Enforcer is the main interface for authorization enforcement and policy management. +type Enforcer struct { + modelPath string + model model.Model + fm model.FunctionMap + eft effector.Effector + + adapter persist.Adapter + watcher persist.Watcher + dispatcher persist.Dispatcher + rmMap map[string]rbac.RoleManager + matcherMap sync.Map + + enabled bool + autoSave bool + autoBuildRoleLinks bool + autoNotifyWatcher bool + autoNotifyDispatcher bool + + logger log.Logger +} + +// EnforceContext is used as the first element of the parameter "rvals" in method "enforce" +type EnforceContext struct { + RType string + PType string + EType string + MType string +} + +// NewEnforcer creates an enforcer via file or DB. +// +// File: +// +// e := casbin.NewEnforcer("path/to/basic_model.conf", "path/to/basic_policy.csv") +// +// MySQL DB: +// +// a := mysqladapter.NewDBAdapter("mysql", "mysql_username:mysql_password@tcp(127.0.0.1:3306)/") +// e := casbin.NewEnforcer("path/to/basic_model.conf", a) +// +func NewEnforcer(params ...interface{}) (*Enforcer, error) { + e := &Enforcer{logger: &log.DefaultLogger{}} + + parsedParamLen := 0 + paramLen := len(params) + if paramLen >= 1 { + enableLog, ok := params[paramLen-1].(bool) + if ok { + e.EnableLog(enableLog) + parsedParamLen++ + } + } + + if paramLen-parsedParamLen >= 1 { + logger, ok := params[paramLen-parsedParamLen-1].(log.Logger) + if ok { + e.logger = logger + parsedParamLen++ + } + } + + if paramLen-parsedParamLen == 2 { + switch p0 := params[0].(type) { + case string: + switch p1 := params[1].(type) { + case string: + err := e.InitWithFile(p0, p1) + if err != nil { + return nil, err + } + default: + err := e.InitWithAdapter(p0, p1.(persist.Adapter)) + if err != nil { + return nil, err + } + } + default: + switch params[1].(type) { + case string: + return nil, errors.New("invalid parameters for enforcer") + default: + err := e.InitWithModelAndAdapter(p0.(model.Model), params[1].(persist.Adapter)) + if err != nil { + return nil, err + } + } + } + } else if paramLen-parsedParamLen == 1 { + switch p0 := params[0].(type) { + case string: + err := e.InitWithFile(p0, "") + if err != nil { + return nil, err + } + default: + err := e.InitWithModelAndAdapter(p0.(model.Model), nil) + if err != nil { + return nil, err + } + } + } else if paramLen-parsedParamLen == 0 { + return e, nil + } else { + return nil, errors.New("invalid parameters for enforcer") + } + + return e, nil +} + +// InitWithFile initializes an enforcer with a model file and a policy file. +func (e *Enforcer) InitWithFile(modelPath string, policyPath string) error { + a := fileadapter.NewAdapter(policyPath) + return e.InitWithAdapter(modelPath, a) +} + +// InitWithAdapter initializes an enforcer with a database adapter. +func (e *Enforcer) InitWithAdapter(modelPath string, adapter persist.Adapter) error { + m, err := model.NewModelFromFile(modelPath) + if err != nil { + return err + } + + err = e.InitWithModelAndAdapter(m, adapter) + if err != nil { + return err + } + + e.modelPath = modelPath + return nil +} + +// InitWithModelAndAdapter initializes an enforcer with a model and a database adapter. +func (e *Enforcer) InitWithModelAndAdapter(m model.Model, adapter persist.Adapter) error { + e.adapter = adapter + + e.model = m + m.SetLogger(e.logger) + e.model.PrintModel() + e.fm = model.LoadFunctionMap() + + e.initialize() + + // Do not initialize the full policy when using a filtered adapter + fa, ok := e.adapter.(persist.FilteredAdapter) + if e.adapter != nil && (!ok || ok && !fa.IsFiltered()) { + err := e.LoadPolicy() + if err != nil { + return err + } + } + + return nil +} + +// SetLogger changes the current enforcer's logger. +func (e *Enforcer) SetLogger(logger log.Logger) { + e.logger = logger + e.model.SetLogger(e.logger) + for k := range e.rmMap { + e.rmMap[k].SetLogger(e.logger) + } +} + +func (e *Enforcer) initialize() { + e.rmMap = map[string]rbac.RoleManager{} + e.eft = effector.NewDefaultEffector() + e.watcher = nil + e.matcherMap = sync.Map{} + + e.enabled = true + e.autoSave = true + e.autoBuildRoleLinks = true + e.autoNotifyWatcher = true + e.autoNotifyDispatcher = true + e.initRmMap() +} + +// LoadModel reloads the model from the model CONF file. +// Because the policy is attached to a model, so the policy is invalidated and needs to be reloaded by calling LoadPolicy(). +func (e *Enforcer) LoadModel() error { + var err error + e.model, err = model.NewModelFromFile(e.modelPath) + if err != nil { + return err + } + e.model.SetLogger(e.logger) + + e.model.PrintModel() + e.fm = model.LoadFunctionMap() + + e.initialize() + + return nil +} + +// GetModel gets the current model. +func (e *Enforcer) GetModel() model.Model { + return e.model +} + +// SetModel sets the current model. +func (e *Enforcer) SetModel(m model.Model) { + e.model = m + e.fm = model.LoadFunctionMap() + + e.model.SetLogger(e.logger) + e.initialize() +} + +// GetAdapter gets the current adapter. +func (e *Enforcer) GetAdapter() persist.Adapter { + return e.adapter +} + +// SetAdapter sets the current adapter. +func (e *Enforcer) SetAdapter(adapter persist.Adapter) { + e.adapter = adapter +} + +// SetWatcher sets the current watcher. +func (e *Enforcer) SetWatcher(watcher persist.Watcher) error { + e.watcher = watcher + if _, ok := e.watcher.(persist.WatcherEx); ok { + // The callback of WatcherEx has no generic implementation. + return nil + } else { + // In case the Watcher wants to use a customized callback function, call `SetUpdateCallback` after `SetWatcher`. + return watcher.SetUpdateCallback(func(string) { _ = e.LoadPolicy() }) + } +} + +// GetRoleManager gets the current role manager. +func (e *Enforcer) GetRoleManager() rbac.RoleManager { + return e.rmMap["g"] +} + +// GetNamedRoleManager gets the role manager for the named policy. +func (e *Enforcer) GetNamedRoleManager(ptype string) rbac.RoleManager { + return e.rmMap[ptype] +} + +// SetRoleManager sets the current role manager. +func (e *Enforcer) SetRoleManager(rm rbac.RoleManager) { + e.rmMap["g"] = rm +} + +// SetNamedRoleManager sets the role manager for the named policy. +func (e *Enforcer) SetNamedRoleManager(ptype string, rm rbac.RoleManager) { + e.rmMap[ptype] = rm +} + +// SetEffector sets the current effector. +func (e *Enforcer) SetEffector(eft effector.Effector) { + e.eft = eft +} + +// ClearPolicy clears all policy. +func (e *Enforcer) ClearPolicy() { + if e.dispatcher != nil && e.autoNotifyDispatcher { + _ = e.dispatcher.ClearPolicy() + return + } + e.model.ClearPolicy() +} + +// LoadPolicy reloads the policy from file/database. +func (e *Enforcer) LoadPolicy() error { + needToRebuild := false + newModel := e.model.Copy() + newModel.ClearPolicy() + + var err error + defer func() { + if err != nil { + if e.autoBuildRoleLinks && needToRebuild { + _ = e.BuildRoleLinks() + } + } + }() + + if err = e.adapter.LoadPolicy(newModel); err != nil && err.Error() != "invalid file path, file path cannot be empty" { + return err + } + + if err = newModel.SortPoliciesBySubjectHierarchy(); err != nil { + return err + } + + if err = newModel.SortPoliciesByPriority(); err != nil { + return err + } + + if e.autoBuildRoleLinks { + needToRebuild = true + for _, rm := range e.rmMap { + err := rm.Clear() + if err != nil { + return err + } + } + err = newModel.BuildRoleLinks(e.rmMap) + if err != nil { + return err + } + } + e.model = newModel + return nil +} + +func (e *Enforcer) loadFilteredPolicy(filter interface{}) error { + var filteredAdapter persist.FilteredAdapter + + // Attempt to cast the Adapter as a FilteredAdapter + switch adapter := e.adapter.(type) { + case persist.FilteredAdapter: + filteredAdapter = adapter + default: + return errors.New("filtered policies are not supported by this adapter") + } + if err := filteredAdapter.LoadFilteredPolicy(e.model, filter); err != nil && err.Error() != "invalid file path, file path cannot be empty" { + return err + } + + if err := e.model.SortPoliciesBySubjectHierarchy(); err != nil { + return err + } + + if err := e.model.SortPoliciesByPriority(); err != nil { + return err + } + + e.initRmMap() + e.model.PrintPolicy() + if e.autoBuildRoleLinks { + err := e.BuildRoleLinks() + if err != nil { + return err + } + } + return nil +} + +// LoadFilteredPolicy reloads a filtered policy from file/database. +func (e *Enforcer) LoadFilteredPolicy(filter interface{}) error { + e.model.ClearPolicy() + + return e.loadFilteredPolicy(filter) +} + +// LoadIncrementalFilteredPolicy append a filtered policy from file/database. +func (e *Enforcer) LoadIncrementalFilteredPolicy(filter interface{}) error { + return e.loadFilteredPolicy(filter) +} + +// IsFiltered returns true if the loaded policy has been filtered. +func (e *Enforcer) IsFiltered() bool { + filteredAdapter, ok := e.adapter.(persist.FilteredAdapter) + if !ok { + return false + } + return filteredAdapter.IsFiltered() +} + +// SavePolicy saves the current policy (usually after changed with Casbin API) back to file/database. +func (e *Enforcer) SavePolicy() error { + if e.IsFiltered() { + return errors.New("cannot save a filtered policy") + } + if err := e.adapter.SavePolicy(e.model); err != nil { + return err + } + if e.watcher != nil { + var err error + if watcher, ok := e.watcher.(persist.WatcherEx); ok { + err = watcher.UpdateForSavePolicy(e.model) + } else { + err = e.watcher.Update() + } + return err + } + return nil +} + +func (e *Enforcer) initRmMap() { + for ptype := range e.model["g"] { + if rm, ok := e.rmMap[ptype]; ok { + _ = rm.Clear() + } else { + e.rmMap[ptype] = defaultrolemanager.NewRoleManager(10) + } + } +} + +// EnableEnforce changes the enforcing state of Casbin, when Casbin is disabled, all access will be allowed by the Enforce() function. +func (e *Enforcer) EnableEnforce(enable bool) { + e.enabled = enable +} + +// EnableLog changes whether Casbin will log messages to the Logger. +func (e *Enforcer) EnableLog(enable bool) { + e.logger.EnableLog(enable) +} + +// IsLogEnabled returns the current logger's enabled status. +func (e *Enforcer) IsLogEnabled() bool { + return e.logger.IsEnabled() +} + +// EnableAutoNotifyWatcher controls whether to save a policy rule automatically notify the Watcher when it is added or removed. +func (e *Enforcer) EnableAutoNotifyWatcher(enable bool) { + e.autoNotifyWatcher = enable +} + +// EnableAutoNotifyDispatcher controls whether to save a policy rule automatically notify the Dispatcher when it is added or removed. +func (e *Enforcer) EnableAutoNotifyDispatcher(enable bool) { + e.autoNotifyDispatcher = enable +} + +// EnableAutoSave controls whether to save a policy rule automatically to the adapter when it is added or removed. +func (e *Enforcer) EnableAutoSave(autoSave bool) { + e.autoSave = autoSave +} + +// EnableAutoBuildRoleLinks controls whether to rebuild the role inheritance relations when a role is added or deleted. +func (e *Enforcer) EnableAutoBuildRoleLinks(autoBuildRoleLinks bool) { + e.autoBuildRoleLinks = autoBuildRoleLinks +} + +// BuildRoleLinks manually rebuild the role inheritance relations. +func (e *Enforcer) BuildRoleLinks() error { + for _, rm := range e.rmMap { + err := rm.Clear() + if err != nil { + return err + } + } + + return e.model.BuildRoleLinks(e.rmMap) +} + +// BuildIncrementalRoleLinks provides incremental build the role inheritance relations. +func (e *Enforcer) BuildIncrementalRoleLinks(op model.PolicyOp, ptype string, rules [][]string) error { + e.invalidateMatcherMap() + return e.model.BuildIncrementalRoleLinks(e.rmMap, op, "g", ptype, rules) +} + +// NewEnforceContext Create a default structure based on the suffix +func NewEnforceContext(suffix string) EnforceContext { + return EnforceContext{ + RType: "r" + suffix, + PType: "p" + suffix, + EType: "e" + suffix, + MType: "m" + suffix, + } +} + +func (e *Enforcer) invalidateMatcherMap() { + e.matcherMap = sync.Map{} +} + +// enforce use a custom matcher to decides whether a "subject" can access a "object" with the operation "action", input parameters are usually: (matcher, sub, obj, act), use model matcher by default when matcher is "". +func (e *Enforcer) enforce(matcher string, explains *[]string, rvals ...interface{}) (ok bool, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic: %v\n%s", r, debug.Stack()) + } + }() + + if !e.enabled { + return true, nil + } + + functions := e.fm.GetFunctions() + if _, ok := e.model["g"]; ok { + for key, ast := range e.model["g"] { + rm := ast.RM + functions[key] = util.GenerateGFunction(rm) + } + } + + var ( + rType = "r" + pType = "p" + eType = "e" + mType = "m" + ) + if len(rvals) != 0 { + switch rvals[0].(type) { + case EnforceContext: + enforceContext := rvals[0].(EnforceContext) + rType = enforceContext.RType + pType = enforceContext.PType + eType = enforceContext.EType + mType = enforceContext.MType + rvals = rvals[1:] + default: + break + } + } + + var expString string + if matcher == "" { + expString = e.model["m"][mType].Value + } else { + expString = util.RemoveComments(util.EscapeAssertion(matcher)) + } + + rTokens := make(map[string]int, len(e.model["r"][rType].Tokens)) + for i, token := range e.model["r"][rType].Tokens { + rTokens[token] = i + } + pTokens := make(map[string]int, len(e.model["p"][pType].Tokens)) + for i, token := range e.model["p"][pType].Tokens { + pTokens[token] = i + } + + parameters := enforceParameters{ + rTokens: rTokens, + rVals: rvals, + + pTokens: pTokens, + } + + hasEval := util.HasEval(expString) + if hasEval { + functions["eval"] = generateEvalFunction(functions, ¶meters) + } + var expression *govaluate.EvaluableExpression + expression, err = e.getAndStoreMatcherExpression(hasEval, expString, functions) + if err != nil { + return false, err + } + + if len(e.model["r"][rType].Tokens) != len(rvals) { + return false, fmt.Errorf( + "invalid request size: expected %d, got %d, rvals: %v", + len(e.model["r"][rType].Tokens), + len(rvals), + rvals) + } + + var policyEffects []effector.Effect + var matcherResults []float64 + + var effect effector.Effect + var explainIndex int + + if policyLen := len(e.model["p"][pType].Policy); policyLen != 0 && strings.Contains(expString, pType+"_") { + policyEffects = make([]effector.Effect, policyLen) + matcherResults = make([]float64, policyLen) + + for policyIndex, pvals := range e.model["p"][pType].Policy { + // log.LogPrint("Policy Rule: ", pvals) + if len(e.model["p"][pType].Tokens) != len(pvals) { + return false, fmt.Errorf( + "invalid policy size: expected %d, got %d, pvals: %v", + len(e.model["p"][pType].Tokens), + len(pvals), + pvals) + } + + parameters.pVals = pvals + + result, err := expression.Eval(parameters) + // log.LogPrint("Result: ", result) + + if err != nil { + return false, err + } + + // set to no-match at first + matcherResults[policyIndex] = 0 + switch result := result.(type) { + case bool: + if result { + matcherResults[policyIndex] = 1 + } + case float64: + if result != 0 { + matcherResults[policyIndex] = 1 + } + default: + return false, errors.New("matcher result should be bool, int or float") + } + + if j, ok := parameters.pTokens[pType+"_eft"]; ok { + eft := parameters.pVals[j] + if eft == "allow" { + policyEffects[policyIndex] = effector.Allow + } else if eft == "deny" { + policyEffects[policyIndex] = effector.Deny + } else { + policyEffects[policyIndex] = effector.Indeterminate + } + } else { + policyEffects[policyIndex] = effector.Allow + } + + //if e.model["e"]["e"].Value == "priority(p_eft) || deny" { + // break + //} + + effect, explainIndex, err = e.eft.MergeEffects(e.model["e"][eType].Value, policyEffects, matcherResults, policyIndex, policyLen) + if err != nil { + return false, err + } + if effect != effector.Indeterminate { + break + } + } + } else { + + if hasEval && len(e.model["p"][pType].Policy) == 0 { + return false, errors.New("please make sure rule exists in policy when using eval() in matcher") + } + + policyEffects = make([]effector.Effect, 1) + matcherResults = make([]float64, 1) + matcherResults[0] = 1 + + parameters.pVals = make([]string, len(parameters.pTokens)) + + result, err := expression.Eval(parameters) + + if err != nil { + return false, err + } + + if result.(bool) { + policyEffects[0] = effector.Allow + } else { + policyEffects[0] = effector.Indeterminate + } + + effect, explainIndex, err = e.eft.MergeEffects(e.model["e"][eType].Value, policyEffects, matcherResults, 0, 1) + if err != nil { + return false, err + } + } + + var logExplains [][]string + + if explains != nil { + if len(*explains) > 0 { + logExplains = append(logExplains, *explains) + } + + if explainIndex != -1 && len(e.model["p"][pType].Policy) > explainIndex { + *explains = e.model["p"][pType].Policy[explainIndex] + logExplains = append(logExplains, *explains) + } + } + + // effect -> result + result := false + if effect == effector.Allow { + result = true + } + e.logger.LogEnforce(expString, rvals, result, logExplains) + + return result, nil +} + +func (e *Enforcer) getAndStoreMatcherExpression(hasEval bool, expString string, functions map[string]govaluate.ExpressionFunction) (*govaluate.EvaluableExpression, error) { + var expression *govaluate.EvaluableExpression + var err error + var cachedExpression, isPresent = e.matcherMap.Load(expString) + + if !hasEval && isPresent { + expression = cachedExpression.(*govaluate.EvaluableExpression) + } else { + expression, err = govaluate.NewEvaluableExpressionWithFunctions(expString, functions) + if err != nil { + return nil, err + } + e.matcherMap.Store(expString, expression) + } + return expression, nil +} + +// Enforce decides whether a "subject" can access a "object" with the operation "action", input parameters are usually: (sub, obj, act). +func (e *Enforcer) Enforce(rvals ...interface{}) (bool, error) { + return e.enforce("", nil, rvals...) +} + +// EnforceWithMatcher use a custom matcher to decides whether a "subject" can access a "object" with the operation "action", input parameters are usually: (matcher, sub, obj, act), use model matcher by default when matcher is "". +func (e *Enforcer) EnforceWithMatcher(matcher string, rvals ...interface{}) (bool, error) { + return e.enforce(matcher, nil, rvals...) +} + +// EnforceEx explain enforcement by informing matched rules +func (e *Enforcer) EnforceEx(rvals ...interface{}) (bool, []string, error) { + explain := []string{} + result, err := e.enforce("", &explain, rvals...) + return result, explain, err +} + +// EnforceExWithMatcher use a custom matcher and explain enforcement by informing matched rules +func (e *Enforcer) EnforceExWithMatcher(matcher string, rvals ...interface{}) (bool, []string, error) { + explain := []string{} + result, err := e.enforce(matcher, &explain, rvals...) + return result, explain, err +} + +// BatchEnforce enforce in batches +func (e *Enforcer) BatchEnforce(requests [][]interface{}) ([]bool, error) { + var results []bool + for _, request := range requests { + result, err := e.enforce("", nil, request...) + if err != nil { + return results, err + } + results = append(results, result) + } + return results, nil +} + +// BatchEnforceWithMatcher enforce with matcher in batches +func (e *Enforcer) BatchEnforceWithMatcher(matcher string, requests [][]interface{}) ([]bool, error) { + var results []bool + for _, request := range requests { + result, err := e.enforce(matcher, nil, request...) + if err != nil { + return results, err + } + results = append(results, result) + } + return results, nil +} + +// AddNamedMatchingFunc add MatchingFunc by ptype RoleManager +func (e *Enforcer) AddNamedMatchingFunc(ptype, name string, fn rbac.MatchingFunc) bool { + if rm, ok := e.rmMap[ptype]; ok { + rm.AddMatchingFunc(name, fn) + return true + } + return false +} + +// AddNamedDomainMatchingFunc add MatchingFunc by ptype to RoleManager +func (e *Enforcer) AddNamedDomainMatchingFunc(ptype, name string, fn rbac.MatchingFunc) bool { + if rm, ok := e.rmMap[ptype]; ok { + rm.AddDomainMatchingFunc(name, fn) + return true + } + return false +} + +// assumes bounds have already been checked +type enforceParameters struct { + rTokens map[string]int + rVals []interface{} + + pTokens map[string]int + pVals []string +} + +// implements govaluate.Parameters +func (p enforceParameters) Get(name string) (interface{}, error) { + if name == "" { + return nil, nil + } + + switch name[0] { + case 'p': + i, ok := p.pTokens[name] + if !ok { + return nil, errors.New("No parameter '" + name + "' found.") + } + return p.pVals[i], nil + case 'r': + i, ok := p.rTokens[name] + if !ok { + return nil, errors.New("No parameter '" + name + "' found.") + } + return p.rVals[i], nil + default: + return nil, errors.New("No parameter '" + name + "' found.") + } +} + +func generateEvalFunction(functions map[string]govaluate.ExpressionFunction, parameters *enforceParameters) govaluate.ExpressionFunction { + return func(args ...interface{}) (interface{}, error) { + if len(args) != 1 { + return nil, fmt.Errorf("Function eval(subrule string) expected %d arguments, but got %d", 1, len(args)) + } + + expression, ok := args[0].(string) + if !ok { + return nil, errors.New("Argument of eval(subrule string) must be a string") + } + expression = util.EscapeAssertion(expression) + expr, err := govaluate.NewEvaluableExpressionWithFunctions(expression, functions) + if err != nil { + return nil, fmt.Errorf("Error while parsing eval parameter: %s, %s", expression, err.Error()) + } + return expr.Eval(parameters) + } +} diff --git a/vendor/github.com/casbin/casbin/v2/enforcer_cached.go b/vendor/github.com/casbin/casbin/v2/enforcer_cached.go new file mode 100644 index 00000000..97cb3ba2 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/enforcer_cached.go @@ -0,0 +1,170 @@ +// Copyright 2018 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +import ( + "strings" + "sync" + "sync/atomic" + + "github.com/casbin/casbin/v2/persist/cache" +) + +// CachedEnforcer wraps Enforcer and provides decision cache +type CachedEnforcer struct { + *Enforcer + expireTime uint + cache cache.Cache + enableCache int32 + locker *sync.RWMutex +} + +type CacheableParam interface { + GetCacheKey() string +} + +// NewCachedEnforcer creates a cached enforcer via file or DB. +func NewCachedEnforcer(params ...interface{}) (*CachedEnforcer, error) { + e := &CachedEnforcer{} + var err error + e.Enforcer, err = NewEnforcer(params...) + if err != nil { + return nil, err + } + + e.enableCache = 1 + cache := cache.DefaultCache(make(map[string]bool)) + e.cache = &cache + e.locker = new(sync.RWMutex) + return e, nil +} + +// EnableCache determines whether to enable cache on Enforce(). When enableCache is enabled, cached result (true | false) will be returned for previous decisions. +func (e *CachedEnforcer) EnableCache(enableCache bool) { + var enabled int32 + if enableCache { + enabled = 1 + } + atomic.StoreInt32(&e.enableCache, enabled) +} + +// Enforce decides whether a "subject" can access a "object" with the operation "action", input parameters are usually: (sub, obj, act). +// if rvals is not string , ingore the cache +func (e *CachedEnforcer) Enforce(rvals ...interface{}) (bool, error) { + if atomic.LoadInt32(&e.enableCache) == 0 { + return e.Enforcer.Enforce(rvals...) + } + + key, ok := e.getKey(rvals...) + if !ok { + return e.Enforcer.Enforce(rvals...) + } + + if res, err := e.getCachedResult(key); err == nil { + return res, nil + } else if err != cache.ErrNoSuchKey { + return res, err + } + + res, err := e.Enforcer.Enforce(rvals...) + if err != nil { + return false, err + } + + err = e.setCachedResult(key, res, e.expireTime) + return res, err +} + +func (e *CachedEnforcer) LoadPolicy() error { + if atomic.LoadInt32(&e.enableCache) != 0 { + if err := e.cache.Clear(); err != nil { + return err + } + } + return e.Enforcer.LoadPolicy() +} + +func (e *CachedEnforcer) RemovePolicy(params ...interface{}) (bool, error) { + if atomic.LoadInt32(&e.enableCache) != 0 { + key, ok := e.getKey(params...) + if ok { + if err := e.cache.Delete(key); err != nil && err != cache.ErrNoSuchKey { + return false, err + } + } + } + return e.Enforcer.RemovePolicy(params...) +} + +func (e *CachedEnforcer) RemovePolicies(rules [][]string) (bool, error) { + if len(rules) != 0 { + if atomic.LoadInt32(&e.enableCache) != 0 { + irule := make([]interface{}, len(rules[0])) + for _, rule := range rules { + for i, param := range rule { + irule[i] = param + } + key, _ := e.getKey(irule...) + if err := e.cache.Delete(key); err != nil && err != cache.ErrNoSuchKey { + return false, err + } + } + } + } + return e.Enforcer.RemovePolicies(rules) +} + +func (e *CachedEnforcer) getCachedResult(key string) (res bool, err error) { + e.locker.RLock() + defer e.locker.RUnlock() + return e.cache.Get(key) +} + +func (e *CachedEnforcer) SetExpireTime(expireTime uint) { + e.expireTime = expireTime +} + +func (e *CachedEnforcer) SetCache(c cache.Cache) { + e.cache = c +} + +func (e *CachedEnforcer) setCachedResult(key string, res bool, extra ...interface{}) error { + e.locker.Lock() + defer e.locker.Unlock() + return e.cache.Set(key, res, extra...) +} + +func (e *CachedEnforcer) getKey(params ...interface{}) (string, bool) { + key := strings.Builder{} + for _, param := range params { + switch typedParam := param.(type) { + case string: + key.WriteString(typedParam) + case CacheableParam: + key.WriteString(typedParam.GetCacheKey()) + default: + return "", false + } + key.WriteString("$$") + } + return key.String(), true +} + +// InvalidateCache deletes all the existing cached decisions. +func (e *CachedEnforcer) InvalidateCache() error { + e.locker.Lock() + defer e.locker.Unlock() + return e.cache.Clear() +} diff --git a/vendor/github.com/casbin/casbin/v2/enforcer_distributed.go b/vendor/github.com/casbin/casbin/v2/enforcer_distributed.go new file mode 100644 index 00000000..c8743c4e --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/enforcer_distributed.go @@ -0,0 +1,221 @@ +package casbin + +import ( + "github.com/casbin/casbin/v2/model" + "github.com/casbin/casbin/v2/persist" +) + +// DistributedEnforcer wraps SyncedEnforcer for dispatcher. +type DistributedEnforcer struct { + *SyncedEnforcer +} + +func NewDistributedEnforcer(params ...interface{}) (*DistributedEnforcer, error) { + e := &DistributedEnforcer{} + var err error + e.SyncedEnforcer, err = NewSyncedEnforcer(params...) + if err != nil { + return nil, err + } + + return e, nil +} + +// SetDispatcher sets the current dispatcher. +func (e *DistributedEnforcer) SetDispatcher(dispatcher persist.Dispatcher) { + e.dispatcher = dispatcher +} + +// AddPoliciesSelf provides a method for dispatcher to add authorization rules to the current policy. +// The function returns the rules affected and error. +func (d *DistributedEnforcer) AddPoliciesSelf(shouldPersist func() bool, sec string, ptype string, rules [][]string) (affected [][]string, err error) { + d.m.Lock() + defer d.m.Unlock() + if shouldPersist != nil && shouldPersist() { + var noExistsPolicy [][]string + for _, rule := range rules { + if !d.model.HasPolicy(sec, ptype, rule) { + noExistsPolicy = append(noExistsPolicy, rule) + } + } + + if err := d.adapter.(persist.BatchAdapter).AddPolicies(sec, ptype, noExistsPolicy); err != nil { + if err.Error() != notImplemented { + return nil, err + } + } + } + + affected = d.model.AddPoliciesWithAffected(sec, ptype, rules) + + if sec == "g" { + err := d.BuildIncrementalRoleLinks(model.PolicyAdd, ptype, affected) + if err != nil { + return affected, err + } + } + + return affected, nil +} + +// RemovePoliciesSelf provides a method for dispatcher to remove a set of rules from current policy. +// The function returns the rules affected and error. +func (d *DistributedEnforcer) RemovePoliciesSelf(shouldPersist func() bool, sec string, ptype string, rules [][]string) (affected [][]string, err error) { + d.m.Lock() + defer d.m.Unlock() + if shouldPersist != nil && shouldPersist() { + if err := d.adapter.(persist.BatchAdapter).RemovePolicies(sec, ptype, rules); err != nil { + if err.Error() != notImplemented { + return nil, err + } + } + } + + affected = d.model.RemovePoliciesWithAffected(sec, ptype, rules) + + if sec == "g" { + err := d.BuildIncrementalRoleLinks(model.PolicyRemove, ptype, affected) + if err != nil { + return affected, err + } + } + + return affected, err +} + +// RemoveFilteredPolicySelf provides a method for dispatcher to remove an authorization rule from the current policy, field filters can be specified. +// The function returns the rules affected and error. +func (d *DistributedEnforcer) RemoveFilteredPolicySelf(shouldPersist func() bool, sec string, ptype string, fieldIndex int, fieldValues ...string) (affected [][]string, err error) { + d.m.Lock() + defer d.m.Unlock() + if shouldPersist != nil && shouldPersist() { + if err := d.adapter.RemoveFilteredPolicy(sec, ptype, fieldIndex, fieldValues...); err != nil { + if err.Error() != notImplemented { + return nil, err + } + } + } + + _, affected = d.model.RemoveFilteredPolicy(sec, ptype, fieldIndex, fieldValues...) + + if sec == "g" { + err := d.BuildIncrementalRoleLinks(model.PolicyRemove, ptype, affected) + if err != nil { + return affected, err + } + } + + return affected, nil +} + +// ClearPolicySelf provides a method for dispatcher to clear all rules from the current policy. +func (d *DistributedEnforcer) ClearPolicySelf(shouldPersist func() bool) error { + d.m.Lock() + defer d.m.Unlock() + if shouldPersist != nil && shouldPersist() { + err := d.adapter.SavePolicy(nil) + if err != nil { + return err + } + } + + d.model.ClearPolicy() + + return nil +} + +// UpdatePolicySelf provides a method for dispatcher to update an authorization rule from the current policy. +func (d *DistributedEnforcer) UpdatePolicySelf(shouldPersist func() bool, sec string, ptype string, oldRule, newRule []string) (affected bool, err error) { + d.m.Lock() + defer d.m.Unlock() + if shouldPersist != nil && shouldPersist() { + err := d.adapter.(persist.UpdatableAdapter).UpdatePolicy(sec, ptype, oldRule, newRule) + if err != nil { + return false, err + } + } + + ruleUpdated := d.model.UpdatePolicy(sec, ptype, oldRule, newRule) + if !ruleUpdated { + return ruleUpdated, nil + } + + if sec == "g" { + err := d.BuildIncrementalRoleLinks(model.PolicyRemove, ptype, [][]string{oldRule}) // remove the old rule + if err != nil { + return ruleUpdated, err + } + err = d.BuildIncrementalRoleLinks(model.PolicyAdd, ptype, [][]string{newRule}) // add the new rule + if err != nil { + return ruleUpdated, err + } + } + + return ruleUpdated, nil +} + +// UpdatePoliciesSelf provides a method for dispatcher to update a set of authorization rules from the current policy. +func (d *DistributedEnforcer) UpdatePoliciesSelf(shouldPersist func() bool, sec string, ptype string, oldRules, newRules [][]string) (affected bool, err error) { + d.m.Lock() + defer d.m.Unlock() + if shouldPersist != nil && shouldPersist() { + err := d.adapter.(persist.UpdatableAdapter).UpdatePolicies(sec, ptype, oldRules, newRules) + if err != nil { + return false, err + } + } + + ruleUpdated := d.model.UpdatePolicies(sec, ptype, oldRules, newRules) + if !ruleUpdated { + return ruleUpdated, nil + } + + if sec == "g" { + err := d.BuildIncrementalRoleLinks(model.PolicyRemove, ptype, oldRules) // remove the old rule + if err != nil { + return ruleUpdated, err + } + err = d.BuildIncrementalRoleLinks(model.PolicyAdd, ptype, newRules) // add the new rule + if err != nil { + return ruleUpdated, err + } + } + + return ruleUpdated, nil +} + +// UpdateFilteredPoliciesSelf provides a method for dispatcher to update a set of authorization rules from the current policy. +func (d *DistributedEnforcer) UpdateFilteredPoliciesSelf(shouldPersist func() bool, sec string, ptype string, newRules [][]string, fieldIndex int, fieldValues ...string) (bool, error) { + d.m.Lock() + defer d.m.Unlock() + var ( + oldRules [][]string + err error + ) + if shouldPersist != nil && shouldPersist() { + oldRules, err = d.adapter.(persist.UpdatableAdapter).UpdateFilteredPolicies(sec, ptype, newRules, fieldIndex, fieldValues...) + if err != nil { + return false, err + } + } + + ruleChanged := !d.model.RemovePolicies(sec, ptype, oldRules) + d.model.AddPolicies(sec, ptype, newRules) + ruleChanged = ruleChanged && len(newRules) != 0 + if !ruleChanged { + return ruleChanged, nil + } + + if sec == "g" { + err := d.BuildIncrementalRoleLinks(model.PolicyRemove, ptype, oldRules) // remove the old rule + if err != nil { + return ruleChanged, err + } + err = d.BuildIncrementalRoleLinks(model.PolicyAdd, ptype, newRules) // add the new rule + if err != nil { + return ruleChanged, err + } + } + + return true, nil +} diff --git a/vendor/github.com/casbin/casbin/v2/enforcer_interface.go b/vendor/github.com/casbin/casbin/v2/enforcer_interface.go new file mode 100644 index 00000000..bcc1ff8f --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/enforcer_interface.go @@ -0,0 +1,164 @@ +// Copyright 2019 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +import ( + "github.com/Knetic/govaluate" + "github.com/casbin/casbin/v2/effector" + "github.com/casbin/casbin/v2/model" + "github.com/casbin/casbin/v2/persist" + "github.com/casbin/casbin/v2/rbac" +) + +var _ IEnforcer = &Enforcer{} +var _ IEnforcer = &SyncedEnforcer{} +var _ IEnforcer = &CachedEnforcer{} + +// IEnforcer is the API interface of Enforcer +type IEnforcer interface { + /* Enforcer API */ + InitWithFile(modelPath string, policyPath string) error + InitWithAdapter(modelPath string, adapter persist.Adapter) error + InitWithModelAndAdapter(m model.Model, adapter persist.Adapter) error + LoadModel() error + GetModel() model.Model + SetModel(m model.Model) + GetAdapter() persist.Adapter + SetAdapter(adapter persist.Adapter) + SetWatcher(watcher persist.Watcher) error + GetRoleManager() rbac.RoleManager + SetRoleManager(rm rbac.RoleManager) + SetEffector(eft effector.Effector) + ClearPolicy() + LoadPolicy() error + LoadFilteredPolicy(filter interface{}) error + LoadIncrementalFilteredPolicy(filter interface{}) error + IsFiltered() bool + SavePolicy() error + EnableEnforce(enable bool) + EnableLog(enable bool) + EnableAutoNotifyWatcher(enable bool) + EnableAutoSave(autoSave bool) + EnableAutoBuildRoleLinks(autoBuildRoleLinks bool) + BuildRoleLinks() error + Enforce(rvals ...interface{}) (bool, error) + EnforceWithMatcher(matcher string, rvals ...interface{}) (bool, error) + EnforceEx(rvals ...interface{}) (bool, []string, error) + EnforceExWithMatcher(matcher string, rvals ...interface{}) (bool, []string, error) + BatchEnforce(requests [][]interface{}) ([]bool, error) + BatchEnforceWithMatcher(matcher string, requests [][]interface{}) ([]bool, error) + + /* RBAC API */ + GetRolesForUser(name string, domain ...string) ([]string, error) + GetUsersForRole(name string, domain ...string) ([]string, error) + HasRoleForUser(name string, role string, domain ...string) (bool, error) + AddRoleForUser(user string, role string, domain ...string) (bool, error) + AddPermissionForUser(user string, permission ...string) (bool, error) + AddPermissionsForUser(user string, permissions ...[]string) (bool, error) + DeletePermissionForUser(user string, permission ...string) (bool, error) + DeletePermissionsForUser(user string) (bool, error) + GetPermissionsForUser(user string, domain ...string) [][]string + HasPermissionForUser(user string, permission ...string) bool + GetImplicitRolesForUser(name string, domain ...string) ([]string, error) + GetImplicitPermissionsForUser(user string, domain ...string) ([][]string, error) + GetImplicitUsersForPermission(permission ...string) ([]string, error) + DeleteRoleForUser(user string, role string, domain ...string) (bool, error) + DeleteRolesForUser(user string, domain ...string) (bool, error) + DeleteUser(user string) (bool, error) + DeleteRole(role string) (bool, error) + DeletePermission(permission ...string) (bool, error) + + /* RBAC API with domains*/ + GetUsersForRoleInDomain(name string, domain string) []string + GetRolesForUserInDomain(name string, domain string) []string + GetPermissionsForUserInDomain(user string, domain string) [][]string + AddRoleForUserInDomain(user string, role string, domain string) (bool, error) + DeleteRoleForUserInDomain(user string, role string, domain string) (bool, error) + + /* Management API */ + GetAllSubjects() []string + GetAllNamedSubjects(ptype string) []string + GetAllObjects() []string + GetAllNamedObjects(ptype string) []string + GetAllActions() []string + GetAllNamedActions(ptype string) []string + GetAllRoles() []string + GetAllNamedRoles(ptype string) []string + GetPolicy() [][]string + GetFilteredPolicy(fieldIndex int, fieldValues ...string) [][]string + GetNamedPolicy(ptype string) [][]string + GetFilteredNamedPolicy(ptype string, fieldIndex int, fieldValues ...string) [][]string + GetGroupingPolicy() [][]string + GetFilteredGroupingPolicy(fieldIndex int, fieldValues ...string) [][]string + GetNamedGroupingPolicy(ptype string) [][]string + GetFilteredNamedGroupingPolicy(ptype string, fieldIndex int, fieldValues ...string) [][]string + HasPolicy(params ...interface{}) bool + HasNamedPolicy(ptype string, params ...interface{}) bool + AddPolicy(params ...interface{}) (bool, error) + AddPolicies(rules [][]string) (bool, error) + AddNamedPolicy(ptype string, params ...interface{}) (bool, error) + AddNamedPolicies(ptype string, rules [][]string) (bool, error) + RemovePolicy(params ...interface{}) (bool, error) + RemovePolicies(rules [][]string) (bool, error) + RemoveFilteredPolicy(fieldIndex int, fieldValues ...string) (bool, error) + RemoveNamedPolicy(ptype string, params ...interface{}) (bool, error) + RemoveNamedPolicies(ptype string, rules [][]string) (bool, error) + RemoveFilteredNamedPolicy(ptype string, fieldIndex int, fieldValues ...string) (bool, error) + HasGroupingPolicy(params ...interface{}) bool + HasNamedGroupingPolicy(ptype string, params ...interface{}) bool + AddGroupingPolicy(params ...interface{}) (bool, error) + AddGroupingPolicies(rules [][]string) (bool, error) + AddNamedGroupingPolicy(ptype string, params ...interface{}) (bool, error) + AddNamedGroupingPolicies(ptype string, rules [][]string) (bool, error) + RemoveGroupingPolicy(params ...interface{}) (bool, error) + RemoveGroupingPolicies(rules [][]string) (bool, error) + RemoveFilteredGroupingPolicy(fieldIndex int, fieldValues ...string) (bool, error) + RemoveNamedGroupingPolicy(ptype string, params ...interface{}) (bool, error) + RemoveNamedGroupingPolicies(ptype string, rules [][]string) (bool, error) + RemoveFilteredNamedGroupingPolicy(ptype string, fieldIndex int, fieldValues ...string) (bool, error) + AddFunction(name string, function govaluate.ExpressionFunction) + + UpdatePolicy(oldPolicy []string, newPolicy []string) (bool, error) + UpdatePolicies(oldPolicies [][]string, newPolicies [][]string) (bool, error) + UpdateFilteredPolicies(newPolicies [][]string, fieldIndex int, fieldValues ...string) (bool, error) + + UpdateGroupingPolicy(oldRule []string, newRule []string) (bool, error) + UpdateGroupingPolicies(oldRules [][]string, newRules [][]string) (bool, error) + + /* Management API with autoNotifyWatcher disabled */ + SelfAddPolicy(sec string, ptype string, rule []string) (bool, error) + SelfAddPolicies(sec string, ptype string, rules [][]string) (bool, error) + SelfRemovePolicy(sec string, ptype string, rule []string) (bool, error) + SelfRemovePolicies(sec string, ptype string, rules [][]string) (bool, error) + SelfRemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) (bool, error) + SelfUpdatePolicy(sec string, ptype string, oldRule, newRule []string) (bool, error) + SelfUpdatePolicies(sec string, ptype string, oldRules, newRules [][]string) (bool, error) +} + +var _ IDistributedEnforcer = &DistributedEnforcer{} + +// IDistributedEnforcer defines dispatcher enforcer. +type IDistributedEnforcer interface { + IEnforcer + SetDispatcher(dispatcher persist.Dispatcher) + /* Management API for DistributedEnforcer*/ + AddPoliciesSelf(shouldPersist func() bool, sec string, ptype string, rules [][]string) (affected [][]string, err error) + RemovePoliciesSelf(shouldPersist func() bool, sec string, ptype string, rules [][]string) (affected [][]string, err error) + RemoveFilteredPolicySelf(shouldPersist func() bool, sec string, ptype string, fieldIndex int, fieldValues ...string) (affected [][]string, err error) + ClearPolicySelf(shouldPersist func() bool) error + UpdatePolicySelf(shouldPersist func() bool, sec string, ptype string, oldRule, newRule []string) (affected bool, err error) + UpdatePoliciesSelf(shouldPersist func() bool, sec string, ptype string, oldRules, newRules [][]string) (affected bool, err error) + UpdateFilteredPoliciesSelf(shouldPersist func() bool, sec string, ptype string, newRules [][]string, fieldIndex int, fieldValues ...string) (bool, error) +} diff --git a/vendor/github.com/casbin/casbin/v2/enforcer_synced.go b/vendor/github.com/casbin/casbin/v2/enforcer_synced.go new file mode 100644 index 00000000..d087a43c --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/enforcer_synced.go @@ -0,0 +1,598 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +import ( + "sync" + "sync/atomic" + "time" + + "github.com/Knetic/govaluate" + + "github.com/casbin/casbin/v2/persist" + "github.com/casbin/casbin/v2/rbac" + defaultrolemanager "github.com/casbin/casbin/v2/rbac/default-role-manager" +) + +// SyncedEnforcer wraps Enforcer and provides synchronized access +type SyncedEnforcer struct { + *Enforcer + m sync.RWMutex + stopAutoLoad chan struct{} + autoLoadRunning int32 +} + +// NewSyncedEnforcer creates a synchronized enforcer via file or DB. +func NewSyncedEnforcer(params ...interface{}) (*SyncedEnforcer, error) { + e := &SyncedEnforcer{} + var err error + e.Enforcer, err = NewEnforcer(params...) + if err != nil { + return nil, err + } + + e.stopAutoLoad = make(chan struct{}, 1) + e.autoLoadRunning = 0 + return e, nil +} + +// GetLock return the private RWMutex lock +func (e *SyncedEnforcer) GetLock() *sync.RWMutex { + return &e.m +} + +// IsAutoLoadingRunning check if SyncedEnforcer is auto loading policies +func (e *SyncedEnforcer) IsAutoLoadingRunning() bool { + return atomic.LoadInt32(&(e.autoLoadRunning)) != 0 +} + +// StartAutoLoadPolicy starts a go routine that will every specified duration call LoadPolicy +func (e *SyncedEnforcer) StartAutoLoadPolicy(d time.Duration) { + // Don't start another goroutine if there is already one running + if !atomic.CompareAndSwapInt32(&e.autoLoadRunning, 0, 1) { + return + } + + ticker := time.NewTicker(d) + go func() { + defer func() { + ticker.Stop() + atomic.StoreInt32(&(e.autoLoadRunning), int32(0)) + }() + n := 1 + for { + select { + case <-ticker.C: + // error intentionally ignored + _ = e.LoadPolicy() + // Uncomment this line to see when the policy is loaded. + // log.Print("Load policy for time: ", n) + n++ + case <-e.stopAutoLoad: + return + } + } + }() +} + +// StopAutoLoadPolicy causes the go routine to exit. +func (e *SyncedEnforcer) StopAutoLoadPolicy() { + if e.IsAutoLoadingRunning() { + e.stopAutoLoad <- struct{}{} + } +} + +// SetWatcher sets the current watcher. +func (e *SyncedEnforcer) SetWatcher(watcher persist.Watcher) error { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.SetWatcher(watcher) +} + +// LoadModel reloads the model from the model CONF file. +func (e *SyncedEnforcer) LoadModel() error { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.LoadModel() +} + +// ClearPolicy clears all policy. +func (e *SyncedEnforcer) ClearPolicy() { + e.m.Lock() + defer e.m.Unlock() + e.Enforcer.ClearPolicy() +} + +// LoadPolicy reloads the policy from file/database. +func (e *SyncedEnforcer) LoadPolicy() error { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.LoadPolicy() +} + +// LoadPolicyFast is not blocked when adapter calls LoadPolicy. +func (e *SyncedEnforcer) LoadPolicyFast() error { + e.m.RLock() + newModel := e.model.Copy() + e.m.RUnlock() + + newModel.ClearPolicy() + newRmMap := map[string]rbac.RoleManager{} + var err error + + if err = e.adapter.LoadPolicy(newModel); err != nil && err.Error() != "invalid file path, file path cannot be empty" { + return err + } + + if err = newModel.SortPoliciesBySubjectHierarchy(); err != nil { + return err + } + + if err = newModel.SortPoliciesByPriority(); err != nil { + return err + } + + if e.autoBuildRoleLinks { + for ptype := range newModel["g"] { + newRmMap[ptype] = defaultrolemanager.NewRoleManager(10) + } + err = newModel.BuildRoleLinks(newRmMap) + if err != nil { + return err + } + } + + // reduce the lock range + e.m.Lock() + defer e.m.Unlock() + e.model = newModel + e.rmMap = newRmMap + return nil +} + +// LoadFilteredPolicy reloads a filtered policy from file/database. +func (e *SyncedEnforcer) LoadFilteredPolicy(filter interface{}) error { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.LoadFilteredPolicy(filter) +} + +// LoadIncrementalFilteredPolicy reloads a filtered policy from file/database. +func (e *SyncedEnforcer) LoadIncrementalFilteredPolicy(filter interface{}) error { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.LoadIncrementalFilteredPolicy(filter) +} + +// SavePolicy saves the current policy (usually after changed with Casbin API) back to file/database. +func (e *SyncedEnforcer) SavePolicy() error { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.SavePolicy() +} + +// BuildRoleLinks manually rebuild the role inheritance relations. +func (e *SyncedEnforcer) BuildRoleLinks() error { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.BuildRoleLinks() +} + +// Enforce decides whether a "subject" can access a "object" with the operation "action", input parameters are usually: (sub, obj, act). +func (e *SyncedEnforcer) Enforce(rvals ...interface{}) (bool, error) { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.Enforce(rvals...) +} + +// EnforceWithMatcher use a custom matcher to decides whether a "subject" can access a "object" with the operation "action", input parameters are usually: (matcher, sub, obj, act), use model matcher by default when matcher is "". +func (e *SyncedEnforcer) EnforceWithMatcher(matcher string, rvals ...interface{}) (bool, error) { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.EnforceWithMatcher(matcher, rvals...) +} + +// EnforceEx explain enforcement by informing matched rules +func (e *SyncedEnforcer) EnforceEx(rvals ...interface{}) (bool, []string, error) { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.EnforceEx(rvals...) +} + +// EnforceExWithMatcher use a custom matcher and explain enforcement by informing matched rules +func (e *SyncedEnforcer) EnforceExWithMatcher(matcher string, rvals ...interface{}) (bool, []string, error) { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.EnforceExWithMatcher(matcher, rvals...) +} + +// BatchEnforce enforce in batches +func (e *SyncedEnforcer) BatchEnforce(requests [][]interface{}) ([]bool, error) { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.BatchEnforce(requests) +} + +// BatchEnforceWithMatcher enforce with matcher in batches +func (e *SyncedEnforcer) BatchEnforceWithMatcher(matcher string, requests [][]interface{}) ([]bool, error) { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.BatchEnforceWithMatcher(matcher, requests) +} + +// GetAllSubjects gets the list of subjects that show up in the current policy. +func (e *SyncedEnforcer) GetAllSubjects() []string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetAllSubjects() +} + +// GetAllNamedSubjects gets the list of subjects that show up in the current named policy. +func (e *SyncedEnforcer) GetAllNamedSubjects(ptype string) []string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetAllNamedSubjects(ptype) +} + +// GetAllObjects gets the list of objects that show up in the current policy. +func (e *SyncedEnforcer) GetAllObjects() []string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetAllObjects() +} + +// GetAllNamedObjects gets the list of objects that show up in the current named policy. +func (e *SyncedEnforcer) GetAllNamedObjects(ptype string) []string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetAllNamedObjects(ptype) +} + +// GetAllActions gets the list of actions that show up in the current policy. +func (e *SyncedEnforcer) GetAllActions() []string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetAllActions() +} + +// GetAllNamedActions gets the list of actions that show up in the current named policy. +func (e *SyncedEnforcer) GetAllNamedActions(ptype string) []string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetAllNamedActions(ptype) +} + +// GetAllRoles gets the list of roles that show up in the current policy. +func (e *SyncedEnforcer) GetAllRoles() []string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetAllRoles() +} + +// GetAllNamedRoles gets the list of roles that show up in the current named policy. +func (e *SyncedEnforcer) GetAllNamedRoles(ptype string) []string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetAllNamedRoles(ptype) +} + +// GetPolicy gets all the authorization rules in the policy. +func (e *SyncedEnforcer) GetPolicy() [][]string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetPolicy() +} + +// GetFilteredPolicy gets all the authorization rules in the policy, field filters can be specified. +func (e *SyncedEnforcer) GetFilteredPolicy(fieldIndex int, fieldValues ...string) [][]string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetFilteredPolicy(fieldIndex, fieldValues...) +} + +// GetNamedPolicy gets all the authorization rules in the named policy. +func (e *SyncedEnforcer) GetNamedPolicy(ptype string) [][]string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetNamedPolicy(ptype) +} + +// GetFilteredNamedPolicy gets all the authorization rules in the named policy, field filters can be specified. +func (e *SyncedEnforcer) GetFilteredNamedPolicy(ptype string, fieldIndex int, fieldValues ...string) [][]string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetFilteredNamedPolicy(ptype, fieldIndex, fieldValues...) +} + +// GetGroupingPolicy gets all the role inheritance rules in the policy. +func (e *SyncedEnforcer) GetGroupingPolicy() [][]string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetGroupingPolicy() +} + +// GetFilteredGroupingPolicy gets all the role inheritance rules in the policy, field filters can be specified. +func (e *SyncedEnforcer) GetFilteredGroupingPolicy(fieldIndex int, fieldValues ...string) [][]string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetFilteredGroupingPolicy(fieldIndex, fieldValues...) +} + +// GetNamedGroupingPolicy gets all the role inheritance rules in the policy. +func (e *SyncedEnforcer) GetNamedGroupingPolicy(ptype string) [][]string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetNamedGroupingPolicy(ptype) +} + +// GetFilteredNamedGroupingPolicy gets all the role inheritance rules in the policy, field filters can be specified. +func (e *SyncedEnforcer) GetFilteredNamedGroupingPolicy(ptype string, fieldIndex int, fieldValues ...string) [][]string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetFilteredNamedGroupingPolicy(ptype, fieldIndex, fieldValues...) +} + +// HasPolicy determines whether an authorization rule exists. +func (e *SyncedEnforcer) HasPolicy(params ...interface{}) bool { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.HasPolicy(params...) +} + +// HasNamedPolicy determines whether a named authorization rule exists. +func (e *SyncedEnforcer) HasNamedPolicy(ptype string, params ...interface{}) bool { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.HasNamedPolicy(ptype, params...) +} + +// AddPolicy adds an authorization rule to the current policy. +// If the rule already exists, the function returns false and the rule will not be added. +// Otherwise the function returns true by adding the new rule. +func (e *SyncedEnforcer) AddPolicy(params ...interface{}) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.AddPolicy(params...) +} + +// AddPolicies adds authorization rules to the current policy. +// If the rule already exists, the function returns false for the corresponding rule and the rule will not be added. +// Otherwise the function returns true for the corresponding rule by adding the new rule. +func (e *SyncedEnforcer) AddPolicies(rules [][]string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.AddPolicies(rules) +} + +// AddNamedPolicy adds an authorization rule to the current named policy. +// If the rule already exists, the function returns false and the rule will not be added. +// Otherwise the function returns true by adding the new rule. +func (e *SyncedEnforcer) AddNamedPolicy(ptype string, params ...interface{}) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.AddNamedPolicy(ptype, params...) +} + +// AddNamedPolicies adds authorization rules to the current named policy. +// If the rule already exists, the function returns false for the corresponding rule and the rule will not be added. +// Otherwise the function returns true for the corresponding by adding the new rule. +func (e *SyncedEnforcer) AddNamedPolicies(ptype string, rules [][]string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.AddNamedPolicies(ptype, rules) +} + +// RemovePolicy removes an authorization rule from the current policy. +func (e *SyncedEnforcer) RemovePolicy(params ...interface{}) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.RemovePolicy(params...) +} + +// UpdatePolicy updates an authorization rule from the current policy. +func (e *SyncedEnforcer) UpdatePolicy(oldPolicy []string, newPolicy []string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.UpdatePolicy(oldPolicy, newPolicy) +} + +func (e *SyncedEnforcer) UpdateNamedPolicy(ptype string, p1 []string, p2 []string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.UpdateNamedPolicy(ptype, p1, p2) +} + +// UpdatePolicies updates authorization rules from the current policies. +func (e *SyncedEnforcer) UpdatePolicies(oldPolices [][]string, newPolicies [][]string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.UpdatePolicies(oldPolices, newPolicies) +} + +func (e *SyncedEnforcer) UpdateNamedPolicies(ptype string, p1 [][]string, p2 [][]string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.UpdateNamedPolicies(ptype, p1, p2) +} + +func (e *SyncedEnforcer) UpdateFilteredPolicies(newPolicies [][]string, fieldIndex int, fieldValues ...string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.UpdateFilteredPolicies(newPolicies, fieldIndex, fieldValues...) +} + +func (e *SyncedEnforcer) UpdateFilteredNamedPolicies(ptype string, newPolicies [][]string, fieldIndex int, fieldValues ...string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.UpdateFilteredNamedPolicies(ptype, newPolicies, fieldIndex, fieldValues...) +} + +// RemovePolicies removes authorization rules from the current policy. +func (e *SyncedEnforcer) RemovePolicies(rules [][]string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.RemovePolicies(rules) +} + +// RemoveFilteredPolicy removes an authorization rule from the current policy, field filters can be specified. +func (e *SyncedEnforcer) RemoveFilteredPolicy(fieldIndex int, fieldValues ...string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.RemoveFilteredPolicy(fieldIndex, fieldValues...) +} + +// RemoveNamedPolicy removes an authorization rule from the current named policy. +func (e *SyncedEnforcer) RemoveNamedPolicy(ptype string, params ...interface{}) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.RemoveNamedPolicy(ptype, params...) +} + +// RemoveNamedPolicies removes authorization rules from the current named policy. +func (e *SyncedEnforcer) RemoveNamedPolicies(ptype string, rules [][]string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.RemoveNamedPolicies(ptype, rules) +} + +// RemoveFilteredNamedPolicy removes an authorization rule from the current named policy, field filters can be specified. +func (e *SyncedEnforcer) RemoveFilteredNamedPolicy(ptype string, fieldIndex int, fieldValues ...string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.RemoveFilteredNamedPolicy(ptype, fieldIndex, fieldValues...) +} + +// HasGroupingPolicy determines whether a role inheritance rule exists. +func (e *SyncedEnforcer) HasGroupingPolicy(params ...interface{}) bool { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.HasGroupingPolicy(params...) +} + +// HasNamedGroupingPolicy determines whether a named role inheritance rule exists. +func (e *SyncedEnforcer) HasNamedGroupingPolicy(ptype string, params ...interface{}) bool { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.HasNamedGroupingPolicy(ptype, params...) +} + +// AddGroupingPolicy adds a role inheritance rule to the current policy. +// If the rule already exists, the function returns false and the rule will not be added. +// Otherwise the function returns true by adding the new rule. +func (e *SyncedEnforcer) AddGroupingPolicy(params ...interface{}) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.AddGroupingPolicy(params...) +} + +// AddGroupingPolicies adds role inheritance rulea to the current policy. +// If the rule already exists, the function returns false for the corresponding policy rule and the rule will not be added. +// Otherwise the function returns true for the corresponding policy rule by adding the new rule. +func (e *SyncedEnforcer) AddGroupingPolicies(rules [][]string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.AddGroupingPolicies(rules) +} + +// AddNamedGroupingPolicy adds a named role inheritance rule to the current policy. +// If the rule already exists, the function returns false and the rule will not be added. +// Otherwise the function returns true by adding the new rule. +func (e *SyncedEnforcer) AddNamedGroupingPolicy(ptype string, params ...interface{}) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.AddNamedGroupingPolicy(ptype, params...) +} + +// AddNamedGroupingPolicies adds named role inheritance rules to the current policy. +// If the rule already exists, the function returns false for the corresponding policy rule and the rule will not be added. +// Otherwise the function returns true for the corresponding policy rule by adding the new rule. +func (e *SyncedEnforcer) AddNamedGroupingPolicies(ptype string, rules [][]string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.AddNamedGroupingPolicies(ptype, rules) +} + +// RemoveGroupingPolicy removes a role inheritance rule from the current policy. +func (e *SyncedEnforcer) RemoveGroupingPolicy(params ...interface{}) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.RemoveGroupingPolicy(params...) +} + +// RemoveGroupingPolicies removes role inheritance rules from the current policy. +func (e *SyncedEnforcer) RemoveGroupingPolicies(rules [][]string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.RemoveGroupingPolicies(rules) +} + +// RemoveFilteredGroupingPolicy removes a role inheritance rule from the current policy, field filters can be specified. +func (e *SyncedEnforcer) RemoveFilteredGroupingPolicy(fieldIndex int, fieldValues ...string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.RemoveFilteredGroupingPolicy(fieldIndex, fieldValues...) +} + +// RemoveNamedGroupingPolicy removes a role inheritance rule from the current named policy. +func (e *SyncedEnforcer) RemoveNamedGroupingPolicy(ptype string, params ...interface{}) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.RemoveNamedGroupingPolicy(ptype, params...) +} + +// RemoveNamedGroupingPolicies removes role inheritance rules from the current named policy. +func (e *SyncedEnforcer) RemoveNamedGroupingPolicies(ptype string, rules [][]string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.RemoveNamedGroupingPolicies(ptype, rules) +} + +func (e *SyncedEnforcer) UpdateGroupingPolicy(oldRule []string, newRule []string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.UpdateGroupingPolicy(oldRule, newRule) +} + +func (e *SyncedEnforcer) UpdateGroupingPolicies(oldRules [][]string, newRules [][]string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.UpdateGroupingPolicies(oldRules, newRules) +} + +func (e *SyncedEnforcer) UpdateNamedGroupingPolicy(ptype string, oldRule []string, newRule []string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.UpdateNamedGroupingPolicy(ptype, oldRule, newRule) +} + +func (e *SyncedEnforcer) UpdateNamedGroupingPolicies(ptype string, oldRules [][]string, newRules [][]string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.UpdateNamedGroupingPolicies(ptype, oldRules, newRules) +} + +// RemoveFilteredNamedGroupingPolicy removes a role inheritance rule from the current named policy, field filters can be specified. +func (e *SyncedEnforcer) RemoveFilteredNamedGroupingPolicy(ptype string, fieldIndex int, fieldValues ...string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.RemoveFilteredNamedGroupingPolicy(ptype, fieldIndex, fieldValues...) +} + +// AddFunction adds a customized function. +func (e *SyncedEnforcer) AddFunction(name string, function govaluate.ExpressionFunction) { + e.m.Lock() + defer e.m.Unlock() + e.Enforcer.AddFunction(name, function) +} diff --git a/vendor/github.com/casbin/casbin/v2/errors/rbac_errors.go b/vendor/github.com/casbin/casbin/v2/errors/rbac_errors.go new file mode 100644 index 00000000..53c31515 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/errors/rbac_errors.go @@ -0,0 +1,26 @@ +// Copyright 2018 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package errors + +import "errors" + +// Global errors for rbac defined here +var ( + ERR_NAME_NOT_FOUND = errors.New("error: name does not exist") + ERR_DOMAIN_PARAMETER = errors.New("error: domain should be 1 parameter") + ERR_LINK_NOT_FOUND = errors.New("error: link between name1 and name2 does not exist") + ERR_USE_DOMAIN_PARAMETER = errors.New("error: useDomain should be 1 parameter") + INVALID_FIELDVAULES_PARAMETER = errors.New("fieldValues requires at least one parameter") +) diff --git a/vendor/github.com/casbin/casbin/v2/frontend.go b/vendor/github.com/casbin/casbin/v2/frontend.go new file mode 100644 index 00000000..f6650957 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/frontend.go @@ -0,0 +1,39 @@ +// Copyright 2020 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +import ( + "bytes" + "encoding/json" +) + +func CasbinJsGetPermissionForUser(e IEnforcer, user string) (string, error) { + model := e.GetModel() + m := map[string]interface{}{} + m["m"] = model.ToText() + policies := make([][]string, 0) + for ptype := range model["p"] { + policy := model.GetPolicy("p", ptype) + for i := range policy { + policies = append(policies, append([]string{ptype}, policy[i]...)) + } + } + m["p"] = policies + result := bytes.NewBuffer([]byte{}) + encoder := json.NewEncoder(result) + encoder.SetEscapeHTML(false) + err := encoder.Encode(m) + return result.String(), err +} diff --git a/vendor/github.com/casbin/casbin/v2/frontend_old.go b/vendor/github.com/casbin/casbin/v2/frontend_old.go new file mode 100644 index 00000000..139b164f --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/frontend_old.go @@ -0,0 +1,30 @@ +// Copyright 2021 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +import "encoding/json" + +func CasbinJsGetPermissionForUserOld(e IEnforcer, user string) ([]byte, error) { + policy, err := e.GetImplicitPermissionsForUser(user) + if err != nil { + return nil, err + } + permission := make(map[string][]string) + for i := 0; i < len(policy); i++ { + permission[policy[i][2]] = append(permission[policy[i][2]], policy[i][1]) + } + b, _ := json.Marshal(permission) + return b, nil +} diff --git a/vendor/github.com/casbin/casbin/v2/internal_api.go b/vendor/github.com/casbin/casbin/v2/internal_api.go new file mode 100644 index 00000000..8c3f97f9 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/internal_api.go @@ -0,0 +1,469 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +import ( + "fmt" + + Err "github.com/casbin/casbin/v2/errors" + "github.com/casbin/casbin/v2/model" + "github.com/casbin/casbin/v2/persist" +) + +const ( + notImplemented = "not implemented" +) + +func (e *Enforcer) shouldPersist() bool { + return e.adapter != nil && e.autoSave +} + +func (e *Enforcer) shouldNotify() bool { + return e.watcher != nil && e.autoNotifyWatcher +} + +// addPolicy adds a rule to the current policy. +func (e *Enforcer) addPolicyWithoutNotify(sec string, ptype string, rule []string) (bool, error) { + if e.dispatcher != nil && e.autoNotifyDispatcher { + return true, e.dispatcher.AddPolicies(sec, ptype, [][]string{rule}) + } + + if e.model.HasPolicy(sec, ptype, rule) { + return false, nil + } + + if e.shouldPersist() { + if err := e.adapter.AddPolicy(sec, ptype, rule); err != nil { + if err.Error() != notImplemented { + return false, err + } + } + } + + e.model.AddPolicy(sec, ptype, rule) + + if sec == "g" { + err := e.BuildIncrementalRoleLinks(model.PolicyAdd, ptype, [][]string{rule}) + if err != nil { + return true, err + } + } + + return true, nil +} + +// addPolicies adds rules to the current policy. +func (e *Enforcer) addPoliciesWithoutNotify(sec string, ptype string, rules [][]string) (bool, error) { + if e.dispatcher != nil && e.autoNotifyDispatcher { + return true, e.dispatcher.AddPolicies(sec, ptype, rules) + } + + if e.model.HasPolicies(sec, ptype, rules) { + return false, nil + } + + if e.shouldPersist() { + if err := e.adapter.(persist.BatchAdapter).AddPolicies(sec, ptype, rules); err != nil { + if err.Error() != notImplemented { + return false, err + } + } + } + + e.model.AddPolicies(sec, ptype, rules) + + if sec == "g" { + err := e.BuildIncrementalRoleLinks(model.PolicyAdd, ptype, rules) + if err != nil { + return true, err + } + } + + return true, nil +} + +// removePolicy removes a rule from the current policy. +func (e *Enforcer) removePolicyWithoutNotify(sec string, ptype string, rule []string) (bool, error) { + if e.dispatcher != nil && e.autoNotifyDispatcher { + return true, e.dispatcher.RemovePolicies(sec, ptype, [][]string{rule}) + } + + if e.shouldPersist() { + if err := e.adapter.RemovePolicy(sec, ptype, rule); err != nil { + if err.Error() != notImplemented { + return false, err + } + } + } + + ruleRemoved := e.model.RemovePolicy(sec, ptype, rule) + if !ruleRemoved { + return ruleRemoved, nil + } + + if sec == "g" { + err := e.BuildIncrementalRoleLinks(model.PolicyRemove, ptype, [][]string{rule}) + if err != nil { + return ruleRemoved, err + } + } + + return ruleRemoved, nil +} + +func (e *Enforcer) updatePolicyWithoutNotify(sec string, ptype string, oldRule []string, newRule []string) (bool, error) { + if e.dispatcher != nil && e.autoNotifyDispatcher { + return true, e.dispatcher.UpdatePolicy(sec, ptype, oldRule, newRule) + } + + if e.shouldPersist() { + if err := e.adapter.(persist.UpdatableAdapter).UpdatePolicy(sec, ptype, oldRule, newRule); err != nil { + if err.Error() != notImplemented { + return false, err + } + } + } + ruleUpdated := e.model.UpdatePolicy(sec, ptype, oldRule, newRule) + if !ruleUpdated { + return ruleUpdated, nil + } + + if sec == "g" { + err := e.BuildIncrementalRoleLinks(model.PolicyRemove, ptype, [][]string{oldRule}) // remove the old rule + if err != nil { + return ruleUpdated, err + } + err = e.BuildIncrementalRoleLinks(model.PolicyAdd, ptype, [][]string{newRule}) // add the new rule + if err != nil { + return ruleUpdated, err + } + } + + return ruleUpdated, nil +} + +func (e *Enforcer) updatePoliciesWithoutNotify(sec string, ptype string, oldRules [][]string, newRules [][]string) (bool, error) { + if len(newRules) != len(oldRules) { + return false, fmt.Errorf("the length of oldRules should be equal to the length of newRules, but got the length of oldRules is %d, the length of newRules is %d", len(oldRules), len(newRules)) + } + + if e.dispatcher != nil && e.autoNotifyDispatcher { + return true, e.dispatcher.UpdatePolicies(sec, ptype, oldRules, newRules) + } + + if e.shouldPersist() { + if err := e.adapter.(persist.UpdatableAdapter).UpdatePolicies(sec, ptype, oldRules, newRules); err != nil { + if err.Error() != notImplemented { + return false, err + } + } + } + + ruleUpdated := e.model.UpdatePolicies(sec, ptype, oldRules, newRules) + if !ruleUpdated { + return ruleUpdated, nil + } + + if sec == "g" { + err := e.BuildIncrementalRoleLinks(model.PolicyRemove, ptype, oldRules) // remove the old rules + if err != nil { + return ruleUpdated, err + } + err = e.BuildIncrementalRoleLinks(model.PolicyAdd, ptype, newRules) // add the new rules + if err != nil { + return ruleUpdated, err + } + } + + return ruleUpdated, nil +} + +// removePolicies removes rules from the current policy. +func (e *Enforcer) removePoliciesWithoutNotify(sec string, ptype string, rules [][]string) (bool, error) { + if !e.model.HasPolicies(sec, ptype, rules) { + return false, nil + } + + if e.dispatcher != nil && e.autoNotifyDispatcher { + return true, e.dispatcher.RemovePolicies(sec, ptype, rules) + } + + if e.shouldPersist() { + if err := e.adapter.(persist.BatchAdapter).RemovePolicies(sec, ptype, rules); err != nil { + if err.Error() != notImplemented { + return false, err + } + } + } + + rulesRemoved := e.model.RemovePolicies(sec, ptype, rules) + if !rulesRemoved { + return rulesRemoved, nil + } + + if sec == "g" { + err := e.BuildIncrementalRoleLinks(model.PolicyRemove, ptype, rules) + if err != nil { + return rulesRemoved, err + } + } + return rulesRemoved, nil +} + +// removeFilteredPolicy removes rules based on field filters from the current policy. +func (e *Enforcer) removeFilteredPolicyWithoutNotify(sec string, ptype string, fieldIndex int, fieldValues []string) (bool, error) { + if len(fieldValues) == 0 { + return false, Err.INVALID_FIELDVAULES_PARAMETER + } + + if e.dispatcher != nil && e.autoNotifyDispatcher { + return true, e.dispatcher.RemoveFilteredPolicy(sec, ptype, fieldIndex, fieldValues...) + } + + if e.shouldPersist() { + if err := e.adapter.RemoveFilteredPolicy(sec, ptype, fieldIndex, fieldValues...); err != nil { + if err.Error() != notImplemented { + return false, err + } + } + } + + ruleRemoved, effects := e.model.RemoveFilteredPolicy(sec, ptype, fieldIndex, fieldValues...) + if !ruleRemoved { + return ruleRemoved, nil + } + + if sec == "g" { + err := e.BuildIncrementalRoleLinks(model.PolicyRemove, ptype, effects) + if err != nil { + return ruleRemoved, err + } + } + + return ruleRemoved, nil +} + +func (e *Enforcer) updateFilteredPoliciesWithoutNotify(sec string, ptype string, newRules [][]string, fieldIndex int, fieldValues ...string) ([][]string, error) { + var ( + oldRules [][]string + err error + ) + + if e.shouldPersist() { + if oldRules, err = e.adapter.(persist.UpdatableAdapter).UpdateFilteredPolicies(sec, ptype, newRules, fieldIndex, fieldValues...); err != nil { + if err.Error() != notImplemented { + return nil, err + } + } + // For compatibility, because some adapters return oldRules containing ptype, see https://github.com/casbin/xorm-adapter/issues/49 + for i, oldRule := range oldRules { + if len(oldRules[i]) == len(e.model[sec][ptype].Tokens)+1 { + oldRules[i] = oldRule[1:] + } + } + } + + if e.dispatcher != nil && e.autoNotifyDispatcher { + return oldRules, e.dispatcher.UpdateFilteredPolicies(sec, ptype, oldRules, newRules) + } + + ruleChanged := e.model.RemovePolicies(sec, ptype, oldRules) + e.model.AddPolicies(sec, ptype, newRules) + ruleChanged = ruleChanged && len(newRules) != 0 + if !ruleChanged { + return make([][]string, 0), nil + } + + if sec == "g" { + err := e.BuildIncrementalRoleLinks(model.PolicyRemove, ptype, oldRules) // remove the old rules + if err != nil { + return oldRules, err + } + err = e.BuildIncrementalRoleLinks(model.PolicyAdd, ptype, newRules) // add the new rules + if err != nil { + return oldRules, err + } + } + + return oldRules, nil +} + +// addPolicy adds a rule to the current policy. +func (e *Enforcer) addPolicy(sec string, ptype string, rule []string) (bool, error) { + ok, err := e.addPolicyWithoutNotify(sec, ptype, rule) + if !ok || err != nil { + return ok, err + } + + if e.shouldNotify() { + var err error + if watcher, ok := e.watcher.(persist.WatcherEx); ok { + err = watcher.UpdateForAddPolicy(sec, ptype, rule...) + } else { + err = e.watcher.Update() + } + return true, err + } + + return true, nil +} + +// addPolicies adds rules to the current policy. +func (e *Enforcer) addPolicies(sec string, ptype string, rules [][]string) (bool, error) { + ok, err := e.addPoliciesWithoutNotify(sec, ptype, rules) + if !ok || err != nil { + return ok, err + } + + if e.shouldNotify() { + var err error + if watcher, ok := e.watcher.(persist.WatcherEx); ok { + err = watcher.UpdateForAddPolicies(sec, ptype, rules...) + } else { + err = e.watcher.Update() + } + return true, err + } + + return true, nil +} + +// removePolicy removes a rule from the current policy. +func (e *Enforcer) removePolicy(sec string, ptype string, rule []string) (bool, error) { + ok, err := e.removePolicyWithoutNotify(sec, ptype, rule) + if !ok || err != nil { + return ok, err + } + + if e.shouldNotify() { + var err error + if watcher, ok := e.watcher.(persist.WatcherEx); ok { + err = watcher.UpdateForRemovePolicy(sec, ptype, rule...) + } else { + err = e.watcher.Update() + } + return true, err + + } + + return true, nil +} + +func (e *Enforcer) updatePolicy(sec string, ptype string, oldRule []string, newRule []string) (bool, error) { + ok, err := e.updatePolicyWithoutNotify(sec, ptype, oldRule, newRule) + if !ok || err != nil { + return ok, err + } + + if e.shouldNotify() { + var err error + if watcher, ok := e.watcher.(persist.UpdatableWatcher); ok { + err = watcher.UpdateForUpdatePolicy(sec, ptype, oldRule, newRule) + } else { + err = e.watcher.Update() + } + return true, err + } + + return true, nil +} + +func (e *Enforcer) updatePolicies(sec string, ptype string, oldRules [][]string, newRules [][]string) (bool, error) { + ok, err := e.updatePoliciesWithoutNotify(sec, ptype, oldRules, newRules) + if !ok || err != nil { + return ok, err + } + + if e.shouldNotify() { + var err error + if watcher, ok := e.watcher.(persist.UpdatableWatcher); ok { + err = watcher.UpdateForUpdatePolicies(sec, ptype, oldRules, newRules) + } else { + err = e.watcher.Update() + } + return true, err + } + + return true, nil +} + +// removePolicies removes rules from the current policy. +func (e *Enforcer) removePolicies(sec string, ptype string, rules [][]string) (bool, error) { + ok, err := e.removePoliciesWithoutNotify(sec, ptype, rules) + if !ok || err != nil { + return ok, err + } + + if e.shouldNotify() { + var err error + if watcher, ok := e.watcher.(persist.WatcherEx); ok { + err = watcher.UpdateForRemovePolicies(sec, ptype, rules...) + } else { + err = e.watcher.Update() + } + return true, err + } + + return true, nil +} + +// removeFilteredPolicy removes rules based on field filters from the current policy. +func (e *Enforcer) removeFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues []string) (bool, error) { + ok, err := e.removeFilteredPolicyWithoutNotify(sec, ptype, fieldIndex, fieldValues) + if !ok || err != nil { + return ok, err + } + + if e.shouldNotify() { + var err error + if watcher, ok := e.watcher.(persist.WatcherEx); ok { + err = watcher.UpdateForRemoveFilteredPolicy(sec, ptype, fieldIndex, fieldValues...) + } else { + err = e.watcher.Update() + } + return true, err + } + + return true, nil +} + +func (e *Enforcer) updateFilteredPolicies(sec string, ptype string, newRules [][]string, fieldIndex int, fieldValues ...string) (bool, error) { + oldRules, err := e.updateFilteredPoliciesWithoutNotify(sec, ptype, newRules, fieldIndex, fieldValues...) + ok := len(oldRules) != 0 + if !ok || err != nil { + return ok, err + } + + if e.shouldNotify() { + var err error + if watcher, ok := e.watcher.(persist.UpdatableWatcher); ok { + err = watcher.UpdateForUpdatePolicies(sec, ptype, oldRules, newRules) + } else { + err = e.watcher.Update() + } + return true, err + } + + return true, nil +} + +func (e *Enforcer) GetFieldIndex(ptype string, field string) (int, error) { + return e.model.GetFieldIndex(ptype, field) +} + +func (e *Enforcer) SetFieldIndex(ptype string, field string, index int) { + assertion := e.model["p"][ptype] + assertion.FieldIndexMap[field] = index +} diff --git a/vendor/github.com/casbin/casbin/v2/log/default_logger.go b/vendor/github.com/casbin/casbin/v2/log/default_logger.go new file mode 100644 index 00000000..1ad3a92b --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/log/default_logger.go @@ -0,0 +1,97 @@ +// Copyright 2018 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package log + +import ( + "fmt" + "log" + "strings" +) + +// DefaultLogger is the implementation for a Logger using golang log. +type DefaultLogger struct { + enabled bool +} + +func (l *DefaultLogger) EnableLog(enable bool) { + l.enabled = enable +} + +func (l *DefaultLogger) IsEnabled() bool { + return l.enabled +} + +func (l *DefaultLogger) LogModel(model [][]string) { + if !l.enabled { + return + } + var str strings.Builder + str.WriteString("Model: ") + for _, v := range model { + str.WriteString(fmt.Sprintf("%v\n", v)) + } + + log.Println(str.String()) +} + +func (l *DefaultLogger) LogEnforce(matcher string, request []interface{}, result bool, explains [][]string) { + if !l.enabled { + return + } + + var reqStr strings.Builder + reqStr.WriteString("Request: ") + for i, rval := range request { + if i != len(request)-1 { + reqStr.WriteString(fmt.Sprintf("%v, ", rval)) + } else { + reqStr.WriteString(fmt.Sprintf("%v", rval)) + } + } + reqStr.WriteString(fmt.Sprintf(" ---> %t\n", result)) + + reqStr.WriteString("Hit Policy: ") + for i, pval := range explains { + if i != len(explains)-1 { + reqStr.WriteString(fmt.Sprintf("%v, ", pval)) + } else { + reqStr.WriteString(fmt.Sprintf("%v \n", pval)) + } + } + + log.Println(reqStr.String()) +} + +func (l *DefaultLogger) LogPolicy(policy map[string][][]string) { + if !l.enabled { + return + } + + var str strings.Builder + str.WriteString("Policy: ") + for k, v := range policy { + str.WriteString(fmt.Sprintf("%s : %v\n", k, v)) + } + + log.Println(str.String()) +} + +func (l *DefaultLogger) LogRole(roles []string) { + if !l.enabled { + return + } + + log.Println("Roles: ", strings.Join(roles, "\n")) +} diff --git a/vendor/github.com/casbin/casbin/v2/log/log_util.go b/vendor/github.com/casbin/casbin/v2/log/log_util.go new file mode 100644 index 00000000..191c8179 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/log/log_util.go @@ -0,0 +1,47 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package log + +var logger Logger = &DefaultLogger{} + +// SetLogger sets the current logger. +func SetLogger(l Logger) { + logger = l +} + +// GetLogger returns the current logger. +func GetLogger() Logger { + return logger +} + +// LogModel logs the model information. +func LogModel(model [][]string) { + logger.LogModel(model) +} + +// LogEnforce logs the enforcer information. +func LogEnforce(matcher string, request []interface{}, result bool, explains [][]string) { + logger.LogEnforce(matcher, request, result, explains) +} + +// LogRole log info related to role. +func LogRole(roles []string) { + logger.LogRole(roles) +} + +// LogPolicy logs the policy information. +func LogPolicy(policy map[string][][]string) { + logger.LogPolicy(policy) +} diff --git a/vendor/github.com/casbin/casbin/v2/log/logger.go b/vendor/github.com/casbin/casbin/v2/log/logger.go new file mode 100644 index 00000000..c60e07aa --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/log/logger.go @@ -0,0 +1,38 @@ +// Copyright 2018 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package log + +//go:generate mockgen -destination=./mocks/mock_logger.go -package=mocks github.com/casbin/casbin/v2/log Logger + +// Logger is the logging interface implementation. +type Logger interface { + // EnableLog controls whether print the message. + EnableLog(bool) + + // IsEnabled returns if logger is enabled. + IsEnabled() bool + + // LogModel log info related to model. + LogModel(model [][]string) + + // LogEnforce log info related to enforce. + LogEnforce(matcher string, request []interface{}, result bool, explains [][]string) + + // LogRole log info related to role. + LogRole(roles []string) + + // LogPolicy log info related to policy. + LogPolicy(policy map[string][][]string) +} diff --git a/vendor/github.com/casbin/casbin/v2/management_api.go b/vendor/github.com/casbin/casbin/v2/management_api.go new file mode 100644 index 00000000..ef6e570e --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/management_api.go @@ -0,0 +1,448 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +import ( + "errors" + "fmt" + "strings" + + "github.com/Knetic/govaluate" + "github.com/casbin/casbin/v2/util" +) + +// GetAllSubjects gets the list of subjects that show up in the current policy. +func (e *Enforcer) GetAllSubjects() []string { + return e.model.GetValuesForFieldInPolicyAllTypes("p", 0) +} + +// GetAllNamedSubjects gets the list of subjects that show up in the current named policy. +func (e *Enforcer) GetAllNamedSubjects(ptype string) []string { + return e.model.GetValuesForFieldInPolicy("p", ptype, 0) +} + +// GetAllObjects gets the list of objects that show up in the current policy. +func (e *Enforcer) GetAllObjects() []string { + return e.model.GetValuesForFieldInPolicyAllTypes("p", 1) +} + +// GetAllNamedObjects gets the list of objects that show up in the current named policy. +func (e *Enforcer) GetAllNamedObjects(ptype string) []string { + return e.model.GetValuesForFieldInPolicy("p", ptype, 1) +} + +// GetAllActions gets the list of actions that show up in the current policy. +func (e *Enforcer) GetAllActions() []string { + return e.model.GetValuesForFieldInPolicyAllTypes("p", 2) +} + +// GetAllNamedActions gets the list of actions that show up in the current named policy. +func (e *Enforcer) GetAllNamedActions(ptype string) []string { + return e.model.GetValuesForFieldInPolicy("p", ptype, 2) +} + +// GetAllRoles gets the list of roles that show up in the current policy. +func (e *Enforcer) GetAllRoles() []string { + return e.model.GetValuesForFieldInPolicyAllTypes("g", 1) +} + +// GetAllNamedRoles gets the list of roles that show up in the current named policy. +func (e *Enforcer) GetAllNamedRoles(ptype string) []string { + return e.model.GetValuesForFieldInPolicy("g", ptype, 1) +} + +// GetPolicy gets all the authorization rules in the policy. +func (e *Enforcer) GetPolicy() [][]string { + return e.GetNamedPolicy("p") +} + +// GetFilteredPolicy gets all the authorization rules in the policy, field filters can be specified. +func (e *Enforcer) GetFilteredPolicy(fieldIndex int, fieldValues ...string) [][]string { + return e.GetFilteredNamedPolicy("p", fieldIndex, fieldValues...) +} + +// GetNamedPolicy gets all the authorization rules in the named policy. +func (e *Enforcer) GetNamedPolicy(ptype string) [][]string { + return e.model.GetPolicy("p", ptype) +} + +// GetFilteredNamedPolicy gets all the authorization rules in the named policy, field filters can be specified. +func (e *Enforcer) GetFilteredNamedPolicy(ptype string, fieldIndex int, fieldValues ...string) [][]string { + return e.model.GetFilteredPolicy("p", ptype, fieldIndex, fieldValues...) +} + +// GetGroupingPolicy gets all the role inheritance rules in the policy. +func (e *Enforcer) GetGroupingPolicy() [][]string { + return e.GetNamedGroupingPolicy("g") +} + +// GetFilteredGroupingPolicy gets all the role inheritance rules in the policy, field filters can be specified. +func (e *Enforcer) GetFilteredGroupingPolicy(fieldIndex int, fieldValues ...string) [][]string { + return e.GetFilteredNamedGroupingPolicy("g", fieldIndex, fieldValues...) +} + +// GetNamedGroupingPolicy gets all the role inheritance rules in the policy. +func (e *Enforcer) GetNamedGroupingPolicy(ptype string) [][]string { + return e.model.GetPolicy("g", ptype) +} + +// GetFilteredNamedGroupingPolicy gets all the role inheritance rules in the policy, field filters can be specified. +func (e *Enforcer) GetFilteredNamedGroupingPolicy(ptype string, fieldIndex int, fieldValues ...string) [][]string { + return e.model.GetFilteredPolicy("g", ptype, fieldIndex, fieldValues...) +} + +// GetFilteredNamedPolicyWithMatcher gets rules based on matcher from the policy. +func (e *Enforcer) GetFilteredNamedPolicyWithMatcher(ptype string, matcher string) ([][]string, error) { + var res [][]string + var err error + + functions := e.fm.GetFunctions() + if _, ok := e.model["g"]; ok { + for key, ast := range e.model["g"] { + rm := ast.RM + functions[key] = util.GenerateGFunction(rm) + } + } + + var expString string + if matcher == "" { + return res, fmt.Errorf("matcher is empty") + } else { + expString = util.RemoveComments(util.EscapeAssertion(matcher)) + } + + var expression *govaluate.EvaluableExpression + + expression, err = govaluate.NewEvaluableExpressionWithFunctions(expString, functions) + if err != nil { + return res, err + } + + pTokens := make(map[string]int, len(e.model["p"][ptype].Tokens)) + for i, token := range e.model["p"][ptype].Tokens { + pTokens[token] = i + } + + parameters := enforceParameters{ + pTokens: pTokens, + } + + if policyLen := len(e.model["p"][ptype].Policy); policyLen != 0 && strings.Contains(expString, ptype+"_") { + for _, pvals := range e.model["p"][ptype].Policy { + if len(e.model["p"][ptype].Tokens) != len(pvals) { + return res, fmt.Errorf( + "invalid policy size: expected %d, got %d, pvals: %v", + len(e.model["p"][ptype].Tokens), + len(pvals), + pvals) + } + + parameters.pVals = pvals + + result, err := expression.Eval(parameters) + + if err != nil { + return res, err + } + + switch result := result.(type) { + case bool: + if result { + res = append(res, pvals) + } + case float64: + if result != 0 { + res = append(res, pvals) + } + default: + return res, errors.New("matcher result should be bool, int or float") + } + } + } + return res, nil +} + +// HasPolicy determines whether an authorization rule exists. +func (e *Enforcer) HasPolicy(params ...interface{}) bool { + return e.HasNamedPolicy("p", params...) +} + +// HasNamedPolicy determines whether a named authorization rule exists. +func (e *Enforcer) HasNamedPolicy(ptype string, params ...interface{}) bool { + if strSlice, ok := params[0].([]string); len(params) == 1 && ok { + return e.model.HasPolicy("p", ptype, strSlice) + } + + policy := make([]string, 0) + for _, param := range params { + policy = append(policy, param.(string)) + } + + return e.model.HasPolicy("p", ptype, policy) +} + +// AddPolicy adds an authorization rule to the current policy. +// If the rule already exists, the function returns false and the rule will not be added. +// Otherwise the function returns true by adding the new rule. +func (e *Enforcer) AddPolicy(params ...interface{}) (bool, error) { + return e.AddNamedPolicy("p", params...) +} + +// AddPolicies adds authorization rules to the current policy. +// If the rule already exists, the function returns false for the corresponding rule and the rule will not be added. +// Otherwise the function returns true for the corresponding rule by adding the new rule. +func (e *Enforcer) AddPolicies(rules [][]string) (bool, error) { + return e.AddNamedPolicies("p", rules) +} + +// AddNamedPolicy adds an authorization rule to the current named policy. +// If the rule already exists, the function returns false and the rule will not be added. +// Otherwise the function returns true by adding the new rule. +func (e *Enforcer) AddNamedPolicy(ptype string, params ...interface{}) (bool, error) { + if strSlice, ok := params[0].([]string); len(params) == 1 && ok { + strSlice = append(make([]string, 0, len(strSlice)), strSlice...) + return e.addPolicy("p", ptype, strSlice) + } + policy := make([]string, 0) + for _, param := range params { + policy = append(policy, param.(string)) + } + + return e.addPolicy("p", ptype, policy) +} + +// AddNamedPolicies adds authorization rules to the current named policy. +// If the rule already exists, the function returns false for the corresponding rule and the rule will not be added. +// Otherwise the function returns true for the corresponding by adding the new rule. +func (e *Enforcer) AddNamedPolicies(ptype string, rules [][]string) (bool, error) { + return e.addPolicies("p", ptype, rules) +} + +// RemovePolicy removes an authorization rule from the current policy. +func (e *Enforcer) RemovePolicy(params ...interface{}) (bool, error) { + return e.RemoveNamedPolicy("p", params...) +} + +// UpdatePolicy updates an authorization rule from the current policy. +func (e *Enforcer) UpdatePolicy(oldPolicy []string, newPolicy []string) (bool, error) { + return e.UpdateNamedPolicy("p", oldPolicy, newPolicy) +} + +func (e *Enforcer) UpdateNamedPolicy(ptype string, p1 []string, p2 []string) (bool, error) { + return e.updatePolicy("p", ptype, p1, p2) +} + +// UpdatePolicies updates authorization rules from the current policies. +func (e *Enforcer) UpdatePolicies(oldPolices [][]string, newPolicies [][]string) (bool, error) { + return e.UpdateNamedPolicies("p", oldPolices, newPolicies) +} + +func (e *Enforcer) UpdateNamedPolicies(ptype string, p1 [][]string, p2 [][]string) (bool, error) { + return e.updatePolicies("p", ptype, p1, p2) +} + +func (e *Enforcer) UpdateFilteredPolicies(newPolicies [][]string, fieldIndex int, fieldValues ...string) (bool, error) { + return e.UpdateFilteredNamedPolicies("p", newPolicies, fieldIndex, fieldValues...) +} + +func (e *Enforcer) UpdateFilteredNamedPolicies(ptype string, newPolicies [][]string, fieldIndex int, fieldValues ...string) (bool, error) { + return e.updateFilteredPolicies("p", ptype, newPolicies, fieldIndex, fieldValues...) +} + +// RemovePolicies removes authorization rules from the current policy. +func (e *Enforcer) RemovePolicies(rules [][]string) (bool, error) { + return e.RemoveNamedPolicies("p", rules) +} + +// RemoveFilteredPolicy removes an authorization rule from the current policy, field filters can be specified. +func (e *Enforcer) RemoveFilteredPolicy(fieldIndex int, fieldValues ...string) (bool, error) { + return e.RemoveFilteredNamedPolicy("p", fieldIndex, fieldValues...) +} + +// RemoveNamedPolicy removes an authorization rule from the current named policy. +func (e *Enforcer) RemoveNamedPolicy(ptype string, params ...interface{}) (bool, error) { + if strSlice, ok := params[0].([]string); len(params) == 1 && ok { + return e.removePolicy("p", ptype, strSlice) + } + policy := make([]string, 0) + for _, param := range params { + policy = append(policy, param.(string)) + } + + return e.removePolicy("p", ptype, policy) +} + +// RemoveNamedPolicies removes authorization rules from the current named policy. +func (e *Enforcer) RemoveNamedPolicies(ptype string, rules [][]string) (bool, error) { + return e.removePolicies("p", ptype, rules) +} + +// RemoveFilteredNamedPolicy removes an authorization rule from the current named policy, field filters can be specified. +func (e *Enforcer) RemoveFilteredNamedPolicy(ptype string, fieldIndex int, fieldValues ...string) (bool, error) { + return e.removeFilteredPolicy("p", ptype, fieldIndex, fieldValues) +} + +// HasGroupingPolicy determines whether a role inheritance rule exists. +func (e *Enforcer) HasGroupingPolicy(params ...interface{}) bool { + return e.HasNamedGroupingPolicy("g", params...) +} + +// HasNamedGroupingPolicy determines whether a named role inheritance rule exists. +func (e *Enforcer) HasNamedGroupingPolicy(ptype string, params ...interface{}) bool { + if strSlice, ok := params[0].([]string); len(params) == 1 && ok { + return e.model.HasPolicy("g", ptype, strSlice) + } + + policy := make([]string, 0) + for _, param := range params { + policy = append(policy, param.(string)) + } + + return e.model.HasPolicy("g", ptype, policy) +} + +// AddGroupingPolicy adds a role inheritance rule to the current policy. +// If the rule already exists, the function returns false and the rule will not be added. +// Otherwise the function returns true by adding the new rule. +func (e *Enforcer) AddGroupingPolicy(params ...interface{}) (bool, error) { + return e.AddNamedGroupingPolicy("g", params...) +} + +// AddGroupingPolicies adds role inheritance rules to the current policy. +// If the rule already exists, the function returns false for the corresponding policy rule and the rule will not be added. +// Otherwise the function returns true for the corresponding policy rule by adding the new rule. +func (e *Enforcer) AddGroupingPolicies(rules [][]string) (bool, error) { + return e.AddNamedGroupingPolicies("g", rules) +} + +// AddNamedGroupingPolicy adds a named role inheritance rule to the current policy. +// If the rule already exists, the function returns false and the rule will not be added. +// Otherwise the function returns true by adding the new rule. +func (e *Enforcer) AddNamedGroupingPolicy(ptype string, params ...interface{}) (bool, error) { + var ruleAdded bool + var err error + if strSlice, ok := params[0].([]string); len(params) == 1 && ok { + ruleAdded, err = e.addPolicy("g", ptype, strSlice) + } else { + policy := make([]string, 0) + for _, param := range params { + policy = append(policy, param.(string)) + } + + ruleAdded, err = e.addPolicy("g", ptype, policy) + } + + return ruleAdded, err +} + +// AddNamedGroupingPolicies adds named role inheritance rules to the current policy. +// If the rule already exists, the function returns false for the corresponding policy rule and the rule will not be added. +// Otherwise the function returns true for the corresponding policy rule by adding the new rule. +func (e *Enforcer) AddNamedGroupingPolicies(ptype string, rules [][]string) (bool, error) { + return e.addPolicies("g", ptype, rules) +} + +// RemoveGroupingPolicy removes a role inheritance rule from the current policy. +func (e *Enforcer) RemoveGroupingPolicy(params ...interface{}) (bool, error) { + return e.RemoveNamedGroupingPolicy("g", params...) +} + +// RemoveGroupingPolicies removes role inheritance rules from the current policy. +func (e *Enforcer) RemoveGroupingPolicies(rules [][]string) (bool, error) { + return e.RemoveNamedGroupingPolicies("g", rules) +} + +// RemoveFilteredGroupingPolicy removes a role inheritance rule from the current policy, field filters can be specified. +func (e *Enforcer) RemoveFilteredGroupingPolicy(fieldIndex int, fieldValues ...string) (bool, error) { + return e.RemoveFilteredNamedGroupingPolicy("g", fieldIndex, fieldValues...) +} + +// RemoveNamedGroupingPolicy removes a role inheritance rule from the current named policy. +func (e *Enforcer) RemoveNamedGroupingPolicy(ptype string, params ...interface{}) (bool, error) { + var ruleRemoved bool + var err error + if strSlice, ok := params[0].([]string); len(params) == 1 && ok { + ruleRemoved, err = e.removePolicy("g", ptype, strSlice) + } else { + policy := make([]string, 0) + for _, param := range params { + policy = append(policy, param.(string)) + } + + ruleRemoved, err = e.removePolicy("g", ptype, policy) + } + + return ruleRemoved, err +} + +// RemoveNamedGroupingPolicies removes role inheritance rules from the current named policy. +func (e *Enforcer) RemoveNamedGroupingPolicies(ptype string, rules [][]string) (bool, error) { + return e.removePolicies("g", ptype, rules) +} + +func (e *Enforcer) UpdateGroupingPolicy(oldRule []string, newRule []string) (bool, error) { + return e.UpdateNamedGroupingPolicy("g", oldRule, newRule) +} + +// UpdateGroupingPolicies updates authorization rules from the current policies. +func (e *Enforcer) UpdateGroupingPolicies(oldRules [][]string, newRules [][]string) (bool, error) { + return e.UpdateNamedGroupingPolicies("g", oldRules, newRules) +} + +func (e *Enforcer) UpdateNamedGroupingPolicy(ptype string, oldRule []string, newRule []string) (bool, error) { + return e.updatePolicy("g", ptype, oldRule, newRule) +} + +func (e *Enforcer) UpdateNamedGroupingPolicies(ptype string, oldRules [][]string, newRules [][]string) (bool, error) { + return e.updatePolicies("g", ptype, oldRules, newRules) +} + +// RemoveFilteredNamedGroupingPolicy removes a role inheritance rule from the current named policy, field filters can be specified. +func (e *Enforcer) RemoveFilteredNamedGroupingPolicy(ptype string, fieldIndex int, fieldValues ...string) (bool, error) { + return e.removeFilteredPolicy("g", ptype, fieldIndex, fieldValues) +} + +// AddFunction adds a customized function. +func (e *Enforcer) AddFunction(name string, function govaluate.ExpressionFunction) { + e.fm.AddFunction(name, function) +} + +func (e *Enforcer) SelfAddPolicy(sec string, ptype string, rule []string) (bool, error) { + return e.addPolicyWithoutNotify(sec, ptype, rule) +} + +func (e *Enforcer) SelfAddPolicies(sec string, ptype string, rules [][]string) (bool, error) { + return e.addPoliciesWithoutNotify(sec, ptype, rules) +} + +func (e *Enforcer) SelfRemovePolicy(sec string, ptype string, rule []string) (bool, error) { + return e.removePolicyWithoutNotify(sec, ptype, rule) +} + +func (e *Enforcer) SelfRemovePolicies(sec string, ptype string, rules [][]string) (bool, error) { + return e.removePoliciesWithoutNotify(sec, ptype, rules) +} + +func (e *Enforcer) SelfRemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) (bool, error) { + return e.removeFilteredPolicyWithoutNotify(sec, ptype, fieldIndex, fieldValues) +} + +func (e *Enforcer) SelfUpdatePolicy(sec string, ptype string, oldRule, newRule []string) (bool, error) { + return e.updatePolicyWithoutNotify(sec, ptype, oldRule, newRule) +} + +func (e *Enforcer) SelfUpdatePolicies(sec string, ptype string, oldRules, newRules [][]string) (bool, error) { + return e.updatePoliciesWithoutNotify(sec, ptype, oldRules, newRules) +} diff --git a/vendor/github.com/casbin/casbin/v2/model/assertion.go b/vendor/github.com/casbin/casbin/v2/model/assertion.go new file mode 100644 index 00000000..054dd0f4 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/model/assertion.go @@ -0,0 +1,118 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package model + +import ( + "errors" + "strings" + + "github.com/casbin/casbin/v2/log" + "github.com/casbin/casbin/v2/rbac" +) + +// Assertion represents an expression in a section of the model. +// For example: r = sub, obj, act +type Assertion struct { + Key string + Value string + Tokens []string + Policy [][]string + PolicyMap map[string]int + RM rbac.RoleManager + FieldIndexMap map[string]int + + logger log.Logger +} + +func (ast *Assertion) buildIncrementalRoleLinks(rm rbac.RoleManager, op PolicyOp, rules [][]string) error { + ast.RM = rm + count := strings.Count(ast.Value, "_") + if count < 2 { + return errors.New("the number of \"_\" in role definition should be at least 2") + } + + for _, rule := range rules { + if len(rule) < count { + return errors.New("grouping policy elements do not meet role definition") + } + if len(rule) > count { + rule = rule[:count] + } + switch op { + case PolicyAdd: + err := rm.AddLink(rule[0], rule[1], rule[2:]...) + if err != nil { + return err + } + case PolicyRemove: + err := rm.DeleteLink(rule[0], rule[1], rule[2:]...) + if err != nil { + return err + } + } + } + + return nil +} + +func (ast *Assertion) buildRoleLinks(rm rbac.RoleManager) error { + ast.RM = rm + count := strings.Count(ast.Value, "_") + if count < 2 { + return errors.New("the number of \"_\" in role definition should be at least 2") + } + for _, rule := range ast.Policy { + if len(rule) < count { + return errors.New("grouping policy elements do not meet role definition") + } + if len(rule) > count { + rule = rule[:count] + } + err := ast.RM.AddLink(rule[0], rule[1], rule[2:]...) + if err != nil { + return err + } + } + + return nil +} + +func (ast *Assertion) setLogger(logger log.Logger) { + ast.logger = logger +} + +func (ast *Assertion) copy() *Assertion { + tokens := append([]string(nil), ast.Tokens...) + policy := make([][]string, len(ast.Policy)) + + for i, p := range ast.Policy { + policy[i] = append(policy[i], p...) + } + policyMap := make(map[string]int) + for k, v := range ast.PolicyMap { + policyMap[k] = v + } + + newAst := &Assertion{ + Key: ast.Key, + Value: ast.Value, + PolicyMap: policyMap, + Tokens: tokens, + Policy: policy, + FieldIndexMap: ast.FieldIndexMap, + } + + return newAst +} diff --git a/vendor/github.com/casbin/casbin/v2/model/function.go b/vendor/github.com/casbin/casbin/v2/model/function.go new file mode 100644 index 00000000..0ae12460 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/model/function.go @@ -0,0 +1,66 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package model + +import ( + "sync" + + "github.com/Knetic/govaluate" + "github.com/casbin/casbin/v2/util" +) + +// FunctionMap represents the collection of Function. +type FunctionMap struct { + fns *sync.Map +} + +// [string]govaluate.ExpressionFunction + +// AddFunction adds an expression function. +func (fm *FunctionMap) AddFunction(name string, function govaluate.ExpressionFunction) { + fm.fns.LoadOrStore(name, function) +} + +// LoadFunctionMap loads an initial function map. +func LoadFunctionMap() FunctionMap { + fm := &FunctionMap{} + fm.fns = &sync.Map{} + + fm.AddFunction("keyMatch", util.KeyMatchFunc) + fm.AddFunction("keyGet", util.KeyGetFunc) + fm.AddFunction("keyMatch2", util.KeyMatch2Func) + fm.AddFunction("keyGet2", util.KeyGet2Func) + fm.AddFunction("keyMatch3", util.KeyMatch3Func) + fm.AddFunction("keyGet3", util.KeyGet3Func) + fm.AddFunction("keyMatch4", util.KeyMatch4Func) + fm.AddFunction("keyMatch5", util.KeyMatch5Func) + fm.AddFunction("regexMatch", util.RegexMatchFunc) + fm.AddFunction("ipMatch", util.IPMatchFunc) + fm.AddFunction("globMatch", util.GlobMatchFunc) + + return *fm +} + +// GetFunctions return a map with all the functions +func (fm *FunctionMap) GetFunctions() map[string]govaluate.ExpressionFunction { + ret := make(map[string]govaluate.ExpressionFunction) + + fm.fns.Range(func(k interface{}, v interface{}) bool { + ret[k.(string)] = v.(govaluate.ExpressionFunction) + return true + }) + + return ret +} diff --git a/vendor/github.com/casbin/casbin/v2/model/model.go b/vendor/github.com/casbin/casbin/v2/model/model.go new file mode 100644 index 00000000..25ce8243 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/model/model.go @@ -0,0 +1,406 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package model + +import ( + "container/list" + "errors" + "fmt" + "regexp" + "sort" + "strconv" + "strings" + + "github.com/casbin/casbin/v2/config" + "github.com/casbin/casbin/v2/constant" + "github.com/casbin/casbin/v2/log" + "github.com/casbin/casbin/v2/util" +) + +// Model represents the whole access control model. +type Model map[string]AssertionMap + +// AssertionMap is the collection of assertions, can be "r", "p", "g", "e", "m". +type AssertionMap map[string]*Assertion + +const defaultDomain string = "" +const defaultSeparator = "::" + +var sectionNameMap = map[string]string{ + "r": "request_definition", + "p": "policy_definition", + "g": "role_definition", + "e": "policy_effect", + "m": "matchers", +} + +// Minimal required sections for a model to be valid +var requiredSections = []string{"r", "p", "e", "m"} + +func loadAssertion(model Model, cfg config.ConfigInterface, sec string, key string) bool { + value := cfg.String(sectionNameMap[sec] + "::" + key) + return model.AddDef(sec, key, value) +} + +// AddDef adds an assertion to the model. +func (model Model) AddDef(sec string, key string, value string) bool { + if value == "" { + return false + } + + ast := Assertion{} + ast.Key = key + ast.Value = value + ast.PolicyMap = make(map[string]int) + ast.FieldIndexMap = make(map[string]int) + ast.setLogger(model.GetLogger()) + + if sec == "r" || sec == "p" { + ast.Tokens = strings.Split(ast.Value, ",") + for i := range ast.Tokens { + ast.Tokens[i] = key + "_" + strings.TrimSpace(ast.Tokens[i]) + } + } else if sec == "g" { + ast.Tokens = strings.Split(ast.Value, ",") + } else { + ast.Value = util.RemoveComments(util.EscapeAssertion(ast.Value)) + } + + if sec == "m" && strings.Contains(ast.Value, "in") { + ast.Value = strings.Replace(strings.Replace(ast.Value, "[", "(", -1), "]", ")", -1) + } + + _, ok := model[sec] + if !ok { + model[sec] = make(AssertionMap) + } + + model[sec][key] = &ast + return true +} + +func getKeySuffix(i int) string { + if i == 1 { + return "" + } + + return strconv.Itoa(i) +} + +func loadSection(model Model, cfg config.ConfigInterface, sec string) { + i := 1 + for { + if !loadAssertion(model, cfg, sec, sec+getKeySuffix(i)) { + break + } else { + i++ + } + } +} + +// SetLogger sets the model's logger. +func (model Model) SetLogger(logger log.Logger) { + for _, astMap := range model { + for _, ast := range astMap { + ast.logger = logger + } + } + model["logger"] = AssertionMap{"logger": &Assertion{logger: logger}} +} + +// GetLogger returns the model's logger. +func (model Model) GetLogger() log.Logger { + return model["logger"]["logger"].logger +} + +// NewModel creates an empty model. +func NewModel() Model { + m := make(Model) + m.SetLogger(&log.DefaultLogger{}) + + return m +} + +// NewModelFromFile creates a model from a .CONF file. +func NewModelFromFile(path string) (Model, error) { + m := NewModel() + + err := m.LoadModel(path) + if err != nil { + return nil, err + } + + return m, nil +} + +// NewModelFromString creates a model from a string which contains model text. +func NewModelFromString(text string) (Model, error) { + m := NewModel() + + err := m.LoadModelFromText(text) + if err != nil { + return nil, err + } + + return m, nil +} + +// LoadModel loads the model from model CONF file. +func (model Model) LoadModel(path string) error { + cfg, err := config.NewConfig(path) + if err != nil { + return err + } + + return model.loadModelFromConfig(cfg) +} + +// LoadModelFromText loads the model from the text. +func (model Model) LoadModelFromText(text string) error { + cfg, err := config.NewConfigFromText(text) + if err != nil { + return err + } + + return model.loadModelFromConfig(cfg) +} + +func (model Model) loadModelFromConfig(cfg config.ConfigInterface) error { + for s := range sectionNameMap { + loadSection(model, cfg, s) + } + ms := make([]string, 0) + for _, rs := range requiredSections { + if !model.hasSection(rs) { + ms = append(ms, sectionNameMap[rs]) + } + } + if len(ms) > 0 { + return fmt.Errorf("missing required sections: %s", strings.Join(ms, ",")) + } + return nil +} + +func (model Model) hasSection(sec string) bool { + section := model[sec] + return section != nil +} + +// PrintModel prints the model to the log. +func (model Model) PrintModel() { + if !model.GetLogger().IsEnabled() { + return + } + + var modelInfo [][]string + for k, v := range model { + if k == "logger" { + continue + } + + for i, j := range v { + modelInfo = append(modelInfo, []string{k, i, j.Value}) + } + } + + model.GetLogger().LogModel(modelInfo) +} + +func (model Model) SortPoliciesBySubjectHierarchy() error { + if model["e"]["e"].Value != constant.SubjectPriorityEffect { + return nil + } + subIndex := 0 + for ptype, assertion := range model["p"] { + domainIndex, err := model.GetFieldIndex(ptype, constant.DomainIndex) + if err != nil { + domainIndex = -1 + } + policies := assertion.Policy + subjectHierarchyMap, err := getSubjectHierarchyMap(model["g"]["g"].Policy) + if err != nil { + return err + } + sort.SliceStable(policies, func(i, j int) bool { + domain1, domain2 := defaultDomain, defaultDomain + if domainIndex != -1 { + domain1 = policies[i][domainIndex] + domain2 = policies[j][domainIndex] + } + name1, name2 := getNameWithDomain(domain1, policies[i][subIndex]), getNameWithDomain(domain2, policies[j][subIndex]) + p1 := subjectHierarchyMap[name1] + p2 := subjectHierarchyMap[name2] + return p1 > p2 + }) + for i, policy := range assertion.Policy { + assertion.PolicyMap[strings.Join(policy, ",")] = i + } + } + return nil +} + +func getSubjectHierarchyMap(policies [][]string) (map[string]int, error) { + subjectHierarchyMap := make(map[string]int) + // Tree structure of role + policyMap := make(map[string][]string) + for _, policy := range policies { + if len(policy) < 2 { + return nil, errors.New("policy g expect 2 more params") + } + domain := defaultDomain + if len(policy) != 2 { + domain = policy[2] + } + child := getNameWithDomain(domain, policy[0]) + parent := getNameWithDomain(domain, policy[1]) + policyMap[parent] = append(policyMap[parent], child) + if _, ok := subjectHierarchyMap[child]; !ok { + subjectHierarchyMap[child] = 0 + } + if _, ok := subjectHierarchyMap[parent]; !ok { + subjectHierarchyMap[parent] = 0 + } + subjectHierarchyMap[child] = 1 + } + // Use queues for levelOrder + queue := list.New() + for k, v := range subjectHierarchyMap { + root := k + if v != 0 { + continue + } + lv := 0 + queue.PushBack(root) + for queue.Len() != 0 { + sz := queue.Len() + for i := 0; i < sz; i++ { + node := queue.Front() + queue.Remove(node) + nodeValue := node.Value.(string) + subjectHierarchyMap[nodeValue] = lv + if _, ok := policyMap[nodeValue]; ok { + for _, child := range policyMap[nodeValue] { + queue.PushBack(child) + } + } + } + lv++ + } + } + return subjectHierarchyMap, nil +} + +func getNameWithDomain(domain string, name string) string { + return domain + defaultSeparator + name +} + +func (model Model) SortPoliciesByPriority() error { + for ptype, assertion := range model["p"] { + priorityIndex, err := model.GetFieldIndex(ptype, constant.PriorityIndex) + if err != nil { + continue + } + policies := assertion.Policy + sort.SliceStable(policies, func(i, j int) bool { + p1, err := strconv.Atoi(policies[i][priorityIndex]) + if err != nil { + return true + } + p2, err := strconv.Atoi(policies[j][priorityIndex]) + if err != nil { + return true + } + return p1 < p2 + }) + for i, policy := range assertion.Policy { + assertion.PolicyMap[strings.Join(policy, ",")] = i + } + } + return nil +} + +func (model Model) ToText() string { + tokenPatterns := make(map[string]string) + + pPattern, rPattern := regexp.MustCompile("^p_"), regexp.MustCompile("^r_") + for _, ptype := range []string{"r", "p"} { + for _, token := range model[ptype][ptype].Tokens { + tokenPatterns[token] = rPattern.ReplaceAllString(pPattern.ReplaceAllString(token, "p."), "r.") + } + } + if strings.Contains(model["e"]["e"].Value, "p_eft") { + tokenPatterns["p_eft"] = "p.eft" + } + s := strings.Builder{} + writeString := func(sec string) { + for ptype := range model[sec] { + value := model[sec][ptype].Value + for tokenPattern, newToken := range tokenPatterns { + value = strings.Replace(value, tokenPattern, newToken, -1) + } + s.WriteString(fmt.Sprintf("%s = %s\n", sec, value)) + } + } + s.WriteString("[request_definition]\n") + writeString("r") + s.WriteString("[policy_definition]\n") + writeString("p") + if _, ok := model["g"]; ok { + s.WriteString("[role_definition]\n") + for ptype := range model["g"] { + s.WriteString(fmt.Sprintf("%s = %s\n", ptype, model["g"][ptype].Value)) + } + } + s.WriteString("[policy_effect]\n") + writeString("e") + s.WriteString("[matchers]\n") + writeString("m") + return s.String() +} + +func (model Model) Copy() Model { + newModel := NewModel() + + for sec, m := range model { + newAstMap := make(AssertionMap) + for ptype, ast := range m { + newAstMap[ptype] = ast.copy() + } + newModel[sec] = newAstMap + } + + newModel.SetLogger(model.GetLogger()) + return newModel +} + +func (model Model) GetFieldIndex(ptype string, field string) (int, error) { + assertion := model["p"][ptype] + if index, ok := assertion.FieldIndexMap[field]; ok { + return index, nil + } + pattern := fmt.Sprintf("%s_"+field, ptype) + index := -1 + for i, token := range assertion.Tokens { + if token == pattern { + index = i + break + } + } + if index == -1 { + return index, fmt.Errorf(field + " index is not set, please use enforcer.SetFieldIndex() to set index") + } + assertion.FieldIndexMap[field] = index + return index, nil +} diff --git a/vendor/github.com/casbin/casbin/v2/model/policy.go b/vendor/github.com/casbin/casbin/v2/model/policy.go new file mode 100644 index 00000000..d4b4a7a5 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/model/policy.go @@ -0,0 +1,372 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package model + +import ( + "fmt" + "strconv" + "strings" + + "github.com/casbin/casbin/v2/constant" + "github.com/casbin/casbin/v2/rbac" + "github.com/casbin/casbin/v2/util" +) + +type ( + PolicyOp int +) + +const ( + PolicyAdd PolicyOp = iota + PolicyRemove +) + +const DefaultSep = "," + +// BuildIncrementalRoleLinks provides incremental build the role inheritance relations. +func (model Model) BuildIncrementalRoleLinks(rmMap map[string]rbac.RoleManager, op PolicyOp, sec string, ptype string, rules [][]string) error { + if sec == "g" { + return model[sec][ptype].buildIncrementalRoleLinks(rmMap[ptype], op, rules) + } + return nil +} + +// BuildRoleLinks initializes the roles in RBAC. +func (model Model) BuildRoleLinks(rmMap map[string]rbac.RoleManager) error { + model.PrintPolicy() + for ptype, ast := range model["g"] { + rm := rmMap[ptype] + err := ast.buildRoleLinks(rm) + if err != nil { + return err + } + } + + return nil +} + +// PrintPolicy prints the policy to log. +func (model Model) PrintPolicy() { + if !model.GetLogger().IsEnabled() { + return + } + + policy := make(map[string][][]string) + + for key, ast := range model["p"] { + value, found := policy[key] + if found { + value = append(value, ast.Policy...) + policy[key] = value + } else { + policy[key] = ast.Policy + } + } + + for key, ast := range model["g"] { + value, found := policy[key] + if found { + value = append(value, ast.Policy...) + policy[key] = value + } else { + policy[key] = ast.Policy + } + } + + model.GetLogger().LogPolicy(policy) +} + +// ClearPolicy clears all current policy. +func (model Model) ClearPolicy() { + for _, ast := range model["p"] { + ast.Policy = nil + ast.PolicyMap = map[string]int{} + } + + for _, ast := range model["g"] { + ast.Policy = nil + ast.PolicyMap = map[string]int{} + } +} + +// GetPolicy gets all rules in a policy. +func (model Model) GetPolicy(sec string, ptype string) [][]string { + return model[sec][ptype].Policy +} + +// GetFilteredPolicy gets rules based on field filters from a policy. +func (model Model) GetFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) [][]string { + res := [][]string{} + + for _, rule := range model[sec][ptype].Policy { + matched := true + for i, fieldValue := range fieldValues { + if fieldValue != "" && rule[fieldIndex+i] != fieldValue { + matched = false + break + } + } + + if matched { + res = append(res, rule) + } + } + + return res +} + +// HasPolicyEx determines whether a model has the specified policy rule with error. +func (model Model) HasPolicyEx(sec string, ptype string, rule []string) (bool, error) { + assertion := model[sec][ptype] + switch sec { + case "p": + if len(rule) != len(assertion.Tokens) { + return false, fmt.Errorf( + "invalid policy rule size: expected %d, got %d, rule: %v", + len(model["p"][ptype].Tokens), + len(rule), + rule) + } + case "g": + if len(rule) < len(assertion.Tokens) { + return false, fmt.Errorf( + "invalid policy rule size: expected %d, got %d, rule: %v", + len(model["g"][ptype].Tokens), + len(rule), + rule) + } + } + return model.HasPolicy(sec, ptype, rule), nil +} + +// HasPolicy determines whether a model has the specified policy rule. +func (model Model) HasPolicy(sec string, ptype string, rule []string) bool { + _, ok := model[sec][ptype].PolicyMap[strings.Join(rule, DefaultSep)] + return ok +} + +// HasPolicies determines whether a model has any of the specified policies. If one is found we return true. +func (model Model) HasPolicies(sec string, ptype string, rules [][]string) bool { + for i := 0; i < len(rules); i++ { + if model.HasPolicy(sec, ptype, rules[i]) { + return true + } + } + + return false +} + +// AddPolicy adds a policy rule to the model. +func (model Model) AddPolicy(sec string, ptype string, rule []string) { + assertion := model[sec][ptype] + assertion.Policy = append(assertion.Policy, rule) + assertion.PolicyMap[strings.Join(rule, DefaultSep)] = len(model[sec][ptype].Policy) - 1 + + hasPriority := false + if _, ok := assertion.FieldIndexMap[constant.PriorityIndex]; ok { + hasPriority = true + } + if sec == "p" && hasPriority { + if idxInsert, err := strconv.Atoi(rule[assertion.FieldIndexMap[constant.PriorityIndex]]); err == nil { + i := len(assertion.Policy) - 1 + for ; i > 0; i-- { + idx, err := strconv.Atoi(assertion.Policy[i-1][assertion.FieldIndexMap[constant.PriorityIndex]]) + if err != nil { + break + } + if idx > idxInsert { + assertion.Policy[i] = assertion.Policy[i-1] + assertion.PolicyMap[strings.Join(assertion.Policy[i-1], DefaultSep)]++ + } else { + break + } + } + assertion.Policy[i] = rule + assertion.PolicyMap[strings.Join(rule, DefaultSep)] = i + } + } +} + +// AddPolicies adds policy rules to the model. +func (model Model) AddPolicies(sec string, ptype string, rules [][]string) { + _ = model.AddPoliciesWithAffected(sec, ptype, rules) +} + +// AddPoliciesWithAffected adds policy rules to the model, and returns affected rules. +func (model Model) AddPoliciesWithAffected(sec string, ptype string, rules [][]string) [][]string { + var affected [][]string + for _, rule := range rules { + hashKey := strings.Join(rule, DefaultSep) + _, ok := model[sec][ptype].PolicyMap[hashKey] + if ok { + continue + } + affected = append(affected, rule) + model.AddPolicy(sec, ptype, rule) + } + return affected +} + +// RemovePolicy removes a policy rule from the model. +// Deprecated: Using AddPoliciesWithAffected instead. +func (model Model) RemovePolicy(sec string, ptype string, rule []string) bool { + index, ok := model[sec][ptype].PolicyMap[strings.Join(rule, DefaultSep)] + if !ok { + return false + } + + model[sec][ptype].Policy = append(model[sec][ptype].Policy[:index], model[sec][ptype].Policy[index+1:]...) + delete(model[sec][ptype].PolicyMap, strings.Join(rule, DefaultSep)) + for i := index; i < len(model[sec][ptype].Policy); i++ { + model[sec][ptype].PolicyMap[strings.Join(model[sec][ptype].Policy[i], DefaultSep)] = i + } + + return true +} + +// UpdatePolicy updates a policy rule from the model. +func (model Model) UpdatePolicy(sec string, ptype string, oldRule []string, newRule []string) bool { + oldPolicy := strings.Join(oldRule, DefaultSep) + index, ok := model[sec][ptype].PolicyMap[oldPolicy] + if !ok { + return false + } + + model[sec][ptype].Policy[index] = newRule + delete(model[sec][ptype].PolicyMap, oldPolicy) + model[sec][ptype].PolicyMap[strings.Join(newRule, DefaultSep)] = index + + return true +} + +// UpdatePolicies updates a policy rule from the model. +func (model Model) UpdatePolicies(sec string, ptype string, oldRules, newRules [][]string) bool { + rollbackFlag := false + // index -> []{oldIndex, newIndex} + modifiedRuleIndex := make(map[int][]int) + // rollback + defer func() { + if rollbackFlag { + for index, oldNewIndex := range modifiedRuleIndex { + model[sec][ptype].Policy[index] = oldRules[oldNewIndex[0]] + oldPolicy := strings.Join(oldRules[oldNewIndex[0]], DefaultSep) + newPolicy := strings.Join(newRules[oldNewIndex[1]], DefaultSep) + delete(model[sec][ptype].PolicyMap, newPolicy) + model[sec][ptype].PolicyMap[oldPolicy] = index + } + } + }() + + newIndex := 0 + for oldIndex, oldRule := range oldRules { + oldPolicy := strings.Join(oldRule, DefaultSep) + index, ok := model[sec][ptype].PolicyMap[oldPolicy] + if !ok { + rollbackFlag = true + return false + } + + model[sec][ptype].Policy[index] = newRules[newIndex] + delete(model[sec][ptype].PolicyMap, oldPolicy) + model[sec][ptype].PolicyMap[strings.Join(newRules[newIndex], DefaultSep)] = index + modifiedRuleIndex[index] = []int{oldIndex, newIndex} + newIndex++ + } + + return true +} + +// RemovePolicies removes policy rules from the model. +func (model Model) RemovePolicies(sec string, ptype string, rules [][]string) bool { + affected := model.RemovePoliciesWithAffected(sec, ptype, rules) + return len(affected) != 0 +} + +// RemovePoliciesWithAffected removes policy rules from the model, and returns affected rules. +func (model Model) RemovePoliciesWithAffected(sec string, ptype string, rules [][]string) [][]string { + var affected [][]string + for _, rule := range rules { + index, ok := model[sec][ptype].PolicyMap[strings.Join(rule, DefaultSep)] + if !ok { + continue + } + + affected = append(affected, rule) + model[sec][ptype].Policy = append(model[sec][ptype].Policy[:index], model[sec][ptype].Policy[index+1:]...) + delete(model[sec][ptype].PolicyMap, strings.Join(rule, DefaultSep)) + for i := index; i < len(model[sec][ptype].Policy); i++ { + model[sec][ptype].PolicyMap[strings.Join(model[sec][ptype].Policy[i], DefaultSep)] = i + } + } + return affected +} + +// RemoveFilteredPolicy removes policy rules based on field filters from the model. +func (model Model) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) (bool, [][]string) { + var tmp [][]string + var effects [][]string + res := false + model[sec][ptype].PolicyMap = map[string]int{} + + for _, rule := range model[sec][ptype].Policy { + matched := true + for i, fieldValue := range fieldValues { + if fieldValue != "" && rule[fieldIndex+i] != fieldValue { + matched = false + break + } + } + + if matched { + effects = append(effects, rule) + } else { + tmp = append(tmp, rule) + model[sec][ptype].PolicyMap[strings.Join(rule, DefaultSep)] = len(tmp) - 1 + } + } + + if len(tmp) != len(model[sec][ptype].Policy) { + model[sec][ptype].Policy = tmp + res = true + } + + return res, effects +} + +// GetValuesForFieldInPolicy gets all values for a field for all rules in a policy, duplicated values are removed. +func (model Model) GetValuesForFieldInPolicy(sec string, ptype string, fieldIndex int) []string { + values := []string{} + + for _, rule := range model[sec][ptype].Policy { + values = append(values, rule[fieldIndex]) + } + + util.ArrayRemoveDuplicates(&values) + + return values +} + +// GetValuesForFieldInPolicyAllTypes gets all values for a field for all rules in a policy of all ptypes, duplicated values are removed. +func (model Model) GetValuesForFieldInPolicyAllTypes(sec string, fieldIndex int) []string { + values := []string{} + + for ptype := range model[sec] { + values = append(values, model.GetValuesForFieldInPolicy(sec, ptype, fieldIndex)...) + } + + util.ArrayRemoveDuplicates(&values) + + return values +} diff --git a/vendor/github.com/casbin/casbin/v2/persist/adapter.go b/vendor/github.com/casbin/casbin/v2/persist/adapter.go new file mode 100644 index 00000000..0525657a --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/persist/adapter.go @@ -0,0 +1,74 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package persist + +import ( + "encoding/csv" + "strings" + + "github.com/casbin/casbin/v2/model" +) + +// LoadPolicyLine loads a text line as a policy rule to model. +func LoadPolicyLine(line string, m model.Model) error { + if line == "" || strings.HasPrefix(line, "#") { + return nil + } + + r := csv.NewReader(strings.NewReader(line)) + r.Comma = ',' + r.Comment = '#' + r.TrimLeadingSpace = true + + tokens, err := r.Read() + if err != nil { + return err + } + + return LoadPolicyArray(tokens, m) +} + +// LoadPolicyArray loads a policy rule to model. +func LoadPolicyArray(rule []string, m model.Model) error { + key := rule[0] + sec := key[:1] + ok, err := m.HasPolicyEx(sec, key, rule[1:]) + if err != nil { + return err + } + if ok { + return nil // skip duplicated policy + } + m.AddPolicy(sec, key, rule[1:]) + return nil +} + +// Adapter is the interface for Casbin adapters. +type Adapter interface { + // LoadPolicy loads all policy rules from the storage. + LoadPolicy(model model.Model) error + // SavePolicy saves all policy rules to the storage. + SavePolicy(model model.Model) error + + // AddPolicy adds a policy rule to the storage. + // This is part of the Auto-Save feature. + AddPolicy(sec string, ptype string, rule []string) error + // RemovePolicy removes a policy rule from the storage. + // This is part of the Auto-Save feature. + RemovePolicy(sec string, ptype string, rule []string) error + // RemoveFilteredPolicy removes policy rules that match the filter from the storage. + // This is part of the Auto-Save feature. + RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error +} diff --git a/vendor/github.com/casbin/casbin/v2/persist/adapter_filtered.go b/vendor/github.com/casbin/casbin/v2/persist/adapter_filtered.go new file mode 100644 index 00000000..82c9a0e7 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/persist/adapter_filtered.go @@ -0,0 +1,29 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package persist + +import ( + "github.com/casbin/casbin/v2/model" +) + +// FilteredAdapter is the interface for Casbin adapters supporting filtered policies. +type FilteredAdapter interface { + Adapter + + // LoadFilteredPolicy loads only policy rules that match the filter. + LoadFilteredPolicy(model model.Model, filter interface{}) error + // IsFiltered returns true if the loaded policy has been filtered. + IsFiltered() bool +} diff --git a/vendor/github.com/casbin/casbin/v2/persist/batch_adapter.go b/vendor/github.com/casbin/casbin/v2/persist/batch_adapter.go new file mode 100644 index 00000000..56ec415f --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/persist/batch_adapter.go @@ -0,0 +1,26 @@ +// Copyright 2020 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package persist + +// BatchAdapter is the interface for Casbin adapters with multiple add and remove policy functions. +type BatchAdapter interface { + Adapter + // AddPolicies adds policy rules to the storage. + // This is part of the Auto-Save feature. + AddPolicies(sec string, ptype string, rules [][]string) error + // RemovePolicies removes policy rules from the storage. + // This is part of the Auto-Save feature. + RemovePolicies(sec string, ptype string, rules [][]string) error +} diff --git a/vendor/github.com/casbin/casbin/v2/persist/cache/cache.go b/vendor/github.com/casbin/casbin/v2/persist/cache/cache.go new file mode 100644 index 00000000..d08ed225 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/persist/cache/cache.go @@ -0,0 +1,39 @@ +// Copyright 2021 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import "errors" + +var ErrNoSuchKey = errors.New("there's no such key existing in cache") + +type Cache interface { + // Set puts key and value into cache. + // First parameter for extra should be uint denoting expected survival time. + // If survival time equals 0 or less, the key will always be survival. + Set(key string, value bool, extra ...interface{}) error + + // Get returns result for key, + // If there's no such key existing in cache, + // ErrNoSuchKey will be returned. + Get(key string) (bool, error) + + // Delete will remove the specific key in cache. + // If there's no such key existing in cache, + // ErrNoSuchKey will be returned. + Delete(key string) error + + // Clear deletes all the items stored in cache. + Clear() error +} diff --git a/vendor/github.com/casbin/casbin/v2/persist/cache/default-cache.go b/vendor/github.com/casbin/casbin/v2/persist/cache/default-cache.go new file mode 100644 index 00000000..0d6e2478 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/persist/cache/default-cache.go @@ -0,0 +1,44 @@ +// Copyright 2021 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +type DefaultCache map[string]bool + +func (c *DefaultCache) Set(key string, value bool, extra ...interface{}) error { + (*c)[key] = value + return nil +} + +func (c *DefaultCache) Get(key string) (bool, error) { + if res, ok := (*c)[key]; !ok { + return false, ErrNoSuchKey + } else { + return res, nil + } +} + +func (c *DefaultCache) Delete(key string) error { + if _, ok := (*c)[key]; !ok { + return ErrNoSuchKey + } else { + delete(*c, key) + return nil + } +} + +func (c *DefaultCache) Clear() error { + *c = make(DefaultCache) + return nil +} diff --git a/vendor/github.com/casbin/casbin/v2/persist/dispatcher.go b/vendor/github.com/casbin/casbin/v2/persist/dispatcher.go new file mode 100644 index 00000000..3eb4605b --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/persist/dispatcher.go @@ -0,0 +1,33 @@ +// Copyright 2020 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package persist + +// Dispatcher is the interface for Casbin dispatcher +type Dispatcher interface { + // AddPolicies adds policies rule to all instance. + AddPolicies(sec string, ptype string, rules [][]string) error + // RemovePolicies removes policies rule from all instance. + RemovePolicies(sec string, ptype string, rules [][]string) error + // RemoveFilteredPolicy removes policy rules that match the filter from all instance. + RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error + // ClearPolicy clears all current policy in all instances + ClearPolicy() error + // UpdatePolicy updates policy rule from all instance. + UpdatePolicy(sec string, ptype string, oldRule, newRule []string) error + // UpdatePolicies updates some policy rules from all instance + UpdatePolicies(sec string, ptype string, oldrules, newRules [][]string) error + // UpdateFilteredPolicies deletes old rules and adds new rules. + UpdateFilteredPolicies(sec string, ptype string, oldRules [][]string, newRules [][]string) error +} diff --git a/vendor/github.com/casbin/casbin/v2/persist/file-adapter/adapter.go b/vendor/github.com/casbin/casbin/v2/persist/file-adapter/adapter.go new file mode 100644 index 00000000..c68f0eaa --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/persist/file-adapter/adapter.go @@ -0,0 +1,149 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fileadapter + +import ( + "bufio" + "bytes" + "errors" + "os" + "strings" + + "github.com/casbin/casbin/v2/model" + "github.com/casbin/casbin/v2/persist" + "github.com/casbin/casbin/v2/util" +) + +// Adapter is the file adapter for Casbin. +// It can load policy from file or save policy to file. +type Adapter struct { + filePath string +} + +func (a *Adapter) UpdatePolicy(sec string, ptype string, oldRule, newRule []string) error { + return errors.New("not implemented") +} + +func (a *Adapter) UpdatePolicies(sec string, ptype string, oldRules, newRules [][]string) error { + return errors.New("not implemented") +} + +func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newRules [][]string, fieldIndex int, fieldValues ...string) ([][]string, error) { + return nil, errors.New("not implemented") +} + +// NewAdapter is the constructor for Adapter. +func NewAdapter(filePath string) *Adapter { + return &Adapter{filePath: filePath} +} + +// LoadPolicy loads all policy rules from the storage. +func (a *Adapter) LoadPolicy(model model.Model) error { + if a.filePath == "" { + return errors.New("invalid file path, file path cannot be empty") + } + + return a.loadPolicyFile(model, persist.LoadPolicyLine) +} + +// SavePolicy saves all policy rules to the storage. +func (a *Adapter) SavePolicy(model model.Model) error { + if a.filePath == "" { + return errors.New("invalid file path, file path cannot be empty") + } + + var tmp bytes.Buffer + + for ptype, ast := range model["p"] { + for _, rule := range ast.Policy { + tmp.WriteString(ptype + ", ") + tmp.WriteString(util.ArrayToString(rule)) + tmp.WriteString("\n") + } + } + + for ptype, ast := range model["g"] { + for _, rule := range ast.Policy { + tmp.WriteString(ptype + ", ") + tmp.WriteString(util.ArrayToString(rule)) + tmp.WriteString("\n") + } + } + + return a.savePolicyFile(strings.TrimRight(tmp.String(), "\n")) +} + +func (a *Adapter) loadPolicyFile(model model.Model, handler func(string, model.Model) error) error { + f, err := os.Open(a.filePath) + if err != nil { + return err + } + defer f.Close() + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + err = handler(line, model) + if err != nil { + return err + } + } + return scanner.Err() +} + +func (a *Adapter) savePolicyFile(text string) error { + f, err := os.Create(a.filePath) + if err != nil { + return err + } + w := bufio.NewWriter(f) + + _, err = w.WriteString(text) + if err != nil { + return err + } + + err = w.Flush() + if err != nil { + return err + } + + return f.Close() +} + +// AddPolicy adds a policy rule to the storage. +func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error { + return errors.New("not implemented") +} + +// AddPolicies adds policy rules to the storage. +func (a *Adapter) AddPolicies(sec string, ptype string, rules [][]string) error { + return errors.New("not implemented") +} + +// RemovePolicy removes a policy rule from the storage. +func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error { + return errors.New("not implemented") +} + +// RemovePolicies removes policy rules from the storage. +func (a *Adapter) RemovePolicies(sec string, ptype string, rules [][]string) error { + return errors.New("not implemented") +} + +// RemoveFilteredPolicy removes policy rules that match the filter from the storage. +func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error { + return errors.New("not implemented") +} diff --git a/vendor/github.com/casbin/casbin/v2/persist/file-adapter/adapter_filtered.go b/vendor/github.com/casbin/casbin/v2/persist/file-adapter/adapter_filtered.go new file mode 100644 index 00000000..1a074c9a --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/persist/file-adapter/adapter_filtered.go @@ -0,0 +1,156 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fileadapter + +import ( + "bufio" + "errors" + "os" + "strings" + + "github.com/casbin/casbin/v2/model" + "github.com/casbin/casbin/v2/persist" +) + +// FilteredAdapter is the filtered file adapter for Casbin. It can load policy +// from file or save policy to file and supports loading of filtered policies. +type FilteredAdapter struct { + *Adapter + filtered bool +} + +// Filter defines the filtering rules for a FilteredAdapter's policy. Empty values +// are ignored, but all others must match the filter. +type Filter struct { + P []string + G []string + G1 []string + G2 []string + G3 []string + G4 []string + G5 []string +} + +// NewFilteredAdapter is the constructor for FilteredAdapter. +func NewFilteredAdapter(filePath string) *FilteredAdapter { + a := FilteredAdapter{} + a.filtered = true + a.Adapter = NewAdapter(filePath) + return &a +} + +// LoadPolicy loads all policy rules from the storage. +func (a *FilteredAdapter) LoadPolicy(model model.Model) error { + a.filtered = false + return a.Adapter.LoadPolicy(model) +} + +// LoadFilteredPolicy loads only policy rules that match the filter. +func (a *FilteredAdapter) LoadFilteredPolicy(model model.Model, filter interface{}) error { + if filter == nil { + return a.LoadPolicy(model) + } + if a.filePath == "" { + return errors.New("invalid file path, file path cannot be empty") + } + + filterValue, ok := filter.(*Filter) + if !ok { + return errors.New("invalid filter type") + } + err := a.loadFilteredPolicyFile(model, filterValue, persist.LoadPolicyLine) + if err == nil { + a.filtered = true + } + return err +} + +func (a *FilteredAdapter) loadFilteredPolicyFile(model model.Model, filter *Filter, handler func(string, model.Model) error) error { + f, err := os.Open(a.filePath) + if err != nil { + return err + } + defer f.Close() + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + if filterLine(line, filter) { + continue + } + + err = handler(line, model) + if err != nil { + return err + } + } + return scanner.Err() +} + +// IsFiltered returns true if the loaded policy has been filtered. +func (a *FilteredAdapter) IsFiltered() bool { + return a.filtered +} + +// SavePolicy saves all policy rules to the storage. +func (a *FilteredAdapter) SavePolicy(model model.Model) error { + if a.filtered { + return errors.New("cannot save a filtered policy") + } + return a.Adapter.SavePolicy(model) +} + +func filterLine(line string, filter *Filter) bool { + if filter == nil { + return false + } + p := strings.Split(line, ",") + if len(p) == 0 { + return true + } + var filterSlice []string + switch strings.TrimSpace(p[0]) { + case "p": + filterSlice = filter.P + case "g": + filterSlice = filter.G + case "g1": + filterSlice = filter.G1 + case "g2": + filterSlice = filter.G2 + case "g3": + filterSlice = filter.G3 + case "g4": + filterSlice = filter.G4 + case "g5": + filterSlice = filter.G5 + } + return filterWords(p, filterSlice) +} + +func filterWords(line []string, filter []string) bool { + if len(line) < len(filter)+1 { + return true + } + var skipLine bool + for i, v := range filter { + if len(v) > 0 && strings.TrimSpace(v) != strings.TrimSpace(line[i+1]) { + skipLine = true + break + } + } + return skipLine +} diff --git a/vendor/github.com/casbin/casbin/v2/persist/file-adapter/adapter_mock.go b/vendor/github.com/casbin/casbin/v2/persist/file-adapter/adapter_mock.go new file mode 100644 index 00000000..8f8632b8 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/persist/file-adapter/adapter_mock.go @@ -0,0 +1,122 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fileadapter + +import ( + "bufio" + "errors" + "io" + "os" + "strings" + + "github.com/casbin/casbin/v2/model" + "github.com/casbin/casbin/v2/persist" +) + +// AdapterMock is the file adapter for Casbin. +// It can load policy from file or save policy to file. +type AdapterMock struct { + filePath string + errorValue string +} + +// NewAdapterMock is the constructor for AdapterMock. +func NewAdapterMock(filePath string) *AdapterMock { + a := AdapterMock{} + a.filePath = filePath + return &a +} + +// LoadPolicy loads all policy rules from the storage. +func (a *AdapterMock) LoadPolicy(model model.Model) error { + err := a.loadPolicyFile(model, persist.LoadPolicyLine) + return err +} + +// SavePolicy saves all policy rules to the storage. +func (a *AdapterMock) SavePolicy(model model.Model) error { + return nil +} + +func (a *AdapterMock) loadPolicyFile(model model.Model, handler func(string, model.Model) error) error { + f, err := os.Open(a.filePath) + if err != nil { + return err + } + defer f.Close() + + buf := bufio.NewReader(f) + for { + line, err := buf.ReadString('\n') + line = strings.TrimSpace(line) + if err2 := handler(line, model); err2 != nil { + return err2 + } + if err != nil { + if err == io.EOF { + return nil + } + return err + } + } +} + +// SetMockErr sets string to be returned by of the mock during testing +func (a *AdapterMock) SetMockErr(errorToSet string) { + a.errorValue = errorToSet +} + +// GetMockErr returns a mock error or nil +func (a *AdapterMock) GetMockErr() error { + var returnError error + if a.errorValue != "" { + returnError = errors.New(a.errorValue) + } + return returnError +} + +// AddPolicy adds a policy rule to the storage. +func (a *AdapterMock) AddPolicy(sec string, ptype string, rule []string) error { + return a.GetMockErr() +} + +// AddPolicies removes policy rules from the storage. +func (a *AdapterMock) AddPolicies(sec string, ptype string, rules [][]string) error { + return a.GetMockErr() +} + +// RemovePolicy removes a policy rule from the storage. +func (a *AdapterMock) RemovePolicy(sec string, ptype string, rule []string) error { + return a.GetMockErr() +} + +// RemovePolicies removes policy rules from the storage. +func (a *AdapterMock) RemovePolicies(sec string, ptype string, rules [][]string) error { + return a.GetMockErr() +} + +// UpdatePolicy removes a policy rule from the storage. +func (a *AdapterMock) UpdatePolicy(sec string, ptype string, oldRule, newPolicy []string) error { + return a.GetMockErr() +} + +func (a *AdapterMock) UpdatePolicies(sec string, ptype string, oldRules, newRules [][]string) error { + return a.GetMockErr() +} + +// RemoveFilteredPolicy removes policy rules that match the filter from the storage. +func (a *AdapterMock) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error { + return a.GetMockErr() +} diff --git a/vendor/github.com/casbin/casbin/v2/persist/update_adapter.go b/vendor/github.com/casbin/casbin/v2/persist/update_adapter.go new file mode 100644 index 00000000..fe9204af --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/persist/update_adapter.go @@ -0,0 +1,27 @@ +// Copyright 2020 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package persist + +// UpdatableAdapter is the interface for Casbin adapters with add update policy function. +type UpdatableAdapter interface { + Adapter + // UpdatePolicy updates a policy rule from storage. + // This is part of the Auto-Save feature. + UpdatePolicy(sec string, ptype string, oldRule, newRule []string) error + // UpdatePolicies updates some policy rules to storage, like db, redis. + UpdatePolicies(sec string, ptype string, oldRules, newRules [][]string) error + // UpdateFilteredPolicies deletes old rules and adds new rules. + UpdateFilteredPolicies(sec string, ptype string, newRules [][]string, fieldIndex int, fieldValues ...string) ([][]string, error) +} diff --git a/vendor/github.com/casbin/casbin/v2/persist/watcher.go b/vendor/github.com/casbin/casbin/v2/persist/watcher.go new file mode 100644 index 00000000..0d843606 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/persist/watcher.go @@ -0,0 +1,29 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package persist + +// Watcher is the interface for Casbin watchers. +type Watcher interface { + // SetUpdateCallback sets the callback function that the watcher will call + // when the policy in DB has been changed by other instances. + // A classic callback is Enforcer.LoadPolicy(). + SetUpdateCallback(func(string)) error + // Update calls the update callback of other instances to synchronize their policy. + // It is usually called after changing the policy in DB, like Enforcer.SavePolicy(), + // Enforcer.AddPolicy(), Enforcer.RemovePolicy(), etc. + Update() error + // Close stops and releases the watcher, the callback function will not be called any more. + Close() +} diff --git a/vendor/github.com/casbin/casbin/v2/persist/watcher_ex.go b/vendor/github.com/casbin/casbin/v2/persist/watcher_ex.go new file mode 100644 index 00000000..1c6f4299 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/persist/watcher_ex.go @@ -0,0 +1,40 @@ +// Copyright 2020 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package persist + +import "github.com/casbin/casbin/v2/model" + +// WatcherEx is the strengthened Casbin watchers. +type WatcherEx interface { + Watcher + // UpdateForAddPolicy calls the update callback of other instances to synchronize their policy. + // It is called after Enforcer.AddPolicy() + UpdateForAddPolicy(sec, ptype string, params ...string) error + // UpdateForRemovePolicy calls the update callback of other instances to synchronize their policy. + // It is called after Enforcer.RemovePolicy() + UpdateForRemovePolicy(sec, ptype string, params ...string) error + // UpdateForRemoveFilteredPolicy calls the update callback of other instances to synchronize their policy. + // It is called after Enforcer.RemoveFilteredNamedGroupingPolicy() + UpdateForRemoveFilteredPolicy(sec, ptype string, fieldIndex int, fieldValues ...string) error + // UpdateForSavePolicy calls the update callback of other instances to synchronize their policy. + // It is called after Enforcer.RemoveFilteredNamedGroupingPolicy() + UpdateForSavePolicy(model model.Model) error + // UpdateForAddPolicies calls the update callback of other instances to synchronize their policy. + // It is called after Enforcer.AddPolicies() + UpdateForAddPolicies(sec string, ptype string, rules ...[]string) error + // UpdateForRemovePolicies calls the update callback of other instances to synchronize their policy. + // It is called after Enforcer.RemovePolicies() + UpdateForRemovePolicies(sec string, ptype string, rules ...[]string) error +} diff --git a/vendor/github.com/casbin/casbin/v2/persist/watcher_update.go b/vendor/github.com/casbin/casbin/v2/persist/watcher_update.go new file mode 100644 index 00000000..694123c4 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/persist/watcher_update.go @@ -0,0 +1,26 @@ +// Copyright 2020 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package persist + +// UpdatableWatcher is strengthened for Casbin watchers. +type UpdatableWatcher interface { + Watcher + // UpdateForUpdatePolicy calls the update callback of other instances to synchronize their policy. + // It is called after Enforcer.UpdatePolicy() + UpdateForUpdatePolicy(sec string, ptype string, oldRule, newRule []string) error + // UpdateForUpdatePolicies calls the update callback of other instances to synchronize their policy. + // It is called after Enforcer.UpdatePolicies() + UpdateForUpdatePolicies(sec string, ptype string, oldRules, newRules [][]string) error +} diff --git a/vendor/github.com/casbin/casbin/v2/rbac/default-role-manager/role_manager.go b/vendor/github.com/casbin/casbin/v2/rbac/default-role-manager/role_manager.go new file mode 100644 index 00000000..8939a486 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/rbac/default-role-manager/role_manager.go @@ -0,0 +1,701 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package defaultrolemanager + +import ( + "fmt" + "strings" + "sync" + + "github.com/casbin/casbin/v2/errors" + "github.com/casbin/casbin/v2/log" + "github.com/casbin/casbin/v2/rbac" + "github.com/casbin/casbin/v2/util" +) + +const defaultDomain string = "" + +// Role represents the data structure for a role in RBAC. +type Role struct { + name string + roles *sync.Map + users *sync.Map + matched *sync.Map + matchedBy *sync.Map +} + +func newRole(name string) *Role { + r := Role{} + r.name = name + r.roles = &sync.Map{} + r.users = &sync.Map{} + r.matched = &sync.Map{} + r.matchedBy = &sync.Map{} + return &r +} + +func (r *Role) addRole(role *Role) { + r.roles.Store(role.name, role) + role.addUser(r) +} + +func (r *Role) removeRole(role *Role) { + r.roles.Delete(role.name) + role.removeUser(r) +} + +//should only be called inside addRole +func (r *Role) addUser(user *Role) { + r.users.Store(user.name, user) +} + +//should only be called inside removeRole +func (r *Role) removeUser(user *Role) { + r.users.Delete(user.name) +} + +func (r *Role) addMatch(role *Role) { + r.matched.Store(role.name, role) + role.matchedBy.Store(r.name, r) +} + +func (r *Role) removeMatch(role *Role) { + r.matched.Delete(role.name) + role.matchedBy.Delete(r.name) +} + +func (r *Role) removeMatches() { + r.matched.Range(func(key, value interface{}) bool { + r.removeMatch(value.(*Role)) + return true + }) + r.matchedBy.Range(func(key, value interface{}) bool { + value.(*Role).removeMatch(r) + return true + }) +} + +func (r *Role) rangeRoles(fn func(key, value interface{}) bool) { + r.roles.Range(fn) + r.roles.Range(func(key, value interface{}) bool { + role := value.(*Role) + role.matched.Range(fn) + return true + }) + r.matchedBy.Range(func(key, value interface{}) bool { + role := value.(*Role) + role.roles.Range(fn) + return true + }) +} + +func (r *Role) rangeUsers(fn func(key, value interface{}) bool) { + r.users.Range(fn) + r.users.Range(func(key, value interface{}) bool { + role := value.(*Role) + role.matched.Range(fn) + return true + }) + r.matchedBy.Range(func(key, value interface{}) bool { + role := value.(*Role) + role.users.Range(fn) + return true + }) +} + +func (r *Role) toString() string { + roles := r.getRoles() + + if len(roles) == 0 { + return "" + } + + var sb strings.Builder + sb.WriteString(r.name) + sb.WriteString(" < ") + if len(roles) != 1 { + sb.WriteString("(") + } + + for i, role := range roles { + if i == 0 { + sb.WriteString(role) + } else { + sb.WriteString(", ") + sb.WriteString(role) + } + } + + if len(roles) != 1 { + sb.WriteString(")") + } + + return sb.String() +} + +func (r *Role) getRoles() []string { + var names []string + r.rangeRoles(func(key, value interface{}) bool { + names = append(names, key.(string)) + return true + }) + return util.RemoveDuplicateElement(names) +} + +func (r *Role) getUsers() []string { + var names []string + r.rangeUsers(func(key, value interface{}) bool { + names = append(names, key.(string)) + return true + }) + return names +} + +// RoleManagerImpl provides a default implementation for the RoleManager interface +type RoleManagerImpl struct { + allRoles *sync.Map + maxHierarchyLevel int + matchingFunc rbac.MatchingFunc + domainMatchingFunc rbac.MatchingFunc + logger log.Logger + matchingFuncCache *util.SyncLRUCache +} + +// NewRoleManagerImpl is the constructor for creating an instance of the +// default RoleManager implementation. +func NewRoleManagerImpl(maxHierarchyLevel int) *RoleManagerImpl { + rm := RoleManagerImpl{} + _ = rm.Clear() //init allRoles and matchingFuncCache + rm.maxHierarchyLevel = maxHierarchyLevel + rm.SetLogger(&log.DefaultLogger{}) + return &rm +} + +// use this constructor to avoid rebuild of AddMatchingFunc +func newRoleManagerWithMatchingFunc(maxHierarchyLevel int, fn rbac.MatchingFunc) *RoleManagerImpl { + rm := NewRoleManagerImpl(maxHierarchyLevel) + rm.matchingFunc = fn + return rm +} + +// rebuilds role cache +func (rm *RoleManagerImpl) rebuild() { + roles := rm.allRoles + _ = rm.Clear() + rangeLinks(roles, func(name1, name2 string, domain ...string) bool { + _ = rm.AddLink(name1, name2, domain...) + return true + }) +} + +func (rm *RoleManagerImpl) Match(str string, pattern string) bool { + cacheKey := strings.Join([]string{str, pattern}, "$$") + if v, has := rm.matchingFuncCache.Get(cacheKey); has { + return v.(bool) + } else { + var matched bool + if rm.matchingFunc != nil { + matched = rm.matchingFunc(str, pattern) + } else { + matched = str == pattern + } + rm.matchingFuncCache.Put(cacheKey, matched) + return matched + } +} + +func (rm *RoleManagerImpl) rangeMatchingRoles(name string, isPattern bool, fn func(role *Role) bool) { + rm.allRoles.Range(func(key, value interface{}) bool { + name2 := key.(string) + if isPattern && name != name2 && rm.Match(name2, name) { + fn(value.(*Role)) + } else if !isPattern && name != name2 && rm.Match(name, name2) { + fn(value.(*Role)) + } + return true + }) +} + +func (rm *RoleManagerImpl) load(name interface{}) (value *Role, ok bool) { + if r, ok := rm.allRoles.Load(name); ok { + return r.(*Role), true + } + return nil, false +} + +// loads or creates a role +func (rm *RoleManagerImpl) getRole(name string) (r *Role, created bool) { + var role *Role + var ok bool + + if role, ok = rm.load(name); !ok { + role = newRole(name) + rm.allRoles.Store(name, role) + + if rm.matchingFunc != nil { + rm.rangeMatchingRoles(name, false, func(r *Role) bool { + r.addMatch(role) + return true + }) + + rm.rangeMatchingRoles(name, true, func(r *Role) bool { + role.addMatch(r) + return true + }) + } + } + + return role, !ok +} + +func loadAndDelete(m *sync.Map, name string) (value interface{}, loaded bool) { + value, loaded = m.Load(name) + if loaded { + m.Delete(name) + } + return value, loaded +} + +func (rm *RoleManagerImpl) removeRole(name string) { + if role, ok := loadAndDelete(rm.allRoles, name); ok { + role.(*Role).removeMatches() + } +} + +// AddMatchingFunc support use pattern in g +func (rm *RoleManagerImpl) AddMatchingFunc(name string, fn rbac.MatchingFunc) { + rm.matchingFunc = fn + rm.rebuild() +} + +// AddDomainMatchingFunc support use domain pattern in g +func (rm *RoleManagerImpl) AddDomainMatchingFunc(name string, fn rbac.MatchingFunc) { + rm.domainMatchingFunc = fn +} + +// SetLogger sets role manager's logger. +func (rm *RoleManagerImpl) SetLogger(logger log.Logger) { + rm.logger = logger +} + +// Clear clears all stored data and resets the role manager to the initial state. +func (rm *RoleManagerImpl) Clear() error { + rm.matchingFuncCache = util.NewSyncLRUCache(100) + rm.allRoles = &sync.Map{} + return nil +} + +// AddLink adds the inheritance link between role: name1 and role: name2. +// aka role: name1 inherits role: name2. +func (rm *RoleManagerImpl) AddLink(name1 string, name2 string, domains ...string) error { + user, _ := rm.getRole(name1) + role, _ := rm.getRole(name2) + user.addRole(role) + return nil +} + +// DeleteLink deletes the inheritance link between role: name1 and role: name2. +// aka role: name1 does not inherit role: name2 any more. +func (rm *RoleManagerImpl) DeleteLink(name1 string, name2 string, domains ...string) error { + user, _ := rm.getRole(name1) + role, _ := rm.getRole(name2) + user.removeRole(role) + return nil +} + +// HasLink determines whether role: name1 inherits role: name2. +func (rm *RoleManagerImpl) HasLink(name1 string, name2 string, domains ...string) (bool, error) { + if name1 == name2 || (rm.matchingFunc != nil && rm.Match(name1, name2)) { + return true, nil + } + + user, userCreated := rm.getRole(name1) + role, roleCreated := rm.getRole(name2) + + if userCreated { + defer rm.removeRole(user.name) + } + if roleCreated { + defer rm.removeRole(role.name) + } + + return rm.hasLinkHelper(role.name, map[string]*Role{user.name: user}, rm.maxHierarchyLevel), nil +} + +func (rm *RoleManagerImpl) hasLinkHelper(targetName string, roles map[string]*Role, level int) bool { + if level < 0 || len(roles) == 0 { + return false + } + + nextRoles := map[string]*Role{} + for _, role := range roles { + if targetName == role.name || (rm.matchingFunc != nil && rm.Match(role.name, targetName)) { + return true + } + role.rangeRoles(func(key, value interface{}) bool { + nextRoles[key.(string)] = value.(*Role) + return true + }) + } + + return rm.hasLinkHelper(targetName, nextRoles, level-1) +} + +// GetRoles gets the roles that a user inherits. +func (rm *RoleManagerImpl) GetRoles(name string, domains ...string) ([]string, error) { + user, created := rm.getRole(name) + if created { + defer rm.removeRole(user.name) + } + return user.getRoles(), nil +} + +// GetUsers gets the users of a role. +// domain is an unreferenced parameter here, may be used in other implementations. +func (rm *RoleManagerImpl) GetUsers(name string, domain ...string) ([]string, error) { + role, created := rm.getRole(name) + if created { + defer rm.removeRole(role.name) + } + return role.getUsers(), nil +} + +func (rm *RoleManagerImpl) toString() []string { + var roles []string + + rm.allRoles.Range(func(key, value interface{}) bool { + role := value.(*Role) + if text := role.toString(); text != "" { + roles = append(roles, text) + } + return true + }) + + return roles +} + +// PrintRoles prints all the roles to log. +func (rm *RoleManagerImpl) PrintRoles() error { + if !(rm.logger).IsEnabled() { + return nil + } + roles := rm.toString() + rm.logger.LogRole(roles) + return nil +} + +// GetDomains gets domains that a user has +func (rm *RoleManagerImpl) GetDomains(name string) ([]string, error) { + domains := []string{defaultDomain} + return domains, nil +} + +// GetAllDomains gets all domains +func (rm *RoleManagerImpl) GetAllDomains() ([]string, error) { + domains := []string{defaultDomain} + return domains, nil +} + +func (rm *RoleManagerImpl) copyFrom(other *RoleManagerImpl) { + other.Range(func(name1, name2 string, domain ...string) bool { + _ = rm.AddLink(name1, name2, domain...) + return true + }) +} + +func rangeLinks(users *sync.Map, fn func(name1, name2 string, domain ...string) bool) { + users.Range(func(_, value interface{}) bool { + user := value.(*Role) + user.roles.Range(func(key, _ interface{}) bool { + roleName := key.(string) + return fn(user.name, roleName, defaultDomain) + }) + return true + }) +} + +func (rm *RoleManagerImpl) Range(fn func(name1, name2 string, domain ...string) bool) { + rangeLinks(rm.allRoles, fn) +} + +// Deprecated: BuildRelationship is no longer required +func (rm *RoleManagerImpl) BuildRelationship(name1 string, name2 string, domain ...string) error { + return nil +} + +type DomainManager struct { + rmMap *sync.Map + maxHierarchyLevel int + matchingFunc rbac.MatchingFunc + domainMatchingFunc rbac.MatchingFunc + logger log.Logger + matchingFuncCache *util.SyncLRUCache +} + +// NewDomainManager is the constructor for creating an instance of the +// default DomainManager implementation. +func NewDomainManager(maxHierarchyLevel int) *DomainManager { + dm := &DomainManager{} + _ = dm.Clear() // init rmMap and rmCache + dm.maxHierarchyLevel = maxHierarchyLevel + return dm +} + +// SetLogger sets role manager's logger. +func (dm *DomainManager) SetLogger(logger log.Logger) { + dm.logger = logger +} + +// AddMatchingFunc support use pattern in g +func (dm *DomainManager) AddMatchingFunc(name string, fn rbac.MatchingFunc) { + dm.matchingFunc = fn + dm.rmMap.Range(func(key, value interface{}) bool { + value.(*RoleManagerImpl).AddMatchingFunc(name, fn) + return true + }) +} + +// AddDomainMatchingFunc support use domain pattern in g +func (dm *DomainManager) AddDomainMatchingFunc(name string, fn rbac.MatchingFunc) { + dm.domainMatchingFunc = fn + dm.rmMap.Range(func(key, value interface{}) bool { + value.(*RoleManagerImpl).AddDomainMatchingFunc(name, fn) + return true + }) + dm.rebuild() +} + +// clears the map of RoleManagers +func (dm *DomainManager) rebuild() { + rmMap := dm.rmMap + _ = dm.Clear() + rmMap.Range(func(key, value interface{}) bool { + domain := key.(string) + rm := value.(*RoleManagerImpl) + + rm.Range(func(name1, name2 string, _ ...string) bool { + _ = dm.AddLink(name1, name2, domain) + return true + }) + return true + }) +} + +//Clear clears all stored data and resets the role manager to the initial state. +func (dm *DomainManager) Clear() error { + dm.rmMap = &sync.Map{} + dm.matchingFuncCache = util.NewSyncLRUCache(100) + return nil +} + +func (dm *DomainManager) getDomain(domains ...string) (domain string, err error) { + switch len(domains) { + case 0: + return defaultDomain, nil + case 1: + return domains[0], nil + default: + return "", errors.ERR_DOMAIN_PARAMETER + } +} + +func (dm *DomainManager) Match(str string, pattern string) bool { + cacheKey := strings.Join([]string{str, pattern}, "$$") + if v, has := dm.matchingFuncCache.Get(cacheKey); has { + return v.(bool) + } else { + var matched bool + if dm.domainMatchingFunc != nil { + matched = dm.domainMatchingFunc(str, pattern) + } else { + matched = str == pattern + } + dm.matchingFuncCache.Put(cacheKey, matched) + return matched + } +} + +func (dm *DomainManager) rangeAffectedRoleManagers(domain string, fn func(rm *RoleManagerImpl)) { + if dm.domainMatchingFunc != nil { + dm.rmMap.Range(func(key, value interface{}) bool { + domain2 := key.(string) + if domain != domain2 && dm.Match(domain2, domain) { + fn(value.(*RoleManagerImpl)) + } + return true + }) + } +} + +func (dm *DomainManager) load(name interface{}) (value *RoleManagerImpl, ok bool) { + if r, ok := dm.rmMap.Load(name); ok { + return r.(*RoleManagerImpl), true + } + return nil, false +} + +// load or create a RoleManager instance of domain +func (dm *DomainManager) getRoleManager(domain string, store bool) *RoleManagerImpl { + var rm *RoleManagerImpl + var ok bool + + if rm, ok = dm.load(domain); !ok { + rm = newRoleManagerWithMatchingFunc(dm.maxHierarchyLevel, dm.matchingFunc) + if store { + dm.rmMap.Store(domain, rm) + } + if dm.domainMatchingFunc != nil { + dm.rmMap.Range(func(key, value interface{}) bool { + domain2 := key.(string) + rm2 := value.(*RoleManagerImpl) + if domain != domain2 && dm.Match(domain, domain2) { + rm.copyFrom(rm2) + } + return true + }) + } + } + return rm +} + +// AddLink adds the inheritance link between role: name1 and role: name2. +// aka role: name1 inherits role: name2. +func (dm *DomainManager) AddLink(name1 string, name2 string, domains ...string) error { + domain, err := dm.getDomain(domains...) + if err != nil { + return err + } + roleManager := dm.getRoleManager(domain, true) //create role manager if it does not exist + _ = roleManager.AddLink(name1, name2, domains...) + + dm.rangeAffectedRoleManagers(domain, func(rm *RoleManagerImpl) { + _ = rm.AddLink(name1, name2, domains...) + }) + return nil +} + +// DeleteLink deletes the inheritance link between role: name1 and role: name2. +// aka role: name1 does not inherit role: name2 any more. +func (dm *DomainManager) DeleteLink(name1 string, name2 string, domains ...string) error { + domain, err := dm.getDomain(domains...) + if err != nil { + return err + } + roleManager := dm.getRoleManager(domain, true) //create role manager if it does not exist + _ = roleManager.DeleteLink(name1, name2, domains...) + + dm.rangeAffectedRoleManagers(domain, func(rm *RoleManagerImpl) { + _ = rm.DeleteLink(name1, name2, domains...) + }) + return nil +} + +// HasLink determines whether role: name1 inherits role: name2. +func (dm *DomainManager) HasLink(name1 string, name2 string, domains ...string) (bool, error) { + domain, err := dm.getDomain(domains...) + if err != nil { + return false, err + } + rm := dm.getRoleManager(domain, false) + return rm.HasLink(name1, name2, domains...) +} + +// GetRoles gets the roles that a subject inherits. +func (dm *DomainManager) GetRoles(name string, domains ...string) ([]string, error) { + domain, err := dm.getDomain(domains...) + if err != nil { + return nil, err + } + rm := dm.getRoleManager(domain, false) + return rm.GetRoles(name, domains...) +} + +// GetUsers gets the users of a role. +func (dm *DomainManager) GetUsers(name string, domains ...string) ([]string, error) { + domain, err := dm.getDomain(domains...) + if err != nil { + return nil, err + } + rm := dm.getRoleManager(domain, false) + return rm.GetUsers(name, domains...) +} + +func (dm *DomainManager) toString() []string { + var roles []string + + dm.rmMap.Range(func(key, value interface{}) bool { + domain := key.(string) + rm := value.(*RoleManagerImpl) + domainRoles := rm.toString() + roles = append(roles, fmt.Sprintf("%s: %s", domain, strings.Join(domainRoles, ", "))) + return true + }) + + return roles +} + +// PrintRoles prints all the roles to log. +func (dm *DomainManager) PrintRoles() error { + if !(dm.logger).IsEnabled() { + return nil + } + + roles := dm.toString() + dm.logger.LogRole(roles) + return nil +} + +// GetDomains gets domains that a user has +func (dm *DomainManager) GetDomains(name string) ([]string, error) { + var domains []string + dm.rmMap.Range(func(key, value interface{}) bool { + domain := key.(string) + rm := value.(*RoleManagerImpl) + role, created := rm.getRole(name) + if created { + defer rm.removeRole(role.name) + } + if len(role.getUsers()) > 0 || len(role.getRoles()) > 0 { + domains = append(domains, domain) + } + return true + }) + return domains, nil +} + +// GetAllDomains gets all domains +func (rm *DomainManager) GetAllDomains() ([]string, error) { + var domains []string + rm.rmMap.Range(func(key, value interface{}) bool { + domains = append(domains, key.(string)) + return true + }) + return domains, nil +} + +// Deprecated: BuildRelationship is no longer required +func (rm *DomainManager) BuildRelationship(name1 string, name2 string, domain ...string) error { + return nil +} + +type RoleManager struct { + *DomainManager +} + +func NewRoleManager(maxHierarchyLevel int) *RoleManager { + rm := &RoleManager{} + rm.DomainManager = NewDomainManager(maxHierarchyLevel) + return rm +} diff --git a/vendor/github.com/casbin/casbin/v2/rbac/role_manager.go b/vendor/github.com/casbin/casbin/v2/rbac/role_manager.go new file mode 100644 index 00000000..18b4ca8f --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/rbac/role_manager.go @@ -0,0 +1,91 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rbac + +import ( + "context" + + "github.com/casbin/casbin/v2/log" +) + +type MatchingFunc func(arg1 string, arg2 string) bool + +// RoleManager provides interface to define the operations for managing roles. +type RoleManager interface { + // Clear clears all stored data and resets the role manager to the initial state. + Clear() error + // AddLink adds the inheritance link between two roles. role: name1 and role: name2. + // domain is a prefix to the roles (can be used for other purposes). + AddLink(name1 string, name2 string, domain ...string) error + // Deprecated: BuildRelationship is no longer required + BuildRelationship(name1 string, name2 string, domain ...string) error + // DeleteLink deletes the inheritance link between two roles. role: name1 and role: name2. + // domain is a prefix to the roles (can be used for other purposes). + DeleteLink(name1 string, name2 string, domain ...string) error + // HasLink determines whether a link exists between two roles. role: name1 inherits role: name2. + // domain is a prefix to the roles (can be used for other purposes). + HasLink(name1 string, name2 string, domain ...string) (bool, error) + // GetRoles gets the roles that a user inherits. + // domain is a prefix to the roles (can be used for other purposes). + GetRoles(name string, domain ...string) ([]string, error) + // GetUsers gets the users that inherits a role. + // domain is a prefix to the users (can be used for other purposes). + GetUsers(name string, domain ...string) ([]string, error) + // GetDomains gets domains that a user has + GetDomains(name string) ([]string, error) + // GetAllDomains gets all domains + GetAllDomains() ([]string, error) + // PrintRoles prints all the roles to log. + PrintRoles() error + // SetLogger sets role manager's logger. + SetLogger(logger log.Logger) + // Match matches the domain with the pattern + Match(str string, pattern string) bool + // AddMatchingFunc adds the matching function + AddMatchingFunc(name string, fn MatchingFunc) + // AddDomainMatchingFunc adds the domain matching function + AddDomainMatchingFunc(name string, fn MatchingFunc) +} + +// RoleManagerWithContext provides a context-aware interface to define the operations for managing roles. +// Prefer this over RoleManager interface for context propagation, which is useful for things like handling +// request timeouts. +type RoleManagerWithContext interface { + // Clear clears all stored data and resets the role manager to the initial state. + Clear(ctx context.Context) error + // AddLink adds the inheritance link between two roles. role: name1 and role: name2. + // domain is a prefix to the roles (can be used for other purposes). + AddLink(ctx context.Context, name1 string, name2 string, domain ...string) error + // DeleteLink deletes the inheritance link between two roles. role: name1 and role: name2. + // domain is a prefix to the roles (can be used for other purposes). + DeleteLink(ctx context.Context, name1 string, name2 string, domain ...string) error + // HasLink determines whether a link exists between two roles. role: name1 inherits role: name2. + // domain is a prefix to the roles (can be used for other purposes). + HasLink(ctx context.Context, name1 string, name2 string, domain ...string) (bool, error) + // GetRoles gets the roles that a user inherits. + // domain is a prefix to the roles (can be used for other purposes). + GetRoles(ctx context.Context, name string, domain ...string) ([]string, error) + // GetUsers gets the users that inherits a role. + // domain is a prefix to the users (can be used for other purposes). + GetUsers(ctx context.Context, name string, domain ...string) ([]string, error) + // GetDomains gets domains that a user has + GetDomains(ctx context.Context, name string) ([]string, error) + // GetAllDomains gets all domains + GetAllDomains(ctx context.Context) ([]string, error) + // PrintRoles prints all the roles to log. + PrintRoles() error + // SetLogger sets role manager's logger. + SetLogger(logger log.Logger) +} diff --git a/vendor/github.com/casbin/casbin/v2/rbac_api.go b/vendor/github.com/casbin/casbin/v2/rbac_api.go new file mode 100644 index 00000000..7a44311a --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/rbac_api.go @@ -0,0 +1,416 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +import ( + "github.com/casbin/casbin/v2/constant" + "github.com/casbin/casbin/v2/errors" + "github.com/casbin/casbin/v2/util" +) + +// GetRolesForUser gets the roles that a user has. +func (e *Enforcer) GetRolesForUser(name string, domain ...string) ([]string, error) { + res, err := e.model["g"]["g"].RM.GetRoles(name, domain...) + return res, err +} + +// GetUsersForRole gets the users that has a role. +func (e *Enforcer) GetUsersForRole(name string, domain ...string) ([]string, error) { + res, err := e.model["g"]["g"].RM.GetUsers(name, domain...) + return res, err +} + +// HasRoleForUser determines whether a user has a role. +func (e *Enforcer) HasRoleForUser(name string, role string, domain ...string) (bool, error) { + roles, err := e.GetRolesForUser(name, domain...) + if err != nil { + return false, err + } + hasRole := false + for _, r := range roles { + if r == role { + hasRole = true + break + } + } + + return hasRole, nil +} + +// AddRoleForUser adds a role for a user. +// Returns false if the user already has the role (aka not affected). +func (e *Enforcer) AddRoleForUser(user string, role string, domain ...string) (bool, error) { + args := []string{user, role} + args = append(args, domain...) + return e.AddGroupingPolicy(args) +} + +// AddRolesForUser adds roles for a user. +// Returns false if the user already has the roles (aka not affected). +func (e *Enforcer) AddRolesForUser(user string, roles []string, domain ...string) (bool, error) { + var rules [][]string + for _, role := range roles { + rule := []string{user, role} + rule = append(rule, domain...) + rules = append(rules, rule) + } + return e.AddGroupingPolicies(rules) +} + +// DeleteRoleForUser deletes a role for a user. +// Returns false if the user does not have the role (aka not affected). +func (e *Enforcer) DeleteRoleForUser(user string, role string, domain ...string) (bool, error) { + args := []string{user, role} + args = append(args, domain...) + return e.RemoveGroupingPolicy(args) +} + +// DeleteRolesForUser deletes all roles for a user. +// Returns false if the user does not have any roles (aka not affected). +func (e *Enforcer) DeleteRolesForUser(user string, domain ...string) (bool, error) { + var args []string + if len(domain) == 0 { + args = []string{user} + } else if len(domain) > 1 { + return false, errors.ERR_DOMAIN_PARAMETER + } else { + args = []string{user, "", domain[0]} + } + return e.RemoveFilteredGroupingPolicy(0, args...) +} + +// DeleteUser deletes a user. +// Returns false if the user does not exist (aka not affected). +func (e *Enforcer) DeleteUser(user string) (bool, error) { + var err error + res1, err := e.RemoveFilteredGroupingPolicy(0, user) + if err != nil { + return res1, err + } + + subIndex, err := e.GetFieldIndex("p", constant.SubjectIndex) + if err != nil { + return false, err + } + res2, err := e.RemoveFilteredPolicy(subIndex, user) + return res1 || res2, err +} + +// DeleteRole deletes a role. +// Returns false if the role does not exist (aka not affected). +func (e *Enforcer) DeleteRole(role string) (bool, error) { + var err error + res1, err := e.RemoveFilteredGroupingPolicy(1, role) + if err != nil { + return res1, err + } + + subIndex, err := e.GetFieldIndex("p", constant.SubjectIndex) + if err != nil { + return false, err + } + res2, err := e.RemoveFilteredPolicy(subIndex, role) + return res1 || res2, err +} + +// DeletePermission deletes a permission. +// Returns false if the permission does not exist (aka not affected). +func (e *Enforcer) DeletePermission(permission ...string) (bool, error) { + return e.RemoveFilteredPolicy(1, permission...) +} + +// AddPermissionForUser adds a permission for a user or role. +// Returns false if the user or role already has the permission (aka not affected). +func (e *Enforcer) AddPermissionForUser(user string, permission ...string) (bool, error) { + return e.AddPolicy(util.JoinSlice(user, permission...)) +} + +// AddPermissionsForUser adds multiple permissions for a user or role. +// Returns false if the user or role already has one of the permissions (aka not affected). +func (e *Enforcer) AddPermissionsForUser(user string, permissions ...[]string) (bool, error) { + var rules [][]string + for _, permission := range permissions { + rules = append(rules, util.JoinSlice(user, permission...)) + } + return e.AddPolicies(rules) +} + +// DeletePermissionForUser deletes a permission for a user or role. +// Returns false if the user or role does not have the permission (aka not affected). +func (e *Enforcer) DeletePermissionForUser(user string, permission ...string) (bool, error) { + return e.RemovePolicy(util.JoinSlice(user, permission...)) +} + +// DeletePermissionsForUser deletes permissions for a user or role. +// Returns false if the user or role does not have any permissions (aka not affected). +func (e *Enforcer) DeletePermissionsForUser(user string) (bool, error) { + subIndex, err := e.GetFieldIndex("p", constant.SubjectIndex) + if err != nil { + return false, err + } + return e.RemoveFilteredPolicy(subIndex, user) +} + +// GetPermissionsForUser gets permissions for a user or role. +func (e *Enforcer) GetPermissionsForUser(user string, domain ...string) [][]string { + return e.GetNamedPermissionsForUser("p", user, domain...) +} + +// GetNamedPermissionsForUser gets permissions for a user or role by named policy. +func (e *Enforcer) GetNamedPermissionsForUser(ptype string, user string, domain ...string) [][]string { + permission := make([][]string, 0) + for pType, assertion := range e.model["p"] { + if pType != ptype { + continue + } + args := make([]string, len(assertion.Tokens)) + subIndex, err := e.GetFieldIndex("p", constant.SubjectIndex) + if err != nil { + subIndex = 0 + } + args[subIndex] = user + + if len(domain) > 0 { + index, err := e.GetFieldIndex(ptype, constant.DomainIndex) + if err != nil { + return permission + } + args[index] = domain[0] + } + perm := e.GetFilteredNamedPolicy(ptype, 0, args...) + permission = append(permission, perm...) + } + return permission +} + +// HasPermissionForUser determines whether a user has a permission. +func (e *Enforcer) HasPermissionForUser(user string, permission ...string) bool { + return e.HasPolicy(util.JoinSlice(user, permission...)) +} + +// GetImplicitRolesForUser gets implicit roles that a user has. +// Compared to GetRolesForUser(), this function retrieves indirect roles besides direct roles. +// For example: +// g, alice, role:admin +// g, role:admin, role:user +// +// GetRolesForUser("alice") can only get: ["role:admin"]. +// But GetImplicitRolesForUser("alice") will get: ["role:admin", "role:user"]. +func (e *Enforcer) GetImplicitRolesForUser(name string, domain ...string) ([]string, error) { + res := []string{} + + for _, rm := range e.rmMap { + + roleSet := make(map[string]bool) + roleSet[name] = true + q := make([]string, 0) + q = append(q, name) + + for len(q) > 0 { + name := q[0] + q = q[1:] + + roles, err := rm.GetRoles(name, domain...) + if err != nil { + return nil, err + } + for _, r := range roles { + if _, ok := roleSet[r]; !ok { + res = append(res, r) + q = append(q, r) + roleSet[r] = true + } + } + } + } + + return res, nil +} + +// GetImplicitUsersForRole gets implicit users for a role. +func (e *Enforcer) GetImplicitUsersForRole(name string, domain ...string) ([]string, error) { + res := []string{} + + for _, rm := range e.rmMap { + + roleSet := make(map[string]bool) + roleSet[name] = true + q := make([]string, 0) + q = append(q, name) + + for len(q) > 0 { + name := q[0] + q = q[1:] + + roles, err := rm.GetUsers(name, domain...) + if err != nil && err.Error() != "error: name does not exist" { + return nil, err + } + for _, r := range roles { + if _, ok := roleSet[r]; !ok { + res = append(res, r) + q = append(q, r) + roleSet[r] = true + } + } + } + } + + return res, nil +} + +// GetImplicitPermissionsForUser gets implicit permissions for a user or role. +// Compared to GetPermissionsForUser(), this function retrieves permissions for inherited roles. +// For example: +// p, admin, data1, read +// p, alice, data2, read +// g, alice, admin +// +// GetPermissionsForUser("alice") can only get: [["alice", "data2", "read"]]. +// But GetImplicitPermissionsForUser("alice") will get: [["admin", "data1", "read"], ["alice", "data2", "read"]]. +func (e *Enforcer) GetImplicitPermissionsForUser(user string, domain ...string) ([][]string, error) { + return e.GetNamedImplicitPermissionsForUser("p", user, domain...) +} + +// GetNamedImplicitPermissionsForUser gets implicit permissions for a user or role by named policy. +// Compared to GetNamedPermissionsForUser(), this function retrieves permissions for inherited roles. +// For example: +// p, admin, data1, read +// p2, admin, create +// g, alice, admin +// +// GetImplicitPermissionsForUser("alice") can only get: [["admin", "data1", "read"]], whose policy is default policy "p" +// But you can specify the named policy "p2" to get: [["admin", "create"]] by GetNamedImplicitPermissionsForUser("p2","alice") +func (e *Enforcer) GetNamedImplicitPermissionsForUser(ptype string, user string, domain ...string) ([][]string, error) { + permission := make([][]string, 0) + rm := e.GetRoleManager() + domainIndex, _ := e.GetFieldIndex(ptype, constant.DomainIndex) + for _, rule := range e.model["p"][ptype].Policy { + if len(domain) == 0 { + matched, _ := rm.HasLink(user, rule[0]) + if matched { + permission = append(permission, deepCopyPolicy(rule)) + } + } else if len(domain) > 1 { + return nil, errors.ERR_DOMAIN_PARAMETER + } else { + d := domain[0] + matched := rm.Match(d, rule[domainIndex]) + if !matched { + continue + } + matched, _ = rm.HasLink(user, rule[0], d) + if matched { + newRule := deepCopyPolicy(rule) + newRule[domainIndex] = d + permission = append(permission, newRule) + } + } + } + return permission, nil +} + +// GetImplicitUsersForPermission gets implicit users for a permission. +// For example: +// p, admin, data1, read +// p, bob, data1, read +// g, alice, admin +// +// GetImplicitUsersForPermission("data1", "read") will get: ["alice", "bob"]. +// Note: only users will be returned, roles (2nd arg in "g") will be excluded. +func (e *Enforcer) GetImplicitUsersForPermission(permission ...string) ([]string, error) { + pSubjects := e.GetAllSubjects() + gInherit := e.model.GetValuesForFieldInPolicyAllTypes("g", 1) + gSubjects := e.model.GetValuesForFieldInPolicyAllTypes("g", 0) + + subjects := append(pSubjects, gSubjects...) + util.ArrayRemoveDuplicates(&subjects) + + subjects = util.SetSubtract(subjects, gInherit) + + res := []string{} + for _, user := range subjects { + req := util.JoinSliceAny(user, permission...) + allowed, err := e.Enforce(req...) + if err != nil { + return nil, err + } + + if allowed { + res = append(res, user) + } + } + + return res, nil +} + +// GetDomainsForUser gets all domains +func (e *Enforcer) GetDomainsForUser(user string) ([]string, error) { + var domains []string + for _, rm := range e.rmMap { + domain, err := rm.GetDomains(user) + if err != nil { + return nil, err + } + domains = append(domains, domain...) + } + return domains, nil +} + +// GetImplicitResourcesForUser returns all policies that user obtaining in domain +func (e *Enforcer) GetImplicitResourcesForUser(user string, domain ...string) ([][]string, error) { + permissions, err := e.GetImplicitPermissionsForUser(user, domain...) + if err != nil { + return nil, err + } + res := make([][]string, 0) + for _, permission := range permissions { + if permission[0] == user { + res = append(res, permission) + continue + } + resLocal := [][]string{{user}} + tokensLength := len(permission) + t := make([][]string, 1, tokensLength) + for _, token := range permission[1:] { + tokens, err := e.GetImplicitUsersForRole(token, domain...) + if err != nil { + return nil, err + } + tokens = append(tokens, token) + t = append(t, tokens) + } + for i := 1; i < tokensLength; i++ { + n := make([][]string, 0) + for _, tokens := range t[i] { + for _, policy := range resLocal { + t := append([]string(nil), policy...) + t = append(t, tokens) + n = append(n, t) + } + } + resLocal = n + } + res = append(res, resLocal...) + } + return res, nil +} + +// deepCopyPolicy returns a deepcopy version of the policy to prevent changing policies through returned slice +func deepCopyPolicy(src []string) []string { + newRule := make([]string, len(src)) + copy(newRule, src) + return newRule +} diff --git a/vendor/github.com/casbin/casbin/v2/rbac_api_synced.go b/vendor/github.com/casbin/casbin/v2/rbac_api_synced.go new file mode 100644 index 00000000..22df8bf6 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/rbac_api_synced.go @@ -0,0 +1,195 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +// GetRolesForUser gets the roles that a user has. +func (e *SyncedEnforcer) GetRolesForUser(name string, domain ...string) ([]string, error) { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetRolesForUser(name, domain...) +} + +// GetUsersForRole gets the users that has a role. +func (e *SyncedEnforcer) GetUsersForRole(name string, domain ...string) ([]string, error) { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetUsersForRole(name, domain...) +} + +// HasRoleForUser determines whether a user has a role. +func (e *SyncedEnforcer) HasRoleForUser(name string, role string, domain ...string) (bool, error) { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.HasRoleForUser(name, role, domain...) +} + +// AddRoleForUser adds a role for a user. +// Returns false if the user already has the role (aka not affected). +func (e *SyncedEnforcer) AddRoleForUser(user string, role string, domain ...string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.AddRoleForUser(user, role, domain...) +} + +// AddRolesForUser adds roles for a user. +// Returns false if the user already has the roles (aka not affected). +func (e *SyncedEnforcer) AddRolesForUser(user string, roles []string, domain ...string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.AddRolesForUser(user, roles, domain...) +} + +// DeleteRoleForUser deletes a role for a user. +// Returns false if the user does not have the role (aka not affected). +func (e *SyncedEnforcer) DeleteRoleForUser(user string, role string, domain ...string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.DeleteRoleForUser(user, role, domain...) +} + +// DeleteRolesForUser deletes all roles for a user. +// Returns false if the user does not have any roles (aka not affected). +func (e *SyncedEnforcer) DeleteRolesForUser(user string, domain ...string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.DeleteRolesForUser(user, domain...) +} + +// DeleteUser deletes a user. +// Returns false if the user does not exist (aka not affected). +func (e *SyncedEnforcer) DeleteUser(user string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.DeleteUser(user) +} + +// DeleteRole deletes a role. +// Returns false if the role does not exist (aka not affected). +func (e *SyncedEnforcer) DeleteRole(role string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.DeleteRole(role) +} + +// DeletePermission deletes a permission. +// Returns false if the permission does not exist (aka not affected). +func (e *SyncedEnforcer) DeletePermission(permission ...string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.DeletePermission(permission...) +} + +// AddPermissionForUser adds a permission for a user or role. +// Returns false if the user or role already has the permission (aka not affected). +func (e *SyncedEnforcer) AddPermissionForUser(user string, permission ...string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.AddPermissionForUser(user, permission...) +} + +// DeletePermissionForUser deletes a permission for a user or role. +// Returns false if the user or role does not have the permission (aka not affected). +func (e *SyncedEnforcer) DeletePermissionForUser(user string, permission ...string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.DeletePermissionForUser(user, permission...) +} + +// DeletePermissionsForUser deletes permissions for a user or role. +// Returns false if the user or role does not have any permissions (aka not affected). +func (e *SyncedEnforcer) DeletePermissionsForUser(user string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.DeletePermissionsForUser(user) +} + +// GetPermissionsForUser gets permissions for a user or role. +func (e *SyncedEnforcer) GetPermissionsForUser(user string, domain ...string) [][]string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetPermissionsForUser(user, domain...) +} + +// GetNamedPermissionsForUser gets permissions for a user or role by named policy. +func (e *SyncedEnforcer) GetNamedPermissionsForUser(ptype string, user string, domain ...string) [][]string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetNamedPermissionsForUser(ptype, user, domain...) +} + +// HasPermissionForUser determines whether a user has a permission. +func (e *SyncedEnforcer) HasPermissionForUser(user string, permission ...string) bool { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.HasPermissionForUser(user, permission...) +} + +// GetImplicitRolesForUser gets implicit roles that a user has. +// Compared to GetRolesForUser(), this function retrieves indirect roles besides direct roles. +// For example: +// g, alice, role:admin +// g, role:admin, role:user +// +// GetRolesForUser("alice") can only get: ["role:admin"]. +// But GetImplicitRolesForUser("alice") will get: ["role:admin", "role:user"]. +func (e *SyncedEnforcer) GetImplicitRolesForUser(name string, domain ...string) ([]string, error) { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetImplicitRolesForUser(name, domain...) +} + +// GetImplicitPermissionsForUser gets implicit permissions for a user or role. +// Compared to GetPermissionsForUser(), this function retrieves permissions for inherited roles. +// For example: +// p, admin, data1, read +// p, alice, data2, read +// g, alice, admin +// +// GetPermissionsForUser("alice") can only get: [["alice", "data2", "read"]]. +// But GetImplicitPermissionsForUser("alice") will get: [["admin", "data1", "read"], ["alice", "data2", "read"]]. +func (e *SyncedEnforcer) GetImplicitPermissionsForUser(user string, domain ...string) ([][]string, error) { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetImplicitPermissionsForUser(user, domain...) +} + +// GetNamedImplicitPermissionsForUser gets implicit permissions for a user or role by named policy. +// Compared to GetNamedPermissionsForUser(), this function retrieves permissions for inherited roles. +// For example: +// p, admin, data1, read +// p2, admin, create +// g, alice, admin +// +// GetImplicitPermissionsForUser("alice") can only get: [["admin", "data1", "read"]], whose policy is default policy "p" +// But you can specify the named policy "p2" to get: [["admin", "create"]] by GetNamedImplicitPermissionsForUser("p2","alice") +func (e *SyncedEnforcer) GetNamedImplicitPermissionsForUser(ptype string, user string, domain ...string) ([][]string, error) { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetNamedImplicitPermissionsForUser(ptype, user, domain...) +} + +// GetImplicitUsersForPermission gets implicit users for a permission. +// For example: +// p, admin, data1, read +// p, bob, data1, read +// g, alice, admin +// +// GetImplicitUsersForPermission("data1", "read") will get: ["alice", "bob"]. +// Note: only users will be returned, roles (2nd arg in "g") will be excluded. +func (e *SyncedEnforcer) GetImplicitUsersForPermission(permission ...string) ([]string, error) { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetImplicitUsersForPermission(permission...) +} diff --git a/vendor/github.com/casbin/casbin/v2/rbac_api_with_domains.go b/vendor/github.com/casbin/casbin/v2/rbac_api_with_domains.go new file mode 100644 index 00000000..04944f3b --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/rbac_api_with_domains.go @@ -0,0 +1,146 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +import "github.com/casbin/casbin/v2/constant" + +// GetUsersForRoleInDomain gets the users that has a role inside a domain. Add by Gordon +func (e *Enforcer) GetUsersForRoleInDomain(name string, domain string) []string { + res, _ := e.model["g"]["g"].RM.GetUsers(name, domain) + return res +} + +// GetRolesForUserInDomain gets the roles that a user has inside a domain. +func (e *Enforcer) GetRolesForUserInDomain(name string, domain string) []string { + res, _ := e.model["g"]["g"].RM.GetRoles(name, domain) + return res +} + +// GetPermissionsForUserInDomain gets permissions for a user or role inside a domain. +func (e *Enforcer) GetPermissionsForUserInDomain(user string, domain string) [][]string { + res, _ := e.GetImplicitPermissionsForUser(user, domain) + return res +} + +// AddRoleForUserInDomain adds a role for a user inside a domain. +// Returns false if the user already has the role (aka not affected). +func (e *Enforcer) AddRoleForUserInDomain(user string, role string, domain string) (bool, error) { + return e.AddGroupingPolicy(user, role, domain) +} + +// DeleteRoleForUserInDomain deletes a role for a user inside a domain. +// Returns false if the user does not have the role (aka not affected). +func (e *Enforcer) DeleteRoleForUserInDomain(user string, role string, domain string) (bool, error) { + return e.RemoveGroupingPolicy(user, role, domain) +} + +// DeleteRolesForUserInDomain deletes all roles for a user inside a domain. +// Returns false if the user does not have any roles (aka not affected). +func (e *Enforcer) DeleteRolesForUserInDomain(user string, domain string) (bool, error) { + roles, err := e.model["g"]["g"].RM.GetRoles(user, domain) + if err != nil { + return false, err + } + + var rules [][]string + for _, role := range roles { + rules = append(rules, []string{user, role, domain}) + } + + return e.RemoveGroupingPolicies(rules) +} + +// GetAllUsersByDomain would get all users associated with the domain. +func (e *Enforcer) GetAllUsersByDomain(domain string) []string { + m := make(map[string]struct{}) + g := e.model["g"]["g"] + p := e.model["p"]["p"] + users := make([]string, 0) + index, err := e.GetFieldIndex("p", constant.DomainIndex) + if err != nil { + return []string{} + } + + getUser := func(index int, policies [][]string, domain string, m map[string]struct{}) []string { + if len(policies) == 0 || len(policies[0]) <= index { + return []string{} + } + res := make([]string, 0) + for _, policy := range policies { + if _, ok := m[policy[0]]; policy[index] == domain && !ok { + res = append(res, policy[0]) + m[policy[0]] = struct{}{} + } + } + return res + } + + users = append(users, getUser(2, g.Policy, domain, m)...) + users = append(users, getUser(index, p.Policy, domain, m)...) + return users +} + +// DeleteAllUsersByDomain would delete all users associated with the domain. +func (e *Enforcer) DeleteAllUsersByDomain(domain string) (bool, error) { + g := e.model["g"]["g"] + p := e.model["p"]["p"] + index, err := e.GetFieldIndex("p", constant.DomainIndex) + if err != nil { + return false, err + } + + getUser := func(index int, policies [][]string, domain string) [][]string { + if len(policies) == 0 || len(policies[0]) <= index { + return [][]string{} + } + res := make([][]string, 0) + for _, policy := range policies { + if policy[index] == domain { + res = append(res, policy) + } + } + return res + } + + users := getUser(2, g.Policy, domain) + if _, err := e.RemoveGroupingPolicies(users); err != nil { + return false, err + } + users = getUser(index, p.Policy, domain) + if _, err := e.RemovePolicies(users); err != nil { + return false, err + } + return true, nil +} + +// DeleteDomains would delete all associated users and roles. +// It would delete all domains if parameter is not provided. +func (e *Enforcer) DeleteDomains(domains ...string) (bool, error) { + if len(domains) == 0 { + e.ClearPolicy() + return true, nil + } + for _, domain := range domains { + if _, err := e.DeleteAllUsersByDomain(domain); err != nil { + return false, err + } + } + return true, nil +} + +// GetAllDomains would get all domains. +func (e *Enforcer) GetAllDomains() ([]string, error) { + return e.model["g"]["g"].RM.GetAllDomains() +} diff --git a/vendor/github.com/casbin/casbin/v2/rbac_api_with_domains_synced.go b/vendor/github.com/casbin/casbin/v2/rbac_api_with_domains_synced.go new file mode 100644 index 00000000..bf194e05 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/rbac_api_with_domains_synced.go @@ -0,0 +1,60 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +// GetUsersForRoleInDomain gets the users that has a role inside a domain. Add by Gordon +func (e *SyncedEnforcer) GetUsersForRoleInDomain(name string, domain string) []string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetUsersForRoleInDomain(name, domain) +} + +// GetRolesForUserInDomain gets the roles that a user has inside a domain. +func (e *SyncedEnforcer) GetRolesForUserInDomain(name string, domain string) []string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetRolesForUserInDomain(name, domain) +} + +// GetPermissionsForUserInDomain gets permissions for a user or role inside a domain. +func (e *SyncedEnforcer) GetPermissionsForUserInDomain(user string, domain string) [][]string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetPermissionsForUserInDomain(user, domain) +} + +// AddRoleForUserInDomain adds a role for a user inside a domain. +// Returns false if the user already has the role (aka not affected). +func (e *SyncedEnforcer) AddRoleForUserInDomain(user string, role string, domain string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.AddRoleForUserInDomain(user, role, domain) +} + +// DeleteRoleForUserInDomain deletes a role for a user inside a domain. +// Returns false if the user does not have the role (aka not affected). +func (e *SyncedEnforcer) DeleteRoleForUserInDomain(user string, role string, domain string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.DeleteRoleForUserInDomain(user, role, domain) +} + +// DeleteRolesForUserInDomain deletes all roles for a user inside a domain. +// Returns false if the user does not have any roles (aka not affected). +func (e *SyncedEnforcer) DeleteRolesForUserInDomain(user string, domain string) (bool, error) { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.DeleteRolesForUserInDomain(user, domain) +} diff --git a/vendor/github.com/casbin/casbin/v2/util/builtin_operators.go b/vendor/github.com/casbin/casbin/v2/util/builtin_operators.go new file mode 100644 index 00000000..ee091ce0 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/util/builtin_operators.go @@ -0,0 +1,410 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "errors" + "fmt" + "net" + "path" + "regexp" + "strings" + "sync" + + "github.com/Knetic/govaluate" + "github.com/casbin/casbin/v2/rbac" +) + +var ( + keyMatch4Re *regexp.Regexp = regexp.MustCompile(`{([^/]+)}`) +) + +// validate the variadic parameter size and type as string +func validateVariadicArgs(expectedLen int, args ...interface{}) error { + if len(args) != expectedLen { + return fmt.Errorf("Expected %d arguments, but got %d", expectedLen, len(args)) + } + + for _, p := range args { + _, ok := p.(string) + if !ok { + return errors.New("Argument must be a string") + } + } + + return nil +} + +// KeyMatch determines whether key1 matches the pattern of key2 (similar to RESTful path), key2 can contain a *. +// For example, "/foo/bar" matches "/foo/*" +func KeyMatch(key1 string, key2 string) bool { + i := strings.Index(key2, "*") + if i == -1 { + return key1 == key2 + } + + if len(key1) > i { + return key1[:i] == key2[:i] + } + return key1 == key2[:i] +} + +// KeyMatchFunc is the wrapper for KeyMatch. +func KeyMatchFunc(args ...interface{}) (interface{}, error) { + if err := validateVariadicArgs(2, args...); err != nil { + return false, fmt.Errorf("%s: %s", "keyMatch", err) + } + + name1 := args[0].(string) + name2 := args[1].(string) + + return bool(KeyMatch(name1, name2)), nil +} + +// KeyGet returns the matched part +// For example, "/foo/bar/foo" matches "/foo/*" +// "bar/foo" will been returned +func KeyGet(key1, key2 string) string { + i := strings.Index(key2, "*") + if i == -1 { + return "" + } + if len(key1) > i { + if key1[:i] == key2[:i] { + return key1[i:] + } + } + return "" +} + +// KeyGetFunc is the wrapper for KeyGet +func KeyGetFunc(args ...interface{}) (interface{}, error) { + if err := validateVariadicArgs(2, args...); err != nil { + return false, fmt.Errorf("%s: %s", "keyGet", err) + } + + name1 := args[0].(string) + name2 := args[1].(string) + + return KeyGet(name1, name2), nil +} + +// KeyMatch2 determines whether key1 matches the pattern of key2 (similar to RESTful path), key2 can contain a *. +// For example, "/foo/bar" matches "/foo/*", "/resource1" matches "/:resource" +func KeyMatch2(key1 string, key2 string) bool { + key2 = strings.Replace(key2, "/*", "/.*", -1) + + re := regexp.MustCompile(`:[^/]+`) + key2 = re.ReplaceAllString(key2, "$1[^/]+$2") + + return RegexMatch(key1, "^"+key2+"$") +} + +// KeyMatch2Func is the wrapper for KeyMatch2. +func KeyMatch2Func(args ...interface{}) (interface{}, error) { + if err := validateVariadicArgs(2, args...); err != nil { + return false, fmt.Errorf("%s: %s", "keyMatch2", err) + } + + name1 := args[0].(string) + name2 := args[1].(string) + + return bool(KeyMatch2(name1, name2)), nil +} + +// KeyGet2 returns value matched pattern +// For example, "/resource1" matches "/:resource" +// if the pathVar == "resource", then "resource1" will be returned +func KeyGet2(key1, key2 string, pathVar string) string { + key2 = strings.Replace(key2, "/*", "/.*", -1) + + re := regexp.MustCompile(`:[^/]+`) + keys := re.FindAllString(key2, -1) + key2 = re.ReplaceAllString(key2, "$1([^/]+)$2") + key2 = "^" + key2 + "$" + re2 := regexp.MustCompile(key2) + values := re2.FindAllStringSubmatch(key1, -1) + if len(values) == 0 { + return "" + } + for i, key := range keys { + if pathVar == key[1:] { + return values[0][i+1] + } + } + return "" +} + +// KeyGet2Func is the wrapper for KeyGet2 +func KeyGet2Func(args ...interface{}) (interface{}, error) { + if err := validateVariadicArgs(3, args...); err != nil { + return false, fmt.Errorf("%s: %s", "keyGet2", err) + } + + name1 := args[0].(string) + name2 := args[1].(string) + key := args[2].(string) + + return KeyGet2(name1, name2, key), nil +} + +// KeyMatch3 determines whether key1 matches the pattern of key2 (similar to RESTful path), key2 can contain a *. +// For example, "/foo/bar" matches "/foo/*", "/resource1" matches "/{resource}" +func KeyMatch3(key1 string, key2 string) bool { + key2 = strings.Replace(key2, "/*", "/.*", -1) + + re := regexp.MustCompile(`\{[^/]+\}`) + key2 = re.ReplaceAllString(key2, "$1[^/]+$2") + + return RegexMatch(key1, "^"+key2+"$") +} + +// KeyMatch3Func is the wrapper for KeyMatch3. +func KeyMatch3Func(args ...interface{}) (interface{}, error) { + if err := validateVariadicArgs(2, args...); err != nil { + return false, fmt.Errorf("%s: %s", "keyMatch3", err) + } + + name1 := args[0].(string) + name2 := args[1].(string) + + return bool(KeyMatch3(name1, name2)), nil +} + +// KeyGet3 returns value matched pattern +// For example, "project/proj_project1_admin/" matches "project/proj_{project}_admin/" +// if the pathVar == "project", then "project1" will be returned +func KeyGet3(key1, key2 string, pathVar string) string { + key2 = strings.Replace(key2, "/*", "/.*", -1) + + re := regexp.MustCompile(`\{[^/]+?\}`) // non-greedy match of `{...}` to support multiple {} in `/.../` + keys := re.FindAllString(key2, -1) + key2 = re.ReplaceAllString(key2, "$1([^/]+?)$2") + key2 = "^" + key2 + "$" + re2 := regexp.MustCompile(key2) + values := re2.FindAllStringSubmatch(key1, -1) + if len(values) == 0 { + return "" + } + for i, key := range keys { + if pathVar == key[1:len(key)-1] { + return values[0][i+1] + } + } + return "" +} + +// KeyGet3Func is the wrapper for KeyGet3 +func KeyGet3Func(args ...interface{}) (interface{}, error) { + if err := validateVariadicArgs(3, args...); err != nil { + return false, fmt.Errorf("%s: %s", "keyGet3", err) + } + + name1 := args[0].(string) + name2 := args[1].(string) + key := args[2].(string) + + return KeyGet3(name1, name2, key), nil +} + +// KeyMatch4 determines whether key1 matches the pattern of key2 (similar to RESTful path), key2 can contain a *. +// Besides what KeyMatch3 does, KeyMatch4 can also match repeated patterns: +// "/parent/123/child/123" matches "/parent/{id}/child/{id}" +// "/parent/123/child/456" does not match "/parent/{id}/child/{id}" +// But KeyMatch3 will match both. +func KeyMatch4(key1 string, key2 string) bool { + key2 = strings.Replace(key2, "/*", "/.*", -1) + + tokens := []string{} + + re := keyMatch4Re + key2 = re.ReplaceAllStringFunc(key2, func(s string) string { + tokens = append(tokens, s[1:len(s)-1]) + return "([^/]+)" + }) + + re = regexp.MustCompile("^" + key2 + "$") + matches := re.FindStringSubmatch(key1) + if matches == nil { + return false + } + matches = matches[1:] + + if len(tokens) != len(matches) { + panic(errors.New("KeyMatch4: number of tokens is not equal to number of values")) + } + + values := map[string]string{} + + for key, token := range tokens { + if _, ok := values[token]; !ok { + values[token] = matches[key] + } + if values[token] != matches[key] { + return false + } + } + + return true +} + +// KeyMatch4Func is the wrapper for KeyMatch4. +func KeyMatch4Func(args ...interface{}) (interface{}, error) { + if err := validateVariadicArgs(2, args...); err != nil { + return false, fmt.Errorf("%s: %s", "keyMatch4", err) + } + + name1 := args[0].(string) + name2 := args[1].(string) + + return bool(KeyMatch4(name1, name2)), nil +} + +// KeyMatch determines whether key1 matches the pattern of key2 and ignores the parameters in key2. +// For example, "/foo/bar?status=1&type=2" matches "/foo/bar" +func KeyMatch5(key1 string, key2 string) bool { + i := strings.Index(key1, "?") + if i == -1 { + return key1 == key2 + } + + return key1[:i] == key2 +} + +// KeyMatch5Func is the wrapper for KeyMatch5. +func KeyMatch5Func(args ...interface{}) (interface{}, error) { + if err := validateVariadicArgs(2, args...); err != nil { + return false, fmt.Errorf("%s: %s", "keyMatch5", err) + } + + name1 := args[0].(string) + name2 := args[1].(string) + + return bool(KeyMatch5(name1, name2)), nil +} + +// RegexMatch determines whether key1 matches the pattern of key2 in regular expression. +func RegexMatch(key1 string, key2 string) bool { + res, err := regexp.MatchString(key2, key1) + if err != nil { + panic(err) + } + return res +} + +// RegexMatchFunc is the wrapper for RegexMatch. +func RegexMatchFunc(args ...interface{}) (interface{}, error) { + if err := validateVariadicArgs(2, args...); err != nil { + return false, fmt.Errorf("%s: %s", "regexMatch", err) + } + + name1 := args[0].(string) + name2 := args[1].(string) + + return bool(RegexMatch(name1, name2)), nil +} + +// IPMatch determines whether IP address ip1 matches the pattern of IP address ip2, ip2 can be an IP address or a CIDR pattern. +// For example, "192.168.2.123" matches "192.168.2.0/24" +func IPMatch(ip1 string, ip2 string) bool { + objIP1 := net.ParseIP(ip1) + if objIP1 == nil { + panic("invalid argument: ip1 in IPMatch() function is not an IP address.") + } + + _, cidr, err := net.ParseCIDR(ip2) + if err != nil { + objIP2 := net.ParseIP(ip2) + if objIP2 == nil { + panic("invalid argument: ip2 in IPMatch() function is neither an IP address nor a CIDR.") + } + + return objIP1.Equal(objIP2) + } + + return cidr.Contains(objIP1) +} + +// IPMatchFunc is the wrapper for IPMatch. +func IPMatchFunc(args ...interface{}) (interface{}, error) { + if err := validateVariadicArgs(2, args...); err != nil { + return false, fmt.Errorf("%s: %s", "ipMatch", err) + } + + ip1 := args[0].(string) + ip2 := args[1].(string) + + return bool(IPMatch(ip1, ip2)), nil +} + +// GlobMatch determines whether key1 matches the pattern of key2 using glob pattern +func GlobMatch(key1 string, key2 string) (bool, error) { + return path.Match(key2, key1) +} + +// GlobMatchFunc is the wrapper for GlobMatch. +func GlobMatchFunc(args ...interface{}) (interface{}, error) { + if err := validateVariadicArgs(2, args...); err != nil { + return false, fmt.Errorf("%s: %s", "globMatch", err) + } + + name1 := args[0].(string) + name2 := args[1].(string) + + return GlobMatch(name1, name2) +} + +// GenerateGFunction is the factory method of the g(_, _[, _]) function. +func GenerateGFunction(rm rbac.RoleManager) govaluate.ExpressionFunction { + memorized := sync.Map{} + return func(args ...interface{}) (interface{}, error) { + // Like all our other govaluate functions, all args are strings. + + // Allocate and generate a cache key from the arguments... + total := len(args) + for _, a := range args { + aStr := a.(string) + total += len(aStr) + } + builder := strings.Builder{} + builder.Grow(total) + for _, arg := range args { + builder.WriteByte(0) + builder.WriteString(arg.(string)) + } + key := builder.String() + + // ...and see if we've already calculated this. + v, found := memorized.Load(key) + if found { + return v, nil + } + + // If not, do the calculation. + // There are guaranteed to be exactly 2 or 3 arguments. + name1, name2 := args[0].(string), args[1].(string) + if rm == nil { + v = name1 == name2 + } else if len(args) == 2 { + v, _ = rm.HasLink(name1, name2) + } else { + domain := args[2].(string) + v, _ = rm.HasLink(name1, name2, domain) + } + + memorized.Store(key, v) + return v, nil + } +} diff --git a/vendor/github.com/casbin/casbin/v2/util/util.go b/vendor/github.com/casbin/casbin/v2/util/util.go new file mode 100644 index 00000000..1b8d4bf9 --- /dev/null +++ b/vendor/github.com/casbin/casbin/v2/util/util.go @@ -0,0 +1,335 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "regexp" + "sort" + "strings" + "sync" +) + +var evalReg = regexp.MustCompile(`\beval\((?P[^)]*)\)`) + +var escapeAssertionRegex = regexp.MustCompile(`\b((r|p)[0-9]*)\.`) + +// EscapeAssertion escapes the dots in the assertion, because the expression evaluation doesn't support such variable names. +func EscapeAssertion(s string) string { + s = escapeAssertionRegex.ReplaceAllStringFunc(s, func(m string) string { + return strings.Replace(m, ".", "_", 1) + }) + return s +} + +// RemoveComments removes the comments starting with # in the text. +func RemoveComments(s string) string { + pos := strings.Index(s, "#") + if pos == -1 { + return s + } + return strings.TrimSpace(s[0:pos]) +} + +// ArrayEquals determines whether two string arrays are identical. +func ArrayEquals(a []string, b []string) bool { + if len(a) != len(b) { + return false + } + + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} + +// Array2DEquals determines whether two 2-dimensional string arrays are identical. +func Array2DEquals(a [][]string, b [][]string) bool { + if len(a) != len(b) { + return false + } + + for i, v := range a { + if !ArrayEquals(v, b[i]) { + return false + } + } + return true +} + +// ArrayRemoveDuplicates removes any duplicated elements in a string array. +func ArrayRemoveDuplicates(s *[]string) { + found := make(map[string]bool) + j := 0 + for i, x := range *s { + if !found[x] { + found[x] = true + (*s)[j] = (*s)[i] + j++ + } + } + *s = (*s)[:j] +} + +// ArrayToString gets a printable string for a string array. +func ArrayToString(s []string) string { + return strings.Join(s, ", ") +} + +// ParamsToString gets a printable string for variable number of parameters. +func ParamsToString(s ...string) string { + return strings.Join(s, ", ") +} + +// SetEquals determines whether two string sets are identical. +func SetEquals(a []string, b []string) bool { + if len(a) != len(b) { + return false + } + + sort.Strings(a) + sort.Strings(b) + + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} + +// SetEquals determines whether two string sets are identical. +func SetEqualsInt(a []int, b []int) bool { + if len(a) != len(b) { + return false + } + + sort.Ints(a) + sort.Ints(b) + + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} + +// SetEquals determines whether two string sets are identical. +func Set2DEquals(a [][]string, b [][]string) bool { + if len(a) != len(b) { + return false + } + + var aa []string + for _, v := range a { + sort.Strings(v) + aa = append(aa, strings.Join(v, ", ")) + } + var bb []string + for _, v := range b { + sort.Strings(v) + bb = append(bb, strings.Join(v, ", ")) + } + + return SetEquals(aa, bb) +} + +// JoinSlice joins a string and a slice into a new slice. +func JoinSlice(a string, b ...string) []string { + res := make([]string, 0, len(b)+1) + + res = append(res, a) + res = append(res, b...) + + return res +} + +// JoinSliceAny joins a string and a slice into a new interface{} slice. +func JoinSliceAny(a string, b ...string) []interface{} { + res := make([]interface{}, 0, len(b)+1) + + res = append(res, a) + for _, s := range b { + res = append(res, s) + } + + return res +} + +// SetSubtract returns the elements in `a` that aren't in `b`. +func SetSubtract(a []string, b []string) []string { + mb := make(map[string]struct{}, len(b)) + for _, x := range b { + mb[x] = struct{}{} + } + var diff []string + for _, x := range a { + if _, found := mb[x]; !found { + diff = append(diff, x) + } + } + return diff +} + +// HasEval determine whether matcher contains function eval +func HasEval(s string) bool { + return evalReg.MatchString(s) +} + +// ReplaceEval replace function eval with the value of its parameters +func ReplaceEval(s string, rule string) string { + return evalReg.ReplaceAllString(s, "("+rule+")") +} + +// ReplaceEvalWithMap replace function eval with the value of its parameters via given sets. +func ReplaceEvalWithMap(src string, sets map[string]string) string { + return evalReg.ReplaceAllStringFunc(src, func(s string) string { + subs := evalReg.FindStringSubmatch(s) + if subs == nil { + return s + } + key := subs[1] + value, found := sets[key] + if !found { + return s + } + return evalReg.ReplaceAllString(s, value) + }) +} + +// GetEvalValue returns the parameters of function eval +func GetEvalValue(s string) []string { + subMatch := evalReg.FindAllStringSubmatch(s, -1) + var rules []string + for _, rule := range subMatch { + rules = append(rules, rule[1]) + } + return rules +} + +func RemoveDuplicateElement(s []string) []string { + result := make([]string, 0, len(s)) + temp := map[string]struct{}{} + for _, item := range s { + if _, ok := temp[item]; !ok { + temp[item] = struct{}{} + result = append(result, item) + } + } + return result +} + +type node struct { + key interface{} + value interface{} + prev *node + next *node +} + +type LRUCache struct { + capacity int + m map[interface{}]*node + head *node + tail *node +} + +func NewLRUCache(capacity int) *LRUCache { + cache := &LRUCache{} + cache.capacity = capacity + cache.m = map[interface{}]*node{} + + head := &node{} + tail := &node{} + + head.next = tail + tail.prev = head + + cache.head = head + cache.tail = tail + + return cache +} + +func (cache *LRUCache) remove(n *node, listOnly bool) { + if !listOnly { + delete(cache.m, n.key) + } + n.prev.next = n.next + n.next.prev = n.prev +} + +func (cache *LRUCache) add(n *node, listOnly bool) { + if !listOnly { + cache.m[n.key] = n + } + headNext := cache.head.next + cache.head.next = n + headNext.prev = n + n.next = headNext + n.prev = cache.head +} + +func (cache *LRUCache) moveToHead(n *node) { + cache.remove(n, true) + cache.add(n, true) +} + +func (cache *LRUCache) Get(key interface{}) (value interface{}, ok bool) { + n, ok := cache.m[key] + if ok { + cache.moveToHead(n) + return n.value, ok + } else { + return nil, ok + } +} + +func (cache *LRUCache) Put(key interface{}, value interface{}) { + n, ok := cache.m[key] + if ok { + cache.remove(n, false) + } else { + n = &node{key, value, nil, nil} + if len(cache.m) >= cache.capacity { + cache.remove(cache.tail.prev, false) + } + } + cache.add(n, false) +} + +type SyncLRUCache struct { + rwm sync.RWMutex + *LRUCache +} + +func NewSyncLRUCache(capacity int) *SyncLRUCache { + cache := &SyncLRUCache{} + cache.LRUCache = NewLRUCache(capacity) + return cache +} + +func (cache *SyncLRUCache) Get(key interface{}) (value interface{}, ok bool) { + cache.rwm.RLock() + defer cache.rwm.RUnlock() + return cache.LRUCache.Get(key) +} + +func (cache *SyncLRUCache) Put(key interface{}, value interface{}) { + cache.rwm.Lock() + defer cache.rwm.Unlock() + cache.LRUCache.Put(key, value) +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 1d9708ad..b11063c8 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -24,6 +24,9 @@ github.com/99designs/gqlgen/plugin/federation/fieldset github.com/99designs/gqlgen/plugin/modelgen github.com/99designs/gqlgen/plugin/resolvergen github.com/99designs/gqlgen/plugin/servergen +# github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible +## explicit +github.com/Knetic/govaluate # github.com/KyleBanks/depth v1.2.1 ## explicit github.com/KyleBanks/depth @@ -51,6 +54,21 @@ github.com/boltdb/bolt # github.com/caddyserver/certmagic v0.17.2 ## explicit; go 1.18 github.com/caddyserver/certmagic +# github.com/casbin/casbin/v2 v2.60.0 +## explicit; go 1.13 +github.com/casbin/casbin/v2 +github.com/casbin/casbin/v2/config +github.com/casbin/casbin/v2/constant +github.com/casbin/casbin/v2/effector +github.com/casbin/casbin/v2/errors +github.com/casbin/casbin/v2/log +github.com/casbin/casbin/v2/model +github.com/casbin/casbin/v2/persist +github.com/casbin/casbin/v2/persist/cache +github.com/casbin/casbin/v2/persist/file-adapter +github.com/casbin/casbin/v2/rbac +github.com/casbin/casbin/v2/rbac/default-role-manager +github.com/casbin/casbin/v2/util # github.com/cespare/xxhash/v2 v2.2.0 ## explicit; go 1.11 github.com/cespare/xxhash/v2