mirror of
https://github.com/xaionaro-go/streamctl.git
synced 2025-09-27 11:52:11 +08:00
272 lines
7.1 KiB
Go
272 lines
7.1 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"maps"
|
|
"slices"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/facebookincubator/go-belt/tool/experimental/errmon"
|
|
"github.com/facebookincubator/go-belt/tool/logger"
|
|
"github.com/hashicorp/go-multierror"
|
|
"github.com/nicklaw5/helix/v2"
|
|
"github.com/xaionaro-go/observability"
|
|
"github.com/xaionaro-go/streamctl/pkg/oauthhandler"
|
|
twitch "github.com/xaionaro-go/streamctl/pkg/streamcontrol/twitch/types"
|
|
)
|
|
|
|
type OAuthHandler = twitch.OAuthHandler
|
|
|
|
func getScopes() []string {
|
|
scopes := map[string]struct{}{
|
|
"chat:read": {},
|
|
"chat:edit": {},
|
|
|
|
"analytics:read:extensions": {},
|
|
"analytics:read:games": {},
|
|
|
|
"bits:read": {},
|
|
|
|
"channel:bot": {},
|
|
"channel:manage:ads": {},
|
|
"channel:read:ads": {},
|
|
"channel:manage:broadcast": {},
|
|
"channel:read:charity": {},
|
|
"channel:edit:commercial": {},
|
|
"channel:read:editors": {},
|
|
"channel:manage:extensions": {},
|
|
"channel:read:goals": {},
|
|
"channel:read:guest_star": {},
|
|
"channel:manage:guest_star": {},
|
|
"channel:read:hype_train": {},
|
|
"channel:manage:moderators": {},
|
|
"channel:read:polls": {},
|
|
"channel:manage:polls": {},
|
|
"channel:read:predictions": {},
|
|
"channel:manage:predictions": {},
|
|
"channel:manage:raids": {},
|
|
"channel:read:redemptions": {},
|
|
"channel:manage:redemptions": {},
|
|
"channel:manage:schedule": {},
|
|
"channel:read:stream_key": {},
|
|
"channel:read:subscriptions": {},
|
|
"channel:manage:videos": {},
|
|
"channel:read:vips": {},
|
|
"channel:manage:vips": {},
|
|
"channel:moderate": {},
|
|
|
|
"clips:edit": {},
|
|
|
|
"moderation:read": {},
|
|
|
|
"moderator:manage:announcements": {},
|
|
"moderator:manage:automod": {},
|
|
"moderator:read:automod_settings": {},
|
|
"moderator:manage:automod_settings": {},
|
|
"moderator:read:banned_users": {},
|
|
"moderator:manage:banned_users": {},
|
|
"moderator:read:blocked_terms": {},
|
|
"moderator:read:chat_messages": {},
|
|
"moderator:manage:blocked_terms": {},
|
|
"moderator:manage:chat_messages": {},
|
|
"moderator:read:chat_settings": {},
|
|
"moderator:manage:chat_settings": {},
|
|
"moderator:read:chatters": {},
|
|
"moderator:read:followers": {},
|
|
"moderator:read:guest_star": {},
|
|
"moderator:manage:guest_star": {},
|
|
"moderator:read:moderators": {},
|
|
"moderator:read:shield_mode": {},
|
|
"moderator:manage:shield_mode": {},
|
|
"moderator:read:shoutouts": {},
|
|
"moderator:manage:shoutouts": {},
|
|
"moderator:read:suspicious_users": {},
|
|
"moderator:read:unban_requests": {},
|
|
"moderator:manage:unban_requests": {},
|
|
"moderator:read:vips": {},
|
|
"moderator:read:warnings": {},
|
|
"moderator:manage:warnings": {},
|
|
|
|
"user:bot": {},
|
|
"user:edit": {},
|
|
"user:edit:broadcast": {},
|
|
"user:read:blocked_users": {},
|
|
"user:manage:blocked_users": {},
|
|
"user:read:broadcast": {},
|
|
"user:read:chat": {},
|
|
"user:manage:chat_color": {},
|
|
"user:read:email": {},
|
|
"user:read:emotes": {},
|
|
"user:read:follows": {},
|
|
"user:read:moderated_channels": {},
|
|
"user:read:subscriptions": {},
|
|
"user:read:whispers": {},
|
|
"user:manage:whispers": {},
|
|
"user:write:chat": {},
|
|
}
|
|
|
|
scopesStrings := slices.Collect(maps.Keys(scopes))
|
|
sort.Strings(scopesStrings)
|
|
return scopesStrings
|
|
}
|
|
|
|
func NewClientCode(
|
|
ctx context.Context,
|
|
clientID string,
|
|
oauthHandler OAuthHandler,
|
|
getOAuthListenPortsFn func() []uint16,
|
|
onNewClientCode func(string),
|
|
) (_err error) {
|
|
logger.Debugf(ctx, "getNewClientCode")
|
|
defer func() { logger.Debugf(ctx, "/getNewClientCode: %v", _err) }()
|
|
|
|
if oauthHandler == nil {
|
|
oauthHandler = oauthhandler.OAuth2HandlerViaCLI
|
|
}
|
|
|
|
ctx, ctxCancelFunc := context.WithCancel(ctx)
|
|
cancelFunc := func() {
|
|
logger.Debugf(ctx, "cancelling the context")
|
|
ctxCancelFunc()
|
|
}
|
|
|
|
var errWg sync.WaitGroup
|
|
var resultErr error
|
|
errCh := make(chan error)
|
|
errWg.Add(1)
|
|
observability.Go(ctx, func(ctx context.Context) {
|
|
errWg.Done()
|
|
for err := range errCh {
|
|
errmon.ObserveErrorCtx(ctx, err)
|
|
resultErr = multierror.Append(resultErr, err)
|
|
}
|
|
})
|
|
|
|
alreadyListening := map[uint16]struct{}{}
|
|
var wg sync.WaitGroup
|
|
success := false
|
|
|
|
startHandlerForPort := func(listenPort uint16) {
|
|
if _, ok := alreadyListening[listenPort]; ok {
|
|
return
|
|
}
|
|
alreadyListening[listenPort] = struct{}{}
|
|
|
|
logger.Debugf(ctx, "starting the oauth handler at port %d", listenPort)
|
|
wg.Add(1)
|
|
{
|
|
listenPort := listenPort
|
|
observability.Go(ctx, func(ctx context.Context) {
|
|
defer func() { logger.Debugf(ctx, "ended the oauth handler at port %d", listenPort) }()
|
|
defer wg.Done()
|
|
authURL := GetAuthorizationURL(
|
|
&helix.AuthorizationURLParams{
|
|
ResponseType: "code", // or "token"
|
|
Scopes: getScopes(),
|
|
},
|
|
clientID,
|
|
RedirectURI(listenPort),
|
|
)
|
|
|
|
arg := oauthhandler.OAuthHandlerArgument{
|
|
AuthURL: authURL,
|
|
ListenPort: listenPort,
|
|
ExchangeFn: func(ctx context.Context, code string) (_err error) {
|
|
logger.Debugf(ctx, "ExchangeFn()")
|
|
defer func() { logger.Debugf(ctx, "/ExchangeFn(): %v", _err) }()
|
|
if code == "" {
|
|
return fmt.Errorf("code is empty")
|
|
}
|
|
onNewClientCode(code)
|
|
return nil
|
|
},
|
|
}
|
|
|
|
err := oauthHandler(ctx, arg)
|
|
if err != nil {
|
|
errCh <- fmt.Errorf("unable to get or exchange the oauth code to a token: %w", err)
|
|
return
|
|
}
|
|
cancelFunc()
|
|
success = true
|
|
})
|
|
}
|
|
}
|
|
|
|
// TODO: either support only one port as in New, or support multiple
|
|
// ports as we do below
|
|
getPortsFn := getOAuthListenPortsFn
|
|
if getPortsFn == nil {
|
|
return fmt.Errorf("the function GetOAuthListenPorts is not set")
|
|
}
|
|
|
|
for _, listenPort := range getPortsFn() {
|
|
startHandlerForPort(listenPort)
|
|
}
|
|
|
|
wg.Add(1)
|
|
observability.Go(ctx, func(ctx context.Context) {
|
|
defer wg.Done()
|
|
t := time.NewTicker(time.Second)
|
|
defer t.Stop()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-t.C:
|
|
}
|
|
ports := getPortsFn()
|
|
logger.Tracef(ctx, "oauth listener ports: %#+v", ports)
|
|
|
|
for _, listenPort := range ports {
|
|
startHandlerForPort(listenPort)
|
|
}
|
|
}
|
|
})
|
|
|
|
observability.Go(ctx, func(ctx context.Context) {
|
|
wg.Wait()
|
|
close(errCh)
|
|
})
|
|
<-ctx.Done()
|
|
logger.Debugf(ctx, "did successfully took a new client code? -- %v", success)
|
|
if !success {
|
|
errWg.Wait()
|
|
return resultErr
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func RedirectURI(listenPort uint16) string {
|
|
return fmt.Sprintf("http://localhost:%d/", listenPort)
|
|
}
|
|
|
|
func GetAuthorizationURL(
|
|
params *helix.AuthorizationURLParams,
|
|
clientID string,
|
|
redirectURI string,
|
|
) string {
|
|
url := helix.AuthBaseURL + "/authorize"
|
|
url += "?response_type=" + params.ResponseType
|
|
url += "&client_id=" + clientID
|
|
url += "&redirect_uri=" + redirectURI
|
|
|
|
if params.State != "" {
|
|
url += "&state=" + params.State
|
|
}
|
|
|
|
if params.ForceVerify {
|
|
url += "&force_verify=true"
|
|
}
|
|
|
|
if len(params.Scopes) != 0 {
|
|
url += "&scope=" + strings.Join(params.Scopes, "%20")
|
|
}
|
|
|
|
return url
|
|
}
|