diff --git a/app/api/api.go b/app/api/api.go index 82200e56..f6a6b12e 100644 --- a/app/api/api.go +++ b/app/api/api.go @@ -979,7 +979,7 @@ func (a *api) start(ctx context.Context) error { } if identity != nil { - u.User = url.UserPassword(identity.Name(), identity.GetServiceBasicAuth()) + u.User = identity.GetServiceBasicAuth() } else { u.User = url.User(config.Owner) } @@ -1007,7 +1007,7 @@ func (a *api) start(ctx context.Context) error { } if identity != nil { - u.User = url.UserPassword(identity.Name(), identity.GetServiceBasicAuth()) + u.User = identity.GetServiceBasicAuth() } else { u.User = url.User(config.Owner) } @@ -1037,7 +1037,9 @@ func (a *api) start(ctx context.Context) error { } if identity != nil { - u.User = url.UserPassword(identity.Name(), identity.GetServiceBasicAuth()) + u.User = identity.GetServiceBasicAuth() + } else { + u.User = url.User(config.Owner) } if len(config.Domain) != 0 { diff --git a/iam/identity/identity.go b/iam/identity/identity.go index 1a587a7d..9df452ea 100644 --- a/iam/identity/identity.go +++ b/iam/identity/identity.go @@ -2,6 +2,7 @@ package identity import ( "fmt" + "net/url" "strings" "sync" "time" @@ -105,7 +106,7 @@ type Verifier interface { VerifyServiceToken(token string) (bool, error) VerifyServiceSession(jwt string) (bool, interface{}, error) - GetServiceBasicAuth() string + GetServiceBasicAuth() *url.Userinfo GetServiceToken() string GetServiceSession(interface{}, time.Duration) string @@ -319,12 +320,17 @@ func (i *identity) VerifyServiceBasicAuth(password string) (bool, error) { return false, nil } -func (i *identity) GetServiceBasicAuth() string { +func (i *identity) GetServiceBasicAuth() *url.Userinfo { i.lock.RLock() defer i.lock.RUnlock() if !i.isValid() { - return "" + return nil + } + + name := i.Alias() + if len(name) == 0 { + name = i.Name() } for _, password := range i.user.Auth.Services.Basic { @@ -332,10 +338,10 @@ func (i *identity) GetServiceBasicAuth() string { continue } - return password + return url.UserPassword(name, password) } - return "" + return url.User(name) } func (i *identity) VerifyServiceToken(token string) (bool, error) { @@ -368,7 +374,12 @@ func (i *identity) GetServiceToken() string { continue } - return enctoken.Marshal(i.Name(), token) + name := i.Alias() + if len(name) == 0 { + name = i.Name() + } + + return enctoken.Marshal(name, token) } return "" diff --git a/iam/identity/identity_test.go b/iam/identity/identity_test.go index 8ddc3553..ed6c717b 100644 --- a/iam/identity/identity_test.go +++ b/iam/identity/identity_test.go @@ -182,7 +182,8 @@ func TestIdentityServiceBasicAuth(t *testing.T) { require.False(t, ok) require.NoError(t, err) - password := identity.GetServiceBasicAuth() + userinfo := identity.GetServiceBasicAuth() + password, _ := userinfo.Password() require.Equal(t, "terces", password) } diff --git a/restream/rewrite/rewrite.go b/restream/rewrite/rewrite.go index a40cabf3..055ca6c9 100644 --- a/restream/rewrite/rewrite.go +++ b/restream/rewrite/rewrite.go @@ -115,13 +115,7 @@ func (g *rewrite) isLocal(u *url.URL) bool { } func (g *rewrite) httpURL(u *url.URL, mode Access, identity iamidentity.Verifier) string { - password := identity.GetServiceBasicAuth() - - if len(password) == 0 { - u.User = nil - } else { - u.User = url.UserPassword(identity.Name(), password) - } + u.User = identity.GetServiceBasicAuth() return u.String() } diff --git a/restream/rewrite/rewrite_test.go b/restream/rewrite/rewrite_test.go index 59f26358..10a3cac8 100644 --- a/restream/rewrite/rewrite_test.go +++ b/restream/rewrite/rewrite_test.go @@ -74,10 +74,10 @@ func TestRewriteHTTP(t *testing.T) { {"http://example.com/live/stream.m3u8", "write", "http://example.com/live/stream.m3u8"}, {"http://localhost:8181/live/stream.m3u8", "read", "http://localhost:8181/live/stream.m3u8"}, {"http://localhost:8181/live/stream.m3u8", "write", "http://localhost:8181/live/stream.m3u8"}, - {"http://localhost:8080/live/stream.m3u8", "read", "http://localhost:8080/live/stream.m3u8"}, - {"http://localhost:8080/live/stream.m3u8", "write", "http://localhost:8080/live/stream.m3u8"}, - {"http://admin:pass@localhost:8080/live/stream.m3u8", "read", "http://localhost:8080/live/stream.m3u8"}, - {"http://admin:pass@localhost:8080/live/stream.m3u8", "write", "http://localhost:8080/live/stream.m3u8"}, + {"http://localhost:8080/live/stream.m3u8", "read", "http://foobar@localhost:8080/live/stream.m3u8"}, + {"http://localhost:8080/live/stream.m3u8", "write", "http://foobar@localhost:8080/live/stream.m3u8"}, + {"http://admin:pass@localhost:8080/live/stream.m3u8", "read", "http://foobar@localhost:8080/live/stream.m3u8"}, + {"http://admin:pass@localhost:8080/live/stream.m3u8", "write", "http://foobar@localhost:8080/live/stream.m3u8"}, } for _, e := range samples {