From 736f3409796b6ac95b09c81d8b196cdd1ecda475 Mon Sep 17 00:00:00 2001 From: MaysWind Date: Sat, 3 Jun 2023 16:49:19 +0800 Subject: [PATCH] code refactor --- pkg/api/map_image_proxies.go | 1 - pkg/api/tokens.go | 2 +- pkg/middlewares/authorization.go | 49 ++++++++++--- pkg/services/tokens.go | 117 +++++++++++++++++-------------- 4 files changed, 103 insertions(+), 66 deletions(-) diff --git a/pkg/api/map_image_proxies.go b/pkg/api/map_image_proxies.go index 0c31feb2..4a63928c 100644 --- a/pkg/api/map_image_proxies.go +++ b/pkg/api/map_image_proxies.go @@ -31,7 +31,6 @@ func (p *MapImageProxy) OpenStreetMapTileImageProxyHandler(c *core.Context) (*ht imageRawUrl := fmt.Sprintf(openStreetMapTileImageUrlFormat, zoomLevel, coordinateX, fileName) imageUrl, _ := url.Parse(imageRawUrl) - req.Header.Del("Authorization") req.URL = imageUrl req.RequestURI = req.URL.RequestURI() req.Host = imageUrl.Host diff --git a/pkg/api/tokens.go b/pkg/api/tokens.go index cb08137d..ecbc4c62 100644 --- a/pkg/api/tokens.go +++ b/pkg/api/tokens.go @@ -62,7 +62,7 @@ func (a *TokensApi) TokenListHandler(c *core.Context) (interface{}, *errs.Error) // TokenRevokeCurrentHandler revokes current token of current user func (a *TokensApi) TokenRevokeCurrentHandler(c *core.Context) (interface{}, *errs.Error) { - _, claims, err := a.tokens.ParseToken(c) + _, claims, err := a.tokens.ParseTokenByHeader(c) if err != nil { return nil, errs.Or(err, errs.NewIncompleteOrIncorrectSubmissionError(err)) diff --git a/pkg/middlewares/authorization.go b/pkg/middlewares/authorization.go index 39c75487..33263e42 100644 --- a/pkg/middlewares/authorization.go +++ b/pkg/middlewares/authorization.go @@ -1,6 +1,8 @@ package middlewares import ( + "github.com/golang-jwt/jwt/v5" + "github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/errs" "github.com/mayswind/ezbookkeeping/pkg/log" @@ -8,11 +10,20 @@ import ( "github.com/mayswind/ezbookkeeping/pkg/utils" ) +// TokenSourceType represents token source +type TokenSourceType byte + +// Token source types +const ( + TOKEN_SOURCE_TYPE_HEADER TokenSourceType = 1 + TOKEN_SOURCE_TYPE_ARGUMENT TokenSourceType = 2 +) + const tokenQueryStringParam = "token" // JWTAuthorization verifies whether current request is valid by jwt token func JWTAuthorization(c *core.Context) { - claims, err := getTokenClaims(c) + claims, err := getTokenClaims(c, TOKEN_SOURCE_TYPE_HEADER) if err != nil { utils.PrintJsonErrorResult(c, err) @@ -37,22 +48,32 @@ func JWTAuthorization(c *core.Context) { // JWTAuthorizationByQueryString verifies whether current request is valid by jwt token func JWTAuthorizationByQueryString(c *core.Context) { - token, exists := c.GetQuery(tokenQueryStringParam) + claims, err := getTokenClaims(c, TOKEN_SOURCE_TYPE_ARGUMENT) - if !exists { - log.WarnfWithRequestId(c, "[authorization.JWTAuthorizationByQueryString] no token provided") - utils.PrintJsonErrorResult(c, errs.ErrUnauthorizedAccess) + if err != nil { + utils.PrintJsonErrorResult(c, err) return } - c.Request.Header.Set("Authorization", token) + if claims.Type == core.USER_TOKEN_TYPE_REQUIRE_2FA { + log.WarnfWithRequestId(c, "[authorization.JWTAuthorizationByQueryString] user \"uid:%d\" token requires 2fa", claims.Uid) + utils.PrintJsonErrorResult(c, errs.ErrCurrentTokenRequire2FA) + return + } - JWTAuthorization(c) + if claims.Type != core.USER_TOKEN_TYPE_NORMAL { + log.WarnfWithRequestId(c, "[authorization.JWTAuthorizationByQueryString] user \"uid:%d\" token type is invalid", claims.Uid) + utils.PrintJsonErrorResult(c, errs.ErrCurrentInvalidTokenType) + return + } + + c.SetTokenClaims(claims) + c.Next() } // JWTTwoFactorAuthorization verifies whether current request is valid by 2fa passcode func JWTTwoFactorAuthorization(c *core.Context) { - claims, err := getTokenClaims(c) + claims, err := getTokenClaims(c, TOKEN_SOURCE_TYPE_HEADER) if err != nil { utils.PrintJsonErrorResult(c, err) @@ -69,8 +90,8 @@ func JWTTwoFactorAuthorization(c *core.Context) { c.Next() } -func getTokenClaims(c *core.Context) (*core.UserTokenClaims, *errs.Error) { - token, claims, err := services.Tokens.ParseToken(c) +func getTokenClaims(c *core.Context, source TokenSourceType) (*core.UserTokenClaims, *errs.Error) { + token, claims, err := parseToken(c, source) if err != nil { log.WarnfWithRequestId(c, "[authorization.getTokenClaims] failed to parse token, because %s", err.Error()) @@ -89,3 +110,11 @@ func getTokenClaims(c *core.Context) (*core.UserTokenClaims, *errs.Error) { return claims, nil } + +func parseToken(c *core.Context, source TokenSourceType) (*jwt.Token, *core.UserTokenClaims, error) { + if source == TOKEN_SOURCE_TYPE_ARGUMENT { + return services.Tokens.ParseTokenByArgument(c, tokenQueryStringParam) + } + + return services.Tokens.ParseTokenByHeader(c) +} diff --git a/pkg/services/tokens.go b/pkg/services/tokens.go index 316ed85d..6f695e33 100644 --- a/pkg/services/tokens.go +++ b/pkg/services/tokens.go @@ -63,61 +63,14 @@ func (s *TokenService) GetAllUnexpiredNormalTokensByUid(uid int64) ([]*models.To return tokenRecords, err } -// ParseToken returns the token model according to request data -func (s *TokenService) ParseToken(c *core.Context) (*jwt.Token, *core.UserTokenClaims, error) { - claims := &core.UserTokenClaims{} +// ParseTokenByHeader returns the token model according to request data +func (s *TokenService) ParseTokenByHeader(c *core.Context) (*jwt.Token, *core.UserTokenClaims, error) { + return s.parseToken(c, request.BearerExtractor{}) +} - token, err := request.ParseFromRequest(c.Request, request.AuthorizationHeaderExtractor, - func(token *jwt.Token) (interface{}, error) { - now := time.Now().Unix() - userTokenId, err := utils.StringToInt64(claims.UserTokenId) - - if err != nil { - log.WarnfWithRequestId(c, "[tokens.ParseToken] token \"utid:%s\" in token of user \"uid:%d\" is invalid, because %s", claims.UserTokenId, claims.Uid, err.Error()) - return nil, errs.ErrInvalidUserTokenId - } - - tokenRecord, err := s.getTokenRecord(claims.Uid, userTokenId, claims.IssuedAt) - - if err != nil { - log.WarnfWithRequestId(c, "[tokens.ParseToken] token \"utid:%s\" of user \"uid:%d\" record not found, because %s", claims.UserTokenId, claims.Uid, err.Error()) - return nil, errs.ErrTokenRecordNotFound - } - - if tokenRecord.ExpiredUnixTime < now { - log.WarnfWithRequestId(c, "[tokens.ParseToken] token \"utid:%s\" of user \"uid:%d\" record is expired", claims.UserTokenId, claims.Uid) - return nil, errs.ErrTokenExpired - } - - return []byte(tokenRecord.Secret), nil - }, - request.WithClaims(claims), - request.WithParser(jwt.NewParser(jwt.WithIssuedAt())), - ) - - if err != nil { - if err == request.ErrNoTokenInRequest { - return nil, nil, errs.ErrTokenIsEmpty - } - - if err == jwt.ErrTokenMalformed || err == jwt.ErrTokenUnverifiable || err == jwt.ErrTokenSignatureInvalid { - log.WarnfWithRequestId(c, "[tokens.ParseToken] token is invalid, because %s", err.Error()) - return nil, nil, errs.ErrCurrentInvalidToken - } - - if err == jwt.ErrTokenExpired { - return nil, nil, errs.ErrCurrentTokenExpired - } - - if err == jwt.ErrTokenUsedBeforeIssued { - log.WarnfWithRequestId(c, "[tokens.ParseToken] token is invalid, because issue time is later than now") - return nil, nil, errs.ErrCurrentInvalidToken - } - - return nil, nil, err - } - - return token, claims, err +// ParseTokenByArgument returns the token model according to request data +func (s *TokenService) ParseTokenByArgument(c *core.Context, tokenParameterName string) (*jwt.Token, *core.UserTokenClaims, error) { + return s.parseToken(c, request.ArgumentExtractor{tokenParameterName}) } // CreateToken generates a new normal token and saves to database @@ -242,6 +195,62 @@ func (s *TokenService) GenerateTokenId(tokenRecord *models.TokenRecord) string { return fmt.Sprintf("%d:%d:%d", tokenRecord.Uid, tokenRecord.CreatedUnixTime, tokenRecord.UserTokenId) } +func (s *TokenService) parseToken(c *core.Context, extractor request.Extractor) (*jwt.Token, *core.UserTokenClaims, error) { + claims := &core.UserTokenClaims{} + + token, err := request.ParseFromRequest(c.Request, extractor, + func(token *jwt.Token) (interface{}, error) { + now := time.Now().Unix() + userTokenId, err := utils.StringToInt64(claims.UserTokenId) + + if err != nil { + log.WarnfWithRequestId(c, "[tokens.ParseToken] token \"utid:%s\" in token of user \"uid:%d\" is invalid, because %s", claims.UserTokenId, claims.Uid, err.Error()) + return nil, errs.ErrInvalidUserTokenId + } + + tokenRecord, err := s.getTokenRecord(claims.Uid, userTokenId, claims.IssuedAt) + + if err != nil { + log.WarnfWithRequestId(c, "[tokens.ParseToken] token \"utid:%s\" of user \"uid:%d\" record not found, because %s", claims.UserTokenId, claims.Uid, err.Error()) + return nil, errs.ErrTokenRecordNotFound + } + + if tokenRecord.ExpiredUnixTime < now { + log.WarnfWithRequestId(c, "[tokens.ParseToken] token \"utid:%s\" of user \"uid:%d\" record is expired", claims.UserTokenId, claims.Uid) + return nil, errs.ErrTokenExpired + } + + return []byte(tokenRecord.Secret), nil + }, + request.WithClaims(claims), + request.WithParser(jwt.NewParser(jwt.WithIssuedAt())), + ) + + if err != nil { + if err == request.ErrNoTokenInRequest { + return nil, nil, errs.ErrTokenIsEmpty + } + + if err == jwt.ErrTokenMalformed || err == jwt.ErrTokenUnverifiable || err == jwt.ErrTokenSignatureInvalid { + log.WarnfWithRequestId(c, "[tokens.ParseToken] token is invalid, because %s", err.Error()) + return nil, nil, errs.ErrCurrentInvalidToken + } + + if err == jwt.ErrTokenExpired { + return nil, nil, errs.ErrCurrentTokenExpired + } + + if err == jwt.ErrTokenUsedBeforeIssued { + log.WarnfWithRequestId(c, "[tokens.ParseToken] token is invalid, because issue time is later than now") + return nil, nil, errs.ErrCurrentInvalidToken + } + + return nil, nil, err + } + + return token, claims, err +} + func (s *TokenService) createToken(user *models.User, tokenType core.TokenType, userAgent string, expiryDate time.Duration) (string, *core.UserTokenClaims, error) { var err error now := time.Now()