package utils import ( "context" "encoding/json" "errors" "fmt" "strconv" "time" "github.com/golang-jwt/jwt/v4" "github.com/redis/go-redis/v9" ) type TokenPayload struct { UserId int64 IsAdmin bool } type Claims struct { TokenPayload jwt.RegisteredClaims } const ( tokenCachePrefixUser = "jwt:user:" tokenCachePrefixToken = "jwt:token:" tokenCacheTTL = 30 * 24 * time.Hour tokenLifetime = 7 * 24 * time.Hour ) var ( errMissingToken = errors.New("token missing in request") errInvalidToken = errors.New("invalid token claims") errTokenNotInCache = errors.New("token not found in cache") errNoRedisClient = errors.New("redis client not configured") errInvalidUserID = errors.New("invalid user id") // errExpiredToken = errors.New("token expired") ) type JwtManager struct { redisCluster *redis.ClusterClient secretKey string issuer string } func NewJwtManager(redisCluster *redis.ClusterClient, secretKey, issuer string) *JwtManager { return &JwtManager{ redisCluster: redisCluster, secretKey: secretKey, issuer: issuer, } } // New 生成新的 JWT token func (m *JwtManager) New(ctx context.Context, payload *TokenPayload) (string, error) { if m.redisCluster == nil { return "", errNoRedisClient } now := time.Now() expiresAt := now.Add(tokenLifetime) claims := &Claims{ TokenPayload: *payload, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(expiresAt), IssuedAt: jwt.NewNumericDate(now), Issuer: m.issuer, }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString([]byte(m.secretKey)) if err != nil { return "", err } userKey := tokenCachePrefixUser + strconv.FormatInt(claims.UserId, 10) tokenKey := tokenCachePrefixToken + tokenString tokenData, _ := json.Marshal(payload) // 同时存储两个 key:用户 -> token 和 token -> payload pipe := m.redisCluster.Pipeline() pipe.Set(ctx, userKey, tokenString, tokenCacheTTL) pipe.Set(ctx, tokenKey, string(tokenData), tokenCacheTTL) _, err = pipe.Exec(ctx) if err != nil { return "", err } return tokenString, nil } // Valid 验证 token 有效性,支持自动换票 func (m *JwtManager) Valid(ctx context.Context, tokenString string) (*TokenPayload, error) { if m.redisCluster == nil { return nil, errNoRedisClient } if tokenString == "" { return nil, errMissingToken } // 检查 token 是否在 Redis 中 tokenKey := tokenCachePrefixToken + tokenString tokenData, err := m.redisCluster.Get(ctx, tokenKey).Result() if err != nil && !errors.Is(err, redis.Nil) { return nil, err } var payload TokenPayload if errors.Is(err, redis.Nil) { return nil, errTokenNotInCache } err = json.Unmarshal([]byte(tokenData), &payload) if err != nil { return nil, errInvalidToken } // 解析 JWT 并验证签名和过期时间 token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { return []byte(m.secretKey), nil }) if err != nil { if errors.Is(err, jwt.ErrTokenExpired) { if _, renewErr := m.Renew(ctx, tokenString); renewErr != nil { return nil, renewErr } if token != nil { if claims, ok := token.Claims.(*Claims); ok { return &claims.TokenPayload, nil } } return &payload, nil } return nil, err } claims, ok := token.Claims.(*Claims) if !ok || !token.Valid { return nil, errInvalidToken } return &claims.TokenPayload, nil } // Renew 换票逻辑:如果 token 过期但 Redis 中还存在,则生成新 token func (m *JwtManager) Renew(ctx context.Context, tokenString string) (string, error) { if m.redisCluster == nil { return "", errNoRedisClient } // 检查 token 是否在 Redis 中(不检查过期时间) tokenKey := tokenCachePrefixToken + tokenString tokenData, err := m.redisCluster.Get(ctx, tokenKey).Result() if err != nil { if errors.Is(err, redis.Nil) { return "", errTokenNotInCache } return "", err } var payload TokenPayload err = json.Unmarshal([]byte(tokenData), &payload) if err != nil { return "", errInvalidToken } // 删除旧 token 记录 userKey := tokenCachePrefixUser + strconv.FormatInt(payload.UserId, 10) m.redisCluster.Del(ctx, tokenKey, userKey) // 生成新 token return m.New(ctx, &payload) } // Extract payload from token without validating expiration (used for auto-renewal) func (m *JwtManager) Extract(_ context.Context, tokenString string) (*TokenPayload, error) { token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { return []byte(m.secretKey), nil }) if err != nil { return nil, err } claims, ok := token.Claims.(*Claims) if !ok { return nil, errInvalidToken } return &claims.TokenPayload, nil } // Exists check if token exists in Redis (i.e. is valid and not revoked) func (m *JwtManager) Exists(ctx context.Context, tokenString string) (bool, error) { if m.redisCluster == nil { return false, errNoRedisClient } tokenKey := tokenCachePrefixToken + tokenString exists, err := m.redisCluster.Exists(ctx, tokenKey).Result() if err != nil { return false, err } return exists > 0, nil } // ClaimsToPayload extract payload from JWT claims func (m *JwtManager) ClaimsToPayload(claims *Claims) *TokenPayload { return &claims.TokenPayload } // Revoke revoke token by deleting both user -> token and token -> payload keys from Redis func (m *JwtManager) Revoke(ctx context.Context, tokenString string) error { if m.redisCluster == nil { return errNoRedisClient } payload, err := m.Extract(ctx, tokenString) if err != nil { return err } userKey := tokenCachePrefixUser + strconv.FormatInt(payload.UserId, 10) tokenKey := tokenCachePrefixToken + tokenString pipe := m.redisCluster.Pipeline() pipe.Del(ctx, userKey) pipe.Del(ctx, tokenKey) _, err = pipe.Exec(ctx) return err } func (m *JwtManager) GetUserToken(ctx context.Context, userID int64) (string, error) { if m.redisCluster == nil { return "", errNoRedisClient } //userID, err := strconv.FormatInt(userID, 10) id := strconv.FormatInt(userID, 10) userKey := tokenCachePrefixUser + id token, err := m.redisCluster.Get(ctx, userKey).Result() if err != nil { if errors.Is(err, redis.Nil) { return "", fmt.Errorf("user %v has no token", userID) } return "", err } return token, nil } // Logout 按用户登出:删除 user->token 和 token->payload 两类缓存数据 func (m *JwtManager) Logout(ctx context.Context, userID int64) error { if m.redisCluster == nil { return errNoRedisClient } if userID <= 0 { return errInvalidUserID } userKey := tokenCachePrefixUser + strconv.FormatInt(userID, 10) tokenString, err := m.redisCluster.Get(ctx, userKey).Result() if err != nil && !errors.Is(err, redis.Nil) { return err } pipe := m.redisCluster.Pipeline() pipe.Del(ctx, userKey) if !errors.Is(err, redis.Nil) && tokenString != "" { tokenKey := tokenCachePrefixToken + tokenString pipe.Del(ctx, tokenKey) } _, execErr := pipe.Exec(ctx) return execErr }