add: user auth accomplished
This commit is contained in:
@@ -2,9 +2,11 @@ package logic
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"errors"
|
||||
"juwan-backend/app/users/rpc/internal/svc"
|
||||
utils2 "juwan-backend/app/users/rpc/internal/utils"
|
||||
"juwan-backend/app/users/rpc/pb"
|
||||
"juwan-backend/common/utils"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
@@ -24,7 +26,29 @@ func NewLoginLogic(ctx context.Context, svcCtx *svc.ServiceContext) *LoginLogic
|
||||
}
|
||||
|
||||
func (l *LoginLogic) Login(in *pb.LoginReq) (*pb.LoginResp, error) {
|
||||
// todo: add your logic here and delete this line
|
||||
user, err := l.svcCtx.UsersModelRO.FindOneByUsername(l.ctx, in.Username)
|
||||
if err != nil {
|
||||
logx.WithContext(l.ctx).Errorf("LoginLogic.Login error:%v", err)
|
||||
return nil, err
|
||||
}
|
||||
if !utils.VerifyPassword(user.Passwd, in.Passwd) {
|
||||
logx.WithContext(l.ctx).Errorf("User %s Login failed", user.Username)
|
||||
return nil, errors.New("incorrect password")
|
||||
}
|
||||
|
||||
return &pb.LoginResp{}, nil
|
||||
token, err := l.svcCtx.JwtManager.New(l.ctx, &utils2.TokenPayload{
|
||||
UserId: user.UserId,
|
||||
IsAdmin: false,
|
||||
})
|
||||
if err != nil {
|
||||
logx.Errorf("LoginLogic.Login gen jwt for user %v error:%v", user.UserId, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pb.LoginResp{
|
||||
Token: token,
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
Id: user.UserId,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
package logic
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"juwan-backend/app/users/rpc/internal/svc"
|
||||
"juwan-backend/app/users/rpc/pb"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
|
||||
type LogoutLogic struct {
|
||||
ctx context.Context
|
||||
svcCtx *svc.ServiceContext
|
||||
logx.Logger
|
||||
}
|
||||
|
||||
func NewLogoutLogic(ctx context.Context, svcCtx *svc.ServiceContext) *LogoutLogic {
|
||||
return &LogoutLogic{
|
||||
ctx: ctx,
|
||||
svcCtx: svcCtx,
|
||||
Logger: logx.WithContext(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *LogoutLogic) Logout(in *pb.LogoutReq) (*pb.LogoutResp, error) {
|
||||
// todo: add your logic here and delete this line
|
||||
err := l.svcCtx.JwtManager.Logout(l.ctx, in.UserId)
|
||||
if err != nil {
|
||||
logx.WithContext(l.ctx).Errorf("Logout failed: %s", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
return &pb.LogoutResp{}, nil
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
package logic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"juwan-backend/app/snowflake/rpc/snowflake"
|
||||
"juwan-backend/app/users/rpc/internal/models"
|
||||
"juwan-backend/app/users/rpc/internal/svc"
|
||||
"juwan-backend/app/users/rpc/pb"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
|
||||
type RegisterLogic struct {
|
||||
ctx context.Context
|
||||
svcCtx *svc.ServiceContext
|
||||
logx.Logger
|
||||
}
|
||||
|
||||
func NewRegisterLogic(ctx context.Context, svcCtx *svc.ServiceContext) *RegisterLogic {
|
||||
return &RegisterLogic{
|
||||
ctx: ctx,
|
||||
svcCtx: svcCtx,
|
||||
Logger: logx.WithContext(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
func mustNewRandomNickname() string {
|
||||
bytes := make([]byte, 5)
|
||||
_, err := rand.Read(bytes)
|
||||
if err != nil {
|
||||
return "NewUser"
|
||||
}
|
||||
nickname := strings.Builder{}
|
||||
nickname.WriteString("user_")
|
||||
nickname.WriteString(hex.EncodeToString(bytes))
|
||||
return nickname.String()
|
||||
}
|
||||
|
||||
func (l *RegisterLogic) Register(in *pb.RegisterReq) (*pb.RegisterResp, error) {
|
||||
// todo: add your logic here and delete this line
|
||||
if in.Phone == "" || in.Username == "" || in.Passwd == "" {
|
||||
logx.Error("invalid input")
|
||||
return nil, errors.New("invalid input")
|
||||
}
|
||||
|
||||
redisKey := fmt.Sprintf("vcode:%s:%s:%s", in.RequestId, "register", in.Email)
|
||||
vcode, err := l.svcCtx.RedisCluster.Get(l.ctx, redisKey).Result()
|
||||
logx.Infof("vcode:%s, err:%v", vcode, err)
|
||||
if err != nil {
|
||||
logx.Error("invalid verification code")
|
||||
return nil, errors.New("invalid verification code")
|
||||
}
|
||||
|
||||
code, err := strconv.ParseInt(vcode, 10, 32)
|
||||
if err != nil || int32(code) != in.Vcode {
|
||||
logx.Error("invalid verification code")
|
||||
return nil, errors.New("invalid verification code")
|
||||
}
|
||||
|
||||
resp, err := l.svcCtx.Snowflake.NextId(l.ctx, &snowflake.NextIdReq{})
|
||||
if err != nil {
|
||||
return nil, errors.New("generate user ID failed")
|
||||
}
|
||||
|
||||
user := models.Users{
|
||||
UserId: resp.Id,
|
||||
Username: in.Username,
|
||||
Nickname: mustNewRandomNickname(),
|
||||
Passwd: in.Passwd,
|
||||
Phone: in.Phone,
|
||||
Email: in.Email,
|
||||
RoleType: 0,
|
||||
IsVerified: false,
|
||||
}
|
||||
|
||||
_, err = l.svcCtx.UsersModelRW.Insert(l.ctx, &user)
|
||||
if err != nil {
|
||||
logx.Error("failed to create user: ", err)
|
||||
return nil, errors.New("failed to create user")
|
||||
}
|
||||
|
||||
return &pb.RegisterResp{
|
||||
Res: "user registered successfully",
|
||||
}, nil
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package logic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"juwan-backend/app/users/rpc/internal/svc"
|
||||
"juwan-backend/app/users/rpc/pb"
|
||||
@@ -9,6 +10,8 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
|
||||
var USER_TOKEN_TEMP = "jwt:%v"
|
||||
|
||||
type ValidateTokenLogic struct {
|
||||
ctx context.Context
|
||||
svcCtx *svc.ServiceContext
|
||||
@@ -24,7 +27,20 @@ func NewValidateTokenLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Val
|
||||
}
|
||||
|
||||
func (l *ValidateTokenLogic) ValidateToken(in *pb.ValidateTokenReq) (*pb.ValidateTokenResp, error) {
|
||||
// todo: add your logic here and delete this line
|
||||
redisKey := fmt.Sprintf(USER_TOKEN_TEMP, in.UserId)
|
||||
_, err := l.svcCtx.JwtManager.Valid(l.ctx, redisKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users, err := l.svcCtx.UsersModelRO.FindOne(l.ctx, in.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pb.ValidateTokenResp{}, nil
|
||||
return &pb.ValidateTokenResp{
|
||||
Valid: true,
|
||||
Message: "OK",
|
||||
UserId: in.UserId,
|
||||
RoleType: users.RoleType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
+32
-6
@@ -25,6 +25,7 @@ var (
|
||||
usersRowsWithPlaceHolder = builder.PostgreSqlJoin(stringx.Remove(usersFieldNames, "user_id", "create_at", "create_time", "created_at", "update_at", "update_time", "updated_at"))
|
||||
|
||||
cachePublicUsersUserIdPrefix = "cache:public:users:userId:"
|
||||
cachePublicUsersEmailPrefix = "cache:public:users:email:"
|
||||
cachePublicUsersPhonePrefix = "cache:public:users:phone:"
|
||||
cachePublicUsersUsernamePrefix = "cache:public:users:username:"
|
||||
)
|
||||
@@ -33,6 +34,7 @@ type (
|
||||
usersModel interface {
|
||||
Insert(ctx context.Context, data *Users) (sql.Result, error)
|
||||
FindOne(ctx context.Context, userId int64) (*Users, error)
|
||||
FindOneByEmail(ctx context.Context, email string) (*Users, error)
|
||||
FindOneByPhone(ctx context.Context, phone string) (*Users, error)
|
||||
FindOneByUsername(ctx context.Context, username string) (*Users, error)
|
||||
Update(ctx context.Context, data *Users) error
|
||||
@@ -56,6 +58,7 @@ type (
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at"`
|
||||
DeletedAt sql.NullTime `db:"deleted_at"`
|
||||
Email string `db:"email"`
|
||||
}
|
||||
)
|
||||
|
||||
@@ -72,13 +75,14 @@ func (m *defaultUsersModel) Delete(ctx context.Context, userId int64) error {
|
||||
return err
|
||||
}
|
||||
|
||||
publicUsersEmailKey := fmt.Sprintf("%s%v", cachePublicUsersEmailPrefix, data.Email)
|
||||
publicUsersPhoneKey := fmt.Sprintf("%s%v", cachePublicUsersPhonePrefix, data.Phone)
|
||||
publicUsersUserIdKey := fmt.Sprintf("%s%v", cachePublicUsersUserIdPrefix, userId)
|
||||
publicUsersUsernameKey := fmt.Sprintf("%s%v", cachePublicUsersUsernamePrefix, data.Username)
|
||||
_, err = m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (result sql.Result, err error) {
|
||||
query := fmt.Sprintf("delete from %s where user_id = $1", m.table)
|
||||
return conn.ExecCtx(ctx, query, userId)
|
||||
}, publicUsersPhoneKey, publicUsersUserIdKey, publicUsersUsernameKey)
|
||||
}, publicUsersEmailKey, publicUsersPhoneKey, publicUsersUserIdKey, publicUsersUsernameKey)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -99,6 +103,26 @@ func (m *defaultUsersModel) FindOne(ctx context.Context, userId int64) (*Users,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *defaultUsersModel) FindOneByEmail(ctx context.Context, email string) (*Users, error) {
|
||||
publicUsersEmailKey := fmt.Sprintf("%s%v", cachePublicUsersEmailPrefix, email)
|
||||
var resp Users
|
||||
err := m.QueryRowIndexCtx(ctx, &resp, publicUsersEmailKey, m.formatPrimary, func(ctx context.Context, conn sqlx.SqlConn, v any) (i any, e error) {
|
||||
query := fmt.Sprintf("select %s from %s where email = $1 limit 1", usersRows, m.table)
|
||||
if err := conn.QueryRowCtx(ctx, &resp, query, email); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.UserId, nil
|
||||
}, m.queryPrimary)
|
||||
switch err {
|
||||
case nil:
|
||||
return &resp, nil
|
||||
case sqlc.ErrNotFound:
|
||||
return nil, ErrNotFound
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (m *defaultUsersModel) FindOneByPhone(ctx context.Context, phone string) (*Users, error) {
|
||||
publicUsersPhoneKey := fmt.Sprintf("%s%v", cachePublicUsersPhonePrefix, phone)
|
||||
var resp Users
|
||||
@@ -140,13 +164,14 @@ func (m *defaultUsersModel) FindOneByUsername(ctx context.Context, username stri
|
||||
}
|
||||
|
||||
func (m *defaultUsersModel) Insert(ctx context.Context, data *Users) (sql.Result, error) {
|
||||
publicUsersEmailKey := fmt.Sprintf("%s%v", cachePublicUsersEmailPrefix, data.Email)
|
||||
publicUsersPhoneKey := fmt.Sprintf("%s%v", cachePublicUsersPhonePrefix, data.Phone)
|
||||
publicUsersUserIdKey := fmt.Sprintf("%s%v", cachePublicUsersUserIdPrefix, data.UserId)
|
||||
publicUsersUsernameKey := fmt.Sprintf("%s%v", cachePublicUsersUsernamePrefix, data.Username)
|
||||
ret, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (result sql.Result, err error) {
|
||||
query := fmt.Sprintf("insert into %s (%s) values ($1, $2, $3, $4, $5, $6, $7, $8, $9)", m.table, usersRowsExpectAutoSet)
|
||||
return conn.ExecCtx(ctx, query, data.UserId, data.Username, data.Passwd, data.Nickname, data.Phone, data.RoleType, data.IsVerified, data.State, data.DeletedAt)
|
||||
}, publicUsersPhoneKey, publicUsersUserIdKey, publicUsersUsernameKey)
|
||||
query := fmt.Sprintf("insert into %s (%s) values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)", m.table, usersRowsExpectAutoSet)
|
||||
return conn.ExecCtx(ctx, query, data.UserId, data.Username, data.Passwd, data.Nickname, data.Phone, data.RoleType, data.IsVerified, data.State, data.DeletedAt, data.Email)
|
||||
}, publicUsersEmailKey, publicUsersPhoneKey, publicUsersUserIdKey, publicUsersUsernameKey)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
@@ -156,13 +181,14 @@ func (m *defaultUsersModel) Update(ctx context.Context, newData *Users) error {
|
||||
return err
|
||||
}
|
||||
|
||||
publicUsersEmailKey := fmt.Sprintf("%s%v", cachePublicUsersEmailPrefix, data.Email)
|
||||
publicUsersPhoneKey := fmt.Sprintf("%s%v", cachePublicUsersPhonePrefix, data.Phone)
|
||||
publicUsersUserIdKey := fmt.Sprintf("%s%v", cachePublicUsersUserIdPrefix, data.UserId)
|
||||
publicUsersUsernameKey := fmt.Sprintf("%s%v", cachePublicUsersUsernamePrefix, data.Username)
|
||||
_, err = m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (result sql.Result, err error) {
|
||||
query := fmt.Sprintf("update %s set %s where user_id = $1", m.table, usersRowsWithPlaceHolder)
|
||||
return conn.ExecCtx(ctx, query, newData.UserId, newData.Username, newData.Passwd, newData.Nickname, newData.Phone, newData.RoleType, newData.IsVerified, newData.State, newData.DeletedAt)
|
||||
}, publicUsersPhoneKey, publicUsersUserIdKey, publicUsersUsernameKey)
|
||||
return conn.ExecCtx(ctx, query, newData.UserId, newData.Username, newData.Passwd, newData.Nickname, newData.Phone, newData.RoleType, newData.IsVerified, newData.State, newData.DeletedAt, newData.Email)
|
||||
}, publicUsersEmailKey, publicUsersPhoneKey, publicUsersUserIdKey, publicUsersUsernameKey)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -59,6 +59,11 @@ func (s *UsercenterServer) Login(ctx context.Context, in *pb.LoginReq) (*pb.Logi
|
||||
return l.Login(in)
|
||||
}
|
||||
|
||||
func (s *UsercenterServer) Register(ctx context.Context, in *pb.RegisterReq) (*pb.RegisterResp, error) {
|
||||
l := logic.NewRegisterLogic(ctx, s.svcCtx)
|
||||
return l.Register(in)
|
||||
}
|
||||
|
||||
func (s *UsercenterServer) ValidateToken(ctx context.Context, in *pb.ValidateTokenReq) (*pb.ValidateTokenResp, error) {
|
||||
l := logic.NewValidateTokenLogic(ctx, s.svcCtx)
|
||||
return l.ValidateToken(in)
|
||||
@@ -68,3 +73,8 @@ func (s *UsercenterServer) CheckPermission(ctx context.Context, in *pb.CheckPerm
|
||||
l := logic.NewCheckPermissionLogic(ctx, s.svcCtx)
|
||||
return l.CheckPermission(in)
|
||||
}
|
||||
|
||||
func (s *UsercenterServer) Logout(ctx context.Context, in *pb.LogoutReq) (*pb.LogoutResp, error) {
|
||||
l := logic.NewLogoutLogic(ctx, s.svcCtx)
|
||||
return l.Logout(in)
|
||||
}
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
// JWKS (JSON Web Key Set) 结构
|
||||
type JWKSKey struct {
|
||||
Kty string `json:"kty"`
|
||||
Use string `json:"use"`
|
||||
Kid string `json:"kid"`
|
||||
N string `json:"n,omitempty"`
|
||||
E string `json:"e,omitempty"`
|
||||
K string `json:"k,omitempty"` // 对称密钥
|
||||
Alg string `json:"alg"`
|
||||
}
|
||||
|
||||
type JWKS struct {
|
||||
Keys []JWKSKey `json:"keys"`
|
||||
}
|
||||
|
||||
// GenerateJWKSFromSecret 从密钥生成 JWKS(用于对称加密 HS256)
|
||||
func GenerateJWKSFromSecret(secretKey string, keyID string) *JWKS {
|
||||
// 对于 HS256,将密钥进行 base64 编码
|
||||
encodedSecret := base64.RawURLEncoding.EncodeToString([]byte(secretKey))
|
||||
|
||||
return &JWKS{
|
||||
Keys: []JWKSKey{
|
||||
{
|
||||
Kty: "oct",
|
||||
Use: "sig",
|
||||
Kid: keyID,
|
||||
K: encodedSecret,
|
||||
Alg: "HS256",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateJWKSEndpoint 生成可以被 Envoy 使用的 JWKS JSON
|
||||
// 此端点应在 user-rpc 中暴露,URL 为 /.well-known/jwks.json
|
||||
func GenerateJWKSEndpoint(secretKey string, keyID string) (string, error) {
|
||||
if secretKey == "" {
|
||||
return "", fmt.Errorf("secret key cannot be empty")
|
||||
}
|
||||
|
||||
jwks := GenerateJWKSFromSecret(secretKey, keyID)
|
||||
|
||||
jsonData, err := json.MarshalIndent(jwks, "", " ")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(jsonData), nil
|
||||
}
|
||||
|
||||
// TokenPayload 令牌负载
|
||||
type TokenMetadata struct {
|
||||
IssuedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
Subject string // userId
|
||||
Issuer string
|
||||
Audience string
|
||||
}
|
||||
|
||||
// ExtractTokenMetadata 从 token 中提取元数据(不验证签名)
|
||||
func ExtractTokenMetadata(tokenString string) (*TokenMetadata, error) {
|
||||
token, _, err := new(jwt.Parser).ParseUnverified(tokenString, &Claims{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*Claims)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid token claims type")
|
||||
}
|
||||
|
||||
return &TokenMetadata{
|
||||
IssuedAt: claims.IssuedAt.Time,
|
||||
ExpiresAt: claims.ExpiresAt.Time,
|
||||
Subject: claims.UserId,
|
||||
Issuer: claims.Issuer,
|
||||
Audience: "", // 如果需要,可以增加到 Claims 中
|
||||
}, nil
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
@@ -12,7 +13,7 @@ import (
|
||||
)
|
||||
|
||||
type TokenPayload struct {
|
||||
UserId string
|
||||
UserId int64
|
||||
IsAdmin bool
|
||||
}
|
||||
|
||||
@@ -33,6 +34,7 @@ var (
|
||||
errInvalidToken = errors.New("invalid token claims")
|
||||
errTokenNotInCache = errors.New("token not found in cache")
|
||||
errNoRedisClient = errors.New("redis client not configured")
|
||||
errInvalidUserID = errors.New("invalid user id")
|
||||
// errExpiredToken = errors.New("token expired")
|
||||
)
|
||||
|
||||
@@ -74,8 +76,7 @@ func (m *JwtManager) New(ctx context.Context, payload *TokenPayload) (string, er
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 存储 token 到 Redis,TTL 为 30 天
|
||||
userKey := tokenCachePrefixUser + payload.UserId
|
||||
userKey := tokenCachePrefixUser + strconv.FormatInt(claims.UserId, 10)
|
||||
tokenKey := tokenCachePrefixToken + tokenString
|
||||
|
||||
tokenData, _ := json.Marshal(payload)
|
||||
@@ -105,12 +106,12 @@ func (m *JwtManager) Valid(ctx context.Context, tokenString string) (*TokenPaylo
|
||||
// 检查 token 是否在 Redis 中
|
||||
tokenKey := tokenCachePrefixToken + tokenString
|
||||
tokenData, err := m.redisCluster.Get(ctx, tokenKey).Result()
|
||||
if err != nil && err != redis.Nil {
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var payload TokenPayload
|
||||
if err == redis.Nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, errTokenNotInCache
|
||||
}
|
||||
|
||||
@@ -125,6 +126,20 @@ func (m *JwtManager) Valid(ctx context.Context, tokenString string) (*TokenPaylo
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||
if _, renewErr := m.Renew(ctx, tokenString); renewErr != nil {
|
||||
return nil, renewErr
|
||||
}
|
||||
|
||||
if token != nil {
|
||||
if claims, ok := token.Claims.(*Claims); ok {
|
||||
return &claims.TokenPayload, nil
|
||||
}
|
||||
}
|
||||
|
||||
return &payload, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -146,7 +161,7 @@ func (m *JwtManager) Renew(ctx context.Context, tokenString string) (string, err
|
||||
tokenKey := tokenCachePrefixToken + tokenString
|
||||
tokenData, err := m.redisCluster.Get(ctx, tokenKey).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", errTokenNotInCache
|
||||
}
|
||||
return "", err
|
||||
@@ -159,15 +174,15 @@ func (m *JwtManager) Renew(ctx context.Context, tokenString string) (string, err
|
||||
}
|
||||
|
||||
// 删除旧 token 记录
|
||||
userKey := tokenCachePrefixUser + payload.UserId
|
||||
userKey := tokenCachePrefixUser + strconv.FormatInt(payload.UserId, 10)
|
||||
m.redisCluster.Del(ctx, tokenKey, userKey)
|
||||
|
||||
// 生成新 token
|
||||
return m.New(ctx, &payload)
|
||||
}
|
||||
|
||||
// extract payload from token without validating expiration (used for auto-renewal)
|
||||
func (m *JwtManager) Extract(ctx context.Context, tokenString string) (*TokenPayload, error) {
|
||||
// Extract payload from token without validating expiration (used for auto-renewal)
|
||||
func (m *JwtManager) Extract(_ context.Context, tokenString string) (*TokenPayload, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(m.secretKey), nil
|
||||
})
|
||||
@@ -184,7 +199,7 @@ func (m *JwtManager) Extract(ctx context.Context, tokenString string) (*TokenPay
|
||||
return &claims.TokenPayload, nil
|
||||
}
|
||||
|
||||
// check if token exists in Redis (i.e. is valid and not revoked)
|
||||
// Exists check if token exists in Redis (i.e. is valid and not revoked)
|
||||
func (m *JwtManager) Exists(ctx context.Context, tokenString string) (bool, error) {
|
||||
if m.redisCluster == nil {
|
||||
return false, errNoRedisClient
|
||||
@@ -199,12 +214,12 @@ func (m *JwtManager) Exists(ctx context.Context, tokenString string) (bool, erro
|
||||
return exists > 0, nil
|
||||
}
|
||||
|
||||
// extract payload from JWT claims
|
||||
// ClaimsToPayload extract payload from JWT claims
|
||||
func (m *JwtManager) ClaimsToPayload(claims *Claims) *TokenPayload {
|
||||
return &claims.TokenPayload
|
||||
}
|
||||
|
||||
// revoke token by deleting both user -> token and token -> payload keys from Redis
|
||||
// Revoke revoke token by deleting both user -> token and token -> payload keys from Redis
|
||||
func (m *JwtManager) Revoke(ctx context.Context, tokenString string) error {
|
||||
if m.redisCluster == nil {
|
||||
return errNoRedisClient
|
||||
@@ -215,7 +230,7 @@ func (m *JwtManager) Revoke(ctx context.Context, tokenString string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
userKey := tokenCachePrefixUser + payload.UserId
|
||||
userKey := tokenCachePrefixUser + strconv.FormatInt(payload.UserId, 10)
|
||||
tokenKey := tokenCachePrefixToken + tokenString
|
||||
|
||||
pipe := m.redisCluster.Pipeline()
|
||||
@@ -225,19 +240,48 @@ func (m *JwtManager) Revoke(ctx context.Context, tokenString string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *JwtManager) GetUserToken(ctx context.Context, userID string) (string, error) {
|
||||
func (m *JwtManager) GetUserToken(ctx context.Context, userID int64) (string, error) {
|
||||
if m.redisCluster == nil {
|
||||
return "", errNoRedisClient
|
||||
}
|
||||
//userID, err := strconv.FormatInt(userID, 10)
|
||||
id := strconv.FormatInt(userID, 10)
|
||||
|
||||
userKey := tokenCachePrefixUser + userID
|
||||
userKey := tokenCachePrefixUser + id
|
||||
token, err := m.redisCluster.Get(ctx, userKey).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return "", fmt.Errorf("user %s has no token", userID)
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", fmt.Errorf("user %v has no token", userID)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Logout 按用户登出:删除 user->token 和 token->payload 两类缓存数据
|
||||
func (m *JwtManager) Logout(ctx context.Context, userID int64) error {
|
||||
if m.redisCluster == nil {
|
||||
return errNoRedisClient
|
||||
}
|
||||
|
||||
if userID <= 0 {
|
||||
return errInvalidUserID
|
||||
}
|
||||
|
||||
userKey := tokenCachePrefixUser + strconv.FormatInt(userID, 10)
|
||||
tokenString, err := m.redisCluster.Get(ctx, userKey).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return err
|
||||
}
|
||||
|
||||
pipe := m.redisCluster.Pipeline()
|
||||
pipe.Del(ctx, userKey)
|
||||
if !errors.Is(err, redis.Nil) && tokenString != "" {
|
||||
tokenKey := tokenCachePrefixToken + tokenString
|
||||
pipe.Del(ctx, tokenKey)
|
||||
}
|
||||
|
||||
_, execErr := pipe.Exec(ctx)
|
||||
return execErr
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user