mirror of
https://github.com/zhufuyi/sponge.git
synced 2025-10-16 05:40:41 +08:00
176 lines
4.5 KiB
Go
176 lines
4.5 KiB
Go
package interceptor
|
|
|
|
import (
|
|
"context"
|
|
|
|
"github.com/zhufuyi/sponge/pkg/jwt"
|
|
|
|
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
|
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/metadata"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
// ---------------------------------- client ----------------------------------
|
|
|
|
// SetJwtTokenToCtx set the token to the context in rpc client side
|
|
// Example:
|
|
//
|
|
// ctx := SetJwtTokenToCtx(ctx, "Bearer jwt-token")
|
|
// cli.GetByID(ctx, req)
|
|
func SetJwtTokenToCtx(ctx context.Context, authorization string) context.Context {
|
|
md, ok := metadata.FromOutgoingContext(ctx)
|
|
if ok {
|
|
md.Set(headerAuthorize, authorization)
|
|
} else {
|
|
md = metadata.Pairs(headerAuthorize, authorization)
|
|
}
|
|
return metadata.NewOutgoingContext(ctx, md)
|
|
}
|
|
|
|
// ---------------------------------- server interceptor ----------------------------------
|
|
|
|
var (
|
|
headerAuthorize = "authorization"
|
|
|
|
// auth Scheme
|
|
authScheme = "Bearer"
|
|
|
|
// authentication information in ctx key name
|
|
authCtxClaimsName = "tokenInfo"
|
|
|
|
// collection of skip authentication methods
|
|
authIgnoreMethods = map[string]struct{}{}
|
|
)
|
|
|
|
// AuthOption setting the Authentication Field
|
|
type AuthOption func(*AuthOptions)
|
|
|
|
// AuthOptions settings
|
|
type AuthOptions struct {
|
|
authScheme string
|
|
ctxClaimsName string
|
|
ignoreMethods map[string]struct{}
|
|
}
|
|
|
|
func defaultAuthOptions() *AuthOptions {
|
|
return &AuthOptions{
|
|
authScheme: authScheme,
|
|
ctxClaimsName: authCtxClaimsName,
|
|
ignoreMethods: make(map[string]struct{}), // ways to ignore forensics
|
|
}
|
|
}
|
|
|
|
func (o *AuthOptions) apply(opts ...AuthOption) {
|
|
for _, opt := range opts {
|
|
opt(o)
|
|
}
|
|
}
|
|
|
|
// WithAuthScheme set the message prefix for authentication
|
|
func WithAuthScheme(scheme string) AuthOption {
|
|
return func(o *AuthOptions) {
|
|
o.authScheme = scheme
|
|
}
|
|
}
|
|
|
|
// WithAuthClaimsName set the key name of the information in ctx for authentication
|
|
func WithAuthClaimsName(claimsName string) AuthOption {
|
|
return func(o *AuthOptions) {
|
|
o.ctxClaimsName = claimsName
|
|
}
|
|
}
|
|
|
|
// WithAuthIgnoreMethods ways to ignore forensics
|
|
// fullMethodName format: /packageName.serviceName/methodName,
|
|
// example /api.userExample.v1.userExampleService/GetByID
|
|
func WithAuthIgnoreMethods(fullMethodNames ...string) AuthOption {
|
|
return func(o *AuthOptions) {
|
|
for _, method := range fullMethodNames {
|
|
o.ignoreMethods[method] = struct{}{}
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetAuthorization combining tokens into authentication information
|
|
func GetAuthorization(token string) string {
|
|
return authScheme + " " + token
|
|
}
|
|
|
|
// GetAuthCtxKey get the name of Claims
|
|
func GetAuthCtxKey() string {
|
|
return authCtxClaimsName
|
|
}
|
|
|
|
// JwtVerify get authorization from context to verify legitimacy, authorization composition format: authScheme token
|
|
func JwtVerify(ctx context.Context) (context.Context, error) {
|
|
token, err := grpc_auth.AuthFromMD(ctx, authScheme)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cc, err := jwt.VerifyToken(token)
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Unauthenticated, "%v", err)
|
|
}
|
|
|
|
newCtx := context.WithValue(ctx, authCtxClaimsName, cc) //nolint
|
|
// get value by ctx.Value(interceptor.GetAuthCtxKey()).(*jwt.CustomClaims)
|
|
|
|
return newCtx, nil
|
|
}
|
|
|
|
// UnaryServerJwtAuth jwt unary interceptor
|
|
func UnaryServerJwtAuth(opts ...AuthOption) grpc.UnaryServerInterceptor {
|
|
o := defaultAuthOptions()
|
|
o.apply(opts...)
|
|
authScheme = o.authScheme
|
|
authCtxClaimsName = o.ctxClaimsName
|
|
authIgnoreMethods = o.ignoreMethods
|
|
|
|
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
|
var newCtx context.Context
|
|
var err error
|
|
|
|
if _, ok := authIgnoreMethods[info.FullMethod]; ok {
|
|
newCtx = ctx
|
|
} else {
|
|
newCtx, err = JwtVerify(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return handler(newCtx, req)
|
|
}
|
|
}
|
|
|
|
// StreamServerJwtAuth jwt stream interceptor
|
|
func StreamServerJwtAuth(opts ...AuthOption) grpc.StreamServerInterceptor {
|
|
o := defaultAuthOptions()
|
|
o.apply(opts...)
|
|
authScheme = o.authScheme
|
|
authCtxClaimsName = o.ctxClaimsName
|
|
authIgnoreMethods = o.ignoreMethods
|
|
|
|
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
|
var newCtx context.Context
|
|
var err error
|
|
|
|
if _, ok := authIgnoreMethods[info.FullMethod]; ok {
|
|
newCtx = stream.Context()
|
|
} else {
|
|
newCtx, err = JwtVerify(stream.Context())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
wrapped := grpc_middleware.WrapServerStream(stream)
|
|
wrapped.WrappedContext = newCtx
|
|
return handler(srv, wrapped)
|
|
}
|
|
}
|