support Nextcloud OAuth 2.0 authentication

This commit is contained in:
MaysWind
2025-10-21 01:52:28 +08:00
parent 600ae2bd58
commit 53a8ad71c6
74 changed files with 2046 additions and 241 deletions
+142 -3
View File
@@ -22,6 +22,7 @@ type AuthorizationsApi struct {
userAppCloudSettings *services.UserApplicationCloudSettingsService
tokens *services.TokenService
twoFactorAuthorizations *services.TwoFactorAuthorizationService
userExternalAuths *services.UserExternalAuthService
}
// Initialize a authorization api singleton instance
@@ -48,11 +49,16 @@ var (
userAppCloudSettings: services.UserApplicationCloudSettings,
tokens: services.Tokens,
twoFactorAuthorizations: services.TwoFactorAuthorizations,
userExternalAuths: services.UserExternalAuths,
}
)
// AuthorizeHandler verifies and authorizes current login request
func (a *AuthorizationsApi) AuthorizeHandler(c *core.WebContext) (any, *errs.Error) {
if !a.CurrentConfig().EnableInternalAuth {
return nil, errs.ErrCannotLoginByPassword
}
var credential models.UserLoginRequest
err := c.ShouldBindJSON(&credential)
@@ -151,7 +157,7 @@ func (a *AuthorizationsApi) AuthorizeHandler(c *core.WebContext) (any, *errs.Err
applicationCloudSettingSlice = &userApplicationCloudSettings.Settings
}
log.Infof(c, "[authorizations.AuthorizeHandler] user \"uid:%d\" has logined, token type is %d, token will be expired at %d", user.Uid, claims.Type, claims.ExpiresAt)
log.Infof(c, "[authorizations.AuthorizeHandler] user \"uid:%d\" has logged in, token type is %d, token will be expired at %d", user.Uid, claims.Type, claims.ExpiresAt)
authResp := a.getAuthResponse(c, token, twoFactorEnable, user, applicationCloudSettingSlice)
return authResp, nil
@@ -159,6 +165,10 @@ func (a *AuthorizationsApi) AuthorizeHandler(c *core.WebContext) (any, *errs.Err
// TwoFactorAuthorizeHandler verifies and authorizes current 2fa login by passcode
func (a *AuthorizationsApi) TwoFactorAuthorizeHandler(c *core.WebContext) (any, *errs.Error) {
if !a.CurrentConfig().EnableInternalAuth {
return nil, errs.ErrCannotLoginByPassword
}
var credential models.TwoFactorLoginRequest
err := c.ShouldBindJSON(&credential)
@@ -198,7 +208,7 @@ func (a *AuthorizationsApi) TwoFactorAuthorizeHandler(c *core.WebContext) (any,
user, err := a.users.GetUserById(c, uid)
if err != nil {
log.Errorf(c, "[authorizations.TwoFactorAuthorizeHandler] failed to get user \"uid:%d\" info, because %s", user.Uid, err.Error())
log.Errorf(c, "[authorizations.TwoFactorAuthorizeHandler] failed to get user \"uid:%d\" info, because %s", uid, err.Error())
return nil, errs.ErrUserNotFound
}
@@ -246,6 +256,10 @@ func (a *AuthorizationsApi) TwoFactorAuthorizeHandler(c *core.WebContext) (any,
// TwoFactorAuthorizeByRecoveryCodeHandler verifies and authorizes current 2fa login by recovery code
func (a *AuthorizationsApi) TwoFactorAuthorizeByRecoveryCodeHandler(c *core.WebContext) (any, *errs.Error) {
if !a.CurrentConfig().EnableInternalAuth {
return nil, errs.ErrCannotLoginByPassword
}
var credential models.TwoFactorRecoveryCodeLoginRequest
err := c.ShouldBindJSON(&credential)
@@ -276,7 +290,7 @@ func (a *AuthorizationsApi) TwoFactorAuthorizeByRecoveryCodeHandler(c *core.WebC
user, err := a.users.GetUserById(c, uid)
if err != nil {
log.Errorf(c, "[authorizations.TwoFactorAuthorizeByRecoveryCodeHandler] failed to get user \"uid:%d\" info, because %s", user.Uid, err.Error())
log.Errorf(c, "[authorizations.TwoFactorAuthorizeByRecoveryCodeHandler] failed to get user \"uid:%d\" info, because %s", uid, err.Error())
return nil, errs.ErrUserNotFound
}
@@ -338,6 +352,131 @@ func (a *AuthorizationsApi) TwoFactorAuthorizeByRecoveryCodeHandler(c *core.WebC
return authResp, nil
}
// OAuth2CallbackAuthorizeHandler verifies and authorizes current OAuth 2.0 callback login
func (a *AuthorizationsApi) OAuth2CallbackAuthorizeHandler(c *core.WebContext) (any, *errs.Error) {
if !a.CurrentConfig().EnableOAuth2Login {
return nil, errs.ErrOAuth2NotEnabled
}
var credential models.OAuth2CallbackLoginRequest
err := c.ShouldBindJSON(&credential)
if err != nil {
log.Warnf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] parse request failed, because %s", err.Error())
return nil, errs.NewIncompleteOrIncorrectSubmissionError(err)
}
userExternalAuthType := core.UserExternalAuthType(credential.Provider)
if !userExternalAuthType.IsValid() {
log.Warnf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] provider \"%s\" is invalid", credential.Provider)
return nil, errs.ErrInvalidOAuth2Provider
}
uid := c.GetCurrentUid()
err = a.CheckFailureCount(c, uid)
if err != nil {
log.Warnf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] cannot auth for user \"uid:%d\", because %s", uid, err.Error())
return nil, errs.Or(err, errs.ErrFailureCountLimitReached)
}
user, err := a.users.GetUserById(c, uid)
if err != nil {
log.Errorf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] failed to get user \"uid:%d\" info, because %s", uid, err.Error())
return nil, errs.ErrUserNotFound
}
if user.Disabled {
log.Warnf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] user \"uid:%d\" is disabled", user.Uid)
return nil, errs.ErrUserIsDisabled
}
if a.CurrentConfig().EnableUserForceVerifyEmail && !user.EmailVerified {
log.Warnf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] user \"uid:%d\" has not verified email", user.Uid)
return nil, errs.ErrEmailIsNotVerified
}
oldTokenClaims := c.GetTokenClaims()
if oldTokenClaims.Type == core.USER_TOKEN_TYPE_OAUTH2_CALLBACK_REQUIRE_VERIFY {
if credential.Password == "" {
return nil, errs.ErrPasswordIsEmpty
}
if !a.users.IsPasswordEqualsUserPassword(credential.Password, user) {
failureCheckErr := a.CheckAndIncreaseFailureCount(c, uid)
if failureCheckErr != nil {
log.Warnf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] cannot login for user \"uid:%d\", because %s", user.Uid, failureCheckErr.Error())
return nil, errs.Or(failureCheckErr, errs.ErrFailureCountLimitReached)
}
return nil, errs.ErrUserPasswordWrong
}
userExternalAuth := &models.UserExternalAuth{
Uid: user.Uid,
ExternalAuthType: userExternalAuthType,
}
if a.CurrentConfig().OAuth2UserIdentifier == settings.OAuth2UserIdentifierEmail {
userExternalAuth.ExternalEmail = user.Email
} else if a.CurrentConfig().OAuth2UserIdentifier == settings.OAuth2UserIdentifierUsername {
userExternalAuth.ExternalUsername = user.Username
}
err = a.userExternalAuths.CreateUserExternalAuth(c, userExternalAuth)
if err != nil {
log.Errorf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] failed to create user external auth for user \"uid:%d\", because %s", user.Uid, err.Error())
return nil, errs.Or(err, errs.ErrOperationFailed)
}
log.Infof(c, "[authorizations.OAuth2CallbackAuthorizeHandler] user external auth has been created for user \"uid:%d\"", user.Uid)
} else if oldTokenClaims.Type == core.USER_TOKEN_TYPE_OAUTH2_CALLBACK {
_, err = a.userExternalAuths.GetUserExternalAuthByUid(c, uid, userExternalAuthType)
if err != nil {
log.Errorf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] failed to get user external auth for user \"uid:%d\", because %s", uid, err.Error())
return nil, errs.Or(err, errs.ErrUserExternalAuthNotFound)
}
} else {
return nil, errs.ErrSystemError
}
err = a.tokens.DeleteTokenByClaims(c, oldTokenClaims)
if err != nil {
log.Warnf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] failed to revoke temporary token \"utid:%s\" for user \"uid:%d\", because %s", oldTokenClaims.UserTokenId, user.Uid, err.Error())
}
token, claims, err := a.tokens.CreateToken(c, user)
if err != nil {
log.Errorf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] failed to create token for user \"uid:%d\", because %s", user.Uid, err.Error())
return nil, errs.ErrTokenGenerating
}
c.SetTextualToken(token)
c.SetTokenClaims(claims)
userApplicationCloudSettings, err := a.userAppCloudSettings.GetUserApplicationCloudSettingsByUid(c, user.Uid)
var applicationCloudSettingSlice *models.ApplicationCloudSettingSlice = nil
if err != nil {
log.Warnf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] failed to get latest user application cloud settings for user \"uid:%d\", because %s", user.Uid, err.Error())
} else if userApplicationCloudSettings != nil && len(userApplicationCloudSettings.Settings) > 0 {
applicationCloudSettingSlice = &userApplicationCloudSettings.Settings
}
log.Infof(c, "[authorizations.OAuth2CallbackAuthorizeHandler] user \"uid:%d\" has logged in, token will be expired at %d", user.Uid, claims.ExpiresAt)
authResp := a.getAuthResponse(c, token, false, user, applicationCloudSettingSlice)
return authResp, nil
}
func (a *AuthorizationsApi) getAuthResponse(c *core.WebContext, token string, need2FA bool, user *models.User, applicationCloudSettings *models.ApplicationCloudSettingSlice) *models.AuthResponse {
return &models.AuthResponse{
Token: token,
+8
View File
@@ -3,6 +3,7 @@ package api
import (
"fmt"
"sort"
"time"
"github.com/mayswind/ezbookkeeping/pkg/avatars"
"github.com/mayswind/ezbookkeeping/pkg/core"
@@ -120,6 +121,13 @@ func (a *ApiUsingDuplicateChecker) SetSubmissionRemarkIfEnable(checkerType dupli
}
}
// SetSubmissionRemarkWithCustomExpirationIfEnable saves the identification and remark by the current duplicate checker with custom expiration time if the duplicate submission check is enabled
func (a *ApiUsingDuplicateChecker) SetSubmissionRemarkWithCustomExpirationIfEnable(checkerType duplicatechecker.DuplicateCheckerType, uid int64, identification string, remark string, expiration time.Duration) {
if a.CurrentConfig().EnableDuplicateSubmissionsCheck {
a.container.SetSubmissionRemarkWithCustomExpiration(checkerType, uid, identification, remark, expiration)
}
}
// RemoveSubmissionRemarkIfEnable removes the identification and remark by the current duplicate checker if the duplicate submission check is enabled
func (a *ApiUsingDuplicateChecker) RemoveSubmissionRemarkIfEnable(checkerType duplicatechecker.DuplicateCheckerType, uid int64, identification string) {
if a.CurrentConfig().EnableDuplicateSubmissionsCheck {
+313
View File
@@ -0,0 +1,313 @@
package api
import (
"errors"
"fmt"
"net/url"
"strings"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2"
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/duplicatechecker"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/locales"
"github.com/mayswind/ezbookkeeping/pkg/log"
"github.com/mayswind/ezbookkeeping/pkg/models"
"github.com/mayswind/ezbookkeeping/pkg/services"
"github.com/mayswind/ezbookkeeping/pkg/settings"
"github.com/mayswind/ezbookkeeping/pkg/utils"
"github.com/mayswind/ezbookkeeping/pkg/validators"
)
const oauth2CallbackPageUrlSuccessFormat = "%sdesktop/#/oauth2_callback?platform=%s&provider=%s&token=%s"
const oauth2CallbackPageUrlNeedVerifyFormat = "%sdesktop/#/oauth2_callback?platform=%s&provider=%s&userName=%s&token=%s"
const oauth2CallbackPageUrlFailedFormat = "%sdesktop/#/oauth2_callback?error=%s"
// OAuth2AuthenticationApi represents OAuth 2.0 authorization api
type OAuth2AuthenticationApi struct {
ApiUsingConfig
ApiUsingDuplicateChecker
users *services.UserService
tokens *services.TokenService
userExternalAuths *services.UserExternalAuthService
}
// Initialize a OAuth 2.0 authentication api singleton instance
var (
OAuth2Authentications = &OAuth2AuthenticationApi{
ApiUsingConfig: ApiUsingConfig{
container: settings.Container,
},
ApiUsingDuplicateChecker: ApiUsingDuplicateChecker{
ApiUsingConfig: ApiUsingConfig{
container: settings.Container,
},
container: duplicatechecker.Container,
},
users: services.Users,
tokens: services.Tokens,
userExternalAuths: services.UserExternalAuths,
}
)
// LoginHandler handles user login request via OAuth 2.0
func (a *OAuth2AuthenticationApi) LoginHandler(c *core.WebContext) (string, *errs.Error) {
var oauth2LoginReq models.OAuth2LoginRequest
err := c.ShouldBindQuery(&oauth2LoginReq)
if err != nil {
log.Warnf(c, "[oauth2_authentications.LoginHandler] parse request failed, because %s", err.Error())
return "", errs.NewIncompleteOrIncorrectSubmissionError(err)
}
if oauth2LoginReq.Platform != "mobile" && oauth2LoginReq.Platform != "desktop" {
return "", errs.ErrInvalidOAuth2LoginRequest
}
state := fmt.Sprintf("%s|%s", oauth2LoginReq.Platform, oauth2LoginReq.ClientSessionId)
remark := ""
if a.CurrentConfig().EnableDuplicateSubmissionsCheck {
found := false
found, remark = a.GetSubmissionRemark(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, oauth2LoginReq.ClientSessionId)
if found {
log.Errorf(c, "[oauth2_authentications.LoginHandler] another oauth 2.0 state \"%s\" has been processing for client session id \"%s\"", remark, oauth2LoginReq.ClientSessionId)
return "", errs.ErrRepeatedRequest
}
randomString, err := utils.GetRandomNumberOrLowercaseLetter(32)
if err != nil {
log.Errorf(c, "[oauth2_authentications.LoginHandler] failed to generate random string for oauth 2.0 state, because %s", err.Error())
return "", errs.ErrSystemError
}
remark = fmt.Sprintf("%s|%s|%s", oauth2LoginReq.Platform, oauth2LoginReq.ClientSessionId, randomString)
state = fmt.Sprintf("%s|%s|%s", oauth2LoginReq.Platform, oauth2LoginReq.ClientSessionId, utils.MD5EncodeToString([]byte(remark)))
}
redirectUrl, err := oauth2.GetOAuth2AuthUrl(c, state)
if err != nil {
log.Errorf(c, "[oauth2_authentications.LoginHandler] failed to get oauth 2.0 auth url, because %s", err.Error())
return "", errs.Or(err, errs.ErrSystemError)
}
a.SetSubmissionRemarkWithCustomExpirationIfEnable(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, oauth2LoginReq.ClientSessionId, remark, a.CurrentConfig().OAuth2StateExpiredTimeDuration)
return redirectUrl, nil
}
// CallbackHandler handles OAuth 2.0 callback request
func (a *OAuth2AuthenticationApi) CallbackHandler(c *core.WebContext) (string, *errs.Error) {
var oauth2CallbackReq models.OAuth2CallbackRequest
err := c.ShouldBindQuery(&oauth2CallbackReq)
if err != nil {
log.Warnf(c, "[oauth2_authentications.CallbackHandler] parse request failed, because %s", err.Error())
return a.redirectToFailedCallbackPage(c, errs.NewIncompleteOrIncorrectSubmissionError(err))
}
if oauth2CallbackReq.State == "" {
return a.redirectToFailedCallbackPage(c, errs.ErrMissingOAuth2State)
}
if oauth2CallbackReq.Code == "" {
return a.redirectToFailedCallbackPage(c, errs.ErrMissingOAuth2Code)
}
platform := ""
clientSessionId := ""
stateParts := strings.Split(oauth2CallbackReq.State, "|")
if len(stateParts) >= 2 {
platform = stateParts[0]
clientSessionId = stateParts[1]
} else {
return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2State)
}
if platform != "mobile" && platform != "desktop" {
return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2LoginRequest)
}
if a.CurrentConfig().EnableDuplicateSubmissionsCheck {
found, remark := a.GetSubmissionRemark(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, clientSessionId)
if !found {
log.Errorf(c, "[oauth2_authentications.CallbackHandler] cannot find oauth 2.0 state in duplicate checker for client session id \"%s\"", clientSessionId)
return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2Callback)
}
remarkParts := strings.Split(remark, "|")
if len(remarkParts) != 3 || remarkParts[0] != platform || remarkParts[1] != clientSessionId {
log.Errorf(c, "[oauth2_authentications.CallbackHandler] invalid oauth 2.0 state \"%s\" in duplicate checker for client session id \"%s\"", remark, clientSessionId)
return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2State)
}
expectedState := fmt.Sprintf("%s|%s|%s", platform, clientSessionId, remarkParts[2])
expectedState = fmt.Sprintf("%s|%s|%s", platform, clientSessionId, utils.MD5EncodeToString([]byte(expectedState)))
if oauth2CallbackReq.State != expectedState {
log.Errorf(c, "[oauth2_authentications.CallbackHandler] mismatched random string in oauth 2.0 state, expected \"%s\", got \"%s\"", expectedState, oauth2CallbackReq.State)
return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2State)
}
a.RemoveSubmissionRemarkIfEnable(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, clientSessionId)
}
oauth2Token, err := oauth2.GetOAuth2Token(c, oauth2CallbackReq.Code)
if err != nil {
log.Errorf(c, "[oauth2_authentications.CallbackHandler] failed to retrieve oauth 2.0 token, because %s", err.Error())
return a.redirectToFailedCallbackPage(c, errs.Or(err, errs.ErrCannotRetrieveOAuth2Token))
}
oauth2UserInfo, err := oauth2.GetOAuth2UserInfo(c, oauth2Token)
if err != nil {
log.Errorf(c, "[oauth2_authentications.CallbackHandler] failed to retrieve oauth 2.0 user info, because %s", err.Error())
return a.redirectToFailedCallbackPage(c, errs.Or(err, errs.ErrInvalidOAuth2Token))
}
if oauth2UserInfo == nil {
log.Errorf(c, "[oauth2_authentications.CallbackHandler] failed to retrieve oauth 2.0 user info, because user info is nil")
return a.redirectToFailedCallbackPage(c, errs.ErrCannotRetrieveUserInfo)
}
if oauth2UserInfo.UserName == "" || oauth2UserInfo.Email == "" {
log.Errorf(c, "[oauth2_authentications.CallbackHandler] invalid oauth 2.0 user info, userName: %s, email: %s", oauth2UserInfo.UserName, oauth2UserInfo.Email)
return a.redirectToFailedCallbackPage(c, errs.ErrCannotRetrieveUserInfo)
}
userExternalAuthType := oauth2.GetExternalUserAuthType()
var userExternalAuth *models.UserExternalAuth
if a.CurrentConfig().OAuth2UserIdentifier == settings.OAuth2UserIdentifierEmail {
userExternalAuth, err = a.userExternalAuths.GetUserExternalAuthByExternalEmail(c, oauth2UserInfo.Email, userExternalAuthType)
} else if a.CurrentConfig().OAuth2UserIdentifier == settings.OAuth2UserIdentifierUsername {
userExternalAuth, err = a.userExternalAuths.GetUserExternalAuthByExternalUserName(c, oauth2UserInfo.UserName, userExternalAuthType)
} else {
userExternalAuth, err = a.userExternalAuths.GetUserExternalAuthByExternalEmail(c, oauth2UserInfo.Email, userExternalAuthType)
}
if err != nil && !errors.Is(err, errs.ErrUserExternalAuthNotFound) {
log.Errorf(c, "[oauth2_authentications.CallbackHandler] failed to get user external auth, because %s", err.Error())
return a.redirectToFailedCallbackPage(c, errs.Or(err, errs.ErrOperationFailed))
}
var user *models.User
if err == nil { // user already bound to external auth, redirect to success page
user, err = a.users.GetUserById(c, userExternalAuth.Uid)
if err != nil {
log.Errorf(c, "[oauth2_authentications.CallbackHandler] failed to get user by id %d, because %s", userExternalAuth.Uid, err.Error())
return a.redirectToFailedCallbackPage(c, errs.Or(err, errs.ErrOperationFailed))
}
} else if errors.Is(err, errs.ErrUserExternalAuthNotFound) { // user not bound to external auth, try to bind or register new user
if a.CurrentConfig().OAuth2UserIdentifier == settings.OAuth2UserIdentifierEmail {
user, err = a.users.GetUserByEmail(c, oauth2UserInfo.Email)
} else if a.CurrentConfig().OAuth2UserIdentifier == settings.OAuth2UserIdentifierUsername {
user, err = a.users.GetUserByUsername(c, oauth2UserInfo.UserName)
} else {
user, err = a.users.GetUserByEmail(c, oauth2UserInfo.Email)
}
if err != nil && !errors.Is(err, errs.ErrUserNotFound) {
log.Errorf(c, "[oauth2_authentications.CallbackHandler] failed to get user, because %s", err.Error())
return a.redirectToFailedCallbackPage(c, errs.Or(err, errs.ErrOperationFailed))
}
if user == nil && a.CurrentConfig().EnableUserRegister && a.CurrentConfig().OAuth2AutoRegister {
userName := strings.TrimSpace(oauth2UserInfo.UserName)
email := strings.TrimSpace(oauth2UserInfo.Email)
nickName := strings.TrimSpace(oauth2UserInfo.NickName)
languageCode := ""
currencyCode := "USD"
if _, exists := locales.AllLanguages[oauth2UserInfo.LanguageCode]; exists {
languageCode = oauth2UserInfo.LanguageCode
}
if _, exists := validators.AllCurrencyNames[oauth2UserInfo.CurrencyCode]; exists {
currencyCode = oauth2UserInfo.CurrencyCode
}
user = &models.User{
Username: userName,
Email: email,
Nickname: nickName,
Password: "",
Language: languageCode,
DefaultCurrency: currencyCode,
FirstDayOfWeek: oauth2UserInfo.FirstDayOfWeek,
FiscalYearStart: core.FISCAL_YEAR_START_DEFAULT,
TransactionEditScope: models.TRANSACTION_EDIT_SCOPE_ALL,
FeatureRestriction: a.CurrentConfig().DefaultFeatureRestrictions,
}
err = a.users.CreateUser(c, user)
if err != nil {
log.Errorf(c, "[oauth2_authentications.CallbackHandler] failed to create user \"%s\", because %s", user.Username, err.Error())
return a.redirectToFailedCallbackPage(c, errs.Or(err, errs.ErrOperationFailed))
}
log.Infof(c, "[oauth2_authentications.CallbackHandler] user \"%s\" has registered successfully, uid is %d", user.Username, user.Uid)
userExternalAuth := &models.UserExternalAuth{
Uid: user.Uid,
ExternalAuthType: userExternalAuthType,
ExternalUsername: oauth2UserInfo.UserName,
ExternalEmail: oauth2UserInfo.Email,
}
err = a.userExternalAuths.CreateUserExternalAuth(c, userExternalAuth)
if err != nil {
log.Errorf(c, "[oauth2_authentications.CallbackHandler] failed to create user external auth for user \"uid:%d\", because %s", user.Uid, err.Error())
return a.redirectToFailedCallbackPage(c, errs.Or(err, errs.ErrOperationFailed))
}
log.Infof(c, "[oauth2_authentications.CallbackHandler] user external auth has been created for user \"uid:%d\"", user.Uid)
} else if user == nil {
return a.redirectToFailedCallbackPage(c, errs.ErrOAuth2AutoRegistrationNotEnabled)
}
}
if userExternalAuth == nil {
token, _, err := a.tokens.CreateOAuth2CallbackRequireVerifyToken(c, user)
if err != nil {
log.Errorf(c, "[oauth2_authentications.CallbackHandler] failed to create oauth 2.0 callback verify token for user \"uid:%d\", because %s", user.Uid, err.Error())
return a.redirectToFailedCallbackPage(c, errs.ErrTokenGenerating)
}
return a.redirectToVerifyCallbackPage(c, platform, userExternalAuthType, user.Username, token)
} else {
token, _, err := a.tokens.CreateOAuth2CallbackToken(c, user)
if err != nil {
log.Errorf(c, "[oauth2_authentications.CallbackHandler] failed to create oauth 2.0 callback token for user \"uid:%d\", because %s", user.Uid, err.Error())
return a.redirectToFailedCallbackPage(c, errs.ErrTokenGenerating)
}
return a.redirectToSuccessCallbackPage(c, platform, userExternalAuthType, token)
}
}
func (a *OAuth2AuthenticationApi) redirectToSuccessCallbackPage(c *core.WebContext, platform string, externalAuthType core.UserExternalAuthType, token string) (string, *errs.Error) {
return fmt.Sprintf(oauth2CallbackPageUrlSuccessFormat, a.CurrentConfig().RootUrl, platform, externalAuthType, url.QueryEscape(token)), nil
}
func (a *OAuth2AuthenticationApi) redirectToVerifyCallbackPage(c *core.WebContext, platform string, externalAuthType core.UserExternalAuthType, userName string, token string) (string, *errs.Error) {
return fmt.Sprintf(oauth2CallbackPageUrlNeedVerifyFormat, a.CurrentConfig().RootUrl, platform, externalAuthType, userName, url.QueryEscape(token)), nil
}
func (a *OAuth2AuthenticationApi) redirectToFailedCallbackPage(c *core.WebContext, err *errs.Error) (string, *errs.Error) {
return fmt.Sprintf(oauth2CallbackPageUrlFailedFormat, a.CurrentConfig().RootUrl, url.QueryEscape(utils.GetDisplayErrorMessage(err))), nil
}
+7 -3
View File
@@ -35,14 +35,18 @@ func (a *ServerSettingsApi) ServerSettingsJavascriptHandler(c *core.WebContext)
builder := &strings.Builder{}
builder.WriteString(ezbookkeepingServerSettingsJavascriptFileHeader)
a.appendBooleanSetting(builder, "r", config.EnableUserRegister)
a.appendBooleanSetting(builder, "f", config.EnableUserForgetPassword)
a.appendBooleanSetting(builder, "a", config.EnableInternalAuth)
a.appendBooleanSetting(builder, "o", config.EnableOAuth2Login)
a.appendBooleanSetting(builder, "r", config.EnableInternalAuth && config.EnableUserRegister)
a.appendBooleanSetting(builder, "f", config.EnableInternalAuth && config.EnableUserForgetPassword)
a.appendBooleanSetting(builder, "v", config.EnableUserVerifyEmail)
a.appendBooleanSetting(builder, "p", config.EnableTransactionPictures)
a.appendBooleanSetting(builder, "s", config.EnableScheduledTransaction)
a.appendBooleanSetting(builder, "e", config.EnableDataExport)
a.appendBooleanSetting(builder, "i", config.EnableDataImport)
a.appendStringSetting(builder, "op", config.OAuth2Provider)
if config.EnableMCPServer {
a.appendBooleanSetting(builder, "mcp", config.EnableMCPServer)
}
@@ -138,7 +142,7 @@ func (a *ServerSettingsApi) appendStringSetting(builder *strings.Builder, key st
builder.WriteString(";\n")
}
func (a *ServerSettingsApi) appendMultiLanguageTipSetting(builder *strings.Builder, key string, value settings.TipConfig) {
func (a *ServerSettingsApi) appendMultiLanguageTipSetting(builder *strings.Builder, key string, value settings.MultiLanguageContentConfig) {
builder.WriteString(ezbookkeepingServerSettingsGlobalVariableFullName)
builder.WriteString("[")
a.appendEncodedString(builder, key)
@@ -0,0 +1,110 @@
package oauth2
import (
"encoding/json"
"io"
"net/http"
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/log"
)
type nextcloudUserInfoResponse struct {
OCS *struct {
Meta *struct {
Status string `json:"status"`
StatusCode int `json:"statuscode"`
} `json:"meta"`
Data *struct {
ID string `json:"id"`
Email string `json:"email"`
DisplayName string `json:"display-name"`
} `json:"data"`
} `json:"ocs"`
}
// NextcloudOAuth2Provider represents Nextcloud OAuth 2.0 provider
type NextcloudOAuth2Provider struct {
baseUrl string
}
// NewNextcloudOAuth2Provider creates a new Nextcloud OAuth 2.0 provider instance
func NewNextcloudOAuth2Provider(baseUrl string) OAuth2Provider {
if baseUrl[len(baseUrl)-1] != '/' {
baseUrl += "/"
}
return &NextcloudOAuth2Provider{
baseUrl: baseUrl,
}
}
// GetAuthUrl returns the authentication url of the Nextcloud provider
func (p *NextcloudOAuth2Provider) GetAuthUrl() string {
return p.baseUrl + "apps/oauth2/authorize"
}
// GetTokenUrl returns the token url of the Nextcloud provider
func (p *NextcloudOAuth2Provider) GetTokenUrl() string {
return p.baseUrl + "apps/oauth2/api/v1/token"
}
// GetUserInfo returns the user info by the Nextcloud provider
func (p *NextcloudOAuth2Provider) GetUserInfo(c core.Context, oauth2Client *http.Client) (*OAuth2UserInfo, error) {
url := p.baseUrl + "ocs/v2.php/cloud/user?format=json"
resp, err := oauth2Client.Get(url)
if err != nil {
log.Errorf(c, "[nextcloud_oauth2_provider.GetUserInfo] failed to get user info response, because %s", err.Error())
return nil, errs.ErrFailedToRequestRemoteApi
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
log.Debugf(c, "[nextcloud_oauth2_provider.GetUserInfo] response is %s", body)
if resp.StatusCode != 200 {
log.Errorf(c, "[nextcloud_oauth2_provider.GetUserInfo] failed to get user info response, because response code is %d", resp.StatusCode)
return nil, errs.ErrFailedToRequestRemoteApi
}
return p.parseUserInfo(c, body)
}
// GetScopes returns the scopes required by the Nextcloud provider
func (p *NextcloudOAuth2Provider) GetScopes() []string {
return []string{"profile", "email"}
}
func (p *NextcloudOAuth2Provider) parseUserInfo(c core.Context, body []byte) (*OAuth2UserInfo, error) {
userInfoResp := &nextcloudUserInfoResponse{}
err := json.Unmarshal(body, &userInfoResp)
if err != nil {
log.Warnf(c, "[nextcloud_oauth2_provider.parseUserInfo] failed to parse user info response body, because %s", err.Error())
return nil, errs.ErrCannotRetrieveUserInfo
}
if userInfoResp.OCS == nil || userInfoResp.OCS.Meta == nil || userInfoResp.OCS.Data == nil {
log.Warnf(c, "[nextcloud_oauth2_provider.parseUserInfo] invalid user info response body")
return nil, errs.ErrCannotRetrieveUserInfo
}
if userInfoResp.OCS.Meta.StatusCode != 200 {
log.Warnf(c, "[nextcloud_oauth2_provider.parseUserInfo] user info response status code is %d", userInfoResp.OCS.Meta.StatusCode)
return nil, errs.ErrCannotRetrieveUserInfo
}
if userInfoResp.OCS.Data.ID == "" {
log.Warnf(c, "[nextcloud_oauth2_provider.parseUserInfo] user info id is empty")
return nil, errs.ErrCannotRetrieveUserInfo
}
return &OAuth2UserInfo{
UserName: userInfoResp.OCS.Data.ID,
Email: userInfoResp.OCS.Data.Email,
NickName: userInfoResp.OCS.Data.DisplayName,
}, nil
}
+105
View File
@@ -0,0 +1,105 @@
package oauth2
import (
"net/http"
"golang.org/x/oauth2"
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/settings"
"github.com/mayswind/ezbookkeeping/pkg/utils"
)
// OAuth2Container contains the current OAuth 2.0 authentication provider
type OAuth2Container struct {
oauth2Config *oauth2.Config
oauth2Provider OAuth2Provider
oauth2HttpClient *http.Client
externalUserAuthType core.UserExternalAuthType
}
// Initialize a OAuth 2.0 container singleton instance
var (
Container = &OAuth2Container{}
)
// InitializeOAuth2Provider initializes the current OAuth 2.0 provider according to the config
func InitializeOAuth2Provider(config *settings.Config) error {
if !config.EnableOAuth2Login {
return nil
}
if config.OAuth2ClientID == "" || config.OAuth2ClientSecret == "" || config.OAuth2UserIdentifier == "" || config.OAuth2Provider == "" {
return errs.ErrInvalidOAuth2Config
}
var oauth2Provider OAuth2Provider
var externalUserAuthType core.UserExternalAuthType
if config.OAuth2Provider == settings.OAuth2ProviderNextcloud {
oauth2Provider = NewNextcloudOAuth2Provider(config.OAuth2NextcloudBaseUrl)
externalUserAuthType = core.USER_EXTERNAL_AUTH_TYPE_OAUTH2_NEXTCLOUD
} else {
return errs.ErrInvalidOAuth2Provider
}
Container.oauth2Config = buildOAuth2Config(config, oauth2Provider)
Container.oauth2Provider = oauth2Provider
Container.oauth2HttpClient = utils.NewHttpClient(config.OAuth2RequestTimeout, config.OAuth2Proxy, config.OAuth2SkipTLSVerify, settings.GetUserAgent())
Container.externalUserAuthType = externalUserAuthType
return nil
}
// GetOAuth2AuthUrl returns the OAuth 2.0 authentication url
func GetOAuth2AuthUrl(c core.Context, state string) (string, error) {
if Container.oauth2Config == nil {
return "", errs.ErrOAuth2NotEnabled
}
return Container.oauth2Config.AuthCodeURL(state), nil
}
// GetOAuth2Token exchanges the authorization code for an OAuth 2.0 token
func GetOAuth2Token(c core.Context, code string) (*oauth2.Token, error) {
if Container.oauth2Config == nil || Container.oauth2HttpClient == nil {
return nil, errs.ErrOAuth2NotEnabled
}
return Container.oauth2Config.Exchange(wrapOAuth2Context(c, Container.oauth2HttpClient), code)
}
// GetOAuth2UserInfo retrieves the OAuth 2.0 user info using the provided OAuth 2.0 token
func GetOAuth2UserInfo(c core.Context, token *oauth2.Token) (*OAuth2UserInfo, error) {
if Container.oauth2Config == nil || Container.oauth2Provider == nil || Container.oauth2HttpClient == nil {
return nil, errs.ErrOAuth2NotEnabled
}
if token == nil {
return nil, errs.ErrInvalidOAuth2Token
}
oauth2Client := oauth2.NewClient(wrapOAuth2Context(c, Container.oauth2HttpClient), oauth2.StaticTokenSource(token))
return Container.oauth2Provider.GetUserInfo(c, oauth2Client)
}
// GetExternalUserAuthType returns the external user auth type of the current OAuth 2.0 provider
func GetExternalUserAuthType() core.UserExternalAuthType {
return Container.externalUserAuthType
}
func buildOAuth2Config(config *settings.Config, oauth2Provider OAuth2Provider) *oauth2.Config {
redirectURL := config.RootUrl + "oauth2/callback"
return &oauth2.Config{
ClientID: config.OAuth2ClientID,
ClientSecret: config.OAuth2ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: oauth2Provider.GetAuthUrl(),
TokenURL: oauth2Provider.GetTokenUrl(),
},
RedirectURL: redirectURL,
Scopes: oauth2Provider.GetScopes(),
}
}
+31
View File
@@ -0,0 +1,31 @@
package oauth2
import (
"net/http"
"golang.org/x/oauth2"
"github.com/mayswind/ezbookkeeping/pkg/core"
)
// OAuth2Context represents the context for OAuth 2.0 operations
type OAuth2Context struct {
core.Context
httpClient *http.Client
}
// Value returns the value associated with key
func (o *OAuth2Context) Value(key any) any {
if key == oauth2.HTTPClient {
return o.httpClient
}
return o.Context.Value(key)
}
func wrapOAuth2Context(ctx core.Context, httpClient *http.Client) core.Context {
return &OAuth2Context{
Context: ctx,
httpClient: httpClient,
}
}
+22
View File
@@ -0,0 +1,22 @@
package oauth2
import (
"net/http"
"github.com/mayswind/ezbookkeeping/pkg/core"
)
// OAuth2Provider defines the structure of OAuth 2.0 provider
type OAuth2Provider interface {
// GetAuthUrl returns the authentication url of the provider
GetAuthUrl() string
// GetTokenUrl returns the token url of the provider
GetTokenUrl() string
// GetUserInfo returns the user info
GetUserInfo(c core.Context, oauth2Client *http.Client) (*OAuth2UserInfo, error)
// GetScopes returns the scopes required by the provider
GetScopes() []string
}
+13
View File
@@ -0,0 +1,13 @@
package oauth2
import "github.com/mayswind/ezbookkeeping/pkg/core"
// OAuth2UserInfo represents the user info retrieved from OAuth 2.0 provider
type OAuth2UserInfo struct {
UserName string
Email string
NickName string
LanguageCode string
CurrencyCode string
FirstDayOfWeek core.WeekDay
}
+3
View File
@@ -12,6 +12,9 @@ type CliHandlerFunc func(*CliContext) error
// MiddlewareHandlerFunc represents the middleware handler function
type MiddlewareHandlerFunc func(*WebContext)
// RedirectHandlerFunc represents the redirect handler function
type RedirectHandlerFunc func(*WebContext) (string, *errs.Error)
// ApiHandlerFunc represents the api handler function
type ApiHandlerFunc func(*WebContext) (any, *errs.Error)
+7 -5
View File
@@ -11,11 +11,13 @@ type TokenType byte
// Token types
const (
USER_TOKEN_TYPE_NORMAL TokenType = 1
USER_TOKEN_TYPE_REQUIRE_2FA TokenType = 2
USER_TOKEN_TYPE_EMAIL_VERIFY TokenType = 3
USER_TOKEN_TYPE_PASSWORD_RESET TokenType = 4
USER_TOKEN_TYPE_MCP TokenType = 5
USER_TOKEN_TYPE_NORMAL TokenType = 1
USER_TOKEN_TYPE_REQUIRE_2FA TokenType = 2
USER_TOKEN_TYPE_EMAIL_VERIFY TokenType = 3
USER_TOKEN_TYPE_PASSWORD_RESET TokenType = 4
USER_TOKEN_TYPE_MCP TokenType = 5
USER_TOKEN_TYPE_OAUTH2_CALLBACK_REQUIRE_VERIFY TokenType = 6
USER_TOKEN_TYPE_OAUTH2_CALLBACK TokenType = 7
)
// UserTokenClaims represents user token
+18
View File
@@ -0,0 +1,18 @@
package core
// UserExternalAuthType represents the type of user external authentication
type UserExternalAuthType string
// User External Auth Type
const (
USER_EXTERNAL_AUTH_TYPE_OAUTH2_NEXTCLOUD UserExternalAuthType = "nextcloud"
)
// IsValid checks if the UserExternalAuthType is valid
func (t UserExternalAuthType) IsValid() bool {
switch t {
case USER_EXTERNAL_AUTH_TYPE_OAUTH2_NEXTCLOUD:
return true
}
return false
}
@@ -6,6 +6,7 @@ import "time"
type DuplicateChecker interface {
GetSubmissionRemark(checkerType DuplicateCheckerType, uid int64, identification string) (bool, string)
SetSubmissionRemark(checkerType DuplicateCheckerType, uid int64, identification string, remark string)
SetSubmissionRemarkWithCustomExpiration(checkerType DuplicateCheckerType, uid int64, identification string, remark string, expiration time.Duration)
RemoveSubmissionRemark(checkerType DuplicateCheckerType, uid int64, identification string)
GetOrSetCronJobRunningInfo(jobName string, runningInfo string, runningInterval time.Duration) (bool, string)
RemoveCronJobRunningInfo(jobName string)
@@ -57,6 +57,15 @@ func (c *DuplicateCheckerContainer) SetSubmissionRemark(checkerType DuplicateChe
c.current.SetSubmissionRemark(checkerType, uid, identification, remark)
}
// SetSubmissionRemarkWithCustomExpiration saves the identification and remark by the current duplicate checker with custom expiration time
func (c *DuplicateCheckerContainer) SetSubmissionRemarkWithCustomExpiration(checkerType DuplicateCheckerType, uid int64, identification string, remark string, expiration time.Duration) {
if c.current == nil {
return
}
c.current.SetSubmissionRemarkWithCustomExpiration(checkerType, uid, identification, remark, expiration)
}
// RemoveSubmissionRemark removes the identification and remark by the current duplicate checker
func (c *DuplicateCheckerContainer) RemoveSubmissionRemark(checkerType DuplicateCheckerType, uid int64, identification string) {
if c.current == nil {
@@ -13,5 +13,6 @@ const (
DUPLICATE_CHECKER_TYPE_NEW_TEMPLATE DuplicateCheckerType = 5
DUPLICATE_CHECKER_TYPE_NEW_PICTURE DuplicateCheckerType = 6
DUPLICATE_CHECKER_TYPE_IMPORT_TRANSACTIONS DuplicateCheckerType = 7
DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT DuplicateCheckerType = 8
DUPLICATE_CHECKER_TYPE_FAILURE_CHECK DuplicateCheckerType = 255
)
@@ -42,6 +42,11 @@ func (c *InMemoryDuplicateChecker) SetSubmissionRemark(checkerType DuplicateChec
c.cache.Set(c.getCacheKey(checkerType, uid, identification), remark, cache.DefaultExpiration)
}
// SetSubmissionRemarkWithCustomExpiration saves the identification and remark to in-memory cache with custom expiration time
func (c *InMemoryDuplicateChecker) SetSubmissionRemarkWithCustomExpiration(checkerType DuplicateCheckerType, uid int64, identification string, remark string, expiration time.Duration) {
c.cache.Set(c.getCacheKey(checkerType, uid, identification), remark, expiration)
}
// RemoveSubmissionRemark removes the identification and remark in in-memory cache
func (c *InMemoryDuplicateChecker) RemoveSubmissionRemark(checkerType DuplicateCheckerType, uid int64, identification string) {
c.cache.Delete(c.getCacheKey(checkerType, uid, identification))
+2
View File
@@ -41,6 +41,8 @@ const (
NormalSubcategoryUserCustomExchangeRate = 13
NormalSubcategoryModelContextProtocol = 14
NormalSubcategoryLargeLanguageModel = 15
NormalSubcategoryUserExternalAuth = 16
NormalSubcategoryOAuth2 = 17
)
// Error represents the specific error returned to user
+10
View File
@@ -0,0 +1,10 @@
package errs
import (
"net/http"
)
// Error codes related to user external authentication
var (
ErrUserExternalAuthNotFound = NewNormalError(NormalSubcategoryUserExternalAuth, 0, http.StatusBadRequest, "user external auth is not found")
)
+19
View File
@@ -0,0 +1,19 @@
package errs
import (
"net/http"
)
// Error codes related to oauth 2.0
var (
ErrOAuth2NotEnabled = NewNormalError(NormalSubcategoryOAuth2, 0, http.StatusUnauthorized, "oauth 2.0 not enabled")
ErrOAuth2AutoRegistrationNotEnabled = NewNormalError(NormalSubcategoryOAuth2, 1, http.StatusUnauthorized, "oauth 2.0 auto registration not enabled")
ErrInvalidOAuth2LoginRequest = NewNormalError(NormalSubcategoryOAuth2, 2, http.StatusUnauthorized, "invalid oauth 2.0 login request")
ErrInvalidOAuth2Callback = NewNormalError(NormalSubcategoryOAuth2, 3, http.StatusUnauthorized, "invalid oauth 2.0 callback")
ErrMissingOAuth2State = NewNormalError(NormalSubcategoryOAuth2, 4, http.StatusUnauthorized, "missing state in oauth 2.0 callback")
ErrMissingOAuth2Code = NewNormalError(NormalSubcategoryOAuth2, 5, http.StatusUnauthorized, "missing code in oauth 2.0 callback")
ErrInvalidOAuth2State = NewNormalError(NormalSubcategoryOAuth2, 6, http.StatusUnauthorized, "invalid state in oauth 2.0 callback")
ErrCannotRetrieveOAuth2Token = NewNormalError(NormalSubcategoryOAuth2, 7, http.StatusUnauthorized, "cannot retrieve oauth 2.0 token")
ErrInvalidOAuth2Token = NewNormalError(NormalSubcategoryOAuth2, 8, http.StatusUnauthorized, "invalid oauth 2.0 token")
ErrCannotRetrieveUserInfo = NewNormalError(NormalSubcategoryOAuth2, 9, http.StatusUnauthorized, "cannot retrieve user info from oauth 2.0 provider")
)
+4
View File
@@ -26,4 +26,8 @@ var (
ErrInvalidIpAddressPattern = NewSystemError(SystemSubcategorySetting, 19, http.StatusInternalServerError, "invalid ip address pattern")
ErrInvalidLLMProvider = NewSystemError(SystemSubcategorySetting, 20, http.StatusInternalServerError, "invalid llm provider")
ErrInvalidLLMModelId = NewSystemError(SystemSubcategorySetting, 21, http.StatusInternalServerError, "invalid llm model id")
ErrInvalidOAuth2Config = NewSystemError(SystemSubcategorySetting, 22, http.StatusInternalServerError, "invalid oauth 2.0 config")
ErrInvalidOAuth2UserIdentifier = NewSystemError(SystemSubcategorySetting, 23, http.StatusInternalServerError, "invalid oauth 2.0 user identifier")
ErrInvalidOAuth2Provider = NewSystemError(SystemSubcategorySetting, 24, http.StatusInternalServerError, "invalid oauth 2.0 provider")
ErrInvalidOAuth2StateExpiredTime = NewSystemError(SystemSubcategorySetting, 25, http.StatusInternalServerError, "invalid oauth 2.0 state expired time")
)
+1
View File
@@ -38,4 +38,5 @@ var (
ErrUserAvatarExtensionInvalid = NewNormalError(NormalSubcategoryUser, 29, http.StatusNotFound, "user avatar file extension invalid")
ErrExceedMaxUserAvatarFileSize = NewNormalError(NormalSubcategoryUser, 30, http.StatusBadRequest, "exceed the maximum size of user avatar file")
ErrNotPermittedToPerformThisAction = NewNormalError(NormalSubcategoryUser, 31, http.StatusBadRequest, "not permitted to perform this action")
ErrCannotLoginByPassword = NewNormalError(NormalSubcategoryUser, 32, http.StatusBadRequest, "cannot login by password")
)
@@ -1,11 +1,9 @@
package exchangerates
import (
"crypto/tls"
"io"
"net/http"
"sort"
"time"
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
@@ -28,23 +26,10 @@ type HttpExchangeRatesDataSource interface {
type CommonHttpExchangeRatesDataProvider struct {
ExchangeRatesDataProvider
dataSource HttpExchangeRatesDataSource
httpClient *http.Client
}
func (e *CommonHttpExchangeRatesDataProvider) GetLatestExchangeRates(c core.Context, uid int64, currentConfig *settings.Config) (*models.LatestExchangeRateResponse, error) {
transport := http.DefaultTransport.(*http.Transport).Clone()
utils.SetProxyUrl(transport, currentConfig.ExchangeRatesProxy)
if currentConfig.ExchangeRatesSkipTLSVerify {
transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
}
}
client := &http.Client{
Transport: transport,
Timeout: time.Duration(currentConfig.ExchangeRatesRequestTimeout) * time.Millisecond,
}
requests, err := e.dataSource.BuildRequests()
if err != nil {
@@ -56,14 +41,7 @@ func (e *CommonHttpExchangeRatesDataProvider) GetLatestExchangeRates(c core.Cont
for i := 0; i < len(requests); i++ {
req := requests[i]
if len(req.Header.Values("User-Agent")) < 1 {
req.Header.Set("User-Agent", settings.GetUserAgent())
} else if req.Header.Get("User-Agent") == "" {
req.Header.Del("User-Agent")
}
resp, err := client.Do(req)
resp, err := e.httpClient.Do(req)
if err != nil {
log.Errorf(c, "[common_http_exchange_rates_data_provider.GetLatestExchangeRates] failed to request latest exchange rate data for user \"uid:%d\", because %s", uid, err.Error())
@@ -76,7 +54,7 @@ func (e *CommonHttpExchangeRatesDataProvider) GetLatestExchangeRates(c core.Cont
log.Debugf(c, "[common_http_exchange_rates_data_provider.GetLatestExchangeRates] response#%d is %s", i, body)
if resp.StatusCode != 200 {
log.Errorf(c, "[common_http_exchange_rates_data_provider.GetLatestExchangeRates] failed to get latest exchange rate data response for user \"uid:%d\", because response code is not %d", uid, resp.StatusCode)
log.Errorf(c, "[common_http_exchange_rates_data_provider.GetLatestExchangeRates] failed to get latest exchange rate data response for user \"uid:%d\", because response code is %d", uid, resp.StatusCode)
return nil, errs.ErrFailedToRequestRemoteApi
}
@@ -125,8 +103,9 @@ func (e *CommonHttpExchangeRatesDataProvider) GetLatestExchangeRates(c core.Cont
return finalExchangeRateResponse, nil
}
func newCommonHttpExchangeRatesDataProvider(dataSource HttpExchangeRatesDataSource) *CommonHttpExchangeRatesDataProvider {
func newCommonHttpExchangeRatesDataProvider(config *settings.Config, dataSource HttpExchangeRatesDataSource) *CommonHttpExchangeRatesDataProvider {
return &CommonHttpExchangeRatesDataProvider{
dataSource: dataSource,
httpClient: utils.NewHttpClient(config.ExchangeRatesRequestTimeout, config.ExchangeRatesProxy, config.ExchangeRatesSkipTLSVerify, settings.GetUserAgent()),
}
}
@@ -20,55 +20,55 @@ var (
// InitializeExchangeRatesDataSource initializes the current exchange rates data source according to the config
func InitializeExchangeRatesDataSource(config *settings.Config) error {
if config.ExchangeRatesDataSource == settings.ReserveBankOfAustraliaDataSource {
Container.current = newCommonHttpExchangeRatesDataProvider(&ReserveBankOfAustraliaDataSource{})
Container.current = newCommonHttpExchangeRatesDataProvider(config, &ReserveBankOfAustraliaDataSource{})
return nil
} else if config.ExchangeRatesDataSource == settings.BankOfCanadaDataSource {
Container.current = newCommonHttpExchangeRatesDataProvider(&BankOfCanadaDataSource{})
Container.current = newCommonHttpExchangeRatesDataProvider(config, &BankOfCanadaDataSource{})
return nil
} else if config.ExchangeRatesDataSource == settings.CzechNationalBankDataSource {
Container.current = newCommonHttpExchangeRatesDataProvider(&CzechNationalBankDataSource{})
Container.current = newCommonHttpExchangeRatesDataProvider(config, &CzechNationalBankDataSource{})
return nil
} else if config.ExchangeRatesDataSource == settings.DanmarksNationalbankDataSource {
Container.current = newCommonHttpExchangeRatesDataProvider(&DanmarksNationalbankDataSource{})
Container.current = newCommonHttpExchangeRatesDataProvider(config, &DanmarksNationalbankDataSource{})
return nil
} else if config.ExchangeRatesDataSource == settings.EuroCentralBankDataSource {
Container.current = newCommonHttpExchangeRatesDataProvider(&EuroCentralBankDataSource{})
Container.current = newCommonHttpExchangeRatesDataProvider(config, &EuroCentralBankDataSource{})
return nil
} else if config.ExchangeRatesDataSource == settings.NationalBankOfGeorgiaDataSource {
Container.current = newCommonHttpExchangeRatesDataProvider(&NationalBankOfGeorgiaDataSource{})
Container.current = newCommonHttpExchangeRatesDataProvider(config, &NationalBankOfGeorgiaDataSource{})
return nil
} else if config.ExchangeRatesDataSource == settings.CentralBankOfHungaryDataSource {
Container.current = newCommonHttpExchangeRatesDataProvider(&CentralBankOfHungaryDataSource{})
Container.current = newCommonHttpExchangeRatesDataProvider(config, &CentralBankOfHungaryDataSource{})
return nil
} else if config.ExchangeRatesDataSource == settings.BankOfIsraelDataSource {
Container.current = newCommonHttpExchangeRatesDataProvider(&BankOfIsraelDataSource{})
Container.current = newCommonHttpExchangeRatesDataProvider(config, &BankOfIsraelDataSource{})
return nil
} else if config.ExchangeRatesDataSource == settings.CentralBankOfMyanmarDataSource {
Container.current = newCommonHttpExchangeRatesDataProvider(&CentralBankOfMyanmarDataSource{})
Container.current = newCommonHttpExchangeRatesDataProvider(config, &CentralBankOfMyanmarDataSource{})
return nil
} else if config.ExchangeRatesDataSource == settings.NorgesBankDataSource {
Container.current = newCommonHttpExchangeRatesDataProvider(&NorgesBankDataSource{})
Container.current = newCommonHttpExchangeRatesDataProvider(config, &NorgesBankDataSource{})
return nil
} else if config.ExchangeRatesDataSource == settings.NationalBankOfPolandDataSource {
Container.current = newCommonHttpExchangeRatesDataProvider(&NationalBankOfPolandDataSource{})
Container.current = newCommonHttpExchangeRatesDataProvider(config, &NationalBankOfPolandDataSource{})
return nil
} else if config.ExchangeRatesDataSource == settings.NationalBankOfRomaniaDataSource {
Container.current = newCommonHttpExchangeRatesDataProvider(&NationalBankOfRomaniaDataSource{})
Container.current = newCommonHttpExchangeRatesDataProvider(config, &NationalBankOfRomaniaDataSource{})
return nil
} else if config.ExchangeRatesDataSource == settings.BankOfRussiaDataSource {
Container.current = newCommonHttpExchangeRatesDataProvider(&BankOfRussiaDataSource{})
Container.current = newCommonHttpExchangeRatesDataProvider(config, &BankOfRussiaDataSource{})
return nil
} else if config.ExchangeRatesDataSource == settings.SwissNationalBankDataSource {
Container.current = newCommonHttpExchangeRatesDataProvider(&SwissNationalBankDataSource{})
Container.current = newCommonHttpExchangeRatesDataProvider(config, &SwissNationalBankDataSource{})
return nil
} else if config.ExchangeRatesDataSource == settings.NationalBankOfUkraineDataSource {
Container.current = newCommonHttpExchangeRatesDataProvider(&NationalBankOfUkraineDataSource{})
Container.current = newCommonHttpExchangeRatesDataProvider(config, &NationalBankOfUkraineDataSource{})
return nil
} else if config.ExchangeRatesDataSource == settings.CentralBankOfUzbekistanDataSource {
Container.current = newCommonHttpExchangeRatesDataProvider(&CentralBankOfUzbekistanDataSource{})
Container.current = newCommonHttpExchangeRatesDataProvider(config, &CentralBankOfUzbekistanDataSource{})
return nil
} else if config.ExchangeRatesDataSource == settings.InternationalMonetaryFundDataSource {
Container.current = newCommonHttpExchangeRatesDataProvider(&InternationalMonetaryFundDataSource{})
Container.current = newCommonHttpExchangeRatesDataProvider(config, &InternationalMonetaryFundDataSource{})
return nil
} else if config.ExchangeRatesDataSource == settings.UserCustomExchangeRatesDataSource {
Container.current = newUserCustomExchangeRatesDataProvider()
@@ -1,11 +1,9 @@
package common
import (
"crypto/tls"
"io"
"net/http"
"strings"
"time"
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
@@ -28,7 +26,8 @@ type HttpLargeLanguageModelAdapter interface {
// CommonHttpLargeLanguageModelProvider defines the structure of common http large language model provider
type CommonHttpLargeLanguageModelProvider struct {
provider.LargeLanguageModelProvider
adapter HttpLargeLanguageModelAdapter
adapter HttpLargeLanguageModelAdapter
httpClient *http.Client
}
// GetJsonResponse returns the json response from common http large language model provider
@@ -51,20 +50,6 @@ func (p *CommonHttpLargeLanguageModelProvider) GetJsonResponse(c core.Context, u
}
func (p *CommonHttpLargeLanguageModelProvider) getTextualResponse(c core.Context, uid int64, currentLLMConfig *settings.LLMConfig, request *data.LargeLanguageModelRequest, responseType data.LargeLanguageModelResponseFormat) (*data.LargeLanguageModelTextualResponse, error) {
transport := http.DefaultTransport.(*http.Transport).Clone()
utils.SetProxyUrl(transport, currentLLMConfig.LargeLanguageModelAPIProxy)
if currentLLMConfig.LargeLanguageModelAPISkipTLSVerify {
transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
}
}
client := &http.Client{
Transport: transport,
Timeout: time.Duration(currentLLMConfig.LargeLanguageModelAPIRequestTimeout) * time.Millisecond,
}
httpRequest, err := p.adapter.BuildTextualRequest(c, uid, request, responseType)
if err != nil {
@@ -72,9 +57,7 @@ func (p *CommonHttpLargeLanguageModelProvider) getTextualResponse(c core.Context
return nil, errs.ErrFailedToRequestRemoteApi
}
httpRequest.Header.Set("User-Agent", settings.GetUserAgent())
resp, err := client.Do(httpRequest)
resp, err := p.httpClient.Do(httpRequest)
if err != nil {
log.Errorf(c, "[common_http_large_language_model_provider.getTextualResponse] failed to request large language model api for user \"uid:%d\", because %s", uid, err.Error())
@@ -95,8 +78,9 @@ func (p *CommonHttpLargeLanguageModelProvider) getTextualResponse(c core.Context
}
// NewCommonHttpLargeLanguageModelProvider creates a http adapter based large language model provider instance
func NewCommonHttpLargeLanguageModelProvider(adapter HttpLargeLanguageModelAdapter) *CommonHttpLargeLanguageModelProvider {
func NewCommonHttpLargeLanguageModelProvider(llmConfig *settings.LLMConfig, adapter HttpLargeLanguageModelAdapter) *CommonHttpLargeLanguageModelProvider {
return &CommonHttpLargeLanguageModelProvider{
adapter: adapter,
adapter: adapter,
httpClient: utils.NewHttpClient(llmConfig.LargeLanguageModelAPIRequestTimeout, llmConfig.LargeLanguageModelAPIProxy, llmConfig.LargeLanguageModelAPISkipTLSVerify, settings.GetUserAgent()),
}
}
@@ -160,7 +160,7 @@ func (p *GoogleAILargeLanguageModelAdapter) buildJsonRequestBody(c core.Context,
// NewGoogleAILargeLanguageModelProvider creates a new Google AI large language model provider instance
func NewGoogleAILargeLanguageModelProvider(llmConfig *settings.LLMConfig) provider.LargeLanguageModelProvider {
return common.NewCommonHttpLargeLanguageModelProvider(&GoogleAILargeLanguageModelAdapter{
return common.NewCommonHttpLargeLanguageModelProvider(llmConfig, &GoogleAILargeLanguageModelAdapter{
GoogleAIAPIKey: llmConfig.GoogleAIAPIKey,
GoogleAIModelID: llmConfig.GoogleAIModelID,
})
@@ -159,7 +159,7 @@ func (p *OllamaLargeLanguageModelAdapter) getOllamaRequestUrl() string {
// NewOllamaLargeLanguageModelProvider creates a new Ollama large language model provider instance
func NewOllamaLargeLanguageModelProvider(llmConfig *settings.LLMConfig) provider.LargeLanguageModelProvider {
return common.NewCommonHttpLargeLanguageModelProvider(&OllamaLargeLanguageModelAdapter{
return common.NewCommonHttpLargeLanguageModelProvider(llmConfig, &OllamaLargeLanguageModelAdapter{
OllamaServerURL: llmConfig.OllamaServerURL,
OllamaModelID: llmConfig.OllamaModelID,
})
@@ -37,7 +37,7 @@ func (p *OpenAIOfficialChatCompletionsAPIProvider) GetModelID() string {
// NewOpenAILargeLanguageModelProvider creates a new OpenAI large language model provider instance
func NewOpenAILargeLanguageModelProvider(llmConfig *settings.LLMConfig) provider.LargeLanguageModelProvider {
return newCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter(&OpenAIOfficialChatCompletionsAPIProvider{
return newCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter(llmConfig, &OpenAIOfficialChatCompletionsAPIProvider{
OpenAIAPIKey: llmConfig.OpenAIAPIKey,
OpenAIModelID: llmConfig.OpenAIModelID,
})
@@ -14,6 +14,7 @@ import (
"github.com/mayswind/ezbookkeeping/pkg/llm/provider"
"github.com/mayswind/ezbookkeeping/pkg/llm/provider/common"
"github.com/mayswind/ezbookkeeping/pkg/log"
"github.com/mayswind/ezbookkeeping/pkg/settings"
)
// OpenAIChatCompletionsAPIProvider defines the structure of OpenAI chat completions API provider
@@ -212,8 +213,8 @@ func (p *CommonOpenAIChatCompletionsAPILargeLanguageModelAdapter) buildJsonReque
return requestBodyBytes, nil
}
func newCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter(apiProvider OpenAIChatCompletionsAPIProvider) provider.LargeLanguageModelProvider {
return common.NewCommonHttpLargeLanguageModelProvider(&CommonOpenAIChatCompletionsAPILargeLanguageModelAdapter{
func newCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter(llmConfig *settings.LLMConfig, apiProvider OpenAIChatCompletionsAPIProvider) provider.LargeLanguageModelProvider {
return common.NewCommonHttpLargeLanguageModelProvider(llmConfig, &CommonOpenAIChatCompletionsAPILargeLanguageModelAdapter{
apiProvider: apiProvider,
})
}
@@ -51,7 +51,7 @@ func (p *OpenAICompatibleChatCompletionsAPIProvider) getFinalChatCompletionsRequ
// NewOpenAICompatibleLargeLanguageModelProvider creates a new OpenAI compatible large language model provider instance
func NewOpenAICompatibleLargeLanguageModelProvider(llmConfig *settings.LLMConfig) provider.LargeLanguageModelProvider {
return newCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter(&OpenAICompatibleChatCompletionsAPIProvider{
return newCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter(llmConfig, &OpenAICompatibleChatCompletionsAPIProvider{
OpenAICompatibleBaseURL: llmConfig.OpenAICompatibleBaseURL,
OpenAICompatibleAPIKey: llmConfig.OpenAICompatibleAPIKey,
OpenAICompatibleModelID: llmConfig.OpenAICompatibleModelID,
@@ -39,7 +39,7 @@ func (p *OpenRouterChatCompletionsAPIProvider) GetModelID() string {
// NewOpenRouterLargeLanguageModelProvider creates a new OpenRouter large language model provider instance
func NewOpenRouterLargeLanguageModelProvider(llmConfig *settings.LLMConfig) provider.LargeLanguageModelProvider {
return newCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter(&OpenRouterChatCompletionsAPIProvider{
return newCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter(llmConfig, &OpenRouterChatCompletionsAPIProvider{
OpenRouterAPIKey: llmConfig.OpenRouterAPIKey,
OpenRouterModelID: llmConfig.OpenRouterModelID,
})
+19
View File
@@ -111,6 +111,25 @@ func JWTMCPAuthorization(c *core.WebContext) {
c.Next()
}
// JWTOAuth2CallbackAuthorization verifies whether current request is OAuth 2.0 callback
func JWTOAuth2CallbackAuthorization(c *core.WebContext) {
claims, err := getTokenClaims(c, TOKEN_SOURCE_TYPE_HEADER)
if err != nil {
utils.PrintJsonErrorResult(c, errs.ErrTokenExpired)
return
}
if claims.Type != core.USER_TOKEN_TYPE_OAUTH2_CALLBACK && claims.Type != core.USER_TOKEN_TYPE_OAUTH2_CALLBACK_REQUIRE_VERIFY {
log.Warnf(c, "[authorization.JWTOAuth2CallbackAuthorization] user \"uid:%d\" token is not for oauth 2.0 callback request", claims.Uid)
utils.PrintJsonErrorResult(c, errs.ErrCurrentInvalidToken)
return
}
c.SetTokenClaims(claims)
c.Next()
}
func jwtAuthorization(c *core.WebContext, source TokenSourceType) {
claims, err := getTokenClaims(c, source)
+19
View File
@@ -0,0 +1,19 @@
package models
// OAuth2LoginRequest represents all parameters of OAuth 2.0 login request
type OAuth2LoginRequest struct {
Platform string `form:"platform" binding:"required"`
ClientSessionId string `form:"client_session_id" binding:"required"`
}
// OAuth2CallbackRequest represents all parameters of OAuth 2.0 callback request
type OAuth2CallbackRequest struct {
State string `form:"state"`
Code string `form:"code"`
}
// OAuth2CallbackLoginRequest represents all parameters of OAuth 2.0 callback login request
type OAuth2CallbackLoginRequest struct {
Provider string `json:"provider" binding:"required,notBlank"`
Password string `json:"password" binding:"omitempty,min=6,max=128"`
}
+17
View File
@@ -0,0 +1,17 @@
package models
import "github.com/mayswind/ezbookkeeping/pkg/core"
// UserExternalAuth represents user external auth data stored in database
type UserExternalAuth struct {
Uid int64 `xorm:"PK"`
ExternalAuthType core.UserExternalAuthType `xorm:"VARCHAR(32) PK UNIQUE(uqe_userexternalauth_authtype_username) UNIQUE(uqe_userexternalauth_authtype_email)"`
ExternalUsername string `xorm:"VARCHAR(32) UNIQUE(uqe_userexternalauth_authtype_username) NOT NULL"`
ExternalEmail string `xorm:"VARCHAR(100) UNIQUE(uqe_userexternalauth_authtype_email) NOT NULL"`
CreatedUnixTime int64
}
// UserExternalAuthRevokeRequest represents all parameters of user external auth revoke request
type UserExternalAuthRevokeRequest struct {
ExternalAuthType core.UserExternalAuthType `json:"externalAuthType" binding:"required,notBlank"`
}
+12
View File
@@ -133,6 +133,18 @@ func (s *TokenService) CreateMCPTokenViaCli(c *core.CliContext, user *models.Use
return token, tokenRecord, err
}
// CreateOAuth2CallbackRequireVerifyToken generates a new OAuth 2.0 callback token requiring user to verify and saves to database
func (s *TokenService) CreateOAuth2CallbackRequireVerifyToken(c *core.WebContext, user *models.User) (string, *core.UserTokenClaims, error) {
token, claims, _, err := s.createToken(c, user, core.USER_TOKEN_TYPE_OAUTH2_CALLBACK_REQUIRE_VERIFY, s.getUserAgent(c), s.CurrentConfig().TemporaryTokenExpiredTimeDuration)
return token, claims, err
}
// CreateOAuth2CallbackToken generates a new OAuth 2.0 callback token and saves to database
func (s *TokenService) CreateOAuth2CallbackToken(c *core.WebContext, user *models.User) (string, *core.UserTokenClaims, error) {
token, claims, _, err := s.createToken(c, user, core.USER_TOKEN_TYPE_OAUTH2_CALLBACK, s.getUserAgent(c), s.CurrentConfig().TemporaryTokenExpiredTimeDuration)
return token, claims, err
}
// UpdateTokenLastSeen updates the last seen time of specified token
func (s *TokenService) UpdateTokenLastSeen(c core.Context, tokenRecord *models.TokenRecord) error {
if tokenRecord.Uid <= 0 {
+117
View File
@@ -0,0 +1,117 @@
package services
import (
"time"
"xorm.io/xorm"
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/datastore"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/models"
)
// UserExternalAuthService represents user external auth service
type UserExternalAuthService struct {
ServiceUsingDB
}
// Initialize a user external auth service singleton instance
var (
UserExternalAuths = &UserExternalAuthService{
ServiceUsingDB: ServiceUsingDB{
container: datastore.Container,
},
}
)
// GetUserAllExternalAuthsByUid returns the user all external auth list according to user uid
func (s *UserExternalAuthService) GetUserAllExternalAuthsByUid(c core.Context, uid int64) ([]*models.UserExternalAuth, error) {
if uid <= 0 {
return nil, errs.ErrUserIdInvalid
}
var userExternalAuths []*models.UserExternalAuth
err := s.UserDB().NewSession(c).Where("uid=?", uid).Find(&userExternalAuths)
return userExternalAuths, err
}
// GetUserExternalAuthByUid returns the user external auth record by uid
func (s *UserExternalAuthService) GetUserExternalAuthByUid(c core.Context, uid int64, externalAuthType core.UserExternalAuthType) (*models.UserExternalAuth, error) {
if uid <= 0 {
return nil, errs.ErrUserIdInvalid
}
userExternalAuth := &models.UserExternalAuth{}
has, err := s.UserDB().NewSession(c).Where("uid=? AND external_auth_type=?", uid, externalAuthType).Get(userExternalAuth)
if err != nil {
return nil, err
} else if !has {
return nil, errs.ErrUserExternalAuthNotFound
}
return userExternalAuth, err
}
// GetUserExternalAuthByExternalUserName returns the user external auth record by external username
func (s *UserExternalAuthService) GetUserExternalAuthByExternalUserName(c core.Context, externalUserName string, externalAuthType core.UserExternalAuthType) (*models.UserExternalAuth, error) {
userExternalAuth := &models.UserExternalAuth{}
has, err := s.UserDB().NewSession(c).Where("external_auth_type=? AND external_username=?", externalAuthType, externalUserName).Get(userExternalAuth)
if err != nil {
return nil, err
} else if !has {
return nil, errs.ErrUserExternalAuthNotFound
}
return userExternalAuth, err
}
// GetUserExternalAuthByExternalEmail returns the user external auth record by external email
func (s *UserExternalAuthService) GetUserExternalAuthByExternalEmail(c core.Context, externalEmail string, externalAuthType core.UserExternalAuthType) (*models.UserExternalAuth, error) {
userExternalAuth := &models.UserExternalAuth{}
has, err := s.UserDB().NewSession(c).Where("external_auth_type=? AND external_email=?", externalAuthType, externalEmail).Get(userExternalAuth)
if err != nil {
return nil, err
} else if !has {
return nil, errs.ErrUserExternalAuthNotFound
}
return userExternalAuth, err
}
// CreateUserExternalAuth creates a new user external auth record in database
func (s *UserExternalAuthService) CreateUserExternalAuth(c core.Context, userExternalAuth *models.UserExternalAuth) error {
if userExternalAuth.Uid <= 0 {
return errs.ErrUserIdInvalid
}
userExternalAuth.CreatedUnixTime = time.Now().Unix()
return s.UserDB().DoTransaction(c, func(sess *xorm.Session) error {
_, err := sess.Insert(userExternalAuth)
return err
})
}
// DeleteUserExternalAuth deletes given user external auth record from database
func (s *UserExternalAuthService) DeleteUserExternalAuth(c core.Context, uid int64, externalAuthType core.UserExternalAuthType) error {
if uid <= 0 {
return errs.ErrUserIdInvalid
}
return s.UserDB().DoTransaction(c, func(sess *xorm.Session) error {
deletedRows, err := sess.Where("uid=? AND external_auth_type=?", uid, externalAuthType).Delete(&models.UserExternalAuth{})
if err != nil {
return err
} else if deletedRows < 1 {
return errs.ErrUserExternalAuthNotFound
}
return nil
})
}
+77 -40
View File
@@ -85,6 +85,17 @@ const (
InMemoryDuplicateCheckerType string = "in_memory"
)
// OAuth 2.0 user identifier types
const (
OAuth2UserIdentifierEmail string = "email"
OAuth2UserIdentifierUsername string = "username"
)
// OAuth 2.0 rovider types
const (
OAuth2ProviderNextcloud string = "nextcloud"
)
// Map provider types
const (
OpenStreetMapProvider string = "openstreetmap"
@@ -164,6 +175,9 @@ const (
defaultMaxFailuresPerIpPerMinute uint32 = 5
defaultMaxFailuresPerUserPerMinute uint32 = 5
defaultOAuth2StateExpiredTime uint32 = 300 // 5 minutes
defaultOAuth2RequestTimeout uint32 = 10000 // 10 seconds
defaultTransactionPictureFileMaxSize uint32 = 10485760 // 10MB
defaultUserAvatarFileMaxSize uint32 = 1048576 // 1MB
@@ -240,15 +254,8 @@ type LLMConfig struct {
LargeLanguageModelAPISkipTLSVerify bool
}
// TipConfig represents a tip setting config
type TipConfig struct {
Enabled bool
DefaultContent string
MultiLanguageContent map[string]string
}
// NotificationConfig represents a notification setting config
type NotificationConfig struct {
// MultiLanguageContentConfig represents a multi-language content setting config
type MultiLanguageContentConfig struct {
Enabled bool
DefaultContent string
MultiLanguageContent map[string]string
@@ -351,9 +358,22 @@ type Config struct {
MaxFailuresPerUserPerMinute uint32
// Auth
EnableInternalAuth bool
EnableOAuth2Login bool
EnableTwoFactor bool
EnableUserForgetPassword bool
ForgetPasswordRequireVerifyEmail bool
OAuth2ClientID string
OAuth2ClientSecret string
OAuth2UserIdentifier string
OAuth2AutoRegister bool
OAuth2Provider string
OAuth2StateExpiredTime uint32
OAuth2StateExpiredTimeDuration time.Duration
OAuth2RequestTimeout uint32
OAuth2Proxy string
OAuth2SkipTLSVerify bool
OAuth2NextcloudBaseUrl string
// User
EnableUserRegister bool
@@ -372,12 +392,12 @@ type Config struct {
MaxImportFileSize uint32
// Tip
LoginPageTips TipConfig
LoginPageTips MultiLanguageContentConfig
// Notification
AfterRegisterNotification NotificationConfig
AfterLoginNotification NotificationConfig
AfterOpenNotification NotificationConfig
AfterRegisterNotification MultiLanguageContentConfig
AfterLoginNotification MultiLanguageContentConfig
AfterOpenNotification MultiLanguageContentConfig
// Map
MapProvider string
@@ -956,9 +976,47 @@ func loadSecurityConfiguration(config *Config, configFile *ini.File, sectionName
}
func loadAuthConfiguration(config *Config, configFile *ini.File, sectionName string) error {
config.EnableInternalAuth = getConfigItemBoolValue(configFile, sectionName, "enable_internal_auth", true)
config.EnableOAuth2Login = getConfigItemBoolValue(configFile, sectionName, "enable_oauth2_auth", false)
config.EnableTwoFactor = getConfigItemBoolValue(configFile, sectionName, "enable_two_factor", true)
config.EnableUserForgetPassword = getConfigItemBoolValue(configFile, sectionName, "enable_forget_password", false)
config.ForgetPasswordRequireVerifyEmail = getConfigItemBoolValue(configFile, sectionName, "forget_password_require_email_verify", false)
config.OAuth2ClientID = getConfigItemStringValue(configFile, sectionName, "oauth2_client_id")
config.OAuth2ClientSecret = getConfigItemStringValue(configFile, sectionName, "oauth2_client_secret")
oauth2UserIdentifier := getConfigItemStringValue(configFile, sectionName, "oauth2_user_identifier")
if oauth2UserIdentifier == OAuth2UserIdentifierEmail {
config.OAuth2UserIdentifier = OAuth2UserIdentifierEmail
} else if oauth2UserIdentifier == OAuth2UserIdentifierUsername {
config.OAuth2UserIdentifier = OAuth2UserIdentifierUsername
} else {
return errs.ErrInvalidOAuth2UserIdentifier
}
config.OAuth2AutoRegister = getConfigItemBoolValue(configFile, sectionName, "oauth2_auto_register", true)
oauth2Provider := getConfigItemStringValue(configFile, sectionName, "oauth2_provider")
if oauth2Provider == OAuth2ProviderNextcloud {
config.OAuth2Provider = OAuth2ProviderNextcloud
} else {
return errs.ErrInvalidOAuth2Provider
}
config.OAuth2StateExpiredTime = getConfigItemUint32Value(configFile, sectionName, "oauth2_state_expired_time", defaultOAuth2StateExpiredTime)
if config.OAuth2StateExpiredTime < 60 {
return errs.ErrInvalidOAuth2StateExpiredTime
}
config.OAuth2StateExpiredTimeDuration = time.Duration(config.OAuth2StateExpiredTime) * time.Second
config.OAuth2Proxy = getConfigItemStringValue(configFile, sectionName, "oauth2_proxy", "system")
config.OAuth2RequestTimeout = getConfigItemUint32Value(configFile, sectionName, "oauth2_request_timeout", defaultOAuth2RequestTimeout)
config.OAuth2SkipTLSVerify = getConfigItemBoolValue(configFile, sectionName, "oauth2_skip_tls_verify", false)
config.OAuth2NextcloudBaseUrl = getConfigItemStringValue(configFile, sectionName, "nextcloud_base_url")
return nil
}
@@ -996,15 +1054,15 @@ func loadDataConfiguration(config *Config, configFile *ini.File, sectionName str
}
func loadTipConfiguration(config *Config, configFile *ini.File, sectionName string) error {
config.LoginPageTips = getTipConfiguration(configFile, sectionName, "enable_tips_in_login_page", "login_page_tips_content")
config.LoginPageTips = getMultiLanguageContentConfig(configFile, sectionName, "enable_tips_in_login_page", "login_page_tips_content")
return nil
}
func loadNotificationConfiguration(config *Config, configFile *ini.File, sectionName string) error {
config.AfterRegisterNotification = getNotificationConfiguration(configFile, sectionName, "enable_notification_after_register", "after_register_notification_content")
config.AfterLoginNotification = getNotificationConfiguration(configFile, sectionName, "enable_notification_after_login", "after_login_notification_content")
config.AfterOpenNotification = getNotificationConfiguration(configFile, sectionName, "enable_notification_after_open", "after_open_notification_content")
config.AfterRegisterNotification = getMultiLanguageContentConfig(configFile, sectionName, "enable_notification_after_register", "after_register_notification_content")
config.AfterLoginNotification = getMultiLanguageContentConfig(configFile, sectionName, "enable_notification_after_login", "after_login_notification_content")
config.AfterOpenNotification = getMultiLanguageContentConfig(configFile, sectionName, "enable_notification_after_open", "after_open_notification_content")
return nil
}
@@ -1141,29 +1199,8 @@ func getFinalPath(workingPath, p string) (string, error) {
return p, err
}
func getTipConfiguration(configFile *ini.File, sectionName string, enableKey string, contentKey string) TipConfig {
config := TipConfig{
Enabled: getConfigItemBoolValue(configFile, sectionName, enableKey, false),
DefaultContent: getConfigItemStringValue(configFile, sectionName, contentKey, ""),
MultiLanguageContent: make(map[string]string),
}
for languageTag := range locales.AllLanguages {
multiLanguageContentKey := strings.ToLower(languageTag)
multiLanguageContentKey = strings.Replace(multiLanguageContentKey, "-", "_", -1)
multiLanguageContentKey = contentKey + "_" + multiLanguageContentKey
content := getConfigItemStringValue(configFile, sectionName, multiLanguageContentKey, "")
if content != "" {
config.MultiLanguageContent[languageTag] = content
}
}
return config
}
func getNotificationConfiguration(configFile *ini.File, sectionName string, enableKey string, contentKey string) NotificationConfig {
config := NotificationConfig{
func getMultiLanguageContentConfig(configFile *ini.File, sectionName string, enableKey string, contentKey string) MultiLanguageContentConfig {
config := MultiLanguageContentConfig{
Enabled: getConfigItemBoolValue(configFile, sectionName, enableKey, false),
DefaultContent: getConfigItemStringValue(configFile, sectionName, contentKey, ""),
MultiLanguageContent: make(map[string]string),
+4
View File
@@ -26,5 +26,9 @@ func (c *ConfigContainer) GetCurrentConfig() *Config {
}
func GetUserAgent() string {
if Version == "" {
return "ezBookkeeping"
}
return fmt.Sprintf("ezBookkeeping/%s", Version)
}
+1 -16
View File
@@ -2,12 +2,10 @@ package storage
import (
"bytes"
"crypto/tls"
"io"
"net/http"
"path/filepath"
"strings"
"time"
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
@@ -26,22 +24,9 @@ type WebDAVObjectStorage struct {
// NewWebDAVObjectStorage returns a WebDAV object storage
func NewWebDAVObjectStorage(config *settings.Config, pathPrefix string) (*WebDAVObjectStorage, error) {
webDavConfig := config.WebDAVConfig
transport := http.DefaultTransport.(*http.Transport).Clone()
utils.SetProxyUrl(transport, webDavConfig.Proxy)
if webDavConfig.SkipTLSVerify {
transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
}
}
client := &http.Client{
Transport: transport,
Timeout: time.Duration(webDavConfig.RequestTimeout) * time.Millisecond,
}
storage := &WebDAVObjectStorage{
httpClient: client,
httpClient: utils.NewHttpClient(webDavConfig.RequestTimeout, webDavConfig.Proxy, webDavConfig.SkipTLSVerify, settings.GetUserAgent()),
webDavConfig: webDavConfig,
rootPath: webDavConfig.RootPath,
}
+21 -57
View File
@@ -2,6 +2,7 @@ package utils
import (
"encoding/json"
"errors"
"net/http"
"reflect"
@@ -11,6 +12,22 @@ import (
"github.com/mayswind/ezbookkeeping/pkg/errs"
)
// GetDisplayErrorMessage returns the display error message for given error
func GetDisplayErrorMessage(err *errs.Error) string {
if err.Code() == errs.ErrIncompleteOrIncorrectSubmission.Code() && len(err.BaseError) > 0 {
var validationErrors validator.ValidationErrors
ok := errors.As(err.BaseError[0], &validationErrors)
if ok {
for _, err := range validationErrors {
return getValidationErrorText(err)
}
}
}
return err.Error()
}
// PrintJsonSuccessResult writes success response in json format to current http context
func PrintJsonSuccessResult(c *core.WebContext, result any) {
c.JSON(http.StatusOK, core.O{
@@ -32,23 +49,10 @@ func PrintDataSuccessResult(c *core.WebContext, contentType string, fileName str
func PrintJsonErrorResult(c *core.WebContext, err *errs.Error) {
c.SetResponseError(err)
errorMessage := err.Error()
if err.Code() == errs.ErrIncompleteOrIncorrectSubmission.Code() && len(err.BaseError) > 0 {
validationErrors, ok := err.BaseError[0].(validator.ValidationErrors)
if ok {
for _, err := range validationErrors {
errorMessage = getValidationErrorText(err)
break
}
}
}
result := core.O{
"success": false,
"errorCode": err.Code(),
"errorMessage": errorMessage,
"errorMessage": GetDisplayErrorMessage(err),
"path": c.Request.URL.Path,
}
@@ -68,19 +72,6 @@ func PrintJSONRPCSuccessResult(c *core.WebContext, jsonRPCRequest *core.JSONRPCR
func PrintJSONRPCErrorResult(c *core.WebContext, jsonRPCRequest *core.JSONRPCRequest, err *errs.Error) {
c.SetResponseError(err)
errorMessage := err.Error()
if err.Code() == errs.ErrIncompleteOrIncorrectSubmission.Code() && len(err.BaseError) > 0 {
validationErrors, ok := err.BaseError[0].(validator.ValidationErrors)
if ok {
for _, err := range validationErrors {
errorMessage = getValidationErrorText(err)
break
}
}
}
var id any
if jsonRPCRequest != nil {
@@ -97,27 +88,13 @@ func PrintJSONRPCErrorResult(c *core.WebContext, jsonRPCRequest *core.JSONRPCReq
jsonRPCError = core.JSONRPCInvalidParamsError
}
c.AbortWithStatusJSON(err.HttpStatusCode, core.NewJSONRPCErrorResponseWithCause(id, jsonRPCError, errorMessage))
c.AbortWithStatusJSON(err.HttpStatusCode, core.NewJSONRPCErrorResponseWithCause(id, jsonRPCError, GetDisplayErrorMessage(err)))
}
// PrintDataErrorResult writes error response in custom content type to current http context
func PrintDataErrorResult(c *core.WebContext, contentType string, err *errs.Error) {
c.SetResponseError(err)
errorMessage := err.Error()
if err.Code() == errs.ErrIncompleteOrIncorrectSubmission.Code() && len(err.BaseError) > 0 {
validationErrors, ok := err.BaseError[0].(validator.ValidationErrors)
if ok {
for _, err := range validationErrors {
errorMessage = getValidationErrorText(err)
break
}
}
}
c.Data(err.HttpStatusCode, contentType, []byte(errorMessage))
c.Data(err.HttpStatusCode, contentType, []byte(GetDisplayErrorMessage(err)))
c.Abort()
}
@@ -149,23 +126,10 @@ func WriteEventStreamJsonSuccessResult(c *core.WebContext, result any) {
func WriteEventStreamJsonErrorResult(c *core.WebContext, originalErr *errs.Error) {
c.SetResponseError(originalErr)
errorMessage := originalErr.Error()
if originalErr.Code() == errs.ErrIncompleteOrIncorrectSubmission.Code() && len(originalErr.BaseError) > 0 {
validationErrors, ok := originalErr.BaseError[0].(validator.ValidationErrors)
if ok {
for _, err := range validationErrors {
errorMessage = getValidationErrorText(err)
break
}
}
}
result := core.O{
"success": false,
"errorCode": originalErr.Code(),
"errorMessage": errorMessage,
"errorMessage": GetDisplayErrorMessage(originalErr),
"path": c.Request.URL.Path,
}
+37
View File
@@ -1,10 +1,47 @@
package utils
import (
"crypto/tls"
"net/http"
"net/url"
"time"
)
type defaultTransport struct {
defaultUserAgent string
baseTransport http.RoundTripper
}
func (t *defaultTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if len(req.Header.Values("User-Agent")) < 1 {
req.Header.Set("User-Agent", t.defaultUserAgent)
} else if req.Header.Get("User-Agent") == "" {
req.Header.Del("User-Agent")
}
return t.baseTransport.RoundTrip(req)
}
// NewHttpClient creates and returns a new http client with specified settings
func NewHttpClient(requestTimeout uint32, proxy string, skipTLSVerify bool, defaultUserAgent string) *http.Client {
baseTransport := http.DefaultTransport.(*http.Transport).Clone()
SetProxyUrl(baseTransport, proxy)
if skipTLSVerify {
baseTransport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
}
}
return &http.Client{
Transport: &defaultTransport{
defaultUserAgent: defaultUserAgent,
baseTransport: baseTransport,
},
Timeout: time.Duration(requestTimeout) * time.Millisecond,
}
}
// SetProxyUrl sets proxy url to http transport according to specified proxy setting
func SetProxyUrl(transport *http.Transport, proxy string) {
if proxy == "none" {