Files
juwan-backend/docs/INTEGRATION.md
2026-03-31 22:12:06 +08:00

15 KiB
Raw Permalink Blame History

JWT 集成指南

指导如何将 JWT Manager 集成到 RPC Handlers 和业务逻辑中。

1. gRPC Unary Interceptor 实现

在 RPC 服务中添加 JWT 验证拦截器。

创建拦截器

创建文件 app/users/rpc/internal/interceptor/jwt_interceptor.go

package interceptor

import (
	"context"
	"log"
	
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/status"
	
	"yourmodule/app/users/rpc/internal/svc"
)

// JwtUnaryInterceptor 验证 gRPC 请求中的 JWT 令牌
func JwtUnaryInterceptor(svcCtx *svc.ServiceContext) grpc.UnaryServerInterceptor {
	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
		// 获取请求元数据
		md, ok := metadata.FromIncomingContext(ctx)
		if !ok {
			return nil, status.Error(codes.Unauthenticated, "missing metadata")
		}

		// 从 Authorization 头提取令牌
		tokens := md.Get("authorization")
		if len(tokens) == 0 {
			return nil, status.Error(codes.Unauthenticated, "missing authorization header")
		}

		token := tokens[0]
		
		// 验证令牌
		claims, err := svcCtx.JwtManager.Valid(ctx, token)
		if err != nil {
			log.Printf("Token validation failed: %v", err)
			
			// 尝试刷新令牌(如果过期但仍在 Redis 中)
			newToken, refreshErr := svcCtx.JwtManager.Renew(ctx, token)
			if refreshErr == nil && newToken != "" {
				// 在响应头中返回新令牌
				grpc.SetHeader(ctx, metadata.Pairs("authorization", newToken))
				// 继续处理请求,使用原令牌的声明
				// 注意:实际应用中需要重新验证新令牌
				newClaims, err := svcCtx.JwtManager.Valid(ctx, newToken)
				if err != nil {
					return nil, status.Error(codes.Unauthenticated, "token refresh failed")
				}
				claims = newClaims
			} else {
				return nil, status.Error(codes.Unauthenticated, "invalid or expired token")
			}
		}

		// 将声明附加到上下文,供处理器使用
		newCtx := context.WithValue(ctx, "claims", claims)
		
		return handler(newCtx, req)
	}
}

// JwtStreamInterceptor 验证流式 gRPC 请求中的 JWT 令牌
func JwtStreamInterceptor(svcCtx *svc.ServiceContext) grpc.StreamServerInterceptor {
	return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
		md, ok := metadata.FromIncomingContext(ss.Context())
		if !ok {
			return status.Error(codes.Unauthenticated, "missing metadata")
		}

		tokens := md.Get("authorization")
		if len(tokens) == 0 {
			return status.Error(codes.Unauthenticated, "missing authorization header")
		}

		token := tokens[0]
		claims, err := svcCtx.JwtManager.Valid(ss.Context(), token)
		if err != nil {
			return status.Error(codes.Unauthenticated, "invalid token")
		}

		// 创建包装流以注入上下文
		wrappedStream := &WrappedStream{
			ServerStream: ss,
			ctx:          context.WithValue(ss.Context(), "claims", claims),
		}

		return handler(srv, wrappedStream)
	}
}

// WrappedStream 包装 grpc.ServerStream 以注入新的上下文
type WrappedStream struct {
	grpc.ServerStream
	ctx context.Context
}

func (w *WrappedStream) Context() context.Context {
	return w.ctx
}

在 Server 中注册拦截器

修改 app/users/rpc/usercenter/usercenter.go

package main

import (
	"flag"
	"fmt"
	"log"
	
	"yourmodule/app/users/rpc/internal/config"
	"yourmodule/app/users/rpc/internal/interceptor"
	"yourmodule/app/users/rpc/internal/server"
	"yourmodule/app/users/rpc/internal/svc"
	"yourmodule/app/users/rpc/pb"
	
	"github.com/zeromicro/go-zero/core/conf"
	"github.com/zeromicro/go-zero/core/logx"
	"google.golang.org/grpc"
)

var configFile = flag.String("f", "etc/pb.yaml", "the config file")

func main() {
	flag.Parse()

	var c config.Config
	conf.MustLoad(*configFile, &c)
	ctx := svc.NewServiceContext(c)

	logx.DisableStat()

	s := grpc.NewServer(
		grpc.UnaryInterceptor(interceptor.JwtUnaryInterceptor(ctx)),
		grpc.StreamInterceptor(interceptor.JwtStreamInterceptor(ctx)),
	)
	
	pb.RegisterUsercenterServer(s, server.NewUsercenterServer(ctx))

	logx.Infof("Starting gRPC server on %s:%d", c.Host, c.Port)
	if err := s.Serve(net.Listen("tcp", "0.0.0.0:"+fmt.Sprintf("%d", c.Port))); err != nil {
		logx.Error(err)
	}
}

2. 登录 Handler 实现

实现 app/users/api/internal/handler/user/loginHandler.go

package user

import (
	"context"
	"log"
	"net/http"
	
	"yourmodule/app/users/api/internal/logic/user"
	"yourmodule/app/users/api/internal/svc"
	"yourmodule/app/users/api/internal/types"
)

// LoginHandler 处理用户登录
func LoginHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		var req types.LoginRequest
		
		// 解析请求体...
		if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
			http.Error(w, "Invalid request", http.StatusBadRequest)
			return
		}

		// 调用业务逻辑
		resp, err := user.NewLoginLogic(r.Context(), svcCtx).Login(&req)
		if err != nil {
			log.Printf("Login failed: %v", err)
			http.Error(w, "Login failed", http.StatusUnauthorized)
			return
		}

		// 返回令牌
		w.Header().Set("Content-Type", "application/json")
		json.NewEncoder(w).Encode(resp)
	}
}

实现 app/users/api/internal/logic/user/loginLogic.go

package user

import (
	"context"
	"errors"
	
	"yourmodule/app/users/api/internal/svc"
	"yourmodule/app/users/api/internal/types"
)

type LoginLogic struct {
	ctx    context.Context
	svcCtx *svc.ServiceContext
}

func NewLoginLogic(ctx context.Context, svcCtx *svc.ServiceContext) *LoginLogic {
	return &LoginLogic{
		ctx:    ctx,
		svcCtx: svcCtx,
	}
}

func (l *LoginLogic) Login(req *types.LoginRequest) (*types.LoginResponse, error) {
	// 1. 验证用户凭证(密码等)
	user, err := l.svcCtx.UserModel.FindByEmail(l.ctx, req.Email)
	if err != nil {
		return nil, errors.New("user not found")
	}

	// 2. 验证密码
	if !user.VerifyPassword(req.Password) {
		return nil, errors.New("invalid password")
	}

	// 3. 生成 JWT 令牌
	token, err := l.svcCtx.JwtManager.New(
		l.ctx,
		user.ID,
		user.Email,
		user.Name,
	)
	if err != nil {
		return nil, errors.New("failed to generate token")
	}

	// 4. 返回令牌
	return &types.LoginResponse{
		Token: token,
		User: types.User{
			ID:    user.ID,
			Email: user.Email,
			Name:  user.Name,
		},
	}, nil
}

3. 在 Handlers 中使用声明

在 Protected Handlers 中提取并使用声明:

package user

import (
	"context"
	"log"
	"net/http"
	
	"yourmodule/app/users/api/internal/svc"
	"yourmodule/app/users/api/internal/types"
	"github.com/golang-jwt/jwt/v4"
)

// GetUserInfoHandler 获取当前用户信息
func GetUserInfoHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		// 从上下文提取声明(由拦截器设置)
		claims, ok := r.Context().Value("claims").(*jwt.RegisteredClaims)
		if !ok {
			http.Error(w, "Unauthorized", http.StatusUnauthorized)
			return
		}

		// 使用声明中的用户信息
		userID := claims.Subject // 用户 ID 存储在 Subject 中
		log.Printf("User %s requested their info", userID)

		// 查询用户信息
		user, err := svcCtx.UserModel.FindByID(r.Context(), userID)
		if err != nil {
			http.Error(w, "User not found", http.StatusNotFound)
			return
		}

		// 返回用户信息
		w.Header().Set("Content-Type", "application/json")
		json.NewEncoder(w).Encode(user)
	}
}

4. 令牌刷新端点

实现令牌刷新端点:

package user

import (
	"net/http"
	
	"yourmodule/app/users/api/internal/svc"
	"yourmodule/app/users/api/internal/types"
)

// RefreshTokenHandler 刷新过期的 JWT 令牌
func RefreshTokenHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		var req types.RefreshTokenRequest
		
		if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
			http.Error(w, "Invalid request", http.StatusBadRequest)
			return
		}

		// 提取旧令牌
		oldToken := req.Token

		// 尝试刷新令牌
		newToken, err := svcCtx.JwtManager.Renew(r.Context(), oldToken)
		if err != nil {
			http.Error(w, "Token refresh failed", http.StatusUnauthorized)
			return
		}

		// 返回新令牌
		w.Header().Set("Content-Type", "application/json")
		json.NewEncoder(w).Encode(types.RefreshTokenResponse{
			Token: newToken,
		})
	}
}

5. 登出处理

实现登出端点以撤销令牌:

package user

import (
	"net/http"
	
	"yourmodule/app/users/api/internal/svc"
)

// LogoutHandler 登出用户(撤销令牌)
func LogoutHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		// 从上下文提取声明
		claims, ok := r.Context().Value("claims").(*jwt.RegisteredClaims)
		if !ok {
			http.Error(w, "Unauthorized", http.StatusUnauthorized)
			return
		}

		userID := claims.Subject

		// 获取用户当前令牌
		currentToken := r.Header.Get("Authorization")
		
		// 撤销令牌
		err := svcCtx.JwtManager.Revoke(r.Context(), userID, currentToken)
		if err != nil {
			http.Error(w, "Logout failed", http.StatusInternalServerError)
			return
		}

		w.Header().Set("Content-Type", "application/json")
		json.NewEncoder(w).Encode(map[string]string{"message": "logged out successfully"})
	}
}

6. 特定端点的 JWT 验证

对于 REST API,在需要的 handlers 中手动验证令牌:

在 Routes 中配置

修改 app/users/api/internal/handler/routes.go

package handler

import (
	"net/http"
	
	"yourmodule/app/users/api/internal/middleware"
	"yourmodule/app/users/api/internal/svc"
	"yourmodule/app/users/api/internal/handler/user"
)

// RegisterRoutes 注册所有路由
func RegisterRoutes(router *http.ServeMux, svcCtx *svc.ServiceContext) {
	// 公开路由
	router.HandleFunc("POST /api/v1/auth/login", user.LoginHandler(svcCtx))
	router.HandleFunc("POST /api/v1/auth/refresh", user.RefreshTokenHandler(svcCtx))
	
	// 受保护的路由(需要 JWT 验证)
	protected := middleware.JwtMiddleware(svcCtx)
	router.HandleFunc("GET /api/v1/users/me", protected(user.GetUserInfoHandler(svcCtx)))
	router.HandleFunc("POST /api/v1/users/logout", protected(user.LogoutHandler(svcCtx)))
	router.HandleFunc("PUT /api/v1/users/me", protected(user.UpdateUserInfoHandler(svcCtx)))
}

创建 JWT 中间件

创建 app/users/api/internal/middleware/jwt.go

package middleware

import (
	"context"
	"net/http"
	"strings"
	
	"yourmodule/app/users/api/internal/svc"
)

// JwtMiddleware 为 HTTP 处理器添加 JWT 验证
func JwtMiddleware(svcCtx *svc.ServiceContext) func(http.Handler) http.Handler {
	return func(next http.Handler) http.Handler {
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			// 从 Authorization 头提取令牌
			authHeader := r.Header.Get("Authorization")
			if authHeader == "" {
				http.Error(w, "Missing authorization header", http.StatusUnauthorized)
				return
			}

			// 期望格式: "Bearer <token>"
			parts := strings.SplitN(authHeader, " ", 2)
			if len(parts) != 2 || parts[0] != "Bearer" {
				http.Error(w, "Invalid authorization header", http.StatusUnauthorized)
				return
			}

			token := parts[1]

			// 验证令牌
			claims, err := svcCtx.JwtManager.Valid(r.Context(), token)
			if err != nil {
				// 尝试刷新
				newToken, refreshErr := svcCtx.JwtManager.Renew(r.Context(), token)
				if refreshErr != nil {
					http.Error(w, "Invalid or expired token", http.StatusUnauthorized)
					return
				}
				
				// 在响应头返回新令牌
				w.Header().Set("X-New-Token", newToken)
				
				// 重新验证新令牌
				claims, err = svcCtx.JwtManager.Valid(r.Context(), newToken)
				if err != nil {
					http.Error(w, "Token refresh failed", http.StatusUnauthorized)
					return
				}
			}

			// 将声明附加到上下文
			newCtx := context.WithValue(r.Context(), "claims", claims)
			next.ServeHTTP(w, r.WithContext(newCtx))
		})
	}
}

7. 错误处理最佳实践

package logic

import (
	"errors"
	"log"
	
	"yourmodule/app/users/rpc/internal/utils"
)

// HandleJwtError 处理 JWT 相关错误
func HandleJwtError(err error) error {
	if errors.Is(err, utils.ErrTokenExpired) {
		log.Println("Token has expired, user needs to refresh")
		return errors.New("token expired - use refresh endpoint")
	}
	
	if errors.Is(err, utils.ErrTokenInvalid) {
		log.Println("Token is invalid or malformed")
		return errors.New("invalid token")
	}
	
	if errors.Is(err, utils.ErrTokenNotFound) {
		log.Println("Token not found in Redis (revoked or expired)")
		return errors.New("token revoked or expired")
	}
	
	return err
}

8. 测试 JWT 集成

单元测试示例

package interceptor

import (
	"context"
	"testing"
	
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/status"
)

func TestJwtUnaryInterceptor_ValidToken(t *testing.T) {
	// 1. 创建有效的令牌
	token, err := svcCtx.JwtManager.New(context.Background(), "user123", "user@example.com", "John")
	if err != nil {
		t.Fatalf("Failed to create token: %v", err)
	}

	// 2. 创建包含令牌的上下文
	md := metadata.Pairs("authorization", token)
	ctx := metadata.NewIncomingContext(context.Background(), md)

	// 3. 调用拦截器
	_, err = JwtUnaryInterceptor(svcCtx)(ctx, nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) {
		return "success", nil
	})

	if err != nil {
		t.Errorf("Unexpected error: %v", err)
	}
}

func TestJwtUnaryInterceptor_ExpiredToken(t *testing.T) {
	// 1. 创建过期的令牌或使用无效令牌
	token := "invalid.token.here"

	// 2. 创建包含令牌的上下文
	md := metadata.Pairs("authorization", token)
	ctx := metadata.NewIncomingContext(context.Background(), md)

	// 3. 调用拦截器
	_, err := JwtUnaryInterceptor(svcCtx)(ctx, nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) {
		return "success", nil
	})

	// 4. 验证错误
	st, ok := status.FromError(err)
	if !ok || st.Code() != codes.Unauthenticated {
		t.Errorf("Expected Unauthenticated error, got: %v", err)
	}
}

9. 生产部署清单

在将 JWT 集成部署到生产环境前:

  • 所有令牌端点都进行了压力测试
  • 令牌刷新逻辑已验证
  • 错误处理覆盖了所有 JWT 失败情况
  • 审计日志记录了所有认证尝试
  • 密钥轮换计划已确定
  • 监控和告警已配置
  • 灾难恢复流程已文档化
  • 所有依赖于 JWT 的服务都已更新

相关文件