Files
core/http/jwt/jwks/jwks.go
2022-08-18 10:27:33 +03:00

367 lines
8.9 KiB
Go

package jwks
import (
"bytes"
"context"
"crypto"
"encoding/json"
"errors"
"io"
"net/http"
"sync"
"time"
)
var (
// ErrKIDNotFound indicates that the given key ID was not found in the JWKs.
ErrKIDNotFound = errors.New("the given key ID was not found in the JWKs")
// ErrMissingAssets indicates there are required assets missing to create a public key.
ErrMissingAssets = errors.New("required assets are missing to create a public key")
// ErrUnknownKeyType indicated that a key type is not implemented
ErrUnknownKeyType = errors.New("the key has an unknown type")
)
// ErrorHandler is a function signature that consumes an error.
type ErrorHandler func(err error)
type JWK interface {
ID() string
Alg() string
Type() string
PublicKey() (crypto.PublicKey, error)
}
type JWKS interface {
Key(kid string) (jwk JWK, err error)
Cancel()
}
// JWKs represents a JSON Web Key Set.
type jwksImpl struct {
keys map[string]*jwkImpl
cancel context.CancelFunc
client *http.Client
ctx context.Context
jwksURL string
mux sync.RWMutex
refreshErrorHandler ErrorHandler
refreshInterval time.Duration
refreshRateLimit time.Duration
refreshRequests chan context.CancelFunc
refreshTimeout time.Duration
refreshUnknownKID bool
}
// rawJWK represents a raw key inside a JWKs.
type jwkImpl struct {
key rawJWK
precomputed interface{}
}
func (j *jwkImpl) ID() string {
return j.key.Kid
}
func (j *jwkImpl) Alg() string {
return j.key.Algorithm
}
func (j *jwkImpl) Type() string {
return j.key.KeyType
}
func (j *jwkImpl) PublicKey() (crypto.PublicKey, error) {
if j.key.KeyType == "RSA" {
return j.rsa()
} else if j.key.KeyType == "ecdsa" {
return j.ecdsa()
}
return nil, ErrUnknownKeyType
}
type rawJWK struct {
Algorithm string `json:"alg"`
KeyType string `json:"kty"`
Use string `json:"use"`
Curve string `json:"crv"`
Exponent string `json:"e"`
Kid string `json:"kid"`
Modulus string `json:"n"`
X string `json:"x"`
Y string `json:"y"`
X5C []string `json:"x5c"`
}
// rawJWKs represents a JWKs in JSON format.
type rawJWKs struct {
Keys []rawJWK `json:"keys"`
}
// New creates a new JWKs from a raw JSON message.
func NewFromJSON(jwksBytes []byte) (JWKS, error) {
// Iterate through the keys in the raw JWKs. Add them to the JWKs.
jwks := &jwksImpl{
keys: map[string]*jwkImpl{},
}
if err := jwks.update(jwksBytes); err != nil {
return nil, err
}
return jwks, nil
}
// NewFromURL loads the JWKs at the given URL.
func NewFromURL(jwksURL string, config Config) (JWKS, error) {
// Apply the options to the JWKs.
applyConfigDefaults(&config)
// Create the JWKs.
jwks := &jwksImpl{
keys: map[string]*jwkImpl{},
jwksURL: jwksURL,
client: config.Client,
refreshTimeout: config.RefreshTimeout,
refreshErrorHandler: config.RefreshErrorHandler,
refreshRateLimit: config.RefreshRateLimit,
refreshInterval: config.RefreshInterval,
refreshUnknownKID: config.RefreshUnknownKID,
}
// Get the keys for the JWKs.
if err := jwks.refresh(); err != nil {
return nil, err
}
// Check to see if a background refresh of the JWKs should happen.
if jwks.refreshInterval > 0 || jwks.refreshRateLimit > 0 || jwks.refreshUnknownKID {
// Attach a context used to end the background goroutine.
jwks.ctx, jwks.cancel = context.WithCancel(context.Background())
// Create a channel that will accept requests to refresh the JWKs.
jwks.refreshRequests = make(chan context.CancelFunc, 1)
// Start the background goroutine for data refresh.
go jwks.backgroundRefresh()
}
return jwks, nil
}
// Cancel ends the background goroutine to update the JWKs. It can only happen once and is only effective if the
// JWKs has a background goroutine refreshing the JWKs keys.
func (j *jwksImpl) Cancel() {
if j.cancel != nil {
j.cancel()
}
}
// Key gets the JWK from the given KID from the JWKs. It may refresh the JWKs if configured to.
func (j *jwksImpl) Key(kid string) (JWK, error) {
// Get the JSONKey from the JWKs.
var key *jwkImpl
var ok bool
j.mux.RLock()
if j.keys != nil {
key, ok = j.keys[kid]
}
j.mux.RUnlock()
// Check if the key was present.
if !ok {
// Check to see if configured to refresh on unknown kid.
if j.refreshUnknownKID {
// Create a context for refreshing the JWKs.
ctx, cancel := context.WithCancel(j.ctx)
// Refresh the JWKs.
select {
case <-j.ctx.Done():
return key, nil
case j.refreshRequests <- cancel:
default:
// If the j.refreshRequests channel is full, return the error early.
return nil, ErrKIDNotFound
}
// Wait for the JWKs refresh to done.
<-ctx.Done()
// Lock the JWKs for async safe use.
j.mux.RLock()
defer j.mux.RUnlock()
// Check if the JWKs refresh contained the requested key.
if key, ok = j.keys[kid]; ok {
return key, nil
}
}
return nil, ErrKIDNotFound
}
return key, nil
}
// backgroundRefresh is meant to be a separate goroutine that will update the keys in a JWKs over a given interval of
// time.
func (j *jwksImpl) backgroundRefresh() {
// Create some rate limiting assets.
var lastRefresh time.Time
var queueOnce sync.Once
var refreshMux sync.Mutex
if j.refreshRateLimit > 0 {
lastRefresh = time.Now().Add(-j.refreshRateLimit)
}
// Create a channel that will never send anything unless there is a refresh interval.
refreshInterval := make(<-chan time.Time)
// Enter an infinite loop that ends when the background ends.
for {
// If there is a refresh interval, create the channel for it.
if j.refreshInterval > 0 {
refreshInterval = time.After(j.refreshInterval)
}
// Wait for a refresh to occur or the background to end.
select {
// Send a refresh request the JWKs after the given interval.
case <-refreshInterval:
select {
case <-j.ctx.Done():
return
case j.refreshRequests <- func() {}:
default: // If the j.refreshRequests channel is full, don't don't send another request.
}
// Accept refresh requests.
case cancel := <-j.refreshRequests:
// Rate limit, if needed.
refreshMux.Lock()
if j.refreshRateLimit > 0 && lastRefresh.Add(j.refreshRateLimit).After(time.Now()) {
// Don't make the JWT parsing goroutine wait for the JWKs to refresh.
cancel()
// Only queue a refresh once.
queueOnce.Do(func() {
// Launch a goroutine that will get a reservation for a JWKs refresh or fail to and immediately return.
go func() {
// Wait for the next time to refresh.
refreshMux.Lock()
wait := time.Until(lastRefresh.Add(j.refreshRateLimit))
refreshMux.Unlock()
select {
case <-j.ctx.Done():
return
case <-time.After(wait):
}
// Refresh the JWKs.
refreshMux.Lock()
defer refreshMux.Unlock()
if err := j.refresh(); err != nil && j.refreshErrorHandler != nil {
j.refreshErrorHandler(err)
}
// Reset the last time for the refresh to now.
lastRefresh = time.Now()
// Allow another queue.
queueOnce = sync.Once{}
}()
})
} else {
// Refresh the JWKs.
if err := j.refresh(); err != nil && j.refreshErrorHandler != nil {
j.refreshErrorHandler(err)
}
// Reset the last time for the refresh to now.
lastRefresh = time.Now()
// Allow the JWT parsing goroutine to continue with the refreshed JWKs.
cancel()
}
refreshMux.Unlock()
// Clean up this goroutine when its context expires.
case <-j.ctx.Done():
return
}
}
}
// refresh does an HTTP GET on the JWKs URL to rebuild the JWKs.
func (j *jwksImpl) refresh() (err error) {
// Create a context for the request.
var ctx context.Context
var cancel context.CancelFunc
if j.ctx != nil {
ctx, cancel = context.WithTimeout(j.ctx, j.refreshTimeout)
} else {
ctx, cancel = context.WithTimeout(context.Background(), j.refreshTimeout)
}
defer cancel()
// Create the HTTP request.
var req *http.Request
if req, err = http.NewRequestWithContext(ctx, http.MethodGet, j.jwksURL, bytes.NewReader(nil)); err != nil {
return err
}
// Get the JWKs as JSON from the given URL.
var resp *http.Response
if resp, err = j.client.Do(req); err != nil {
return err
}
defer resp.Body.Close() // Ignore any error.
// Read the raw JWKs from the body of the response.
var jwksBytes []byte
if jwksBytes, err = io.ReadAll(resp.Body); err != nil {
return err
}
if err = j.update(jwksBytes); err != nil {
return err
}
return nil
}
func (j *jwksImpl) update(jwksBytes []byte) error {
// Turn the raw JWKs into the correct Go type.
var rawKS rawJWKs
if err := json.Unmarshal(jwksBytes, &rawKS); err != nil {
return err
}
keys := map[string]*jwkImpl{}
for _, k := range rawKS.Keys {
if k.Use != "sig" {
continue
}
key := &jwkImpl{
key: k,
}
keys[k.Kid] = key
}
// Lock the JWKs for async safe usage.
j.mux.Lock()
defer j.mux.Unlock()
j.keys = keys
return nil
}