602 lines
16 KiB
Markdown
602 lines
16 KiB
Markdown
# JWT 集成指南
|
||
|
||
指导如何将 JWT Manager 集成到 RPC Handlers 和业务逻辑中。
|
||
|
||
## 1. gRPC Unary Interceptor 实现
|
||
|
||
在 RPC 服务中添加 JWT 验证拦截器。
|
||
|
||
### 创建拦截器
|
||
|
||
创建文件 [app/users/rpc/internal/interceptor/jwt_interceptor.go](../../../app/users/rpc/internal/interceptor/jwt_interceptor.go):
|
||
|
||
```go
|
||
package interceptor
|
||
|
||
import (
|
||
"context"
|
||
"log"
|
||
|
||
"google.golang.org/grpc"
|
||
"google.golang.org/grpc/codes"
|
||
"google.golang.org/grpc/metadata"
|
||
"google.golang.org/grpc/status"
|
||
|
||
"yourmodule/app/users/rpc/internal/svc"
|
||
)
|
||
|
||
// JwtUnaryInterceptor 验证 gRPC 请求中的 JWT 令牌
|
||
func JwtUnaryInterceptor(svcCtx *svc.ServiceContext) grpc.UnaryServerInterceptor {
|
||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||
// 获取请求元数据
|
||
md, ok := metadata.FromIncomingContext(ctx)
|
||
if !ok {
|
||
return nil, status.Error(codes.Unauthenticated, "missing metadata")
|
||
}
|
||
|
||
// 从 Authorization 头提取令牌
|
||
tokens := md.Get("authorization")
|
||
if len(tokens) == 0 {
|
||
return nil, status.Error(codes.Unauthenticated, "missing authorization header")
|
||
}
|
||
|
||
token := tokens[0]
|
||
|
||
// 验证令牌
|
||
claims, err := svcCtx.JwtManager.Valid(ctx, token)
|
||
if err != nil {
|
||
log.Printf("Token validation failed: %v", err)
|
||
|
||
// 尝试刷新令牌(如果过期但仍在 Redis 中)
|
||
newToken, refreshErr := svcCtx.JwtManager.Renew(ctx, token)
|
||
if refreshErr == nil && newToken != "" {
|
||
// 在响应头中返回新令牌
|
||
grpc.SetHeader(ctx, metadata.Pairs("authorization", newToken))
|
||
// 继续处理请求,使用原令牌的声明
|
||
// 注意:实际应用中需要重新验证新令牌
|
||
newClaims, err := svcCtx.JwtManager.Valid(ctx, newToken)
|
||
if err != nil {
|
||
return nil, status.Error(codes.Unauthenticated, "token refresh failed")
|
||
}
|
||
claims = newClaims
|
||
} else {
|
||
return nil, status.Error(codes.Unauthenticated, "invalid or expired token")
|
||
}
|
||
}
|
||
|
||
// 将声明附加到上下文,供处理器使用
|
||
newCtx := context.WithValue(ctx, "claims", claims)
|
||
|
||
return handler(newCtx, req)
|
||
}
|
||
}
|
||
|
||
// JwtStreamInterceptor 验证流式 gRPC 请求中的 JWT 令牌
|
||
func JwtStreamInterceptor(svcCtx *svc.ServiceContext) grpc.StreamServerInterceptor {
|
||
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||
md, ok := metadata.FromIncomingContext(ss.Context())
|
||
if !ok {
|
||
return status.Error(codes.Unauthenticated, "missing metadata")
|
||
}
|
||
|
||
tokens := md.Get("authorization")
|
||
if len(tokens) == 0 {
|
||
return status.Error(codes.Unauthenticated, "missing authorization header")
|
||
}
|
||
|
||
token := tokens[0]
|
||
claims, err := svcCtx.JwtManager.Valid(ss.Context(), token)
|
||
if err != nil {
|
||
return status.Error(codes.Unauthenticated, "invalid token")
|
||
}
|
||
|
||
// 创建包装流以注入上下文
|
||
wrappedStream := &WrappedStream{
|
||
ServerStream: ss,
|
||
ctx: context.WithValue(ss.Context(), "claims", claims),
|
||
}
|
||
|
||
return handler(srv, wrappedStream)
|
||
}
|
||
}
|
||
|
||
// WrappedStream 包装 grpc.ServerStream 以注入新的上下文
|
||
type WrappedStream struct {
|
||
grpc.ServerStream
|
||
ctx context.Context
|
||
}
|
||
|
||
func (w *WrappedStream) Context() context.Context {
|
||
return w.ctx
|
||
}
|
||
```
|
||
|
||
### 在 Server 中注册拦截器
|
||
|
||
修改 [app/users/rpc/usercenter/usercenter.go](../../../app/users/rpc/usercenter/usercenter.go):
|
||
|
||
```go
|
||
package main
|
||
|
||
import (
|
||
"flag"
|
||
"fmt"
|
||
"log"
|
||
|
||
"yourmodule/app/users/rpc/internal/config"
|
||
"yourmodule/app/users/rpc/internal/interceptor"
|
||
"yourmodule/app/users/rpc/internal/server"
|
||
"yourmodule/app/users/rpc/internal/svc"
|
||
"yourmodule/app/users/rpc/pb"
|
||
|
||
"github.com/zeromicro/go-zero/core/conf"
|
||
"github.com/zeromicro/go-zero/core/logx"
|
||
"google.golang.org/grpc"
|
||
)
|
||
|
||
var configFile = flag.String("f", "etc/pb.yaml", "the config file")
|
||
|
||
func main() {
|
||
flag.Parse()
|
||
|
||
var c config.Config
|
||
conf.MustLoad(*configFile, &c)
|
||
ctx := svc.NewServiceContext(c)
|
||
|
||
logx.DisableStat()
|
||
|
||
s := grpc.NewServer(
|
||
grpc.UnaryInterceptor(interceptor.JwtUnaryInterceptor(ctx)),
|
||
grpc.StreamInterceptor(interceptor.JwtStreamInterceptor(ctx)),
|
||
)
|
||
|
||
pb.RegisterUsercenterServer(s, server.NewUsercenterServer(ctx))
|
||
|
||
logx.Infof("Starting gRPC server on %s:%d", c.Host, c.Port)
|
||
if err := s.Serve(net.Listen("tcp", "0.0.0.0:"+fmt.Sprintf("%d", c.Port))); err != nil {
|
||
logx.Error(err)
|
||
}
|
||
}
|
||
```
|
||
|
||
## 2. 登录 Handler 实现
|
||
|
||
实现 [app/users/api/internal/handler/user/loginHandler.go](../../../app/users/api/internal/handler/user/loginHandler.go):
|
||
|
||
```go
|
||
package user
|
||
|
||
import (
|
||
"context"
|
||
"log"
|
||
"net/http"
|
||
|
||
"yourmodule/app/users/api/internal/logic/user"
|
||
"yourmodule/app/users/api/internal/svc"
|
||
"yourmodule/app/users/api/internal/types"
|
||
)
|
||
|
||
// LoginHandler 处理用户登录
|
||
func LoginHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
||
return func(w http.ResponseWriter, r *http.Request) {
|
||
var req types.LoginRequest
|
||
|
||
// 解析请求体...
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "Invalid request", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
// 调用业务逻辑
|
||
resp, err := user.NewLoginLogic(r.Context(), svcCtx).Login(&req)
|
||
if err != nil {
|
||
log.Printf("Login failed: %v", err)
|
||
http.Error(w, "Login failed", http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
// 返回令牌
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(resp)
|
||
}
|
||
}
|
||
```
|
||
|
||
实现 [app/users/api/internal/logic/user/loginLogic.go](../../../app/users/api/internal/logic/user/loginLogic.go):
|
||
|
||
```go
|
||
package user
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
|
||
"yourmodule/app/users/api/internal/svc"
|
||
"yourmodule/app/users/api/internal/types"
|
||
)
|
||
|
||
type LoginLogic struct {
|
||
ctx context.Context
|
||
svcCtx *svc.ServiceContext
|
||
}
|
||
|
||
func NewLoginLogic(ctx context.Context, svcCtx *svc.ServiceContext) *LoginLogic {
|
||
return &LoginLogic{
|
||
ctx: ctx,
|
||
svcCtx: svcCtx,
|
||
}
|
||
}
|
||
|
||
func (l *LoginLogic) Login(req *types.LoginRequest) (*types.LoginResponse, error) {
|
||
// 1. 验证用户凭证(密码等)
|
||
user, err := l.svcCtx.UserModel.FindByEmail(l.ctx, req.Email)
|
||
if err != nil {
|
||
return nil, errors.New("user not found")
|
||
}
|
||
|
||
// 2. 验证密码
|
||
if !user.VerifyPassword(req.Password) {
|
||
return nil, errors.New("invalid password")
|
||
}
|
||
|
||
// 3. 生成 JWT 令牌
|
||
token, err := l.svcCtx.JwtManager.New(
|
||
l.ctx,
|
||
user.ID,
|
||
user.Email,
|
||
user.Name,
|
||
)
|
||
if err != nil {
|
||
return nil, errors.New("failed to generate token")
|
||
}
|
||
|
||
// 4. 返回令牌
|
||
return &types.LoginResponse{
|
||
Token: token,
|
||
User: types.User{
|
||
ID: user.ID,
|
||
Email: user.Email,
|
||
Name: user.Name,
|
||
},
|
||
}, nil
|
||
}
|
||
```
|
||
|
||
## 3. 在 Handlers 中使用声明
|
||
|
||
在 Protected Handlers 中提取并使用声明:
|
||
|
||
```go
|
||
package user
|
||
|
||
import (
|
||
"context"
|
||
"log"
|
||
"net/http"
|
||
|
||
"yourmodule/app/users/api/internal/svc"
|
||
"yourmodule/app/users/api/internal/types"
|
||
"github.com/golang-jwt/jwt/v4"
|
||
)
|
||
|
||
// GetUserInfoHandler 获取当前用户信息
|
||
func GetUserInfoHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
||
return func(w http.ResponseWriter, r *http.Request) {
|
||
// 从上下文提取声明(由拦截器设置)
|
||
claims, ok := r.Context().Value("claims").(*jwt.RegisteredClaims)
|
||
if !ok {
|
||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
// 使用声明中的用户信息
|
||
userID := claims.Subject // 用户 ID 存储在 Subject 中
|
||
log.Printf("User %s requested their info", userID)
|
||
|
||
// 查询用户信息
|
||
user, err := svcCtx.UserModel.FindByID(r.Context(), userID)
|
||
if err != nil {
|
||
http.Error(w, "User not found", http.StatusNotFound)
|
||
return
|
||
}
|
||
|
||
// 返回用户信息
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(user)
|
||
}
|
||
}
|
||
```
|
||
|
||
## 4. 令牌刷新端点
|
||
|
||
实现令牌刷新端点:
|
||
|
||
```go
|
||
package user
|
||
|
||
import (
|
||
"net/http"
|
||
|
||
"yourmodule/app/users/api/internal/svc"
|
||
"yourmodule/app/users/api/internal/types"
|
||
)
|
||
|
||
// RefreshTokenHandler 刷新过期的 JWT 令牌
|
||
func RefreshTokenHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
||
return func(w http.ResponseWriter, r *http.Request) {
|
||
var req types.RefreshTokenRequest
|
||
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "Invalid request", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
// 提取旧令牌
|
||
oldToken := req.Token
|
||
|
||
// 尝试刷新令牌
|
||
newToken, err := svcCtx.JwtManager.Renew(r.Context(), oldToken)
|
||
if err != nil {
|
||
http.Error(w, "Token refresh failed", http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
// 返回新令牌
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(types.RefreshTokenResponse{
|
||
Token: newToken,
|
||
})
|
||
}
|
||
}
|
||
```
|
||
|
||
## 5. 登出处理
|
||
|
||
实现登出端点以撤销令牌:
|
||
|
||
```go
|
||
package user
|
||
|
||
import (
|
||
"net/http"
|
||
|
||
"yourmodule/app/users/api/internal/svc"
|
||
)
|
||
|
||
// LogoutHandler 登出用户(撤销令牌)
|
||
func LogoutHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
||
return func(w http.ResponseWriter, r *http.Request) {
|
||
// 从上下文提取声明
|
||
claims, ok := r.Context().Value("claims").(*jwt.RegisteredClaims)
|
||
if !ok {
|
||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
userID := claims.Subject
|
||
|
||
// 获取用户当前令牌
|
||
currentToken := r.Header.Get("Authorization")
|
||
|
||
// 撤销令牌
|
||
err := svcCtx.JwtManager.Revoke(r.Context(), userID, currentToken)
|
||
if err != nil {
|
||
http.Error(w, "Logout failed", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(map[string]string{"message": "logged out successfully"})
|
||
}
|
||
}
|
||
```
|
||
|
||
## 6. 特定端点的 JWT 验证
|
||
|
||
对于 REST API,在需要的 handlers 中手动验证令牌:
|
||
|
||
### 在 Routes 中配置
|
||
|
||
修改 [app/users/api/internal/handler/routes.go](../../../app/users/api/internal/handler/routes.go):
|
||
|
||
```go
|
||
package handler
|
||
|
||
import (
|
||
"net/http"
|
||
|
||
"yourmodule/app/users/api/internal/middleware"
|
||
"yourmodule/app/users/api/internal/svc"
|
||
"yourmodule/app/users/api/internal/handler/user"
|
||
)
|
||
|
||
// RegisterRoutes 注册所有路由
|
||
func RegisterRoutes(router *http.ServeMux, svcCtx *svc.ServiceContext) {
|
||
// 公开路由
|
||
router.HandleFunc("POST /api/v1/auth/login", user.LoginHandler(svcCtx))
|
||
router.HandleFunc("POST /api/v1/auth/refresh", user.RefreshTokenHandler(svcCtx))
|
||
|
||
// 受保护的路由(需要 JWT 验证)
|
||
protected := middleware.JwtMiddleware(svcCtx)
|
||
router.HandleFunc("GET /api/v1/users/me", protected(user.GetUserInfoHandler(svcCtx)))
|
||
router.HandleFunc("POST /api/v1/users/logout", protected(user.LogoutHandler(svcCtx)))
|
||
router.HandleFunc("PUT /api/v1/users/me", protected(user.UpdateUserInfoHandler(svcCtx)))
|
||
}
|
||
```
|
||
|
||
### 创建 JWT 中间件
|
||
|
||
创建 [app/users/api/internal/middleware/jwt.go](../../../app/users/api/internal/middleware/jwt.go):
|
||
|
||
```go
|
||
package middleware
|
||
|
||
import (
|
||
"context"
|
||
"net/http"
|
||
"strings"
|
||
|
||
"yourmodule/app/users/api/internal/svc"
|
||
)
|
||
|
||
// JwtMiddleware 为 HTTP 处理器添加 JWT 验证
|
||
func JwtMiddleware(svcCtx *svc.ServiceContext) func(http.Handler) http.Handler {
|
||
return func(next http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
// 从 Authorization 头提取令牌
|
||
authHeader := r.Header.Get("Authorization")
|
||
if authHeader == "" {
|
||
http.Error(w, "Missing authorization header", http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
// 期望格式: "Bearer <token>"
|
||
parts := strings.SplitN(authHeader, " ", 2)
|
||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||
http.Error(w, "Invalid authorization header", http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
token := parts[1]
|
||
|
||
// 验证令牌
|
||
claims, err := svcCtx.JwtManager.Valid(r.Context(), token)
|
||
if err != nil {
|
||
// 尝试刷新
|
||
newToken, refreshErr := svcCtx.JwtManager.Renew(r.Context(), token)
|
||
if refreshErr != nil {
|
||
http.Error(w, "Invalid or expired token", http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
// 在响应头返回新令牌
|
||
w.Header().Set("X-New-Token", newToken)
|
||
|
||
// 重新验证新令牌
|
||
claims, err = svcCtx.JwtManager.Valid(r.Context(), newToken)
|
||
if err != nil {
|
||
http.Error(w, "Token refresh failed", http.StatusUnauthorized)
|
||
return
|
||
}
|
||
}
|
||
|
||
// 将声明附加到上下文
|
||
newCtx := context.WithValue(r.Context(), "claims", claims)
|
||
next.ServeHTTP(w, r.WithContext(newCtx))
|
||
})
|
||
}
|
||
}
|
||
```
|
||
|
||
## 7. 错误处理最佳实践
|
||
|
||
```go
|
||
package logic
|
||
|
||
import (
|
||
"errors"
|
||
"log"
|
||
|
||
"yourmodule/app/users/rpc/internal/utils"
|
||
)
|
||
|
||
// HandleJwtError 处理 JWT 相关错误
|
||
func HandleJwtError(err error) error {
|
||
if errors.Is(err, utils.ErrTokenExpired) {
|
||
log.Println("Token has expired, user needs to refresh")
|
||
return errors.New("token expired - use refresh endpoint")
|
||
}
|
||
|
||
if errors.Is(err, utils.ErrTokenInvalid) {
|
||
log.Println("Token is invalid or malformed")
|
||
return errors.New("invalid token")
|
||
}
|
||
|
||
if errors.Is(err, utils.ErrTokenNotFound) {
|
||
log.Println("Token not found in Redis (revoked or expired)")
|
||
return errors.New("token revoked or expired")
|
||
}
|
||
|
||
return err
|
||
}
|
||
```
|
||
|
||
## 8. 测试 JWT 集成
|
||
|
||
### 单元测试示例
|
||
|
||
```go
|
||
package interceptor
|
||
|
||
import (
|
||
"context"
|
||
"testing"
|
||
|
||
"google.golang.org/grpc/codes"
|
||
"google.golang.org/grpc/metadata"
|
||
"google.golang.org/grpc/status"
|
||
)
|
||
|
||
func TestJwtUnaryInterceptor_ValidToken(t *testing.T) {
|
||
// 1. 创建有效的令牌
|
||
token, err := svcCtx.JwtManager.New(context.Background(), "user123", "user@example.com", "John")
|
||
if err != nil {
|
||
t.Fatalf("Failed to create token: %v", err)
|
||
}
|
||
|
||
// 2. 创建包含令牌的上下文
|
||
md := metadata.Pairs("authorization", token)
|
||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||
|
||
// 3. 调用拦截器
|
||
_, err = JwtUnaryInterceptor(svcCtx)(ctx, nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||
return "success", nil
|
||
})
|
||
|
||
if err != nil {
|
||
t.Errorf("Unexpected error: %v", err)
|
||
}
|
||
}
|
||
|
||
func TestJwtUnaryInterceptor_ExpiredToken(t *testing.T) {
|
||
// 1. 创建过期的令牌或使用无效令牌
|
||
token := "invalid.token.here"
|
||
|
||
// 2. 创建包含令牌的上下文
|
||
md := metadata.Pairs("authorization", token)
|
||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||
|
||
// 3. 调用拦截器
|
||
_, err := JwtUnaryInterceptor(svcCtx)(ctx, nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||
return "success", nil
|
||
})
|
||
|
||
// 4. 验证错误
|
||
st, ok := status.FromError(err)
|
||
if !ok || st.Code() != codes.Unauthenticated {
|
||
t.Errorf("Expected Unauthenticated error, got: %v", err)
|
||
}
|
||
}
|
||
```
|
||
|
||
## 9. 生产部署清单
|
||
|
||
在将 JWT 集成部署到生产环境前:
|
||
|
||
- [ ] 所有令牌端点都进行了压力测试
|
||
- [ ] 令牌刷新逻辑已验证
|
||
- [ ] 错误处理覆盖了所有 JWT 失败情况
|
||
- [ ] 审计日志记录了所有认证尝试
|
||
- [ ] 密钥轮换计划已确定
|
||
- [ ] 监控和告警已配置
|
||
- [ ] 灾难恢复流程已文档化
|
||
- [ ] 所有依赖于 JWT 的服务都已更新
|
||
|
||
## 相关文件
|
||
|
||
- [app/users/rpc/internal/utils/jwt.go](../../../app/users/rpc/internal/utils/jwt.go) - JWT Manager 实现
|
||
- [app/users/rpc/internal/config/config.go](../../../app/users/rpc/internal/config/config.go) - JWT 配置
|
||
- [app/users/rpc/internal/svc/serviceContext.go](../../../app/users/rpc/internal/svc/serviceContext.go) - 依赖注入
|
||
- [deploy/k8s/secrets/jwt-secret.yaml](./jwt-secret.yaml) - Secret 和 RBAC
|
||
- [deploy/k8s/secrets/DEPLOYMENT.md](./DEPLOYMENT.md) - 部署指南
|