288 lines
7.0 KiB
Go
288 lines
7.0 KiB
Go
package utils
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v4"
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
type TokenPayload struct {
|
|
UserId int64
|
|
IsAdmin bool
|
|
}
|
|
|
|
type Claims struct {
|
|
TokenPayload
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
const (
|
|
tokenCachePrefixUser = "jwt:user:"
|
|
tokenCachePrefixToken = "jwt:token:"
|
|
tokenCacheTTL = 30 * 24 * time.Hour
|
|
tokenLifetime = 7 * 24 * time.Hour
|
|
)
|
|
|
|
var (
|
|
errMissingToken = errors.New("token missing in request")
|
|
errInvalidToken = errors.New("invalid token claims")
|
|
errTokenNotInCache = errors.New("token not found in cache")
|
|
errNoRedisClient = errors.New("redis client not configured")
|
|
errInvalidUserID = errors.New("invalid user id")
|
|
// errExpiredToken = errors.New("token expired")
|
|
)
|
|
|
|
type JwtManager struct {
|
|
redisCluster *redis.ClusterClient
|
|
secretKey string
|
|
issuer string
|
|
}
|
|
|
|
func NewJwtManager(redisCluster *redis.ClusterClient, secretKey, issuer string) *JwtManager {
|
|
return &JwtManager{
|
|
redisCluster: redisCluster,
|
|
secretKey: secretKey,
|
|
issuer: issuer,
|
|
}
|
|
}
|
|
|
|
// New 生成新的 JWT token
|
|
func (m *JwtManager) New(ctx context.Context, payload *TokenPayload) (string, error) {
|
|
if m.redisCluster == nil {
|
|
return "", errNoRedisClient
|
|
}
|
|
|
|
now := time.Now()
|
|
expiresAt := now.Add(tokenLifetime)
|
|
|
|
claims := &Claims{
|
|
TokenPayload: *payload,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
Issuer: m.issuer,
|
|
},
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
tokenString, err := token.SignedString([]byte(m.secretKey))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
userKey := tokenCachePrefixUser + strconv.FormatInt(claims.UserId, 10)
|
|
tokenKey := tokenCachePrefixToken + tokenString
|
|
|
|
tokenData, _ := json.Marshal(payload)
|
|
|
|
// 同时存储两个 key:用户 -> token 和 token -> payload
|
|
pipe := m.redisCluster.Pipeline()
|
|
pipe.Set(ctx, userKey, tokenString, tokenCacheTTL)
|
|
pipe.Set(ctx, tokenKey, string(tokenData), tokenCacheTTL)
|
|
_, err = pipe.Exec(ctx)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return tokenString, nil
|
|
}
|
|
|
|
// Valid 验证 token 有效性,支持自动换票
|
|
func (m *JwtManager) Valid(ctx context.Context, tokenString string) (*TokenPayload, error) {
|
|
if m.redisCluster == nil {
|
|
return nil, errNoRedisClient
|
|
}
|
|
|
|
if tokenString == "" {
|
|
return nil, errMissingToken
|
|
}
|
|
|
|
// 检查 token 是否在 Redis 中
|
|
tokenKey := tokenCachePrefixToken + tokenString
|
|
tokenData, err := m.redisCluster.Get(ctx, tokenKey).Result()
|
|
if err != nil && !errors.Is(err, redis.Nil) {
|
|
return nil, err
|
|
}
|
|
|
|
var payload TokenPayload
|
|
if errors.Is(err, redis.Nil) {
|
|
return nil, errTokenNotInCache
|
|
}
|
|
|
|
err = json.Unmarshal([]byte(tokenData), &payload)
|
|
if err != nil {
|
|
return nil, errInvalidToken
|
|
}
|
|
|
|
// 解析 JWT 并验证签名和过期时间
|
|
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
|
return []byte(m.secretKey), nil
|
|
})
|
|
|
|
if err != nil {
|
|
if errors.Is(err, jwt.ErrTokenExpired) {
|
|
if _, renewErr := m.Renew(ctx, tokenString); renewErr != nil {
|
|
return nil, renewErr
|
|
}
|
|
|
|
if token != nil {
|
|
if claims, ok := token.Claims.(*Claims); ok {
|
|
return &claims.TokenPayload, nil
|
|
}
|
|
}
|
|
|
|
return &payload, nil
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
claims, ok := token.Claims.(*Claims)
|
|
if !ok || !token.Valid {
|
|
return nil, errInvalidToken
|
|
}
|
|
|
|
return &claims.TokenPayload, nil
|
|
}
|
|
|
|
// Renew 换票逻辑:如果 token 过期但 Redis 中还存在,则生成新 token
|
|
func (m *JwtManager) Renew(ctx context.Context, tokenString string) (string, error) {
|
|
if m.redisCluster == nil {
|
|
return "", errNoRedisClient
|
|
}
|
|
|
|
// 检查 token 是否在 Redis 中(不检查过期时间)
|
|
tokenKey := tokenCachePrefixToken + tokenString
|
|
tokenData, err := m.redisCluster.Get(ctx, tokenKey).Result()
|
|
if err != nil {
|
|
if errors.Is(err, redis.Nil) {
|
|
return "", errTokenNotInCache
|
|
}
|
|
return "", err
|
|
}
|
|
|
|
var payload TokenPayload
|
|
err = json.Unmarshal([]byte(tokenData), &payload)
|
|
if err != nil {
|
|
return "", errInvalidToken
|
|
}
|
|
|
|
// 删除旧 token 记录
|
|
userKey := tokenCachePrefixUser + strconv.FormatInt(payload.UserId, 10)
|
|
m.redisCluster.Del(ctx, tokenKey, userKey)
|
|
|
|
// 生成新 token
|
|
return m.New(ctx, &payload)
|
|
}
|
|
|
|
// Extract payload from token without validating expiration (used for auto-renewal)
|
|
func (m *JwtManager) Extract(_ context.Context, tokenString string) (*TokenPayload, error) {
|
|
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
|
return []byte(m.secretKey), nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
claims, ok := token.Claims.(*Claims)
|
|
if !ok {
|
|
return nil, errInvalidToken
|
|
}
|
|
|
|
return &claims.TokenPayload, nil
|
|
}
|
|
|
|
// Exists check if token exists in Redis (i.e. is valid and not revoked)
|
|
func (m *JwtManager) Exists(ctx context.Context, tokenString string) (bool, error) {
|
|
if m.redisCluster == nil {
|
|
return false, errNoRedisClient
|
|
}
|
|
|
|
tokenKey := tokenCachePrefixToken + tokenString
|
|
exists, err := m.redisCluster.Exists(ctx, tokenKey).Result()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return exists > 0, nil
|
|
}
|
|
|
|
// ClaimsToPayload extract payload from JWT claims
|
|
func (m *JwtManager) ClaimsToPayload(claims *Claims) *TokenPayload {
|
|
return &claims.TokenPayload
|
|
}
|
|
|
|
// Revoke revoke token by deleting both user -> token and token -> payload keys from Redis
|
|
func (m *JwtManager) Revoke(ctx context.Context, tokenString string) error {
|
|
if m.redisCluster == nil {
|
|
return errNoRedisClient
|
|
}
|
|
|
|
payload, err := m.Extract(ctx, tokenString)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
userKey := tokenCachePrefixUser + strconv.FormatInt(payload.UserId, 10)
|
|
tokenKey := tokenCachePrefixToken + tokenString
|
|
|
|
pipe := m.redisCluster.Pipeline()
|
|
pipe.Del(ctx, userKey)
|
|
pipe.Del(ctx, tokenKey)
|
|
_, err = pipe.Exec(ctx)
|
|
return err
|
|
}
|
|
|
|
func (m *JwtManager) GetUserToken(ctx context.Context, userID int64) (string, error) {
|
|
if m.redisCluster == nil {
|
|
return "", errNoRedisClient
|
|
}
|
|
//userID, err := strconv.FormatInt(userID, 10)
|
|
id := strconv.FormatInt(userID, 10)
|
|
|
|
userKey := tokenCachePrefixUser + id
|
|
token, err := m.redisCluster.Get(ctx, userKey).Result()
|
|
if err != nil {
|
|
if errors.Is(err, redis.Nil) {
|
|
return "", fmt.Errorf("user %v has no token", userID)
|
|
}
|
|
return "", err
|
|
}
|
|
|
|
return token, nil
|
|
}
|
|
|
|
// Logout 按用户登出:删除 user->token 和 token->payload 两类缓存数据
|
|
func (m *JwtManager) Logout(ctx context.Context, userID int64) error {
|
|
if m.redisCluster == nil {
|
|
return errNoRedisClient
|
|
}
|
|
|
|
if userID <= 0 {
|
|
return errInvalidUserID
|
|
}
|
|
|
|
userKey := tokenCachePrefixUser + strconv.FormatInt(userID, 10)
|
|
tokenString, err := m.redisCluster.Get(ctx, userKey).Result()
|
|
if err != nil && !errors.Is(err, redis.Nil) {
|
|
return err
|
|
}
|
|
|
|
pipe := m.redisCluster.Pipeline()
|
|
pipe.Del(ctx, userKey)
|
|
if !errors.Is(err, redis.Nil) && tokenString != "" {
|
|
tokenKey := tokenCachePrefixToken + tokenString
|
|
pipe.Del(ctx, tokenKey)
|
|
}
|
|
|
|
_, execErr := pipe.Exec(ctx)
|
|
return execErr
|
|
}
|