105 lines
2.6 KiB
Go
105 lines
2.6 KiB
Go
package contextj
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
|
|
"github.com/zeromicro/go-zero/core/logx"
|
|
)
|
|
|
|
var (
|
|
ERRILLEGALUSER = errors.New("illegal user")
|
|
ERRILLEGALTOKEN = errors.New("illegal token")
|
|
ERRILLEGALREQUESTID = errors.New("illegal request id")
|
|
ERRILLEGALISADMIN = errors.New("illegal is_admin")
|
|
)
|
|
|
|
func WithRequestId(c context.Context, requestId string) context.Context {
|
|
return context.WithValue(c, "request_id", requestId)
|
|
}
|
|
|
|
func RequestIdFrom(c context.Context) (string, error) {
|
|
requestID, ok := c.Value("request_id").(string)
|
|
if !ok {
|
|
return "", errors.New("request_id not found in context")
|
|
}
|
|
return requestID, nil
|
|
}
|
|
|
|
func WithToken(c context.Context, token string) context.Context {
|
|
return context.WithValue(c, "token", token)
|
|
}
|
|
|
|
func TokenFrom(c context.Context) (string, error) {
|
|
token, ok := c.Value("token").(string)
|
|
if !ok {
|
|
return "", errors.New("token not found in context")
|
|
}
|
|
return token, nil
|
|
}
|
|
|
|
func WithUserID(c context.Context, id int64) context.Context {
|
|
return context.WithValue(c, "user_id", id)
|
|
}
|
|
|
|
func UserIDFrom(c context.Context) (int64, error) {
|
|
if userID, ok := c.Value("user_id").(int64); !ok {
|
|
return 0, errors.New("user_id not found in context")
|
|
} else {
|
|
return userID, nil
|
|
}
|
|
}
|
|
|
|
// request_id is used for tracing and logging, not for authentication or authorization,
|
|
// so it can be set by clients or generated by the server.
|
|
func WithRequestID(c context.Context, requestID string) context.Context {
|
|
return context.WithValue(c, "request_id", requestID)
|
|
}
|
|
|
|
func RequestIDFrom(c context.Context) (string, error) {
|
|
if requestID, ok := c.Value("request_id").(string); !ok {
|
|
return "", errors.New("request_id not found in context")
|
|
} else {
|
|
return requestID, nil
|
|
}
|
|
}
|
|
|
|
func RIdFrom(c context.Context) (string, error) {
|
|
if rid, ok := c.Value("rid").(string); !ok {
|
|
return "", errors.New("rid not found in context")
|
|
} else {
|
|
return rid, nil
|
|
}
|
|
}
|
|
|
|
func WithIsAdmin(c context.Context, isAdmin bool) context.Context {
|
|
return context.WithValue(c, "is_admin", isAdmin)
|
|
}
|
|
|
|
func IsAdminFrom(c context.Context) (bool, error) {
|
|
if isAdmin, ok := c.Value("is_admin").(bool); !ok {
|
|
return false, errors.New("is_admin not found in context")
|
|
} else {
|
|
return isAdmin, nil
|
|
}
|
|
}
|
|
|
|
func AdminIdFrom(c context.Context) (adminId int64, err error) {
|
|
|
|
adminId, err = UserIDFrom(c)
|
|
if err != nil {
|
|
logx.Errorf("get user id from context: %v", err)
|
|
return 0, ERRILLEGALUSER
|
|
}
|
|
isAdmin, err := IsAdminFrom(c)
|
|
if err != nil {
|
|
logx.Errorf("get isAdmin from context: %v", err)
|
|
return 0, ERRILLEGALUSER
|
|
}
|
|
if !isAdmin {
|
|
logx.Errorf("user %d is not admin", adminId)
|
|
return 0, ERRILLEGALUSER
|
|
}
|
|
return
|
|
}
|