Files
sponge/pkg/grpc/interceptor/jwtAuth.go
2023-08-06 17:41:28 +08:00

176 lines
4.5 KiB
Go

package interceptor
import (
"context"
"github.com/zhufuyi/sponge/pkg/jwt"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
// ---------------------------------- client ----------------------------------
// SetJwtTokenToCtx set the token to the context in rpc client side
// Example:
//
// ctx := SetJwtTokenToCtx(ctx, "Bearer jwt-token")
// cli.GetByID(ctx, req)
func SetJwtTokenToCtx(ctx context.Context, authorization string) context.Context {
md, ok := metadata.FromOutgoingContext(ctx)
if ok {
md.Set(headerAuthorize, authorization)
} else {
md = metadata.Pairs(headerAuthorize, authorization)
}
return metadata.NewOutgoingContext(ctx, md)
}
// ---------------------------------- server interceptor ----------------------------------
var (
headerAuthorize = "authorization"
// auth Scheme
authScheme = "Bearer"
// authentication information in ctx key name
authCtxClaimsName = "tokenInfo"
// collection of skip authentication methods
authIgnoreMethods = map[string]struct{}{}
)
// AuthOption setting the Authentication Field
type AuthOption func(*AuthOptions)
// AuthOptions settings
type AuthOptions struct {
authScheme string
ctxClaimsName string
ignoreMethods map[string]struct{}
}
func defaultAuthOptions() *AuthOptions {
return &AuthOptions{
authScheme: authScheme,
ctxClaimsName: authCtxClaimsName,
ignoreMethods: make(map[string]struct{}), // ways to ignore forensics
}
}
func (o *AuthOptions) apply(opts ...AuthOption) {
for _, opt := range opts {
opt(o)
}
}
// WithAuthScheme set the message prefix for authentication
func WithAuthScheme(scheme string) AuthOption {
return func(o *AuthOptions) {
o.authScheme = scheme
}
}
// WithAuthClaimsName set the key name of the information in ctx for authentication
func WithAuthClaimsName(claimsName string) AuthOption {
return func(o *AuthOptions) {
o.ctxClaimsName = claimsName
}
}
// WithAuthIgnoreMethods ways to ignore forensics
// fullMethodName format: /packageName.serviceName/methodName,
// example /api.userExample.v1.userExampleService/GetByID
func WithAuthIgnoreMethods(fullMethodNames ...string) AuthOption {
return func(o *AuthOptions) {
for _, method := range fullMethodNames {
o.ignoreMethods[method] = struct{}{}
}
}
}
// GetAuthorization combining tokens into authentication information
func GetAuthorization(token string) string {
return authScheme + " " + token
}
// GetAuthCtxKey get the name of Claims
func GetAuthCtxKey() string {
return authCtxClaimsName
}
// JwtVerify get authorization from context to verify legitimacy, authorization composition format: authScheme token
func JwtVerify(ctx context.Context) (context.Context, error) {
token, err := grpc_auth.AuthFromMD(ctx, authScheme)
if err != nil {
return nil, err
}
cc, err := jwt.VerifyToken(token)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "%v", err)
}
newCtx := context.WithValue(ctx, authCtxClaimsName, cc) //nolint
// get value by ctx.Value(interceptor.GetAuthCtxKey()).(*jwt.CustomClaims)
return newCtx, nil
}
// UnaryServerJwtAuth jwt unary interceptor
func UnaryServerJwtAuth(opts ...AuthOption) grpc.UnaryServerInterceptor {
o := defaultAuthOptions()
o.apply(opts...)
authScheme = o.authScheme
authCtxClaimsName = o.ctxClaimsName
authIgnoreMethods = o.ignoreMethods
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
var newCtx context.Context
var err error
if _, ok := authIgnoreMethods[info.FullMethod]; ok {
newCtx = ctx
} else {
newCtx, err = JwtVerify(ctx)
if err != nil {
return nil, err
}
}
return handler(newCtx, req)
}
}
// StreamServerJwtAuth jwt stream interceptor
func StreamServerJwtAuth(opts ...AuthOption) grpc.StreamServerInterceptor {
o := defaultAuthOptions()
o.apply(opts...)
authScheme = o.authScheme
authCtxClaimsName = o.ctxClaimsName
authIgnoreMethods = o.ignoreMethods
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
var newCtx context.Context
var err error
if _, ok := authIgnoreMethods[info.FullMethod]; ok {
newCtx = stream.Context()
} else {
newCtx, err = JwtVerify(stream.Context())
if err != nil {
return err
}
}
wrapped := grpc_middleware.WrapServerStream(stream)
wrapped.WrappedContext = newCtx
return handler(srv, wrapped)
}
}