package contextj import ( "context" "errors" "github.com/zeromicro/go-zero/core/logx" ) var ( ERRILLEGALUSER = errors.New("illegal user") ERRILLEGALTOKEN = errors.New("illegal token") ERRILLEGALREQUESTID = errors.New("illegal request id") ERRILLEGALISADMIN = errors.New("illegal is_admin") ) func WithRequestId(c context.Context, requestId string) context.Context { return context.WithValue(c, "request_id", requestId) } func RequestIdFrom(c context.Context) (string, error) { requestID, ok := c.Value("request_id").(string) if !ok { return "", errors.New("request_id not found in context") } return requestID, nil } func WithToken(c context.Context, token string) context.Context { return context.WithValue(c, "token", token) } func TokenFrom(c context.Context) (string, error) { token, ok := c.Value("token").(string) if !ok { return "", errors.New("token not found in context") } return token, nil } func WithUserID(c context.Context, id int64) context.Context { return context.WithValue(c, "user_id", id) } func UserIDFrom(c context.Context) (int64, error) { if userID, ok := c.Value("user_id").(int64); !ok { return 0, errors.New("user_id not found in context") } else { return userID, nil } } // request_id is used for tracing and logging, not for authentication or authorization, // so it can be set by clients or generated by the server. func WithRequestID(c context.Context, requestID string) context.Context { return context.WithValue(c, "request_id", requestID) } func RequestIDFrom(c context.Context) (string, error) { if requestID, ok := c.Value("request_id").(string); !ok { return "", errors.New("request_id not found in context") } else { return requestID, nil } } func RIdFrom(c context.Context) (string, error) { if rid, ok := c.Value("rid").(string); !ok { return "", errors.New("rid not found in context") } else { return rid, nil } } func WithIsAdmin(c context.Context, isAdmin bool) context.Context { return context.WithValue(c, "is_admin", isAdmin) } func IsAdminFrom(c context.Context) (bool, error) { if isAdmin, ok := c.Value("is_admin").(bool); !ok { return false, errors.New("is_admin not found in context") } else { return isAdmin, nil } } func AdminIdFrom(c context.Context) (adminId int64, err error) { adminId, err = UserIDFrom(c) if err != nil { logx.Errorf("get user id from context: %v", err) return 0, ERRILLEGALUSER } isAdmin, err := IsAdminFrom(c) if err != nil { logx.Errorf("get isAdmin from context: %v", err) return 0, ERRILLEGALUSER } if !isAdmin { logx.Errorf("user %d is not admin", adminId) return 0, ERRILLEGALUSER } return }