code refactor

This commit is contained in:
MaysWind
2025-07-24 23:58:24 +08:00
parent d6ee8a416f
commit d385358aa3
6 changed files with 76 additions and 60 deletions
+7 -1
View File
@@ -130,7 +130,13 @@ func (a *TokensApi) TokenGenerateMCPHandler(c *core.WebContext) (any, *errs.Erro
// TokenRevokeCurrentHandler revokes current token of current user // TokenRevokeCurrentHandler revokes current token of current user
func (a *TokensApi) TokenRevokeCurrentHandler(c *core.WebContext) (any, *errs.Error) { func (a *TokensApi) TokenRevokeCurrentHandler(c *core.WebContext) (any, *errs.Error) {
_, claims, err := a.tokens.ParseTokenByHeader(c) tokenString := c.GetTokenStringFromHeader()
if tokenString == "" {
return false, errs.ErrTokenIsEmpty
}
_, claims, err := a.tokens.ParseToken(c, tokenString)
if err != nil { if err != nil {
return nil, errs.Or(err, errs.NewIncompleteOrIncorrectSubmissionError(err)) return nil, errs.Or(err, errs.NewIncompleteOrIncorrectSubmissionError(err))
+41
View File
@@ -3,6 +3,7 @@ package core
import ( import (
"net" "net"
"strconv" "strconv"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -23,6 +24,11 @@ const RemoteClientPortHeader = "X-Real-Port"
// ClientTimezoneOffsetHeaderName represents the header name of client timezone offset // ClientTimezoneOffsetHeaderName represents the header name of client timezone offset
const ClientTimezoneOffsetHeaderName = "X-Timezone-Offset" const ClientTimezoneOffsetHeaderName = "X-Timezone-Offset"
const tokenHeaderName = "Authorization"
const tokenHeaderValuePrefix = "bearer "
const tokenQueryStringParam = "token"
const tokenCookieParam = "ebk_auth_token"
// WebContext represents the request and response context // WebContext represents the request and response context
type WebContext struct { type WebContext struct {
*gin.Context *gin.Context
@@ -118,6 +124,41 @@ func (c *WebContext) GetCurrentUid() int64 {
return claims.Uid return claims.Uid
} }
// GetTokenStringFromHeader returns the token string from the request header
func (c *WebContext) GetTokenStringFromHeader() string {
tokenHeader := c.GetHeader(tokenHeaderName)
if len(tokenHeader) < 7 || !strings.EqualFold(tokenHeader[:7], tokenHeaderValuePrefix) {
return ""
}
return tokenHeader[7:]
}
// GetTokenStringFromQueryString returns the token string from the request query string
func (c *WebContext) GetTokenStringFromQueryString() string {
return c.Query(tokenQueryStringParam)
}
// GetTokenStringFromCookie returns the token string from the request cookie
func (c *WebContext) GetTokenStringFromCookie() string {
tokenCookie, err := c.Cookie(tokenCookieParam)
if err != nil {
return ""
}
return tokenCookie
}
func (c *WebContext) SetTokenStringToCookie(token string, tokenExpiredTime int, path string) {
if token != "" {
c.SetCookie(tokenCookieParam, token, tokenExpiredTime, path, "", false, true)
} else {
c.SetCookie(tokenCookieParam, "", -1, path, "", false, true)
}
}
// GetClientLocale returns the client locale name // GetClientLocale returns the client locale name
func (c *WebContext) GetClientLocale() string { func (c *WebContext) GetClientLocale() string {
value := c.GetHeader(AcceptLanguageHeaderName) value := c.GetHeader(AcceptLanguageHeaderName)
@@ -5,15 +5,8 @@ import (
"github.com/mayswind/ezbookkeeping/pkg/settings" "github.com/mayswind/ezbookkeeping/pkg/settings"
) )
const tokenCookieParam = "ebk_auth_token"
// AmapApiProxyAuthCookie adds amap api proxy auth cookie to cookies in response // AmapApiProxyAuthCookie adds amap api proxy auth cookie to cookies in response
func AmapApiProxyAuthCookie(c *core.WebContext, config *settings.Config) { func AmapApiProxyAuthCookie(c *core.WebContext, config *settings.Config) {
token := c.GetTextualToken() token := c.GetTextualToken()
c.SetTokenStringToCookie(token, int(config.TokenExpiredTime), "/_AMapService")
if token != "" {
c.SetCookie(tokenCookieParam, token, int(config.TokenExpiredTime), "/_AMapService", "", false, true)
} else {
c.SetCookie(tokenCookieParam, "", -1, "/_AMapService", "", false, true)
}
} }
+11 -5
View File
@@ -20,8 +20,6 @@ const (
TOKEN_SOURCE_TYPE_COOKIE TokenSourceType = 3 TOKEN_SOURCE_TYPE_COOKIE TokenSourceType = 3
) )
const tokenQueryStringParam = "token"
// JWTAuthorization verifies whether current request is valid by jwt token in header // JWTAuthorization verifies whether current request is valid by jwt token in header
func JWTAuthorization(c *core.WebContext) { func JWTAuthorization(c *core.WebContext) {
jwtAuthorization(c, TOKEN_SOURCE_TYPE_HEADER) jwtAuthorization(c, TOKEN_SOURCE_TYPE_HEADER)
@@ -159,11 +157,19 @@ func getTokenClaims(c *core.WebContext, source TokenSourceType) (*core.UserToken
} }
func parseToken(c *core.WebContext, source TokenSourceType) (*jwt.Token, *core.UserTokenClaims, error) { func parseToken(c *core.WebContext, source TokenSourceType) (*jwt.Token, *core.UserTokenClaims, error) {
tokenString := ""
if source == TOKEN_SOURCE_TYPE_ARGUMENT { if source == TOKEN_SOURCE_TYPE_ARGUMENT {
return services.Tokens.ParseTokenByArgument(c, tokenQueryStringParam) tokenString = c.GetTokenStringFromQueryString()
} else if source == TOKEN_SOURCE_TYPE_COOKIE { } else if source == TOKEN_SOURCE_TYPE_COOKIE {
return services.Tokens.ParseTokenByCookie(c, tokenCookieParam) tokenString = c.GetTokenStringFromCookie()
} else { // if source == TOKEN_SOURCE_TYPE_HEADER
tokenString = c.GetTokenStringFromHeader()
} }
return services.Tokens.ParseTokenByHeader(c) if tokenString == "" {
return nil, nil, errs.ErrTokenIsEmpty
}
return services.Tokens.ParseToken(c, tokenString)
} }
+16 -26
View File
@@ -1,6 +1,7 @@
package services package services
import ( import (
"errors"
"fmt" "fmt"
"math" "math"
"strings" "strings"
@@ -71,19 +72,9 @@ func (s *TokenService) GetAllUnexpiredNormalAndMCPTokensByUid(c core.Context, ui
return tokenRecords, err return tokenRecords, err
} }
// ParseTokenByHeader returns the token model according to request data // ParseToken returns the token model according to token content
func (s *TokenService) ParseTokenByHeader(c *core.WebContext) (*jwt.Token, *core.UserTokenClaims, error) { func (s *TokenService) ParseToken(c core.Context, token string) (*jwt.Token, *core.UserTokenClaims, error) {
return s.parseToken(c, request.BearerExtractor{}) return s.parseToken(c, token)
}
// ParseTokenByArgument returns the token model according to request data
func (s *TokenService) ParseTokenByArgument(c *core.WebContext, tokenParameterName string) (*jwt.Token, *core.UserTokenClaims, error) {
return s.parseToken(c, request.ArgumentExtractor{tokenParameterName})
}
// ParseTokenByCookie returns the token model according to request data
func (s *TokenService) ParseTokenByCookie(c *core.WebContext, tokenCookieName string) (*jwt.Token, *core.UserTokenClaims, error) {
return s.parseToken(c, utils.CookieExtractor{tokenCookieName})
} }
// CreateTokenViaCli generates a new normal token and saves to database // CreateTokenViaCli generates a new normal token and saves to database
@@ -328,53 +319,52 @@ func (s *TokenService) GenerateTokenId(tokenRecord *models.TokenRecord) string {
return fmt.Sprintf("%d:%d:%d", tokenRecord.Uid, tokenRecord.CreatedUnixTime, tokenRecord.UserTokenId) return fmt.Sprintf("%d:%d:%d", tokenRecord.Uid, tokenRecord.CreatedUnixTime, tokenRecord.UserTokenId)
} }
func (s *TokenService) parseToken(c *core.WebContext, extractor request.Extractor) (*jwt.Token, *core.UserTokenClaims, error) { func (s *TokenService) parseToken(c core.Context, tokenString string) (*jwt.Token, *core.UserTokenClaims, error) {
claims := &core.UserTokenClaims{} claims := &core.UserTokenClaims{}
token, err := request.ParseFromRequest(c.Request, extractor, token, err := jwt.ParseWithClaims(tokenString, claims,
func(token *jwt.Token) (any, error) { func(token *jwt.Token) (any, error) {
now := time.Now().Unix() now := time.Now().Unix()
userTokenId, err := utils.StringToInt64(claims.UserTokenId) userTokenId, err := utils.StringToInt64(claims.UserTokenId)
if err != nil { if err != nil {
log.Warnf(c, "[tokens.ParseToken] token \"utid:%s\" in token of user \"uid:%d\" is invalid, because %s", claims.UserTokenId, claims.Uid, err.Error()) log.Warnf(c, "[tokens.parseToken] token \"utid:%s\" in token of user \"uid:%d\" is invalid, because %s", claims.UserTokenId, claims.Uid, err.Error())
return nil, errs.ErrInvalidUserTokenId return nil, errs.ErrInvalidUserTokenId
} }
tokenRecord, err := s.getTokenRecord(c, claims.Uid, userTokenId, claims.IssuedAt) tokenRecord, err := s.getTokenRecord(c, claims.Uid, userTokenId, claims.IssuedAt)
if err != nil { if err != nil {
log.Warnf(c, "[tokens.ParseToken] token \"utid:%s\" of user \"uid:%d\" record not found, because %s", claims.UserTokenId, claims.Uid, err.Error()) log.Warnf(c, "[tokens.parseToken] token \"utid:%s\" of user \"uid:%d\" record not found, because %s", claims.UserTokenId, claims.Uid, err.Error())
return nil, errs.ErrTokenRecordNotFound return nil, errs.ErrTokenRecordNotFound
} }
if tokenRecord.ExpiredUnixTime < now { if tokenRecord.ExpiredUnixTime < now {
log.Warnf(c, "[tokens.ParseToken] token \"utid:%s\" of user \"uid:%d\" record is expired", claims.UserTokenId, claims.Uid) log.Warnf(c, "[tokens.parseToken] token \"utid:%s\" of user \"uid:%d\" record is expired", claims.UserTokenId, claims.Uid)
return nil, errs.ErrTokenExpired return nil, errs.ErrTokenExpired
} }
return []byte(tokenRecord.Secret), nil return []byte(tokenRecord.Secret), nil
}, },
request.WithClaims(claims), jwt.WithIssuedAt(),
request.WithParser(jwt.NewParser(jwt.WithIssuedAt())),
) )
if err != nil { if err != nil {
if err == request.ErrNoTokenInRequest { if errors.Is(err, request.ErrNoTokenInRequest) {
return nil, nil, errs.ErrTokenIsEmpty return nil, nil, errs.ErrTokenIsEmpty
} }
if err == jwt.ErrTokenMalformed || err == jwt.ErrTokenUnverifiable || err == jwt.ErrTokenSignatureInvalid { if errors.Is(err, jwt.ErrTokenMalformed) || errors.Is(err, jwt.ErrTokenUnverifiable) || errors.Is(err, jwt.ErrTokenSignatureInvalid) {
log.Warnf(c, "[tokens.ParseToken] token is invalid, because %s", err.Error()) log.Warnf(c, "[tokens.parseToken] token is invalid, because %s", err.Error())
return nil, nil, errs.ErrCurrentInvalidToken return nil, nil, errs.ErrCurrentInvalidToken
} }
if err == jwt.ErrTokenExpired { if errors.Is(err, jwt.ErrTokenExpired) {
return nil, nil, errs.ErrCurrentTokenExpired return nil, nil, errs.ErrCurrentTokenExpired
} }
if err == jwt.ErrTokenUsedBeforeIssued { if errors.Is(err, jwt.ErrTokenUsedBeforeIssued) {
log.Warnf(c, "[tokens.ParseToken] token is invalid, because issue time is later than now") log.Warnf(c, "[tokens.parseToken] token is invalid, because issue time is later than now")
return nil, nil, errs.ErrCurrentInvalidToken return nil, nil, errs.ErrCurrentInvalidToken
} }
-20
View File
@@ -1,20 +0,0 @@
package utils
import (
"net/http"
"github.com/golang-jwt/jwt/v5/request"
)
// CookieExtractor extracts a token from request cookies
type CookieExtractor []string
func (e CookieExtractor) ExtractToken(req *http.Request) (string, error) {
for _, arg := range e {
if cookie, _ := req.Cookie(arg); cookie != nil {
return cookie.Value, nil
}
}
return "", request.ErrNoTokenInRequest
}