Files
streamctl/pkg/streamcontrol/twitch/auth/client_code.go
Dmitrii Okunev 1004082fe4
Some checks failed
rolling-release / build (push) Has been cancelled
rolling-release / rolling-release (push) Has been cancelled
Multiple updates
2025-07-12 23:11:42 +01:00

182 lines
4.2 KiB
Go

package auth
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/facebookincubator/go-belt/tool/experimental/errmon"
"github.com/facebookincubator/go-belt/tool/logger"
"github.com/hashicorp/go-multierror"
"github.com/nicklaw5/helix/v2"
"github.com/xaionaro-go/observability"
"github.com/xaionaro-go/streamctl/pkg/oauthhandler"
)
type OAuthHandler func(context.Context, oauthhandler.OAuthHandlerArgument) error
func NewClientCode(
ctx context.Context,
clientID string,
oauthHandler OAuthHandler,
getOAuthListenPortsFn func() []uint16,
onNewClientCode func(string),
) (_err error) {
logger.Debugf(ctx, "getNewClientCode")
defer func() { logger.Debugf(ctx, "/getNewClientCode: %v", _err) }()
if oauthHandler == nil {
oauthHandler = oauthhandler.OAuth2HandlerViaCLI
}
ctx, ctxCancelFunc := context.WithCancel(ctx)
cancelFunc := func() {
logger.Debugf(ctx, "cancelling the context")
ctxCancelFunc()
}
var errWg sync.WaitGroup
var resultErr error
errCh := make(chan error)
errWg.Add(1)
observability.Go(ctx, func(ctx context.Context) {
errWg.Done()
for err := range errCh {
errmon.ObserveErrorCtx(ctx, err)
resultErr = multierror.Append(resultErr, err)
}
})
alreadyListening := map[uint16]struct{}{}
var wg sync.WaitGroup
success := false
startHandlerForPort := func(listenPort uint16) {
if _, ok := alreadyListening[listenPort]; ok {
return
}
alreadyListening[listenPort] = struct{}{}
logger.Debugf(ctx, "starting the oauth handler at port %d", listenPort)
wg.Add(1)
{
listenPort := listenPort
observability.Go(ctx, func(ctx context.Context) {
defer func() { logger.Debugf(ctx, "ended the oauth handler at port %d", listenPort) }()
defer wg.Done()
authURL := GetAuthorizationURL(
&helix.AuthorizationURLParams{
ResponseType: "code", // or "token"
Scopes: []string{
"user:read:chat",
"chat:read",
"chat:edit",
"channel:manage:broadcast",
"moderator:manage:chat_messages",
"moderator:manage:banned_users",
},
},
clientID,
RedirectURI(listenPort),
)
arg := oauthhandler.OAuthHandlerArgument{
AuthURL: authURL,
ListenPort: listenPort,
ExchangeFn: func(ctx context.Context, code string) (_err error) {
logger.Debugf(ctx, "ExchangeFn()")
defer func() { logger.Debugf(ctx, "/ExchangeFn(): %v", _err) }()
if code == "" {
return fmt.Errorf("code is empty")
}
onNewClientCode(code)
return nil
},
}
err := oauthHandler(ctx, arg)
if err != nil {
errCh <- fmt.Errorf("unable to get or exchange the oauth code to a token: %w", err)
return
}
cancelFunc()
success = true
})
}
}
// TODO: either support only one port as in New, or support multiple
// ports as we do below
getPortsFn := getOAuthListenPortsFn
if getPortsFn == nil {
return fmt.Errorf("the function GetOAuthListenPorts is not set")
}
for _, listenPort := range getPortsFn() {
startHandlerForPort(listenPort)
}
wg.Add(1)
observability.Go(ctx, func(ctx context.Context) {
defer wg.Done()
t := time.NewTicker(time.Second)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
}
ports := getPortsFn()
logger.Tracef(ctx, "oauth listener ports: %#+v", ports)
for _, listenPort := range ports {
startHandlerForPort(listenPort)
}
}
})
observability.Go(ctx, func(ctx context.Context) {
wg.Wait()
close(errCh)
})
<-ctx.Done()
logger.Debugf(ctx, "did successfully took a new client code? -- %v", success)
if !success {
errWg.Wait()
return resultErr
}
return nil
}
func RedirectURI(listenPort uint16) string {
return fmt.Sprintf("http://localhost:%d/", listenPort)
}
func GetAuthorizationURL(
params *helix.AuthorizationURLParams,
clientID string,
redirectURI string,
) string {
url := helix.AuthBaseURL + "/authorize"
url += "?response_type=" + params.ResponseType
url += "&client_id=" + clientID
url += "&redirect_uri=" + redirectURI
if params.State != "" {
url += "&state=" + params.State
}
if params.ForceVerify {
url += "&force_verify=true"
}
if len(params.Scopes) != 0 {
url += "&scope=" + strings.Join(params.Scopes, "%20")
}
return url
}