fix: some api bug
This commit is contained in:
@@ -0,0 +1,601 @@
|
||||
# JWT 集成指南
|
||||
|
||||
指导如何将 JWT Manager 集成到 RPC Handlers 和业务逻辑中。
|
||||
|
||||
## 1. gRPC Unary Interceptor 实现
|
||||
|
||||
在 RPC 服务中添加 JWT 验证拦截器。
|
||||
|
||||
### 创建拦截器
|
||||
|
||||
创建文件 [app/users/rpc/internal/interceptor/jwt_interceptor.go](../../../app/users/rpc/internal/interceptor/jwt_interceptor.go):
|
||||
|
||||
```go
|
||||
package interceptor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"yourmodule/app/users/rpc/internal/svc"
|
||||
)
|
||||
|
||||
// JwtUnaryInterceptor 验证 gRPC 请求中的 JWT 令牌
|
||||
func JwtUnaryInterceptor(svcCtx *svc.ServiceContext) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
// 获取请求元数据
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return nil, status.Error(codes.Unauthenticated, "missing metadata")
|
||||
}
|
||||
|
||||
// 从 Authorization 头提取令牌
|
||||
tokens := md.Get("authorization")
|
||||
if len(tokens) == 0 {
|
||||
return nil, status.Error(codes.Unauthenticated, "missing authorization header")
|
||||
}
|
||||
|
||||
token := tokens[0]
|
||||
|
||||
// 验证令牌
|
||||
claims, err := svcCtx.JwtManager.Valid(ctx, token)
|
||||
if err != nil {
|
||||
log.Printf("Token validation failed: %v", err)
|
||||
|
||||
// 尝试刷新令牌(如果过期但仍在 Redis 中)
|
||||
newToken, refreshErr := svcCtx.JwtManager.Renew(ctx, token)
|
||||
if refreshErr == nil && newToken != "" {
|
||||
// 在响应头中返回新令牌
|
||||
grpc.SetHeader(ctx, metadata.Pairs("authorization", newToken))
|
||||
// 继续处理请求,使用原令牌的声明
|
||||
// 注意:实际应用中需要重新验证新令牌
|
||||
newClaims, err := svcCtx.JwtManager.Valid(ctx, newToken)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Unauthenticated, "token refresh failed")
|
||||
}
|
||||
claims = newClaims
|
||||
} else {
|
||||
return nil, status.Error(codes.Unauthenticated, "invalid or expired token")
|
||||
}
|
||||
}
|
||||
|
||||
// 将声明附加到上下文,供处理器使用
|
||||
newCtx := context.WithValue(ctx, "claims", claims)
|
||||
|
||||
return handler(newCtx, req)
|
||||
}
|
||||
}
|
||||
|
||||
// JwtStreamInterceptor 验证流式 gRPC 请求中的 JWT 令牌
|
||||
func JwtStreamInterceptor(svcCtx *svc.ServiceContext) grpc.StreamServerInterceptor {
|
||||
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
md, ok := metadata.FromIncomingContext(ss.Context())
|
||||
if !ok {
|
||||
return status.Error(codes.Unauthenticated, "missing metadata")
|
||||
}
|
||||
|
||||
tokens := md.Get("authorization")
|
||||
if len(tokens) == 0 {
|
||||
return status.Error(codes.Unauthenticated, "missing authorization header")
|
||||
}
|
||||
|
||||
token := tokens[0]
|
||||
claims, err := svcCtx.JwtManager.Valid(ss.Context(), token)
|
||||
if err != nil {
|
||||
return status.Error(codes.Unauthenticated, "invalid token")
|
||||
}
|
||||
|
||||
// 创建包装流以注入上下文
|
||||
wrappedStream := &WrappedStream{
|
||||
ServerStream: ss,
|
||||
ctx: context.WithValue(ss.Context(), "claims", claims),
|
||||
}
|
||||
|
||||
return handler(srv, wrappedStream)
|
||||
}
|
||||
}
|
||||
|
||||
// WrappedStream 包装 grpc.ServerStream 以注入新的上下文
|
||||
type WrappedStream struct {
|
||||
grpc.ServerStream
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (w *WrappedStream) Context() context.Context {
|
||||
return w.ctx
|
||||
}
|
||||
```
|
||||
|
||||
### 在 Server 中注册拦截器
|
||||
|
||||
修改 [app/users/rpc/usercenter/usercenter.go](../app/users/rpc/usercenter/usercenter.go):
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"yourmodule/app/users/rpc/internal/config"
|
||||
"yourmodule/app/users/rpc/internal/interceptor"
|
||||
"yourmodule/app/users/rpc/internal/server"
|
||||
"yourmodule/app/users/rpc/internal/svc"
|
||||
"yourmodule/app/users/rpc/pb"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/conf"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
var configFile = flag.String("f", "etc/pb.yaml", "the config file")
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
var c config.Config
|
||||
conf.MustLoad(*configFile, &c)
|
||||
ctx := svc.NewServiceContext(c)
|
||||
|
||||
logx.DisableStat()
|
||||
|
||||
s := grpc.NewServer(
|
||||
grpc.UnaryInterceptor(interceptor.JwtUnaryInterceptor(ctx)),
|
||||
grpc.StreamInterceptor(interceptor.JwtStreamInterceptor(ctx)),
|
||||
)
|
||||
|
||||
pb.RegisterUsercenterServer(s, server.NewUsercenterServer(ctx))
|
||||
|
||||
logx.Infof("Starting gRPC server on %s:%d", c.Host, c.Port)
|
||||
if err := s.Serve(net.Listen("tcp", "0.0.0.0:"+fmt.Sprintf("%d", c.Port))); err != nil {
|
||||
logx.Error(err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 2. 登录 Handler 实现
|
||||
|
||||
实现 [app/users/api/internal/handler/user/loginHandler.go](../../../app/users/api/internal/handler/user/loginHandler.go):
|
||||
|
||||
```go
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"yourmodule/app/users/api/internal/logic/user"
|
||||
"yourmodule/app/users/api/internal/svc"
|
||||
"yourmodule/app/users/api/internal/types"
|
||||
)
|
||||
|
||||
// LoginHandler 处理用户登录
|
||||
func LoginHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var req types.LoginRequest
|
||||
|
||||
// 解析请求体...
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 调用业务逻辑
|
||||
resp, err := user.NewLoginLogic(r.Context(), svcCtx).Login(&req)
|
||||
if err != nil {
|
||||
log.Printf("Login failed: %v", err)
|
||||
http.Error(w, "Login failed", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// 返回令牌
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
实现 [app/users/api/internal/logic/user/loginLogic.go](../../../app/users/api/internal/logic/user/loginLogic.go):
|
||||
|
||||
```go
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"yourmodule/app/users/api/internal/svc"
|
||||
"yourmodule/app/users/api/internal/types"
|
||||
)
|
||||
|
||||
type LoginLogic struct {
|
||||
ctx context.Context
|
||||
svcCtx *svc.ServiceContext
|
||||
}
|
||||
|
||||
func NewLoginLogic(ctx context.Context, svcCtx *svc.ServiceContext) *LoginLogic {
|
||||
return &LoginLogic{
|
||||
ctx: ctx,
|
||||
svcCtx: svcCtx,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *LoginLogic) Login(req *types.LoginRequest) (*types.LoginResponse, error) {
|
||||
// 1. 验证用户凭证(密码等)
|
||||
user, err := l.svcCtx.UserModel.FindByEmail(l.ctx, req.Email)
|
||||
if err != nil {
|
||||
return nil, errors.New("user not found")
|
||||
}
|
||||
|
||||
// 2. 验证密码
|
||||
if !user.VerifyPassword(req.Password) {
|
||||
return nil, errors.New("invalid password")
|
||||
}
|
||||
|
||||
// 3. 生成 JWT 令牌
|
||||
token, err := l.svcCtx.JwtManager.New(
|
||||
l.ctx,
|
||||
user.ID,
|
||||
user.Email,
|
||||
user.Name,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to generate token")
|
||||
}
|
||||
|
||||
// 4. 返回令牌
|
||||
return &types.LoginResponse{
|
||||
Token: token,
|
||||
User: types.User{
|
||||
ID: user.ID,
|
||||
Email: user.Email,
|
||||
Name: user.Name,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
```
|
||||
|
||||
## 3. 在 Handlers 中使用声明
|
||||
|
||||
在 Protected Handlers 中提取并使用声明:
|
||||
|
||||
```go
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"yourmodule/app/users/api/internal/svc"
|
||||
"yourmodule/app/users/api/internal/types"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
// GetUserInfoHandler 获取当前用户信息
|
||||
func GetUserInfoHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// 从上下文提取声明(由拦截器设置)
|
||||
claims, ok := r.Context().Value("claims").(*jwt.RegisteredClaims)
|
||||
if !ok {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// 使用声明中的用户信息
|
||||
userID := claims.Subject // 用户 ID 存储在 Subject 中
|
||||
log.Printf("User %s requested their info", userID)
|
||||
|
||||
// 查询用户信息
|
||||
user, err := svcCtx.UserModel.FindByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
http.Error(w, "User not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// 返回用户信息
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(user)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 4. 令牌刷新端点
|
||||
|
||||
实现令牌刷新端点:
|
||||
|
||||
```go
|
||||
package user
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"yourmodule/app/users/api/internal/svc"
|
||||
"yourmodule/app/users/api/internal/types"
|
||||
)
|
||||
|
||||
// RefreshTokenHandler 刷新过期的 JWT 令牌
|
||||
func RefreshTokenHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var req types.RefreshTokenRequest
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 提取旧令牌
|
||||
oldToken := req.Token
|
||||
|
||||
// 尝试刷新令牌
|
||||
newToken, err := svcCtx.JwtManager.Renew(r.Context(), oldToken)
|
||||
if err != nil {
|
||||
http.Error(w, "Token refresh failed", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// 返回新令牌
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(types.RefreshTokenResponse{
|
||||
Token: newToken,
|
||||
})
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 5. 登出处理
|
||||
|
||||
实现登出端点以撤销令牌:
|
||||
|
||||
```go
|
||||
package user
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"yourmodule/app/users/api/internal/svc"
|
||||
)
|
||||
|
||||
// LogoutHandler 登出用户(撤销令牌)
|
||||
func LogoutHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// 从上下文提取声明
|
||||
claims, ok := r.Context().Value("claims").(*jwt.RegisteredClaims)
|
||||
if !ok {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
userID := claims.Subject
|
||||
|
||||
// 获取用户当前令牌
|
||||
currentToken := r.Header.Get("Authorization")
|
||||
|
||||
// 撤销令牌
|
||||
err := svcCtx.JwtManager.Revoke(r.Context(), userID, currentToken)
|
||||
if err != nil {
|
||||
http.Error(w, "Logout failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{"message": "logged out successfully"})
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 6. 特定端点的 JWT 验证
|
||||
|
||||
对于 REST API,在需要的 handlers 中手动验证令牌:
|
||||
|
||||
### 在 Routes 中配置
|
||||
|
||||
修改 [app/users/api/internal/handler/routes.go](../app/users/api/internal/handler/routes.go):
|
||||
|
||||
```go
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"yourmodule/app/users/api/internal/middleware"
|
||||
"yourmodule/app/users/api/internal/svc"
|
||||
"yourmodule/app/users/api/internal/handler/user"
|
||||
)
|
||||
|
||||
// RegisterRoutes 注册所有路由
|
||||
func RegisterRoutes(router *http.ServeMux, svcCtx *svc.ServiceContext) {
|
||||
// 公开路由
|
||||
router.HandleFunc("POST /api/v1/auth/login", user.LoginHandler(svcCtx))
|
||||
router.HandleFunc("POST /api/v1/auth/refresh", user.RefreshTokenHandler(svcCtx))
|
||||
|
||||
// 受保护的路由(需要 JWT 验证)
|
||||
protected := middleware.JwtMiddleware(svcCtx)
|
||||
router.HandleFunc("GET /api/v1/users/me", protected(user.GetUserInfoHandler(svcCtx)))
|
||||
router.HandleFunc("POST /api/v1/users/logout", protected(user.LogoutHandler(svcCtx)))
|
||||
router.HandleFunc("PUT /api/v1/users/me", protected(user.UpdateUserInfoHandler(svcCtx)))
|
||||
}
|
||||
```
|
||||
|
||||
### 创建 JWT 中间件
|
||||
|
||||
创建 [app/users/api/internal/middleware/jwt.go](../../../app/users/api/internal/middleware/jwt.go):
|
||||
|
||||
```go
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"yourmodule/app/users/api/internal/svc"
|
||||
)
|
||||
|
||||
// JwtMiddleware 为 HTTP 处理器添加 JWT 验证
|
||||
func JwtMiddleware(svcCtx *svc.ServiceContext) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 从 Authorization 头提取令牌
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
http.Error(w, "Missing authorization header", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// 期望格式: "Bearer <token>"
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
http.Error(w, "Invalid authorization header", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
token := parts[1]
|
||||
|
||||
// 验证令牌
|
||||
claims, err := svcCtx.JwtManager.Valid(r.Context(), token)
|
||||
if err != nil {
|
||||
// 尝试刷新
|
||||
newToken, refreshErr := svcCtx.JwtManager.Renew(r.Context(), token)
|
||||
if refreshErr != nil {
|
||||
http.Error(w, "Invalid or expired token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// 在响应头返回新令牌
|
||||
w.Header().Set("X-New-Token", newToken)
|
||||
|
||||
// 重新验证新令牌
|
||||
claims, err = svcCtx.JwtManager.Valid(r.Context(), newToken)
|
||||
if err != nil {
|
||||
http.Error(w, "Token refresh failed", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 将声明附加到上下文
|
||||
newCtx := context.WithValue(r.Context(), "claims", claims)
|
||||
next.ServeHTTP(w, r.WithContext(newCtx))
|
||||
})
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 7. 错误处理最佳实践
|
||||
|
||||
```go
|
||||
package logic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
|
||||
"yourmodule/app/users/rpc/internal/utils"
|
||||
)
|
||||
|
||||
// HandleJwtError 处理 JWT 相关错误
|
||||
func HandleJwtError(err error) error {
|
||||
if errors.Is(err, utils.ErrTokenExpired) {
|
||||
log.Println("Token has expired, user needs to refresh")
|
||||
return errors.New("token expired - use refresh endpoint")
|
||||
}
|
||||
|
||||
if errors.Is(err, utils.ErrTokenInvalid) {
|
||||
log.Println("Token is invalid or malformed")
|
||||
return errors.New("invalid token")
|
||||
}
|
||||
|
||||
if errors.Is(err, utils.ErrTokenNotFound) {
|
||||
log.Println("Token not found in Redis (revoked or expired)")
|
||||
return errors.New("token revoked or expired")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
```
|
||||
|
||||
## 8. 测试 JWT 集成
|
||||
|
||||
### 单元测试示例
|
||||
|
||||
```go
|
||||
package interceptor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestJwtUnaryInterceptor_ValidToken(t *testing.T) {
|
||||
// 1. 创建有效的令牌
|
||||
token, err := svcCtx.JwtManager.New(context.Background(), "user123", "user@example.com", "John")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create token: %v", err)
|
||||
}
|
||||
|
||||
// 2. 创建包含令牌的上下文
|
||||
md := metadata.Pairs("authorization", token)
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
|
||||
// 3. 调用拦截器
|
||||
_, err = JwtUnaryInterceptor(svcCtx)(ctx, nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return "success", nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJwtUnaryInterceptor_ExpiredToken(t *testing.T) {
|
||||
// 1. 创建过期的令牌或使用无效令牌
|
||||
token := "invalid.token.here"
|
||||
|
||||
// 2. 创建包含令牌的上下文
|
||||
md := metadata.Pairs("authorization", token)
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
|
||||
// 3. 调用拦截器
|
||||
_, err := JwtUnaryInterceptor(svcCtx)(ctx, nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return "success", nil
|
||||
})
|
||||
|
||||
// 4. 验证错误
|
||||
st, ok := status.FromError(err)
|
||||
if !ok || st.Code() != codes.Unauthenticated {
|
||||
t.Errorf("Expected Unauthenticated error, got: %v", err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 9. 生产部署清单
|
||||
|
||||
在将 JWT 集成部署到生产环境前:
|
||||
|
||||
- [ ] 所有令牌端点都进行了压力测试
|
||||
- [ ] 令牌刷新逻辑已验证
|
||||
- [ ] 错误处理覆盖了所有 JWT 失败情况
|
||||
- [ ] 审计日志记录了所有认证尝试
|
||||
- [ ] 密钥轮换计划已确定
|
||||
- [ ] 监控和告警已配置
|
||||
- [ ] 灾难恢复流程已文档化
|
||||
- [ ] 所有依赖于 JWT 的服务都已更新
|
||||
|
||||
## 相关文件
|
||||
|
||||
- [app/users/rpc/internal/utils/jwt.go](../app/users/rpc/internal/utils/jwt.go) - JWT Manager 实现
|
||||
- [app/users/rpc/internal/config/config.go](../app/users/rpc/internal/config/config.go) - JWT 配置
|
||||
- [app/users/rpc/internal/svc/serviceContext.go](../app/users/rpc/internal/svc/serviceContext.go) - 依赖注入
|
||||
- [deploy/k8s/secrets/jwt-secret.yaml](./jwt-secret.yaml) - Secret 和 RBAC
|
||||
- [deploy/k8s/secrets/DEPLOYMENT.md](./DEPLOYMENT.md) - 部署指南
|
||||
Reference in New Issue
Block a user