diff --git a/pkg/api/tokens.go b/pkg/api/tokens.go index 1b17bc48..d49cd1e2 100644 --- a/pkg/api/tokens.go +++ b/pkg/api/tokens.go @@ -130,7 +130,13 @@ func (a *TokensApi) TokenGenerateMCPHandler(c *core.WebContext) (any, *errs.Erro // TokenRevokeCurrentHandler revokes current token of current user func (a *TokensApi) TokenRevokeCurrentHandler(c *core.WebContext) (any, *errs.Error) { - _, claims, err := a.tokens.ParseTokenByHeader(c) + tokenString := c.GetTokenStringFromHeader() + + if tokenString == "" { + return false, errs.ErrTokenIsEmpty + } + + _, claims, err := a.tokens.ParseToken(c, tokenString) if err != nil { return nil, errs.Or(err, errs.NewIncompleteOrIncorrectSubmissionError(err)) diff --git a/pkg/core/context_web.go b/pkg/core/context_web.go index b7dc770c..7927c3bf 100644 --- a/pkg/core/context_web.go +++ b/pkg/core/context_web.go @@ -3,6 +3,7 @@ package core import ( "net" "strconv" + "strings" "github.com/gin-gonic/gin" @@ -23,6 +24,11 @@ const RemoteClientPortHeader = "X-Real-Port" // ClientTimezoneOffsetHeaderName represents the header name of client timezone offset const ClientTimezoneOffsetHeaderName = "X-Timezone-Offset" +const tokenHeaderName = "Authorization" +const tokenHeaderValuePrefix = "bearer " +const tokenQueryStringParam = "token" +const tokenCookieParam = "ebk_auth_token" + // WebContext represents the request and response context type WebContext struct { *gin.Context @@ -118,6 +124,41 @@ func (c *WebContext) GetCurrentUid() int64 { return claims.Uid } +// GetTokenStringFromHeader returns the token string from the request header +func (c *WebContext) GetTokenStringFromHeader() string { + tokenHeader := c.GetHeader(tokenHeaderName) + + if len(tokenHeader) < 7 || !strings.EqualFold(tokenHeader[:7], tokenHeaderValuePrefix) { + return "" + } + + return tokenHeader[7:] +} + +// GetTokenStringFromQueryString returns the token string from the request query string +func (c *WebContext) GetTokenStringFromQueryString() string { + return c.Query(tokenQueryStringParam) +} + +// GetTokenStringFromCookie returns the token string from the request cookie +func (c *WebContext) GetTokenStringFromCookie() string { + tokenCookie, err := c.Cookie(tokenCookieParam) + + if err != nil { + return "" + } + + return tokenCookie +} + +func (c *WebContext) SetTokenStringToCookie(token string, tokenExpiredTime int, path string) { + if token != "" { + c.SetCookie(tokenCookieParam, token, tokenExpiredTime, path, "", false, true) + } else { + c.SetCookie(tokenCookieParam, "", -1, path, "", false, true) + } +} + // GetClientLocale returns the client locale name func (c *WebContext) GetClientLocale() string { value := c.GetHeader(AcceptLanguageHeaderName) diff --git a/pkg/middlewares/amap_api_proxy_auth_cookie.go b/pkg/middlewares/amap_api_proxy_auth_cookie.go index 5964fc73..33cae4f5 100644 --- a/pkg/middlewares/amap_api_proxy_auth_cookie.go +++ b/pkg/middlewares/amap_api_proxy_auth_cookie.go @@ -5,15 +5,8 @@ import ( "github.com/mayswind/ezbookkeeping/pkg/settings" ) -const tokenCookieParam = "ebk_auth_token" - // AmapApiProxyAuthCookie adds amap api proxy auth cookie to cookies in response func AmapApiProxyAuthCookie(c *core.WebContext, config *settings.Config) { token := c.GetTextualToken() - - if token != "" { - c.SetCookie(tokenCookieParam, token, int(config.TokenExpiredTime), "/_AMapService", "", false, true) - } else { - c.SetCookie(tokenCookieParam, "", -1, "/_AMapService", "", false, true) - } + c.SetTokenStringToCookie(token, int(config.TokenExpiredTime), "/_AMapService") } diff --git a/pkg/middlewares/authorization.go b/pkg/middlewares/authorization.go index 11bc5fd5..c3629c1c 100644 --- a/pkg/middlewares/authorization.go +++ b/pkg/middlewares/authorization.go @@ -20,8 +20,6 @@ const ( TOKEN_SOURCE_TYPE_COOKIE TokenSourceType = 3 ) -const tokenQueryStringParam = "token" - // JWTAuthorization verifies whether current request is valid by jwt token in header func JWTAuthorization(c *core.WebContext) { jwtAuthorization(c, TOKEN_SOURCE_TYPE_HEADER) @@ -159,11 +157,19 @@ func getTokenClaims(c *core.WebContext, source TokenSourceType) (*core.UserToken } func parseToken(c *core.WebContext, source TokenSourceType) (*jwt.Token, *core.UserTokenClaims, error) { + tokenString := "" + if source == TOKEN_SOURCE_TYPE_ARGUMENT { - return services.Tokens.ParseTokenByArgument(c, tokenQueryStringParam) + tokenString = c.GetTokenStringFromQueryString() } else if source == TOKEN_SOURCE_TYPE_COOKIE { - return services.Tokens.ParseTokenByCookie(c, tokenCookieParam) + tokenString = c.GetTokenStringFromCookie() + } else { // if source == TOKEN_SOURCE_TYPE_HEADER + tokenString = c.GetTokenStringFromHeader() } - return services.Tokens.ParseTokenByHeader(c) + if tokenString == "" { + return nil, nil, errs.ErrTokenIsEmpty + } + + return services.Tokens.ParseToken(c, tokenString) } diff --git a/pkg/services/tokens.go b/pkg/services/tokens.go index a7c99c47..f32f673b 100644 --- a/pkg/services/tokens.go +++ b/pkg/services/tokens.go @@ -1,6 +1,7 @@ package services import ( + "errors" "fmt" "math" "strings" @@ -71,19 +72,9 @@ func (s *TokenService) GetAllUnexpiredNormalAndMCPTokensByUid(c core.Context, ui return tokenRecords, err } -// ParseTokenByHeader returns the token model according to request data -func (s *TokenService) ParseTokenByHeader(c *core.WebContext) (*jwt.Token, *core.UserTokenClaims, error) { - return s.parseToken(c, request.BearerExtractor{}) -} - -// ParseTokenByArgument returns the token model according to request data -func (s *TokenService) ParseTokenByArgument(c *core.WebContext, tokenParameterName string) (*jwt.Token, *core.UserTokenClaims, error) { - return s.parseToken(c, request.ArgumentExtractor{tokenParameterName}) -} - -// ParseTokenByCookie returns the token model according to request data -func (s *TokenService) ParseTokenByCookie(c *core.WebContext, tokenCookieName string) (*jwt.Token, *core.UserTokenClaims, error) { - return s.parseToken(c, utils.CookieExtractor{tokenCookieName}) +// ParseToken returns the token model according to token content +func (s *TokenService) ParseToken(c core.Context, token string) (*jwt.Token, *core.UserTokenClaims, error) { + return s.parseToken(c, token) } // CreateTokenViaCli generates a new normal token and saves to database @@ -328,53 +319,52 @@ func (s *TokenService) GenerateTokenId(tokenRecord *models.TokenRecord) string { return fmt.Sprintf("%d:%d:%d", tokenRecord.Uid, tokenRecord.CreatedUnixTime, tokenRecord.UserTokenId) } -func (s *TokenService) parseToken(c *core.WebContext, extractor request.Extractor) (*jwt.Token, *core.UserTokenClaims, error) { +func (s *TokenService) parseToken(c core.Context, tokenString string) (*jwt.Token, *core.UserTokenClaims, error) { claims := &core.UserTokenClaims{} - token, err := request.ParseFromRequest(c.Request, extractor, + token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (any, error) { now := time.Now().Unix() userTokenId, err := utils.StringToInt64(claims.UserTokenId) if err != nil { - log.Warnf(c, "[tokens.ParseToken] token \"utid:%s\" in token of user \"uid:%d\" is invalid, because %s", claims.UserTokenId, claims.Uid, err.Error()) + log.Warnf(c, "[tokens.parseToken] token \"utid:%s\" in token of user \"uid:%d\" is invalid, because %s", claims.UserTokenId, claims.Uid, err.Error()) return nil, errs.ErrInvalidUserTokenId } tokenRecord, err := s.getTokenRecord(c, claims.Uid, userTokenId, claims.IssuedAt) if err != nil { - log.Warnf(c, "[tokens.ParseToken] token \"utid:%s\" of user \"uid:%d\" record not found, because %s", claims.UserTokenId, claims.Uid, err.Error()) + log.Warnf(c, "[tokens.parseToken] token \"utid:%s\" of user \"uid:%d\" record not found, because %s", claims.UserTokenId, claims.Uid, err.Error()) return nil, errs.ErrTokenRecordNotFound } if tokenRecord.ExpiredUnixTime < now { - log.Warnf(c, "[tokens.ParseToken] token \"utid:%s\" of user \"uid:%d\" record is expired", claims.UserTokenId, claims.Uid) + log.Warnf(c, "[tokens.parseToken] token \"utid:%s\" of user \"uid:%d\" record is expired", claims.UserTokenId, claims.Uid) return nil, errs.ErrTokenExpired } return []byte(tokenRecord.Secret), nil }, - request.WithClaims(claims), - request.WithParser(jwt.NewParser(jwt.WithIssuedAt())), + jwt.WithIssuedAt(), ) if err != nil { - if err == request.ErrNoTokenInRequest { + if errors.Is(err, request.ErrNoTokenInRequest) { return nil, nil, errs.ErrTokenIsEmpty } - if err == jwt.ErrTokenMalformed || err == jwt.ErrTokenUnverifiable || err == jwt.ErrTokenSignatureInvalid { - log.Warnf(c, "[tokens.ParseToken] token is invalid, because %s", err.Error()) + if errors.Is(err, jwt.ErrTokenMalformed) || errors.Is(err, jwt.ErrTokenUnverifiable) || errors.Is(err, jwt.ErrTokenSignatureInvalid) { + log.Warnf(c, "[tokens.parseToken] token is invalid, because %s", err.Error()) return nil, nil, errs.ErrCurrentInvalidToken } - if err == jwt.ErrTokenExpired { + if errors.Is(err, jwt.ErrTokenExpired) { return nil, nil, errs.ErrCurrentTokenExpired } - if err == jwt.ErrTokenUsedBeforeIssued { - log.Warnf(c, "[tokens.ParseToken] token is invalid, because issue time is later than now") + if errors.Is(err, jwt.ErrTokenUsedBeforeIssued) { + log.Warnf(c, "[tokens.parseToken] token is invalid, because issue time is later than now") return nil, nil, errs.ErrCurrentInvalidToken } diff --git a/pkg/utils/extractor.go b/pkg/utils/extractor.go deleted file mode 100644 index 367b7445..00000000 --- a/pkg/utils/extractor.go +++ /dev/null @@ -1,20 +0,0 @@ -package utils - -import ( - "net/http" - - "github.com/golang-jwt/jwt/v5/request" -) - -// CookieExtractor extracts a token from request cookies -type CookieExtractor []string - -func (e CookieExtractor) ExtractToken(req *http.Request) (string, error) { - for _, arg := range e { - if cookie, _ := req.Cookie(arg); cookie != nil { - return cookie.Value, nil - } - } - - return "", request.ErrNoTokenInRequest -}