mirror of
https://github.com/xaionaro-go/streamctl.git
synced 2025-11-02 11:44:07 +08:00
Multiple updates
This commit is contained in:
181
pkg/streamcontrol/twitch/auth/client_code.go
Normal file
181
pkg/streamcontrol/twitch/auth/client_code.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/facebookincubator/go-belt/tool/experimental/errmon"
|
||||
"github.com/facebookincubator/go-belt/tool/logger"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/nicklaw5/helix/v2"
|
||||
"github.com/xaionaro-go/observability"
|
||||
"github.com/xaionaro-go/streamctl/pkg/oauthhandler"
|
||||
)
|
||||
|
||||
type OAuthHandler func(context.Context, oauthhandler.OAuthHandlerArgument) error
|
||||
|
||||
func NewClientCode(
|
||||
ctx context.Context,
|
||||
clientID string,
|
||||
oauthHandler OAuthHandler,
|
||||
getOAuthListenPortsFn func() []uint16,
|
||||
onNewClientCode func(string),
|
||||
) (_err error) {
|
||||
logger.Debugf(ctx, "getNewClientCode")
|
||||
defer func() { logger.Debugf(ctx, "/getNewClientCode: %v", _err) }()
|
||||
|
||||
if oauthHandler == nil {
|
||||
oauthHandler = oauthhandler.OAuth2HandlerViaCLI
|
||||
}
|
||||
|
||||
ctx, ctxCancelFunc := context.WithCancel(ctx)
|
||||
cancelFunc := func() {
|
||||
logger.Debugf(ctx, "cancelling the context")
|
||||
ctxCancelFunc()
|
||||
}
|
||||
|
||||
var errWg sync.WaitGroup
|
||||
var resultErr error
|
||||
errCh := make(chan error)
|
||||
errWg.Add(1)
|
||||
observability.Go(ctx, func(ctx context.Context) {
|
||||
errWg.Done()
|
||||
for err := range errCh {
|
||||
errmon.ObserveErrorCtx(ctx, err)
|
||||
resultErr = multierror.Append(resultErr, err)
|
||||
}
|
||||
})
|
||||
|
||||
alreadyListening := map[uint16]struct{}{}
|
||||
var wg sync.WaitGroup
|
||||
success := false
|
||||
|
||||
startHandlerForPort := func(listenPort uint16) {
|
||||
if _, ok := alreadyListening[listenPort]; ok {
|
||||
return
|
||||
}
|
||||
alreadyListening[listenPort] = struct{}{}
|
||||
|
||||
logger.Debugf(ctx, "starting the oauth handler at port %d", listenPort)
|
||||
wg.Add(1)
|
||||
{
|
||||
listenPort := listenPort
|
||||
observability.Go(ctx, func(ctx context.Context) {
|
||||
defer func() { logger.Debugf(ctx, "ended the oauth handler at port %d", listenPort) }()
|
||||
defer wg.Done()
|
||||
authURL := GetAuthorizationURL(
|
||||
&helix.AuthorizationURLParams{
|
||||
ResponseType: "code", // or "token"
|
||||
Scopes: []string{
|
||||
"user:read:chat",
|
||||
"chat:read",
|
||||
"chat:edit",
|
||||
"channel:manage:broadcast",
|
||||
"moderator:manage:chat_messages",
|
||||
"moderator:manage:banned_users",
|
||||
},
|
||||
},
|
||||
clientID,
|
||||
RedirectURI(listenPort),
|
||||
)
|
||||
|
||||
arg := oauthhandler.OAuthHandlerArgument{
|
||||
AuthURL: authURL,
|
||||
ListenPort: listenPort,
|
||||
ExchangeFn: func(ctx context.Context, code string) (_err error) {
|
||||
logger.Debugf(ctx, "ExchangeFn()")
|
||||
defer func() { logger.Debugf(ctx, "/ExchangeFn(): %v", _err) }()
|
||||
if code == "" {
|
||||
return fmt.Errorf("code is empty")
|
||||
}
|
||||
onNewClientCode(code)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
err := oauthHandler(ctx, arg)
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("unable to get or exchange the oauth code to a token: %w", err)
|
||||
return
|
||||
}
|
||||
cancelFunc()
|
||||
success = true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: either support only one port as in New, or support multiple
|
||||
// ports as we do below
|
||||
getPortsFn := getOAuthListenPortsFn
|
||||
if getPortsFn == nil {
|
||||
return fmt.Errorf("the function GetOAuthListenPorts is not set")
|
||||
}
|
||||
|
||||
for _, listenPort := range getPortsFn() {
|
||||
startHandlerForPort(listenPort)
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
observability.Go(ctx, func(ctx context.Context) {
|
||||
defer wg.Done()
|
||||
t := time.NewTicker(time.Second)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
}
|
||||
ports := getPortsFn()
|
||||
logger.Tracef(ctx, "oauth listener ports: %#+v", ports)
|
||||
|
||||
for _, listenPort := range ports {
|
||||
startHandlerForPort(listenPort)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
observability.Go(ctx, func(ctx context.Context) {
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
})
|
||||
<-ctx.Done()
|
||||
logger.Debugf(ctx, "did successfully took a new client code? -- %v", success)
|
||||
if !success {
|
||||
errWg.Wait()
|
||||
return resultErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RedirectURI(listenPort uint16) string {
|
||||
return fmt.Sprintf("http://localhost:%d/", listenPort)
|
||||
}
|
||||
|
||||
func GetAuthorizationURL(
|
||||
params *helix.AuthorizationURLParams,
|
||||
clientID string,
|
||||
redirectURI string,
|
||||
) string {
|
||||
url := helix.AuthBaseURL + "/authorize"
|
||||
url += "?response_type=" + params.ResponseType
|
||||
url += "&client_id=" + clientID
|
||||
url += "&redirect_uri=" + redirectURI
|
||||
|
||||
if params.State != "" {
|
||||
url += "&state=" + params.State
|
||||
}
|
||||
|
||||
if params.ForceVerify {
|
||||
url += "&force_verify=true"
|
||||
}
|
||||
|
||||
if len(params.Scopes) != 0 {
|
||||
url += "&scope=" + strings.Join(params.Scopes, "%20")
|
||||
}
|
||||
|
||||
return url
|
||||
}
|
||||
Reference in New Issue
Block a user