mirror of
https://github.com/mayswind/ezbookkeeping.git
synced 2026-05-19 17:24:26 +08:00
support Nextcloud OAuth 2.0 authentication
This commit is contained in:
+142
-3
@@ -22,6 +22,7 @@ type AuthorizationsApi struct {
|
||||
userAppCloudSettings *services.UserApplicationCloudSettingsService
|
||||
tokens *services.TokenService
|
||||
twoFactorAuthorizations *services.TwoFactorAuthorizationService
|
||||
userExternalAuths *services.UserExternalAuthService
|
||||
}
|
||||
|
||||
// Initialize a authorization api singleton instance
|
||||
@@ -48,11 +49,16 @@ var (
|
||||
userAppCloudSettings: services.UserApplicationCloudSettings,
|
||||
tokens: services.Tokens,
|
||||
twoFactorAuthorizations: services.TwoFactorAuthorizations,
|
||||
userExternalAuths: services.UserExternalAuths,
|
||||
}
|
||||
)
|
||||
|
||||
// AuthorizeHandler verifies and authorizes current login request
|
||||
func (a *AuthorizationsApi) AuthorizeHandler(c *core.WebContext) (any, *errs.Error) {
|
||||
if !a.CurrentConfig().EnableInternalAuth {
|
||||
return nil, errs.ErrCannotLoginByPassword
|
||||
}
|
||||
|
||||
var credential models.UserLoginRequest
|
||||
err := c.ShouldBindJSON(&credential)
|
||||
|
||||
@@ -151,7 +157,7 @@ func (a *AuthorizationsApi) AuthorizeHandler(c *core.WebContext) (any, *errs.Err
|
||||
applicationCloudSettingSlice = &userApplicationCloudSettings.Settings
|
||||
}
|
||||
|
||||
log.Infof(c, "[authorizations.AuthorizeHandler] user \"uid:%d\" has logined, token type is %d, token will be expired at %d", user.Uid, claims.Type, claims.ExpiresAt)
|
||||
log.Infof(c, "[authorizations.AuthorizeHandler] user \"uid:%d\" has logged in, token type is %d, token will be expired at %d", user.Uid, claims.Type, claims.ExpiresAt)
|
||||
|
||||
authResp := a.getAuthResponse(c, token, twoFactorEnable, user, applicationCloudSettingSlice)
|
||||
return authResp, nil
|
||||
@@ -159,6 +165,10 @@ func (a *AuthorizationsApi) AuthorizeHandler(c *core.WebContext) (any, *errs.Err
|
||||
|
||||
// TwoFactorAuthorizeHandler verifies and authorizes current 2fa login by passcode
|
||||
func (a *AuthorizationsApi) TwoFactorAuthorizeHandler(c *core.WebContext) (any, *errs.Error) {
|
||||
if !a.CurrentConfig().EnableInternalAuth {
|
||||
return nil, errs.ErrCannotLoginByPassword
|
||||
}
|
||||
|
||||
var credential models.TwoFactorLoginRequest
|
||||
err := c.ShouldBindJSON(&credential)
|
||||
|
||||
@@ -198,7 +208,7 @@ func (a *AuthorizationsApi) TwoFactorAuthorizeHandler(c *core.WebContext) (any,
|
||||
user, err := a.users.GetUserById(c, uid)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[authorizations.TwoFactorAuthorizeHandler] failed to get user \"uid:%d\" info, because %s", user.Uid, err.Error())
|
||||
log.Errorf(c, "[authorizations.TwoFactorAuthorizeHandler] failed to get user \"uid:%d\" info, because %s", uid, err.Error())
|
||||
return nil, errs.ErrUserNotFound
|
||||
}
|
||||
|
||||
@@ -246,6 +256,10 @@ func (a *AuthorizationsApi) TwoFactorAuthorizeHandler(c *core.WebContext) (any,
|
||||
|
||||
// TwoFactorAuthorizeByRecoveryCodeHandler verifies and authorizes current 2fa login by recovery code
|
||||
func (a *AuthorizationsApi) TwoFactorAuthorizeByRecoveryCodeHandler(c *core.WebContext) (any, *errs.Error) {
|
||||
if !a.CurrentConfig().EnableInternalAuth {
|
||||
return nil, errs.ErrCannotLoginByPassword
|
||||
}
|
||||
|
||||
var credential models.TwoFactorRecoveryCodeLoginRequest
|
||||
err := c.ShouldBindJSON(&credential)
|
||||
|
||||
@@ -276,7 +290,7 @@ func (a *AuthorizationsApi) TwoFactorAuthorizeByRecoveryCodeHandler(c *core.WebC
|
||||
user, err := a.users.GetUserById(c, uid)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[authorizations.TwoFactorAuthorizeByRecoveryCodeHandler] failed to get user \"uid:%d\" info, because %s", user.Uid, err.Error())
|
||||
log.Errorf(c, "[authorizations.TwoFactorAuthorizeByRecoveryCodeHandler] failed to get user \"uid:%d\" info, because %s", uid, err.Error())
|
||||
return nil, errs.ErrUserNotFound
|
||||
}
|
||||
|
||||
@@ -338,6 +352,131 @@ func (a *AuthorizationsApi) TwoFactorAuthorizeByRecoveryCodeHandler(c *core.WebC
|
||||
return authResp, nil
|
||||
}
|
||||
|
||||
// OAuth2CallbackAuthorizeHandler verifies and authorizes current OAuth 2.0 callback login
|
||||
func (a *AuthorizationsApi) OAuth2CallbackAuthorizeHandler(c *core.WebContext) (any, *errs.Error) {
|
||||
if !a.CurrentConfig().EnableOAuth2Login {
|
||||
return nil, errs.ErrOAuth2NotEnabled
|
||||
}
|
||||
|
||||
var credential models.OAuth2CallbackLoginRequest
|
||||
err := c.ShouldBindJSON(&credential)
|
||||
|
||||
if err != nil {
|
||||
log.Warnf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] parse request failed, because %s", err.Error())
|
||||
return nil, errs.NewIncompleteOrIncorrectSubmissionError(err)
|
||||
}
|
||||
|
||||
userExternalAuthType := core.UserExternalAuthType(credential.Provider)
|
||||
|
||||
if !userExternalAuthType.IsValid() {
|
||||
log.Warnf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] provider \"%s\" is invalid", credential.Provider)
|
||||
return nil, errs.ErrInvalidOAuth2Provider
|
||||
}
|
||||
|
||||
uid := c.GetCurrentUid()
|
||||
err = a.CheckFailureCount(c, uid)
|
||||
|
||||
if err != nil {
|
||||
log.Warnf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] cannot auth for user \"uid:%d\", because %s", uid, err.Error())
|
||||
return nil, errs.Or(err, errs.ErrFailureCountLimitReached)
|
||||
}
|
||||
|
||||
user, err := a.users.GetUserById(c, uid)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] failed to get user \"uid:%d\" info, because %s", uid, err.Error())
|
||||
return nil, errs.ErrUserNotFound
|
||||
}
|
||||
|
||||
if user.Disabled {
|
||||
log.Warnf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] user \"uid:%d\" is disabled", user.Uid)
|
||||
return nil, errs.ErrUserIsDisabled
|
||||
}
|
||||
|
||||
if a.CurrentConfig().EnableUserForceVerifyEmail && !user.EmailVerified {
|
||||
log.Warnf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] user \"uid:%d\" has not verified email", user.Uid)
|
||||
return nil, errs.ErrEmailIsNotVerified
|
||||
}
|
||||
|
||||
oldTokenClaims := c.GetTokenClaims()
|
||||
|
||||
if oldTokenClaims.Type == core.USER_TOKEN_TYPE_OAUTH2_CALLBACK_REQUIRE_VERIFY {
|
||||
if credential.Password == "" {
|
||||
return nil, errs.ErrPasswordIsEmpty
|
||||
}
|
||||
|
||||
if !a.users.IsPasswordEqualsUserPassword(credential.Password, user) {
|
||||
failureCheckErr := a.CheckAndIncreaseFailureCount(c, uid)
|
||||
|
||||
if failureCheckErr != nil {
|
||||
log.Warnf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] cannot login for user \"uid:%d\", because %s", user.Uid, failureCheckErr.Error())
|
||||
return nil, errs.Or(failureCheckErr, errs.ErrFailureCountLimitReached)
|
||||
}
|
||||
|
||||
return nil, errs.ErrUserPasswordWrong
|
||||
}
|
||||
|
||||
userExternalAuth := &models.UserExternalAuth{
|
||||
Uid: user.Uid,
|
||||
ExternalAuthType: userExternalAuthType,
|
||||
}
|
||||
|
||||
if a.CurrentConfig().OAuth2UserIdentifier == settings.OAuth2UserIdentifierEmail {
|
||||
userExternalAuth.ExternalEmail = user.Email
|
||||
} else if a.CurrentConfig().OAuth2UserIdentifier == settings.OAuth2UserIdentifierUsername {
|
||||
userExternalAuth.ExternalUsername = user.Username
|
||||
}
|
||||
|
||||
err = a.userExternalAuths.CreateUserExternalAuth(c, userExternalAuth)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] failed to create user external auth for user \"uid:%d\", because %s", user.Uid, err.Error())
|
||||
return nil, errs.Or(err, errs.ErrOperationFailed)
|
||||
}
|
||||
|
||||
log.Infof(c, "[authorizations.OAuth2CallbackAuthorizeHandler] user external auth has been created for user \"uid:%d\"", user.Uid)
|
||||
} else if oldTokenClaims.Type == core.USER_TOKEN_TYPE_OAUTH2_CALLBACK {
|
||||
_, err = a.userExternalAuths.GetUserExternalAuthByUid(c, uid, userExternalAuthType)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] failed to get user external auth for user \"uid:%d\", because %s", uid, err.Error())
|
||||
return nil, errs.Or(err, errs.ErrUserExternalAuthNotFound)
|
||||
}
|
||||
} else {
|
||||
return nil, errs.ErrSystemError
|
||||
}
|
||||
|
||||
err = a.tokens.DeleteTokenByClaims(c, oldTokenClaims)
|
||||
|
||||
if err != nil {
|
||||
log.Warnf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] failed to revoke temporary token \"utid:%s\" for user \"uid:%d\", because %s", oldTokenClaims.UserTokenId, user.Uid, err.Error())
|
||||
}
|
||||
|
||||
token, claims, err := a.tokens.CreateToken(c, user)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] failed to create token for user \"uid:%d\", because %s", user.Uid, err.Error())
|
||||
return nil, errs.ErrTokenGenerating
|
||||
}
|
||||
|
||||
c.SetTextualToken(token)
|
||||
c.SetTokenClaims(claims)
|
||||
|
||||
userApplicationCloudSettings, err := a.userAppCloudSettings.GetUserApplicationCloudSettingsByUid(c, user.Uid)
|
||||
var applicationCloudSettingSlice *models.ApplicationCloudSettingSlice = nil
|
||||
|
||||
if err != nil {
|
||||
log.Warnf(c, "[authorizations.OAuth2CallbackAuthorizeHandler] failed to get latest user application cloud settings for user \"uid:%d\", because %s", user.Uid, err.Error())
|
||||
} else if userApplicationCloudSettings != nil && len(userApplicationCloudSettings.Settings) > 0 {
|
||||
applicationCloudSettingSlice = &userApplicationCloudSettings.Settings
|
||||
}
|
||||
|
||||
log.Infof(c, "[authorizations.OAuth2CallbackAuthorizeHandler] user \"uid:%d\" has logged in, token will be expired at %d", user.Uid, claims.ExpiresAt)
|
||||
|
||||
authResp := a.getAuthResponse(c, token, false, user, applicationCloudSettingSlice)
|
||||
return authResp, nil
|
||||
}
|
||||
|
||||
func (a *AuthorizationsApi) getAuthResponse(c *core.WebContext, token string, need2FA bool, user *models.User, applicationCloudSettings *models.ApplicationCloudSettingSlice) *models.AuthResponse {
|
||||
return &models.AuthResponse{
|
||||
Token: token,
|
||||
|
||||
@@ -3,6 +3,7 @@ package api
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/avatars"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
@@ -120,6 +121,13 @@ func (a *ApiUsingDuplicateChecker) SetSubmissionRemarkIfEnable(checkerType dupli
|
||||
}
|
||||
}
|
||||
|
||||
// SetSubmissionRemarkWithCustomExpirationIfEnable saves the identification and remark by the current duplicate checker with custom expiration time if the duplicate submission check is enabled
|
||||
func (a *ApiUsingDuplicateChecker) SetSubmissionRemarkWithCustomExpirationIfEnable(checkerType duplicatechecker.DuplicateCheckerType, uid int64, identification string, remark string, expiration time.Duration) {
|
||||
if a.CurrentConfig().EnableDuplicateSubmissionsCheck {
|
||||
a.container.SetSubmissionRemarkWithCustomExpiration(checkerType, uid, identification, remark, expiration)
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveSubmissionRemarkIfEnable removes the identification and remark by the current duplicate checker if the duplicate submission check is enabled
|
||||
func (a *ApiUsingDuplicateChecker) RemoveSubmissionRemarkIfEnable(checkerType duplicatechecker.DuplicateCheckerType, uid int64, identification string) {
|
||||
if a.CurrentConfig().EnableDuplicateSubmissionsCheck {
|
||||
|
||||
@@ -0,0 +1,313 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/duplicatechecker"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/locales"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/log"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/models"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/services"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/utils"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/validators"
|
||||
)
|
||||
|
||||
const oauth2CallbackPageUrlSuccessFormat = "%sdesktop/#/oauth2_callback?platform=%s&provider=%s&token=%s"
|
||||
const oauth2CallbackPageUrlNeedVerifyFormat = "%sdesktop/#/oauth2_callback?platform=%s&provider=%s&userName=%s&token=%s"
|
||||
const oauth2CallbackPageUrlFailedFormat = "%sdesktop/#/oauth2_callback?error=%s"
|
||||
|
||||
// OAuth2AuthenticationApi represents OAuth 2.0 authorization api
|
||||
type OAuth2AuthenticationApi struct {
|
||||
ApiUsingConfig
|
||||
ApiUsingDuplicateChecker
|
||||
users *services.UserService
|
||||
tokens *services.TokenService
|
||||
userExternalAuths *services.UserExternalAuthService
|
||||
}
|
||||
|
||||
// Initialize a OAuth 2.0 authentication api singleton instance
|
||||
var (
|
||||
OAuth2Authentications = &OAuth2AuthenticationApi{
|
||||
ApiUsingConfig: ApiUsingConfig{
|
||||
container: settings.Container,
|
||||
},
|
||||
ApiUsingDuplicateChecker: ApiUsingDuplicateChecker{
|
||||
ApiUsingConfig: ApiUsingConfig{
|
||||
container: settings.Container,
|
||||
},
|
||||
container: duplicatechecker.Container,
|
||||
},
|
||||
users: services.Users,
|
||||
tokens: services.Tokens,
|
||||
userExternalAuths: services.UserExternalAuths,
|
||||
}
|
||||
)
|
||||
|
||||
// LoginHandler handles user login request via OAuth 2.0
|
||||
func (a *OAuth2AuthenticationApi) LoginHandler(c *core.WebContext) (string, *errs.Error) {
|
||||
var oauth2LoginReq models.OAuth2LoginRequest
|
||||
err := c.ShouldBindQuery(&oauth2LoginReq)
|
||||
|
||||
if err != nil {
|
||||
log.Warnf(c, "[oauth2_authentications.LoginHandler] parse request failed, because %s", err.Error())
|
||||
return "", errs.NewIncompleteOrIncorrectSubmissionError(err)
|
||||
}
|
||||
|
||||
if oauth2LoginReq.Platform != "mobile" && oauth2LoginReq.Platform != "desktop" {
|
||||
return "", errs.ErrInvalidOAuth2LoginRequest
|
||||
}
|
||||
|
||||
state := fmt.Sprintf("%s|%s", oauth2LoginReq.Platform, oauth2LoginReq.ClientSessionId)
|
||||
remark := ""
|
||||
|
||||
if a.CurrentConfig().EnableDuplicateSubmissionsCheck {
|
||||
found := false
|
||||
found, remark = a.GetSubmissionRemark(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, oauth2LoginReq.ClientSessionId)
|
||||
|
||||
if found {
|
||||
log.Errorf(c, "[oauth2_authentications.LoginHandler] another oauth 2.0 state \"%s\" has been processing for client session id \"%s\"", remark, oauth2LoginReq.ClientSessionId)
|
||||
return "", errs.ErrRepeatedRequest
|
||||
}
|
||||
|
||||
randomString, err := utils.GetRandomNumberOrLowercaseLetter(32)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[oauth2_authentications.LoginHandler] failed to generate random string for oauth 2.0 state, because %s", err.Error())
|
||||
return "", errs.ErrSystemError
|
||||
}
|
||||
|
||||
remark = fmt.Sprintf("%s|%s|%s", oauth2LoginReq.Platform, oauth2LoginReq.ClientSessionId, randomString)
|
||||
state = fmt.Sprintf("%s|%s|%s", oauth2LoginReq.Platform, oauth2LoginReq.ClientSessionId, utils.MD5EncodeToString([]byte(remark)))
|
||||
}
|
||||
|
||||
redirectUrl, err := oauth2.GetOAuth2AuthUrl(c, state)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[oauth2_authentications.LoginHandler] failed to get oauth 2.0 auth url, because %s", err.Error())
|
||||
return "", errs.Or(err, errs.ErrSystemError)
|
||||
}
|
||||
|
||||
a.SetSubmissionRemarkWithCustomExpirationIfEnable(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, oauth2LoginReq.ClientSessionId, remark, a.CurrentConfig().OAuth2StateExpiredTimeDuration)
|
||||
|
||||
return redirectUrl, nil
|
||||
}
|
||||
|
||||
// CallbackHandler handles OAuth 2.0 callback request
|
||||
func (a *OAuth2AuthenticationApi) CallbackHandler(c *core.WebContext) (string, *errs.Error) {
|
||||
var oauth2CallbackReq models.OAuth2CallbackRequest
|
||||
err := c.ShouldBindQuery(&oauth2CallbackReq)
|
||||
|
||||
if err != nil {
|
||||
log.Warnf(c, "[oauth2_authentications.CallbackHandler] parse request failed, because %s", err.Error())
|
||||
return a.redirectToFailedCallbackPage(c, errs.NewIncompleteOrIncorrectSubmissionError(err))
|
||||
}
|
||||
|
||||
if oauth2CallbackReq.State == "" {
|
||||
return a.redirectToFailedCallbackPage(c, errs.ErrMissingOAuth2State)
|
||||
}
|
||||
|
||||
if oauth2CallbackReq.Code == "" {
|
||||
return a.redirectToFailedCallbackPage(c, errs.ErrMissingOAuth2Code)
|
||||
}
|
||||
|
||||
platform := ""
|
||||
clientSessionId := ""
|
||||
|
||||
stateParts := strings.Split(oauth2CallbackReq.State, "|")
|
||||
|
||||
if len(stateParts) >= 2 {
|
||||
platform = stateParts[0]
|
||||
clientSessionId = stateParts[1]
|
||||
} else {
|
||||
return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2State)
|
||||
}
|
||||
|
||||
if platform != "mobile" && platform != "desktop" {
|
||||
return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2LoginRequest)
|
||||
}
|
||||
|
||||
if a.CurrentConfig().EnableDuplicateSubmissionsCheck {
|
||||
found, remark := a.GetSubmissionRemark(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, clientSessionId)
|
||||
|
||||
if !found {
|
||||
log.Errorf(c, "[oauth2_authentications.CallbackHandler] cannot find oauth 2.0 state in duplicate checker for client session id \"%s\"", clientSessionId)
|
||||
return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2Callback)
|
||||
}
|
||||
|
||||
remarkParts := strings.Split(remark, "|")
|
||||
|
||||
if len(remarkParts) != 3 || remarkParts[0] != platform || remarkParts[1] != clientSessionId {
|
||||
log.Errorf(c, "[oauth2_authentications.CallbackHandler] invalid oauth 2.0 state \"%s\" in duplicate checker for client session id \"%s\"", remark, clientSessionId)
|
||||
return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2State)
|
||||
}
|
||||
|
||||
expectedState := fmt.Sprintf("%s|%s|%s", platform, clientSessionId, remarkParts[2])
|
||||
expectedState = fmt.Sprintf("%s|%s|%s", platform, clientSessionId, utils.MD5EncodeToString([]byte(expectedState)))
|
||||
|
||||
if oauth2CallbackReq.State != expectedState {
|
||||
log.Errorf(c, "[oauth2_authentications.CallbackHandler] mismatched random string in oauth 2.0 state, expected \"%s\", got \"%s\"", expectedState, oauth2CallbackReq.State)
|
||||
return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2State)
|
||||
}
|
||||
|
||||
a.RemoveSubmissionRemarkIfEnable(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, clientSessionId)
|
||||
}
|
||||
|
||||
oauth2Token, err := oauth2.GetOAuth2Token(c, oauth2CallbackReq.Code)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[oauth2_authentications.CallbackHandler] failed to retrieve oauth 2.0 token, because %s", err.Error())
|
||||
return a.redirectToFailedCallbackPage(c, errs.Or(err, errs.ErrCannotRetrieveOAuth2Token))
|
||||
}
|
||||
|
||||
oauth2UserInfo, err := oauth2.GetOAuth2UserInfo(c, oauth2Token)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[oauth2_authentications.CallbackHandler] failed to retrieve oauth 2.0 user info, because %s", err.Error())
|
||||
return a.redirectToFailedCallbackPage(c, errs.Or(err, errs.ErrInvalidOAuth2Token))
|
||||
}
|
||||
|
||||
if oauth2UserInfo == nil {
|
||||
log.Errorf(c, "[oauth2_authentications.CallbackHandler] failed to retrieve oauth 2.0 user info, because user info is nil")
|
||||
return a.redirectToFailedCallbackPage(c, errs.ErrCannotRetrieveUserInfo)
|
||||
}
|
||||
|
||||
if oauth2UserInfo.UserName == "" || oauth2UserInfo.Email == "" {
|
||||
log.Errorf(c, "[oauth2_authentications.CallbackHandler] invalid oauth 2.0 user info, userName: %s, email: %s", oauth2UserInfo.UserName, oauth2UserInfo.Email)
|
||||
return a.redirectToFailedCallbackPage(c, errs.ErrCannotRetrieveUserInfo)
|
||||
}
|
||||
|
||||
userExternalAuthType := oauth2.GetExternalUserAuthType()
|
||||
var userExternalAuth *models.UserExternalAuth
|
||||
|
||||
if a.CurrentConfig().OAuth2UserIdentifier == settings.OAuth2UserIdentifierEmail {
|
||||
userExternalAuth, err = a.userExternalAuths.GetUserExternalAuthByExternalEmail(c, oauth2UserInfo.Email, userExternalAuthType)
|
||||
} else if a.CurrentConfig().OAuth2UserIdentifier == settings.OAuth2UserIdentifierUsername {
|
||||
userExternalAuth, err = a.userExternalAuths.GetUserExternalAuthByExternalUserName(c, oauth2UserInfo.UserName, userExternalAuthType)
|
||||
} else {
|
||||
userExternalAuth, err = a.userExternalAuths.GetUserExternalAuthByExternalEmail(c, oauth2UserInfo.Email, userExternalAuthType)
|
||||
}
|
||||
|
||||
if err != nil && !errors.Is(err, errs.ErrUserExternalAuthNotFound) {
|
||||
log.Errorf(c, "[oauth2_authentications.CallbackHandler] failed to get user external auth, because %s", err.Error())
|
||||
return a.redirectToFailedCallbackPage(c, errs.Or(err, errs.ErrOperationFailed))
|
||||
}
|
||||
|
||||
var user *models.User
|
||||
|
||||
if err == nil { // user already bound to external auth, redirect to success page
|
||||
user, err = a.users.GetUserById(c, userExternalAuth.Uid)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[oauth2_authentications.CallbackHandler] failed to get user by id %d, because %s", userExternalAuth.Uid, err.Error())
|
||||
return a.redirectToFailedCallbackPage(c, errs.Or(err, errs.ErrOperationFailed))
|
||||
}
|
||||
} else if errors.Is(err, errs.ErrUserExternalAuthNotFound) { // user not bound to external auth, try to bind or register new user
|
||||
if a.CurrentConfig().OAuth2UserIdentifier == settings.OAuth2UserIdentifierEmail {
|
||||
user, err = a.users.GetUserByEmail(c, oauth2UserInfo.Email)
|
||||
} else if a.CurrentConfig().OAuth2UserIdentifier == settings.OAuth2UserIdentifierUsername {
|
||||
user, err = a.users.GetUserByUsername(c, oauth2UserInfo.UserName)
|
||||
} else {
|
||||
user, err = a.users.GetUserByEmail(c, oauth2UserInfo.Email)
|
||||
}
|
||||
|
||||
if err != nil && !errors.Is(err, errs.ErrUserNotFound) {
|
||||
log.Errorf(c, "[oauth2_authentications.CallbackHandler] failed to get user, because %s", err.Error())
|
||||
return a.redirectToFailedCallbackPage(c, errs.Or(err, errs.ErrOperationFailed))
|
||||
}
|
||||
|
||||
if user == nil && a.CurrentConfig().EnableUserRegister && a.CurrentConfig().OAuth2AutoRegister {
|
||||
userName := strings.TrimSpace(oauth2UserInfo.UserName)
|
||||
email := strings.TrimSpace(oauth2UserInfo.Email)
|
||||
nickName := strings.TrimSpace(oauth2UserInfo.NickName)
|
||||
languageCode := ""
|
||||
currencyCode := "USD"
|
||||
|
||||
if _, exists := locales.AllLanguages[oauth2UserInfo.LanguageCode]; exists {
|
||||
languageCode = oauth2UserInfo.LanguageCode
|
||||
}
|
||||
|
||||
if _, exists := validators.AllCurrencyNames[oauth2UserInfo.CurrencyCode]; exists {
|
||||
currencyCode = oauth2UserInfo.CurrencyCode
|
||||
}
|
||||
|
||||
user = &models.User{
|
||||
Username: userName,
|
||||
Email: email,
|
||||
Nickname: nickName,
|
||||
Password: "",
|
||||
Language: languageCode,
|
||||
DefaultCurrency: currencyCode,
|
||||
FirstDayOfWeek: oauth2UserInfo.FirstDayOfWeek,
|
||||
FiscalYearStart: core.FISCAL_YEAR_START_DEFAULT,
|
||||
TransactionEditScope: models.TRANSACTION_EDIT_SCOPE_ALL,
|
||||
FeatureRestriction: a.CurrentConfig().DefaultFeatureRestrictions,
|
||||
}
|
||||
|
||||
err = a.users.CreateUser(c, user)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[oauth2_authentications.CallbackHandler] failed to create user \"%s\", because %s", user.Username, err.Error())
|
||||
return a.redirectToFailedCallbackPage(c, errs.Or(err, errs.ErrOperationFailed))
|
||||
}
|
||||
|
||||
log.Infof(c, "[oauth2_authentications.CallbackHandler] user \"%s\" has registered successfully, uid is %d", user.Username, user.Uid)
|
||||
|
||||
userExternalAuth := &models.UserExternalAuth{
|
||||
Uid: user.Uid,
|
||||
ExternalAuthType: userExternalAuthType,
|
||||
ExternalUsername: oauth2UserInfo.UserName,
|
||||
ExternalEmail: oauth2UserInfo.Email,
|
||||
}
|
||||
|
||||
err = a.userExternalAuths.CreateUserExternalAuth(c, userExternalAuth)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[oauth2_authentications.CallbackHandler] failed to create user external auth for user \"uid:%d\", because %s", user.Uid, err.Error())
|
||||
return a.redirectToFailedCallbackPage(c, errs.Or(err, errs.ErrOperationFailed))
|
||||
}
|
||||
|
||||
log.Infof(c, "[oauth2_authentications.CallbackHandler] user external auth has been created for user \"uid:%d\"", user.Uid)
|
||||
} else if user == nil {
|
||||
return a.redirectToFailedCallbackPage(c, errs.ErrOAuth2AutoRegistrationNotEnabled)
|
||||
}
|
||||
}
|
||||
|
||||
if userExternalAuth == nil {
|
||||
token, _, err := a.tokens.CreateOAuth2CallbackRequireVerifyToken(c, user)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[oauth2_authentications.CallbackHandler] failed to create oauth 2.0 callback verify token for user \"uid:%d\", because %s", user.Uid, err.Error())
|
||||
return a.redirectToFailedCallbackPage(c, errs.ErrTokenGenerating)
|
||||
}
|
||||
|
||||
return a.redirectToVerifyCallbackPage(c, platform, userExternalAuthType, user.Username, token)
|
||||
} else {
|
||||
token, _, err := a.tokens.CreateOAuth2CallbackToken(c, user)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[oauth2_authentications.CallbackHandler] failed to create oauth 2.0 callback token for user \"uid:%d\", because %s", user.Uid, err.Error())
|
||||
return a.redirectToFailedCallbackPage(c, errs.ErrTokenGenerating)
|
||||
}
|
||||
|
||||
return a.redirectToSuccessCallbackPage(c, platform, userExternalAuthType, token)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *OAuth2AuthenticationApi) redirectToSuccessCallbackPage(c *core.WebContext, platform string, externalAuthType core.UserExternalAuthType, token string) (string, *errs.Error) {
|
||||
return fmt.Sprintf(oauth2CallbackPageUrlSuccessFormat, a.CurrentConfig().RootUrl, platform, externalAuthType, url.QueryEscape(token)), nil
|
||||
}
|
||||
|
||||
func (a *OAuth2AuthenticationApi) redirectToVerifyCallbackPage(c *core.WebContext, platform string, externalAuthType core.UserExternalAuthType, userName string, token string) (string, *errs.Error) {
|
||||
return fmt.Sprintf(oauth2CallbackPageUrlNeedVerifyFormat, a.CurrentConfig().RootUrl, platform, externalAuthType, userName, url.QueryEscape(token)), nil
|
||||
}
|
||||
|
||||
func (a *OAuth2AuthenticationApi) redirectToFailedCallbackPage(c *core.WebContext, err *errs.Error) (string, *errs.Error) {
|
||||
return fmt.Sprintf(oauth2CallbackPageUrlFailedFormat, a.CurrentConfig().RootUrl, url.QueryEscape(utils.GetDisplayErrorMessage(err))), nil
|
||||
}
|
||||
@@ -35,14 +35,18 @@ func (a *ServerSettingsApi) ServerSettingsJavascriptHandler(c *core.WebContext)
|
||||
builder := &strings.Builder{}
|
||||
builder.WriteString(ezbookkeepingServerSettingsJavascriptFileHeader)
|
||||
|
||||
a.appendBooleanSetting(builder, "r", config.EnableUserRegister)
|
||||
a.appendBooleanSetting(builder, "f", config.EnableUserForgetPassword)
|
||||
a.appendBooleanSetting(builder, "a", config.EnableInternalAuth)
|
||||
a.appendBooleanSetting(builder, "o", config.EnableOAuth2Login)
|
||||
a.appendBooleanSetting(builder, "r", config.EnableInternalAuth && config.EnableUserRegister)
|
||||
a.appendBooleanSetting(builder, "f", config.EnableInternalAuth && config.EnableUserForgetPassword)
|
||||
a.appendBooleanSetting(builder, "v", config.EnableUserVerifyEmail)
|
||||
a.appendBooleanSetting(builder, "p", config.EnableTransactionPictures)
|
||||
a.appendBooleanSetting(builder, "s", config.EnableScheduledTransaction)
|
||||
a.appendBooleanSetting(builder, "e", config.EnableDataExport)
|
||||
a.appendBooleanSetting(builder, "i", config.EnableDataImport)
|
||||
|
||||
a.appendStringSetting(builder, "op", config.OAuth2Provider)
|
||||
|
||||
if config.EnableMCPServer {
|
||||
a.appendBooleanSetting(builder, "mcp", config.EnableMCPServer)
|
||||
}
|
||||
@@ -138,7 +142,7 @@ func (a *ServerSettingsApi) appendStringSetting(builder *strings.Builder, key st
|
||||
builder.WriteString(";\n")
|
||||
}
|
||||
|
||||
func (a *ServerSettingsApi) appendMultiLanguageTipSetting(builder *strings.Builder, key string, value settings.TipConfig) {
|
||||
func (a *ServerSettingsApi) appendMultiLanguageTipSetting(builder *strings.Builder, key string, value settings.MultiLanguageContentConfig) {
|
||||
builder.WriteString(ezbookkeepingServerSettingsGlobalVariableFullName)
|
||||
builder.WriteString("[")
|
||||
a.appendEncodedString(builder, key)
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/log"
|
||||
)
|
||||
|
||||
type nextcloudUserInfoResponse struct {
|
||||
OCS *struct {
|
||||
Meta *struct {
|
||||
Status string `json:"status"`
|
||||
StatusCode int `json:"statuscode"`
|
||||
} `json:"meta"`
|
||||
Data *struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
DisplayName string `json:"display-name"`
|
||||
} `json:"data"`
|
||||
} `json:"ocs"`
|
||||
}
|
||||
|
||||
// NextcloudOAuth2Provider represents Nextcloud OAuth 2.0 provider
|
||||
type NextcloudOAuth2Provider struct {
|
||||
baseUrl string
|
||||
}
|
||||
|
||||
// NewNextcloudOAuth2Provider creates a new Nextcloud OAuth 2.0 provider instance
|
||||
func NewNextcloudOAuth2Provider(baseUrl string) OAuth2Provider {
|
||||
if baseUrl[len(baseUrl)-1] != '/' {
|
||||
baseUrl += "/"
|
||||
}
|
||||
|
||||
return &NextcloudOAuth2Provider{
|
||||
baseUrl: baseUrl,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAuthUrl returns the authentication url of the Nextcloud provider
|
||||
func (p *NextcloudOAuth2Provider) GetAuthUrl() string {
|
||||
return p.baseUrl + "apps/oauth2/authorize"
|
||||
}
|
||||
|
||||
// GetTokenUrl returns the token url of the Nextcloud provider
|
||||
func (p *NextcloudOAuth2Provider) GetTokenUrl() string {
|
||||
return p.baseUrl + "apps/oauth2/api/v1/token"
|
||||
}
|
||||
|
||||
// GetUserInfo returns the user info by the Nextcloud provider
|
||||
func (p *NextcloudOAuth2Provider) GetUserInfo(c core.Context, oauth2Client *http.Client) (*OAuth2UserInfo, error) {
|
||||
url := p.baseUrl + "ocs/v2.php/cloud/user?format=json"
|
||||
resp, err := oauth2Client.Get(url)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[nextcloud_oauth2_provider.GetUserInfo] failed to get user info response, because %s", err.Error())
|
||||
return nil, errs.ErrFailedToRequestRemoteApi
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
|
||||
log.Debugf(c, "[nextcloud_oauth2_provider.GetUserInfo] response is %s", body)
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
log.Errorf(c, "[nextcloud_oauth2_provider.GetUserInfo] failed to get user info response, because response code is %d", resp.StatusCode)
|
||||
return nil, errs.ErrFailedToRequestRemoteApi
|
||||
}
|
||||
|
||||
return p.parseUserInfo(c, body)
|
||||
}
|
||||
|
||||
// GetScopes returns the scopes required by the Nextcloud provider
|
||||
func (p *NextcloudOAuth2Provider) GetScopes() []string {
|
||||
return []string{"profile", "email"}
|
||||
}
|
||||
|
||||
func (p *NextcloudOAuth2Provider) parseUserInfo(c core.Context, body []byte) (*OAuth2UserInfo, error) {
|
||||
userInfoResp := &nextcloudUserInfoResponse{}
|
||||
err := json.Unmarshal(body, &userInfoResp)
|
||||
|
||||
if err != nil {
|
||||
log.Warnf(c, "[nextcloud_oauth2_provider.parseUserInfo] failed to parse user info response body, because %s", err.Error())
|
||||
return nil, errs.ErrCannotRetrieveUserInfo
|
||||
}
|
||||
|
||||
if userInfoResp.OCS == nil || userInfoResp.OCS.Meta == nil || userInfoResp.OCS.Data == nil {
|
||||
log.Warnf(c, "[nextcloud_oauth2_provider.parseUserInfo] invalid user info response body")
|
||||
return nil, errs.ErrCannotRetrieveUserInfo
|
||||
}
|
||||
|
||||
if userInfoResp.OCS.Meta.StatusCode != 200 {
|
||||
log.Warnf(c, "[nextcloud_oauth2_provider.parseUserInfo] user info response status code is %d", userInfoResp.OCS.Meta.StatusCode)
|
||||
return nil, errs.ErrCannotRetrieveUserInfo
|
||||
}
|
||||
|
||||
if userInfoResp.OCS.Data.ID == "" {
|
||||
log.Warnf(c, "[nextcloud_oauth2_provider.parseUserInfo] user info id is empty")
|
||||
return nil, errs.ErrCannotRetrieveUserInfo
|
||||
}
|
||||
|
||||
return &OAuth2UserInfo{
|
||||
UserName: userInfoResp.OCS.Data.ID,
|
||||
Email: userInfoResp.OCS.Data.Email,
|
||||
NickName: userInfoResp.OCS.Data.DisplayName,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/utils"
|
||||
)
|
||||
|
||||
// OAuth2Container contains the current OAuth 2.0 authentication provider
|
||||
type OAuth2Container struct {
|
||||
oauth2Config *oauth2.Config
|
||||
oauth2Provider OAuth2Provider
|
||||
oauth2HttpClient *http.Client
|
||||
externalUserAuthType core.UserExternalAuthType
|
||||
}
|
||||
|
||||
// Initialize a OAuth 2.0 container singleton instance
|
||||
var (
|
||||
Container = &OAuth2Container{}
|
||||
)
|
||||
|
||||
// InitializeOAuth2Provider initializes the current OAuth 2.0 provider according to the config
|
||||
func InitializeOAuth2Provider(config *settings.Config) error {
|
||||
if !config.EnableOAuth2Login {
|
||||
return nil
|
||||
}
|
||||
|
||||
if config.OAuth2ClientID == "" || config.OAuth2ClientSecret == "" || config.OAuth2UserIdentifier == "" || config.OAuth2Provider == "" {
|
||||
return errs.ErrInvalidOAuth2Config
|
||||
}
|
||||
|
||||
var oauth2Provider OAuth2Provider
|
||||
var externalUserAuthType core.UserExternalAuthType
|
||||
|
||||
if config.OAuth2Provider == settings.OAuth2ProviderNextcloud {
|
||||
oauth2Provider = NewNextcloudOAuth2Provider(config.OAuth2NextcloudBaseUrl)
|
||||
externalUserAuthType = core.USER_EXTERNAL_AUTH_TYPE_OAUTH2_NEXTCLOUD
|
||||
} else {
|
||||
return errs.ErrInvalidOAuth2Provider
|
||||
}
|
||||
|
||||
Container.oauth2Config = buildOAuth2Config(config, oauth2Provider)
|
||||
Container.oauth2Provider = oauth2Provider
|
||||
Container.oauth2HttpClient = utils.NewHttpClient(config.OAuth2RequestTimeout, config.OAuth2Proxy, config.OAuth2SkipTLSVerify, settings.GetUserAgent())
|
||||
Container.externalUserAuthType = externalUserAuthType
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetOAuth2AuthUrl returns the OAuth 2.0 authentication url
|
||||
func GetOAuth2AuthUrl(c core.Context, state string) (string, error) {
|
||||
if Container.oauth2Config == nil {
|
||||
return "", errs.ErrOAuth2NotEnabled
|
||||
}
|
||||
|
||||
return Container.oauth2Config.AuthCodeURL(state), nil
|
||||
}
|
||||
|
||||
// GetOAuth2Token exchanges the authorization code for an OAuth 2.0 token
|
||||
func GetOAuth2Token(c core.Context, code string) (*oauth2.Token, error) {
|
||||
if Container.oauth2Config == nil || Container.oauth2HttpClient == nil {
|
||||
return nil, errs.ErrOAuth2NotEnabled
|
||||
}
|
||||
|
||||
return Container.oauth2Config.Exchange(wrapOAuth2Context(c, Container.oauth2HttpClient), code)
|
||||
}
|
||||
|
||||
// GetOAuth2UserInfo retrieves the OAuth 2.0 user info using the provided OAuth 2.0 token
|
||||
func GetOAuth2UserInfo(c core.Context, token *oauth2.Token) (*OAuth2UserInfo, error) {
|
||||
if Container.oauth2Config == nil || Container.oauth2Provider == nil || Container.oauth2HttpClient == nil {
|
||||
return nil, errs.ErrOAuth2NotEnabled
|
||||
}
|
||||
|
||||
if token == nil {
|
||||
return nil, errs.ErrInvalidOAuth2Token
|
||||
}
|
||||
|
||||
oauth2Client := oauth2.NewClient(wrapOAuth2Context(c, Container.oauth2HttpClient), oauth2.StaticTokenSource(token))
|
||||
return Container.oauth2Provider.GetUserInfo(c, oauth2Client)
|
||||
}
|
||||
|
||||
// GetExternalUserAuthType returns the external user auth type of the current OAuth 2.0 provider
|
||||
func GetExternalUserAuthType() core.UserExternalAuthType {
|
||||
return Container.externalUserAuthType
|
||||
}
|
||||
|
||||
func buildOAuth2Config(config *settings.Config, oauth2Provider OAuth2Provider) *oauth2.Config {
|
||||
redirectURL := config.RootUrl + "oauth2/callback"
|
||||
|
||||
return &oauth2.Config{
|
||||
ClientID: config.OAuth2ClientID,
|
||||
ClientSecret: config.OAuth2ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: oauth2Provider.GetAuthUrl(),
|
||||
TokenURL: oauth2Provider.GetTokenUrl(),
|
||||
},
|
||||
RedirectURL: redirectURL,
|
||||
Scopes: oauth2Provider.GetScopes(),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
)
|
||||
|
||||
// OAuth2Context represents the context for OAuth 2.0 operations
|
||||
type OAuth2Context struct {
|
||||
core.Context
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// Value returns the value associated with key
|
||||
func (o *OAuth2Context) Value(key any) any {
|
||||
if key == oauth2.HTTPClient {
|
||||
return o.httpClient
|
||||
}
|
||||
|
||||
return o.Context.Value(key)
|
||||
}
|
||||
|
||||
func wrapOAuth2Context(ctx core.Context, httpClient *http.Client) core.Context {
|
||||
return &OAuth2Context{
|
||||
Context: ctx,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
)
|
||||
|
||||
// OAuth2Provider defines the structure of OAuth 2.0 provider
|
||||
type OAuth2Provider interface {
|
||||
// GetAuthUrl returns the authentication url of the provider
|
||||
GetAuthUrl() string
|
||||
|
||||
// GetTokenUrl returns the token url of the provider
|
||||
GetTokenUrl() string
|
||||
|
||||
// GetUserInfo returns the user info
|
||||
GetUserInfo(c core.Context, oauth2Client *http.Client) (*OAuth2UserInfo, error)
|
||||
|
||||
// GetScopes returns the scopes required by the provider
|
||||
GetScopes() []string
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package oauth2
|
||||
|
||||
import "github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
|
||||
// OAuth2UserInfo represents the user info retrieved from OAuth 2.0 provider
|
||||
type OAuth2UserInfo struct {
|
||||
UserName string
|
||||
Email string
|
||||
NickName string
|
||||
LanguageCode string
|
||||
CurrencyCode string
|
||||
FirstDayOfWeek core.WeekDay
|
||||
}
|
||||
@@ -12,6 +12,9 @@ type CliHandlerFunc func(*CliContext) error
|
||||
// MiddlewareHandlerFunc represents the middleware handler function
|
||||
type MiddlewareHandlerFunc func(*WebContext)
|
||||
|
||||
// RedirectHandlerFunc represents the redirect handler function
|
||||
type RedirectHandlerFunc func(*WebContext) (string, *errs.Error)
|
||||
|
||||
// ApiHandlerFunc represents the api handler function
|
||||
type ApiHandlerFunc func(*WebContext) (any, *errs.Error)
|
||||
|
||||
|
||||
@@ -11,11 +11,13 @@ type TokenType byte
|
||||
|
||||
// Token types
|
||||
const (
|
||||
USER_TOKEN_TYPE_NORMAL TokenType = 1
|
||||
USER_TOKEN_TYPE_REQUIRE_2FA TokenType = 2
|
||||
USER_TOKEN_TYPE_EMAIL_VERIFY TokenType = 3
|
||||
USER_TOKEN_TYPE_PASSWORD_RESET TokenType = 4
|
||||
USER_TOKEN_TYPE_MCP TokenType = 5
|
||||
USER_TOKEN_TYPE_NORMAL TokenType = 1
|
||||
USER_TOKEN_TYPE_REQUIRE_2FA TokenType = 2
|
||||
USER_TOKEN_TYPE_EMAIL_VERIFY TokenType = 3
|
||||
USER_TOKEN_TYPE_PASSWORD_RESET TokenType = 4
|
||||
USER_TOKEN_TYPE_MCP TokenType = 5
|
||||
USER_TOKEN_TYPE_OAUTH2_CALLBACK_REQUIRE_VERIFY TokenType = 6
|
||||
USER_TOKEN_TYPE_OAUTH2_CALLBACK TokenType = 7
|
||||
)
|
||||
|
||||
// UserTokenClaims represents user token
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
package core
|
||||
|
||||
// UserExternalAuthType represents the type of user external authentication
|
||||
type UserExternalAuthType string
|
||||
|
||||
// User External Auth Type
|
||||
const (
|
||||
USER_EXTERNAL_AUTH_TYPE_OAUTH2_NEXTCLOUD UserExternalAuthType = "nextcloud"
|
||||
)
|
||||
|
||||
// IsValid checks if the UserExternalAuthType is valid
|
||||
func (t UserExternalAuthType) IsValid() bool {
|
||||
switch t {
|
||||
case USER_EXTERNAL_AUTH_TYPE_OAUTH2_NEXTCLOUD:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import "time"
|
||||
type DuplicateChecker interface {
|
||||
GetSubmissionRemark(checkerType DuplicateCheckerType, uid int64, identification string) (bool, string)
|
||||
SetSubmissionRemark(checkerType DuplicateCheckerType, uid int64, identification string, remark string)
|
||||
SetSubmissionRemarkWithCustomExpiration(checkerType DuplicateCheckerType, uid int64, identification string, remark string, expiration time.Duration)
|
||||
RemoveSubmissionRemark(checkerType DuplicateCheckerType, uid int64, identification string)
|
||||
GetOrSetCronJobRunningInfo(jobName string, runningInfo string, runningInterval time.Duration) (bool, string)
|
||||
RemoveCronJobRunningInfo(jobName string)
|
||||
|
||||
@@ -57,6 +57,15 @@ func (c *DuplicateCheckerContainer) SetSubmissionRemark(checkerType DuplicateChe
|
||||
c.current.SetSubmissionRemark(checkerType, uid, identification, remark)
|
||||
}
|
||||
|
||||
// SetSubmissionRemarkWithCustomExpiration saves the identification and remark by the current duplicate checker with custom expiration time
|
||||
func (c *DuplicateCheckerContainer) SetSubmissionRemarkWithCustomExpiration(checkerType DuplicateCheckerType, uid int64, identification string, remark string, expiration time.Duration) {
|
||||
if c.current == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.current.SetSubmissionRemarkWithCustomExpiration(checkerType, uid, identification, remark, expiration)
|
||||
}
|
||||
|
||||
// RemoveSubmissionRemark removes the identification and remark by the current duplicate checker
|
||||
func (c *DuplicateCheckerContainer) RemoveSubmissionRemark(checkerType DuplicateCheckerType, uid int64, identification string) {
|
||||
if c.current == nil {
|
||||
|
||||
@@ -13,5 +13,6 @@ const (
|
||||
DUPLICATE_CHECKER_TYPE_NEW_TEMPLATE DuplicateCheckerType = 5
|
||||
DUPLICATE_CHECKER_TYPE_NEW_PICTURE DuplicateCheckerType = 6
|
||||
DUPLICATE_CHECKER_TYPE_IMPORT_TRANSACTIONS DuplicateCheckerType = 7
|
||||
DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT DuplicateCheckerType = 8
|
||||
DUPLICATE_CHECKER_TYPE_FAILURE_CHECK DuplicateCheckerType = 255
|
||||
)
|
||||
|
||||
@@ -42,6 +42,11 @@ func (c *InMemoryDuplicateChecker) SetSubmissionRemark(checkerType DuplicateChec
|
||||
c.cache.Set(c.getCacheKey(checkerType, uid, identification), remark, cache.DefaultExpiration)
|
||||
}
|
||||
|
||||
// SetSubmissionRemarkWithCustomExpiration saves the identification and remark to in-memory cache with custom expiration time
|
||||
func (c *InMemoryDuplicateChecker) SetSubmissionRemarkWithCustomExpiration(checkerType DuplicateCheckerType, uid int64, identification string, remark string, expiration time.Duration) {
|
||||
c.cache.Set(c.getCacheKey(checkerType, uid, identification), remark, expiration)
|
||||
}
|
||||
|
||||
// RemoveSubmissionRemark removes the identification and remark in in-memory cache
|
||||
func (c *InMemoryDuplicateChecker) RemoveSubmissionRemark(checkerType DuplicateCheckerType, uid int64, identification string) {
|
||||
c.cache.Delete(c.getCacheKey(checkerType, uid, identification))
|
||||
|
||||
@@ -41,6 +41,8 @@ const (
|
||||
NormalSubcategoryUserCustomExchangeRate = 13
|
||||
NormalSubcategoryModelContextProtocol = 14
|
||||
NormalSubcategoryLargeLanguageModel = 15
|
||||
NormalSubcategoryUserExternalAuth = 16
|
||||
NormalSubcategoryOAuth2 = 17
|
||||
)
|
||||
|
||||
// Error represents the specific error returned to user
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
package errs
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Error codes related to user external authentication
|
||||
var (
|
||||
ErrUserExternalAuthNotFound = NewNormalError(NormalSubcategoryUserExternalAuth, 0, http.StatusBadRequest, "user external auth is not found")
|
||||
)
|
||||
@@ -0,0 +1,19 @@
|
||||
package errs
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Error codes related to oauth 2.0
|
||||
var (
|
||||
ErrOAuth2NotEnabled = NewNormalError(NormalSubcategoryOAuth2, 0, http.StatusUnauthorized, "oauth 2.0 not enabled")
|
||||
ErrOAuth2AutoRegistrationNotEnabled = NewNormalError(NormalSubcategoryOAuth2, 1, http.StatusUnauthorized, "oauth 2.0 auto registration not enabled")
|
||||
ErrInvalidOAuth2LoginRequest = NewNormalError(NormalSubcategoryOAuth2, 2, http.StatusUnauthorized, "invalid oauth 2.0 login request")
|
||||
ErrInvalidOAuth2Callback = NewNormalError(NormalSubcategoryOAuth2, 3, http.StatusUnauthorized, "invalid oauth 2.0 callback")
|
||||
ErrMissingOAuth2State = NewNormalError(NormalSubcategoryOAuth2, 4, http.StatusUnauthorized, "missing state in oauth 2.0 callback")
|
||||
ErrMissingOAuth2Code = NewNormalError(NormalSubcategoryOAuth2, 5, http.StatusUnauthorized, "missing code in oauth 2.0 callback")
|
||||
ErrInvalidOAuth2State = NewNormalError(NormalSubcategoryOAuth2, 6, http.StatusUnauthorized, "invalid state in oauth 2.0 callback")
|
||||
ErrCannotRetrieveOAuth2Token = NewNormalError(NormalSubcategoryOAuth2, 7, http.StatusUnauthorized, "cannot retrieve oauth 2.0 token")
|
||||
ErrInvalidOAuth2Token = NewNormalError(NormalSubcategoryOAuth2, 8, http.StatusUnauthorized, "invalid oauth 2.0 token")
|
||||
ErrCannotRetrieveUserInfo = NewNormalError(NormalSubcategoryOAuth2, 9, http.StatusUnauthorized, "cannot retrieve user info from oauth 2.0 provider")
|
||||
)
|
||||
@@ -26,4 +26,8 @@ var (
|
||||
ErrInvalidIpAddressPattern = NewSystemError(SystemSubcategorySetting, 19, http.StatusInternalServerError, "invalid ip address pattern")
|
||||
ErrInvalidLLMProvider = NewSystemError(SystemSubcategorySetting, 20, http.StatusInternalServerError, "invalid llm provider")
|
||||
ErrInvalidLLMModelId = NewSystemError(SystemSubcategorySetting, 21, http.StatusInternalServerError, "invalid llm model id")
|
||||
ErrInvalidOAuth2Config = NewSystemError(SystemSubcategorySetting, 22, http.StatusInternalServerError, "invalid oauth 2.0 config")
|
||||
ErrInvalidOAuth2UserIdentifier = NewSystemError(SystemSubcategorySetting, 23, http.StatusInternalServerError, "invalid oauth 2.0 user identifier")
|
||||
ErrInvalidOAuth2Provider = NewSystemError(SystemSubcategorySetting, 24, http.StatusInternalServerError, "invalid oauth 2.0 provider")
|
||||
ErrInvalidOAuth2StateExpiredTime = NewSystemError(SystemSubcategorySetting, 25, http.StatusInternalServerError, "invalid oauth 2.0 state expired time")
|
||||
)
|
||||
|
||||
@@ -38,4 +38,5 @@ var (
|
||||
ErrUserAvatarExtensionInvalid = NewNormalError(NormalSubcategoryUser, 29, http.StatusNotFound, "user avatar file extension invalid")
|
||||
ErrExceedMaxUserAvatarFileSize = NewNormalError(NormalSubcategoryUser, 30, http.StatusBadRequest, "exceed the maximum size of user avatar file")
|
||||
ErrNotPermittedToPerformThisAction = NewNormalError(NormalSubcategoryUser, 31, http.StatusBadRequest, "not permitted to perform this action")
|
||||
ErrCannotLoginByPassword = NewNormalError(NormalSubcategoryUser, 32, http.StatusBadRequest, "cannot login by password")
|
||||
)
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
package exchangerates
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
@@ -28,23 +26,10 @@ type HttpExchangeRatesDataSource interface {
|
||||
type CommonHttpExchangeRatesDataProvider struct {
|
||||
ExchangeRatesDataProvider
|
||||
dataSource HttpExchangeRatesDataSource
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func (e *CommonHttpExchangeRatesDataProvider) GetLatestExchangeRates(c core.Context, uid int64, currentConfig *settings.Config) (*models.LatestExchangeRateResponse, error) {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
utils.SetProxyUrl(transport, currentConfig.ExchangeRatesProxy)
|
||||
|
||||
if currentConfig.ExchangeRatesSkipTLSVerify {
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: time.Duration(currentConfig.ExchangeRatesRequestTimeout) * time.Millisecond,
|
||||
}
|
||||
|
||||
requests, err := e.dataSource.BuildRequests()
|
||||
|
||||
if err != nil {
|
||||
@@ -56,14 +41,7 @@ func (e *CommonHttpExchangeRatesDataProvider) GetLatestExchangeRates(c core.Cont
|
||||
|
||||
for i := 0; i < len(requests); i++ {
|
||||
req := requests[i]
|
||||
|
||||
if len(req.Header.Values("User-Agent")) < 1 {
|
||||
req.Header.Set("User-Agent", settings.GetUserAgent())
|
||||
} else if req.Header.Get("User-Agent") == "" {
|
||||
req.Header.Del("User-Agent")
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
resp, err := e.httpClient.Do(req)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[common_http_exchange_rates_data_provider.GetLatestExchangeRates] failed to request latest exchange rate data for user \"uid:%d\", because %s", uid, err.Error())
|
||||
@@ -76,7 +54,7 @@ func (e *CommonHttpExchangeRatesDataProvider) GetLatestExchangeRates(c core.Cont
|
||||
log.Debugf(c, "[common_http_exchange_rates_data_provider.GetLatestExchangeRates] response#%d is %s", i, body)
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
log.Errorf(c, "[common_http_exchange_rates_data_provider.GetLatestExchangeRates] failed to get latest exchange rate data response for user \"uid:%d\", because response code is not %d", uid, resp.StatusCode)
|
||||
log.Errorf(c, "[common_http_exchange_rates_data_provider.GetLatestExchangeRates] failed to get latest exchange rate data response for user \"uid:%d\", because response code is %d", uid, resp.StatusCode)
|
||||
return nil, errs.ErrFailedToRequestRemoteApi
|
||||
}
|
||||
|
||||
@@ -125,8 +103,9 @@ func (e *CommonHttpExchangeRatesDataProvider) GetLatestExchangeRates(c core.Cont
|
||||
return finalExchangeRateResponse, nil
|
||||
}
|
||||
|
||||
func newCommonHttpExchangeRatesDataProvider(dataSource HttpExchangeRatesDataSource) *CommonHttpExchangeRatesDataProvider {
|
||||
func newCommonHttpExchangeRatesDataProvider(config *settings.Config, dataSource HttpExchangeRatesDataSource) *CommonHttpExchangeRatesDataProvider {
|
||||
return &CommonHttpExchangeRatesDataProvider{
|
||||
dataSource: dataSource,
|
||||
httpClient: utils.NewHttpClient(config.ExchangeRatesRequestTimeout, config.ExchangeRatesProxy, config.ExchangeRatesSkipTLSVerify, settings.GetUserAgent()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,55 +20,55 @@ var (
|
||||
// InitializeExchangeRatesDataSource initializes the current exchange rates data source according to the config
|
||||
func InitializeExchangeRatesDataSource(config *settings.Config) error {
|
||||
if config.ExchangeRatesDataSource == settings.ReserveBankOfAustraliaDataSource {
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(&ReserveBankOfAustraliaDataSource{})
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(config, &ReserveBankOfAustraliaDataSource{})
|
||||
return nil
|
||||
} else if config.ExchangeRatesDataSource == settings.BankOfCanadaDataSource {
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(&BankOfCanadaDataSource{})
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(config, &BankOfCanadaDataSource{})
|
||||
return nil
|
||||
} else if config.ExchangeRatesDataSource == settings.CzechNationalBankDataSource {
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(&CzechNationalBankDataSource{})
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(config, &CzechNationalBankDataSource{})
|
||||
return nil
|
||||
} else if config.ExchangeRatesDataSource == settings.DanmarksNationalbankDataSource {
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(&DanmarksNationalbankDataSource{})
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(config, &DanmarksNationalbankDataSource{})
|
||||
return nil
|
||||
} else if config.ExchangeRatesDataSource == settings.EuroCentralBankDataSource {
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(&EuroCentralBankDataSource{})
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(config, &EuroCentralBankDataSource{})
|
||||
return nil
|
||||
} else if config.ExchangeRatesDataSource == settings.NationalBankOfGeorgiaDataSource {
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(&NationalBankOfGeorgiaDataSource{})
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(config, &NationalBankOfGeorgiaDataSource{})
|
||||
return nil
|
||||
} else if config.ExchangeRatesDataSource == settings.CentralBankOfHungaryDataSource {
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(&CentralBankOfHungaryDataSource{})
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(config, &CentralBankOfHungaryDataSource{})
|
||||
return nil
|
||||
} else if config.ExchangeRatesDataSource == settings.BankOfIsraelDataSource {
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(&BankOfIsraelDataSource{})
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(config, &BankOfIsraelDataSource{})
|
||||
return nil
|
||||
} else if config.ExchangeRatesDataSource == settings.CentralBankOfMyanmarDataSource {
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(&CentralBankOfMyanmarDataSource{})
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(config, &CentralBankOfMyanmarDataSource{})
|
||||
return nil
|
||||
} else if config.ExchangeRatesDataSource == settings.NorgesBankDataSource {
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(&NorgesBankDataSource{})
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(config, &NorgesBankDataSource{})
|
||||
return nil
|
||||
} else if config.ExchangeRatesDataSource == settings.NationalBankOfPolandDataSource {
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(&NationalBankOfPolandDataSource{})
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(config, &NationalBankOfPolandDataSource{})
|
||||
return nil
|
||||
} else if config.ExchangeRatesDataSource == settings.NationalBankOfRomaniaDataSource {
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(&NationalBankOfRomaniaDataSource{})
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(config, &NationalBankOfRomaniaDataSource{})
|
||||
return nil
|
||||
} else if config.ExchangeRatesDataSource == settings.BankOfRussiaDataSource {
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(&BankOfRussiaDataSource{})
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(config, &BankOfRussiaDataSource{})
|
||||
return nil
|
||||
} else if config.ExchangeRatesDataSource == settings.SwissNationalBankDataSource {
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(&SwissNationalBankDataSource{})
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(config, &SwissNationalBankDataSource{})
|
||||
return nil
|
||||
} else if config.ExchangeRatesDataSource == settings.NationalBankOfUkraineDataSource {
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(&NationalBankOfUkraineDataSource{})
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(config, &NationalBankOfUkraineDataSource{})
|
||||
return nil
|
||||
} else if config.ExchangeRatesDataSource == settings.CentralBankOfUzbekistanDataSource {
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(&CentralBankOfUzbekistanDataSource{})
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(config, &CentralBankOfUzbekistanDataSource{})
|
||||
return nil
|
||||
} else if config.ExchangeRatesDataSource == settings.InternationalMonetaryFundDataSource {
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(&InternationalMonetaryFundDataSource{})
|
||||
Container.current = newCommonHttpExchangeRatesDataProvider(config, &InternationalMonetaryFundDataSource{})
|
||||
return nil
|
||||
} else if config.ExchangeRatesDataSource == settings.UserCustomExchangeRatesDataSource {
|
||||
Container.current = newUserCustomExchangeRatesDataProvider()
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
@@ -28,7 +26,8 @@ type HttpLargeLanguageModelAdapter interface {
|
||||
// CommonHttpLargeLanguageModelProvider defines the structure of common http large language model provider
|
||||
type CommonHttpLargeLanguageModelProvider struct {
|
||||
provider.LargeLanguageModelProvider
|
||||
adapter HttpLargeLanguageModelAdapter
|
||||
adapter HttpLargeLanguageModelAdapter
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// GetJsonResponse returns the json response from common http large language model provider
|
||||
@@ -51,20 +50,6 @@ func (p *CommonHttpLargeLanguageModelProvider) GetJsonResponse(c core.Context, u
|
||||
}
|
||||
|
||||
func (p *CommonHttpLargeLanguageModelProvider) getTextualResponse(c core.Context, uid int64, currentLLMConfig *settings.LLMConfig, request *data.LargeLanguageModelRequest, responseType data.LargeLanguageModelResponseFormat) (*data.LargeLanguageModelTextualResponse, error) {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
utils.SetProxyUrl(transport, currentLLMConfig.LargeLanguageModelAPIProxy)
|
||||
|
||||
if currentLLMConfig.LargeLanguageModelAPISkipTLSVerify {
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: time.Duration(currentLLMConfig.LargeLanguageModelAPIRequestTimeout) * time.Millisecond,
|
||||
}
|
||||
|
||||
httpRequest, err := p.adapter.BuildTextualRequest(c, uid, request, responseType)
|
||||
|
||||
if err != nil {
|
||||
@@ -72,9 +57,7 @@ func (p *CommonHttpLargeLanguageModelProvider) getTextualResponse(c core.Context
|
||||
return nil, errs.ErrFailedToRequestRemoteApi
|
||||
}
|
||||
|
||||
httpRequest.Header.Set("User-Agent", settings.GetUserAgent())
|
||||
|
||||
resp, err := client.Do(httpRequest)
|
||||
resp, err := p.httpClient.Do(httpRequest)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[common_http_large_language_model_provider.getTextualResponse] failed to request large language model api for user \"uid:%d\", because %s", uid, err.Error())
|
||||
@@ -95,8 +78,9 @@ func (p *CommonHttpLargeLanguageModelProvider) getTextualResponse(c core.Context
|
||||
}
|
||||
|
||||
// NewCommonHttpLargeLanguageModelProvider creates a http adapter based large language model provider instance
|
||||
func NewCommonHttpLargeLanguageModelProvider(adapter HttpLargeLanguageModelAdapter) *CommonHttpLargeLanguageModelProvider {
|
||||
func NewCommonHttpLargeLanguageModelProvider(llmConfig *settings.LLMConfig, adapter HttpLargeLanguageModelAdapter) *CommonHttpLargeLanguageModelProvider {
|
||||
return &CommonHttpLargeLanguageModelProvider{
|
||||
adapter: adapter,
|
||||
adapter: adapter,
|
||||
httpClient: utils.NewHttpClient(llmConfig.LargeLanguageModelAPIRequestTimeout, llmConfig.LargeLanguageModelAPIProxy, llmConfig.LargeLanguageModelAPISkipTLSVerify, settings.GetUserAgent()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -160,7 +160,7 @@ func (p *GoogleAILargeLanguageModelAdapter) buildJsonRequestBody(c core.Context,
|
||||
|
||||
// NewGoogleAILargeLanguageModelProvider creates a new Google AI large language model provider instance
|
||||
func NewGoogleAILargeLanguageModelProvider(llmConfig *settings.LLMConfig) provider.LargeLanguageModelProvider {
|
||||
return common.NewCommonHttpLargeLanguageModelProvider(&GoogleAILargeLanguageModelAdapter{
|
||||
return common.NewCommonHttpLargeLanguageModelProvider(llmConfig, &GoogleAILargeLanguageModelAdapter{
|
||||
GoogleAIAPIKey: llmConfig.GoogleAIAPIKey,
|
||||
GoogleAIModelID: llmConfig.GoogleAIModelID,
|
||||
})
|
||||
|
||||
@@ -159,7 +159,7 @@ func (p *OllamaLargeLanguageModelAdapter) getOllamaRequestUrl() string {
|
||||
|
||||
// NewOllamaLargeLanguageModelProvider creates a new Ollama large language model provider instance
|
||||
func NewOllamaLargeLanguageModelProvider(llmConfig *settings.LLMConfig) provider.LargeLanguageModelProvider {
|
||||
return common.NewCommonHttpLargeLanguageModelProvider(&OllamaLargeLanguageModelAdapter{
|
||||
return common.NewCommonHttpLargeLanguageModelProvider(llmConfig, &OllamaLargeLanguageModelAdapter{
|
||||
OllamaServerURL: llmConfig.OllamaServerURL,
|
||||
OllamaModelID: llmConfig.OllamaModelID,
|
||||
})
|
||||
|
||||
@@ -37,7 +37,7 @@ func (p *OpenAIOfficialChatCompletionsAPIProvider) GetModelID() string {
|
||||
|
||||
// NewOpenAILargeLanguageModelProvider creates a new OpenAI large language model provider instance
|
||||
func NewOpenAILargeLanguageModelProvider(llmConfig *settings.LLMConfig) provider.LargeLanguageModelProvider {
|
||||
return newCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter(&OpenAIOfficialChatCompletionsAPIProvider{
|
||||
return newCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter(llmConfig, &OpenAIOfficialChatCompletionsAPIProvider{
|
||||
OpenAIAPIKey: llmConfig.OpenAIAPIKey,
|
||||
OpenAIModelID: llmConfig.OpenAIModelID,
|
||||
})
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/mayswind/ezbookkeeping/pkg/llm/provider"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/llm/provider/common"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/log"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
||||
)
|
||||
|
||||
// OpenAIChatCompletionsAPIProvider defines the structure of OpenAI chat completions API provider
|
||||
@@ -212,8 +213,8 @@ func (p *CommonOpenAIChatCompletionsAPILargeLanguageModelAdapter) buildJsonReque
|
||||
return requestBodyBytes, nil
|
||||
}
|
||||
|
||||
func newCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter(apiProvider OpenAIChatCompletionsAPIProvider) provider.LargeLanguageModelProvider {
|
||||
return common.NewCommonHttpLargeLanguageModelProvider(&CommonOpenAIChatCompletionsAPILargeLanguageModelAdapter{
|
||||
func newCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter(llmConfig *settings.LLMConfig, apiProvider OpenAIChatCompletionsAPIProvider) provider.LargeLanguageModelProvider {
|
||||
return common.NewCommonHttpLargeLanguageModelProvider(llmConfig, &CommonOpenAIChatCompletionsAPILargeLanguageModelAdapter{
|
||||
apiProvider: apiProvider,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ func (p *OpenAICompatibleChatCompletionsAPIProvider) getFinalChatCompletionsRequ
|
||||
|
||||
// NewOpenAICompatibleLargeLanguageModelProvider creates a new OpenAI compatible large language model provider instance
|
||||
func NewOpenAICompatibleLargeLanguageModelProvider(llmConfig *settings.LLMConfig) provider.LargeLanguageModelProvider {
|
||||
return newCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter(&OpenAICompatibleChatCompletionsAPIProvider{
|
||||
return newCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter(llmConfig, &OpenAICompatibleChatCompletionsAPIProvider{
|
||||
OpenAICompatibleBaseURL: llmConfig.OpenAICompatibleBaseURL,
|
||||
OpenAICompatibleAPIKey: llmConfig.OpenAICompatibleAPIKey,
|
||||
OpenAICompatibleModelID: llmConfig.OpenAICompatibleModelID,
|
||||
|
||||
@@ -39,7 +39,7 @@ func (p *OpenRouterChatCompletionsAPIProvider) GetModelID() string {
|
||||
|
||||
// NewOpenRouterLargeLanguageModelProvider creates a new OpenRouter large language model provider instance
|
||||
func NewOpenRouterLargeLanguageModelProvider(llmConfig *settings.LLMConfig) provider.LargeLanguageModelProvider {
|
||||
return newCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter(&OpenRouterChatCompletionsAPIProvider{
|
||||
return newCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter(llmConfig, &OpenRouterChatCompletionsAPIProvider{
|
||||
OpenRouterAPIKey: llmConfig.OpenRouterAPIKey,
|
||||
OpenRouterModelID: llmConfig.OpenRouterModelID,
|
||||
})
|
||||
|
||||
@@ -111,6 +111,25 @@ func JWTMCPAuthorization(c *core.WebContext) {
|
||||
c.Next()
|
||||
}
|
||||
|
||||
// JWTOAuth2CallbackAuthorization verifies whether current request is OAuth 2.0 callback
|
||||
func JWTOAuth2CallbackAuthorization(c *core.WebContext) {
|
||||
claims, err := getTokenClaims(c, TOKEN_SOURCE_TYPE_HEADER)
|
||||
|
||||
if err != nil {
|
||||
utils.PrintJsonErrorResult(c, errs.ErrTokenExpired)
|
||||
return
|
||||
}
|
||||
|
||||
if claims.Type != core.USER_TOKEN_TYPE_OAUTH2_CALLBACK && claims.Type != core.USER_TOKEN_TYPE_OAUTH2_CALLBACK_REQUIRE_VERIFY {
|
||||
log.Warnf(c, "[authorization.JWTOAuth2CallbackAuthorization] user \"uid:%d\" token is not for oauth 2.0 callback request", claims.Uid)
|
||||
utils.PrintJsonErrorResult(c, errs.ErrCurrentInvalidToken)
|
||||
return
|
||||
}
|
||||
|
||||
c.SetTokenClaims(claims)
|
||||
c.Next()
|
||||
}
|
||||
|
||||
func jwtAuthorization(c *core.WebContext, source TokenSourceType) {
|
||||
claims, err := getTokenClaims(c, source)
|
||||
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
package models
|
||||
|
||||
// OAuth2LoginRequest represents all parameters of OAuth 2.0 login request
|
||||
type OAuth2LoginRequest struct {
|
||||
Platform string `form:"platform" binding:"required"`
|
||||
ClientSessionId string `form:"client_session_id" binding:"required"`
|
||||
}
|
||||
|
||||
// OAuth2CallbackRequest represents all parameters of OAuth 2.0 callback request
|
||||
type OAuth2CallbackRequest struct {
|
||||
State string `form:"state"`
|
||||
Code string `form:"code"`
|
||||
}
|
||||
|
||||
// OAuth2CallbackLoginRequest represents all parameters of OAuth 2.0 callback login request
|
||||
type OAuth2CallbackLoginRequest struct {
|
||||
Provider string `json:"provider" binding:"required,notBlank"`
|
||||
Password string `json:"password" binding:"omitempty,min=6,max=128"`
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package models
|
||||
|
||||
import "github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
|
||||
// UserExternalAuth represents user external auth data stored in database
|
||||
type UserExternalAuth struct {
|
||||
Uid int64 `xorm:"PK"`
|
||||
ExternalAuthType core.UserExternalAuthType `xorm:"VARCHAR(32) PK UNIQUE(uqe_userexternalauth_authtype_username) UNIQUE(uqe_userexternalauth_authtype_email)"`
|
||||
ExternalUsername string `xorm:"VARCHAR(32) UNIQUE(uqe_userexternalauth_authtype_username) NOT NULL"`
|
||||
ExternalEmail string `xorm:"VARCHAR(100) UNIQUE(uqe_userexternalauth_authtype_email) NOT NULL"`
|
||||
CreatedUnixTime int64
|
||||
}
|
||||
|
||||
// UserExternalAuthRevokeRequest represents all parameters of user external auth revoke request
|
||||
type UserExternalAuthRevokeRequest struct {
|
||||
ExternalAuthType core.UserExternalAuthType `json:"externalAuthType" binding:"required,notBlank"`
|
||||
}
|
||||
@@ -133,6 +133,18 @@ func (s *TokenService) CreateMCPTokenViaCli(c *core.CliContext, user *models.Use
|
||||
return token, tokenRecord, err
|
||||
}
|
||||
|
||||
// CreateOAuth2CallbackRequireVerifyToken generates a new OAuth 2.0 callback token requiring user to verify and saves to database
|
||||
func (s *TokenService) CreateOAuth2CallbackRequireVerifyToken(c *core.WebContext, user *models.User) (string, *core.UserTokenClaims, error) {
|
||||
token, claims, _, err := s.createToken(c, user, core.USER_TOKEN_TYPE_OAUTH2_CALLBACK_REQUIRE_VERIFY, s.getUserAgent(c), s.CurrentConfig().TemporaryTokenExpiredTimeDuration)
|
||||
return token, claims, err
|
||||
}
|
||||
|
||||
// CreateOAuth2CallbackToken generates a new OAuth 2.0 callback token and saves to database
|
||||
func (s *TokenService) CreateOAuth2CallbackToken(c *core.WebContext, user *models.User) (string, *core.UserTokenClaims, error) {
|
||||
token, claims, _, err := s.createToken(c, user, core.USER_TOKEN_TYPE_OAUTH2_CALLBACK, s.getUserAgent(c), s.CurrentConfig().TemporaryTokenExpiredTimeDuration)
|
||||
return token, claims, err
|
||||
}
|
||||
|
||||
// UpdateTokenLastSeen updates the last seen time of specified token
|
||||
func (s *TokenService) UpdateTokenLastSeen(c core.Context, tokenRecord *models.TokenRecord) error {
|
||||
if tokenRecord.Uid <= 0 {
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"xorm.io/xorm"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/datastore"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/models"
|
||||
)
|
||||
|
||||
// UserExternalAuthService represents user external auth service
|
||||
type UserExternalAuthService struct {
|
||||
ServiceUsingDB
|
||||
}
|
||||
|
||||
// Initialize a user external auth service singleton instance
|
||||
var (
|
||||
UserExternalAuths = &UserExternalAuthService{
|
||||
ServiceUsingDB: ServiceUsingDB{
|
||||
container: datastore.Container,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// GetUserAllExternalAuthsByUid returns the user all external auth list according to user uid
|
||||
func (s *UserExternalAuthService) GetUserAllExternalAuthsByUid(c core.Context, uid int64) ([]*models.UserExternalAuth, error) {
|
||||
if uid <= 0 {
|
||||
return nil, errs.ErrUserIdInvalid
|
||||
}
|
||||
|
||||
var userExternalAuths []*models.UserExternalAuth
|
||||
err := s.UserDB().NewSession(c).Where("uid=?", uid).Find(&userExternalAuths)
|
||||
|
||||
return userExternalAuths, err
|
||||
}
|
||||
|
||||
// GetUserExternalAuthByUid returns the user external auth record by uid
|
||||
func (s *UserExternalAuthService) GetUserExternalAuthByUid(c core.Context, uid int64, externalAuthType core.UserExternalAuthType) (*models.UserExternalAuth, error) {
|
||||
if uid <= 0 {
|
||||
return nil, errs.ErrUserIdInvalid
|
||||
}
|
||||
|
||||
userExternalAuth := &models.UserExternalAuth{}
|
||||
has, err := s.UserDB().NewSession(c).Where("uid=? AND external_auth_type=?", uid, externalAuthType).Get(userExternalAuth)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, errs.ErrUserExternalAuthNotFound
|
||||
}
|
||||
|
||||
return userExternalAuth, err
|
||||
}
|
||||
|
||||
// GetUserExternalAuthByExternalUserName returns the user external auth record by external username
|
||||
func (s *UserExternalAuthService) GetUserExternalAuthByExternalUserName(c core.Context, externalUserName string, externalAuthType core.UserExternalAuthType) (*models.UserExternalAuth, error) {
|
||||
userExternalAuth := &models.UserExternalAuth{}
|
||||
has, err := s.UserDB().NewSession(c).Where("external_auth_type=? AND external_username=?", externalAuthType, externalUserName).Get(userExternalAuth)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, errs.ErrUserExternalAuthNotFound
|
||||
}
|
||||
|
||||
return userExternalAuth, err
|
||||
}
|
||||
|
||||
// GetUserExternalAuthByExternalEmail returns the user external auth record by external email
|
||||
func (s *UserExternalAuthService) GetUserExternalAuthByExternalEmail(c core.Context, externalEmail string, externalAuthType core.UserExternalAuthType) (*models.UserExternalAuth, error) {
|
||||
userExternalAuth := &models.UserExternalAuth{}
|
||||
has, err := s.UserDB().NewSession(c).Where("external_auth_type=? AND external_email=?", externalAuthType, externalEmail).Get(userExternalAuth)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, errs.ErrUserExternalAuthNotFound
|
||||
}
|
||||
|
||||
return userExternalAuth, err
|
||||
}
|
||||
|
||||
// CreateUserExternalAuth creates a new user external auth record in database
|
||||
func (s *UserExternalAuthService) CreateUserExternalAuth(c core.Context, userExternalAuth *models.UserExternalAuth) error {
|
||||
if userExternalAuth.Uid <= 0 {
|
||||
return errs.ErrUserIdInvalid
|
||||
}
|
||||
|
||||
userExternalAuth.CreatedUnixTime = time.Now().Unix()
|
||||
|
||||
return s.UserDB().DoTransaction(c, func(sess *xorm.Session) error {
|
||||
_, err := sess.Insert(userExternalAuth)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteUserExternalAuth deletes given user external auth record from database
|
||||
func (s *UserExternalAuthService) DeleteUserExternalAuth(c core.Context, uid int64, externalAuthType core.UserExternalAuthType) error {
|
||||
if uid <= 0 {
|
||||
return errs.ErrUserIdInvalid
|
||||
}
|
||||
|
||||
return s.UserDB().DoTransaction(c, func(sess *xorm.Session) error {
|
||||
deletedRows, err := sess.Where("uid=? AND external_auth_type=?", uid, externalAuthType).Delete(&models.UserExternalAuth{})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
} else if deletedRows < 1 {
|
||||
return errs.ErrUserExternalAuthNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
+77
-40
@@ -85,6 +85,17 @@ const (
|
||||
InMemoryDuplicateCheckerType string = "in_memory"
|
||||
)
|
||||
|
||||
// OAuth 2.0 user identifier types
|
||||
const (
|
||||
OAuth2UserIdentifierEmail string = "email"
|
||||
OAuth2UserIdentifierUsername string = "username"
|
||||
)
|
||||
|
||||
// OAuth 2.0 rovider types
|
||||
const (
|
||||
OAuth2ProviderNextcloud string = "nextcloud"
|
||||
)
|
||||
|
||||
// Map provider types
|
||||
const (
|
||||
OpenStreetMapProvider string = "openstreetmap"
|
||||
@@ -164,6 +175,9 @@ const (
|
||||
defaultMaxFailuresPerIpPerMinute uint32 = 5
|
||||
defaultMaxFailuresPerUserPerMinute uint32 = 5
|
||||
|
||||
defaultOAuth2StateExpiredTime uint32 = 300 // 5 minutes
|
||||
defaultOAuth2RequestTimeout uint32 = 10000 // 10 seconds
|
||||
|
||||
defaultTransactionPictureFileMaxSize uint32 = 10485760 // 10MB
|
||||
defaultUserAvatarFileMaxSize uint32 = 1048576 // 1MB
|
||||
|
||||
@@ -240,15 +254,8 @@ type LLMConfig struct {
|
||||
LargeLanguageModelAPISkipTLSVerify bool
|
||||
}
|
||||
|
||||
// TipConfig represents a tip setting config
|
||||
type TipConfig struct {
|
||||
Enabled bool
|
||||
DefaultContent string
|
||||
MultiLanguageContent map[string]string
|
||||
}
|
||||
|
||||
// NotificationConfig represents a notification setting config
|
||||
type NotificationConfig struct {
|
||||
// MultiLanguageContentConfig represents a multi-language content setting config
|
||||
type MultiLanguageContentConfig struct {
|
||||
Enabled bool
|
||||
DefaultContent string
|
||||
MultiLanguageContent map[string]string
|
||||
@@ -351,9 +358,22 @@ type Config struct {
|
||||
MaxFailuresPerUserPerMinute uint32
|
||||
|
||||
// Auth
|
||||
EnableInternalAuth bool
|
||||
EnableOAuth2Login bool
|
||||
EnableTwoFactor bool
|
||||
EnableUserForgetPassword bool
|
||||
ForgetPasswordRequireVerifyEmail bool
|
||||
OAuth2ClientID string
|
||||
OAuth2ClientSecret string
|
||||
OAuth2UserIdentifier string
|
||||
OAuth2AutoRegister bool
|
||||
OAuth2Provider string
|
||||
OAuth2StateExpiredTime uint32
|
||||
OAuth2StateExpiredTimeDuration time.Duration
|
||||
OAuth2RequestTimeout uint32
|
||||
OAuth2Proxy string
|
||||
OAuth2SkipTLSVerify bool
|
||||
OAuth2NextcloudBaseUrl string
|
||||
|
||||
// User
|
||||
EnableUserRegister bool
|
||||
@@ -372,12 +392,12 @@ type Config struct {
|
||||
MaxImportFileSize uint32
|
||||
|
||||
// Tip
|
||||
LoginPageTips TipConfig
|
||||
LoginPageTips MultiLanguageContentConfig
|
||||
|
||||
// Notification
|
||||
AfterRegisterNotification NotificationConfig
|
||||
AfterLoginNotification NotificationConfig
|
||||
AfterOpenNotification NotificationConfig
|
||||
AfterRegisterNotification MultiLanguageContentConfig
|
||||
AfterLoginNotification MultiLanguageContentConfig
|
||||
AfterOpenNotification MultiLanguageContentConfig
|
||||
|
||||
// Map
|
||||
MapProvider string
|
||||
@@ -956,9 +976,47 @@ func loadSecurityConfiguration(config *Config, configFile *ini.File, sectionName
|
||||
}
|
||||
|
||||
func loadAuthConfiguration(config *Config, configFile *ini.File, sectionName string) error {
|
||||
config.EnableInternalAuth = getConfigItemBoolValue(configFile, sectionName, "enable_internal_auth", true)
|
||||
config.EnableOAuth2Login = getConfigItemBoolValue(configFile, sectionName, "enable_oauth2_auth", false)
|
||||
config.EnableTwoFactor = getConfigItemBoolValue(configFile, sectionName, "enable_two_factor", true)
|
||||
config.EnableUserForgetPassword = getConfigItemBoolValue(configFile, sectionName, "enable_forget_password", false)
|
||||
config.ForgetPasswordRequireVerifyEmail = getConfigItemBoolValue(configFile, sectionName, "forget_password_require_email_verify", false)
|
||||
config.OAuth2ClientID = getConfigItemStringValue(configFile, sectionName, "oauth2_client_id")
|
||||
config.OAuth2ClientSecret = getConfigItemStringValue(configFile, sectionName, "oauth2_client_secret")
|
||||
|
||||
oauth2UserIdentifier := getConfigItemStringValue(configFile, sectionName, "oauth2_user_identifier")
|
||||
|
||||
if oauth2UserIdentifier == OAuth2UserIdentifierEmail {
|
||||
config.OAuth2UserIdentifier = OAuth2UserIdentifierEmail
|
||||
} else if oauth2UserIdentifier == OAuth2UserIdentifierUsername {
|
||||
config.OAuth2UserIdentifier = OAuth2UserIdentifierUsername
|
||||
} else {
|
||||
return errs.ErrInvalidOAuth2UserIdentifier
|
||||
}
|
||||
|
||||
config.OAuth2AutoRegister = getConfigItemBoolValue(configFile, sectionName, "oauth2_auto_register", true)
|
||||
|
||||
oauth2Provider := getConfigItemStringValue(configFile, sectionName, "oauth2_provider")
|
||||
|
||||
if oauth2Provider == OAuth2ProviderNextcloud {
|
||||
config.OAuth2Provider = OAuth2ProviderNextcloud
|
||||
} else {
|
||||
return errs.ErrInvalidOAuth2Provider
|
||||
}
|
||||
|
||||
config.OAuth2StateExpiredTime = getConfigItemUint32Value(configFile, sectionName, "oauth2_state_expired_time", defaultOAuth2StateExpiredTime)
|
||||
|
||||
if config.OAuth2StateExpiredTime < 60 {
|
||||
return errs.ErrInvalidOAuth2StateExpiredTime
|
||||
}
|
||||
|
||||
config.OAuth2StateExpiredTimeDuration = time.Duration(config.OAuth2StateExpiredTime) * time.Second
|
||||
|
||||
config.OAuth2Proxy = getConfigItemStringValue(configFile, sectionName, "oauth2_proxy", "system")
|
||||
config.OAuth2RequestTimeout = getConfigItemUint32Value(configFile, sectionName, "oauth2_request_timeout", defaultOAuth2RequestTimeout)
|
||||
config.OAuth2SkipTLSVerify = getConfigItemBoolValue(configFile, sectionName, "oauth2_skip_tls_verify", false)
|
||||
|
||||
config.OAuth2NextcloudBaseUrl = getConfigItemStringValue(configFile, sectionName, "nextcloud_base_url")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -996,15 +1054,15 @@ func loadDataConfiguration(config *Config, configFile *ini.File, sectionName str
|
||||
}
|
||||
|
||||
func loadTipConfiguration(config *Config, configFile *ini.File, sectionName string) error {
|
||||
config.LoginPageTips = getTipConfiguration(configFile, sectionName, "enable_tips_in_login_page", "login_page_tips_content")
|
||||
config.LoginPageTips = getMultiLanguageContentConfig(configFile, sectionName, "enable_tips_in_login_page", "login_page_tips_content")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadNotificationConfiguration(config *Config, configFile *ini.File, sectionName string) error {
|
||||
config.AfterRegisterNotification = getNotificationConfiguration(configFile, sectionName, "enable_notification_after_register", "after_register_notification_content")
|
||||
config.AfterLoginNotification = getNotificationConfiguration(configFile, sectionName, "enable_notification_after_login", "after_login_notification_content")
|
||||
config.AfterOpenNotification = getNotificationConfiguration(configFile, sectionName, "enable_notification_after_open", "after_open_notification_content")
|
||||
config.AfterRegisterNotification = getMultiLanguageContentConfig(configFile, sectionName, "enable_notification_after_register", "after_register_notification_content")
|
||||
config.AfterLoginNotification = getMultiLanguageContentConfig(configFile, sectionName, "enable_notification_after_login", "after_login_notification_content")
|
||||
config.AfterOpenNotification = getMultiLanguageContentConfig(configFile, sectionName, "enable_notification_after_open", "after_open_notification_content")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1141,29 +1199,8 @@ func getFinalPath(workingPath, p string) (string, error) {
|
||||
return p, err
|
||||
}
|
||||
|
||||
func getTipConfiguration(configFile *ini.File, sectionName string, enableKey string, contentKey string) TipConfig {
|
||||
config := TipConfig{
|
||||
Enabled: getConfigItemBoolValue(configFile, sectionName, enableKey, false),
|
||||
DefaultContent: getConfigItemStringValue(configFile, sectionName, contentKey, ""),
|
||||
MultiLanguageContent: make(map[string]string),
|
||||
}
|
||||
|
||||
for languageTag := range locales.AllLanguages {
|
||||
multiLanguageContentKey := strings.ToLower(languageTag)
|
||||
multiLanguageContentKey = strings.Replace(multiLanguageContentKey, "-", "_", -1)
|
||||
multiLanguageContentKey = contentKey + "_" + multiLanguageContentKey
|
||||
content := getConfigItemStringValue(configFile, sectionName, multiLanguageContentKey, "")
|
||||
|
||||
if content != "" {
|
||||
config.MultiLanguageContent[languageTag] = content
|
||||
}
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func getNotificationConfiguration(configFile *ini.File, sectionName string, enableKey string, contentKey string) NotificationConfig {
|
||||
config := NotificationConfig{
|
||||
func getMultiLanguageContentConfig(configFile *ini.File, sectionName string, enableKey string, contentKey string) MultiLanguageContentConfig {
|
||||
config := MultiLanguageContentConfig{
|
||||
Enabled: getConfigItemBoolValue(configFile, sectionName, enableKey, false),
|
||||
DefaultContent: getConfigItemStringValue(configFile, sectionName, contentKey, ""),
|
||||
MultiLanguageContent: make(map[string]string),
|
||||
|
||||
@@ -26,5 +26,9 @@ func (c *ConfigContainer) GetCurrentConfig() *Config {
|
||||
}
|
||||
|
||||
func GetUserAgent() string {
|
||||
if Version == "" {
|
||||
return "ezBookkeeping"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("ezBookkeeping/%s", Version)
|
||||
}
|
||||
|
||||
@@ -2,12 +2,10 @@ package storage
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
@@ -26,22 +24,9 @@ type WebDAVObjectStorage struct {
|
||||
// NewWebDAVObjectStorage returns a WebDAV object storage
|
||||
func NewWebDAVObjectStorage(config *settings.Config, pathPrefix string) (*WebDAVObjectStorage, error) {
|
||||
webDavConfig := config.WebDAVConfig
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
utils.SetProxyUrl(transport, webDavConfig.Proxy)
|
||||
|
||||
if webDavConfig.SkipTLSVerify {
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: time.Duration(webDavConfig.RequestTimeout) * time.Millisecond,
|
||||
}
|
||||
|
||||
storage := &WebDAVObjectStorage{
|
||||
httpClient: client,
|
||||
httpClient: utils.NewHttpClient(webDavConfig.RequestTimeout, webDavConfig.Proxy, webDavConfig.SkipTLSVerify, settings.GetUserAgent()),
|
||||
webDavConfig: webDavConfig,
|
||||
rootPath: webDavConfig.RootPath,
|
||||
}
|
||||
|
||||
+21
-57
@@ -2,6 +2,7 @@ package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"reflect"
|
||||
|
||||
@@ -11,6 +12,22 @@ import (
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
)
|
||||
|
||||
// GetDisplayErrorMessage returns the display error message for given error
|
||||
func GetDisplayErrorMessage(err *errs.Error) string {
|
||||
if err.Code() == errs.ErrIncompleteOrIncorrectSubmission.Code() && len(err.BaseError) > 0 {
|
||||
var validationErrors validator.ValidationErrors
|
||||
ok := errors.As(err.BaseError[0], &validationErrors)
|
||||
|
||||
if ok {
|
||||
for _, err := range validationErrors {
|
||||
return getValidationErrorText(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
// PrintJsonSuccessResult writes success response in json format to current http context
|
||||
func PrintJsonSuccessResult(c *core.WebContext, result any) {
|
||||
c.JSON(http.StatusOK, core.O{
|
||||
@@ -32,23 +49,10 @@ func PrintDataSuccessResult(c *core.WebContext, contentType string, fileName str
|
||||
func PrintJsonErrorResult(c *core.WebContext, err *errs.Error) {
|
||||
c.SetResponseError(err)
|
||||
|
||||
errorMessage := err.Error()
|
||||
|
||||
if err.Code() == errs.ErrIncompleteOrIncorrectSubmission.Code() && len(err.BaseError) > 0 {
|
||||
validationErrors, ok := err.BaseError[0].(validator.ValidationErrors)
|
||||
|
||||
if ok {
|
||||
for _, err := range validationErrors {
|
||||
errorMessage = getValidationErrorText(err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := core.O{
|
||||
"success": false,
|
||||
"errorCode": err.Code(),
|
||||
"errorMessage": errorMessage,
|
||||
"errorMessage": GetDisplayErrorMessage(err),
|
||||
"path": c.Request.URL.Path,
|
||||
}
|
||||
|
||||
@@ -68,19 +72,6 @@ func PrintJSONRPCSuccessResult(c *core.WebContext, jsonRPCRequest *core.JSONRPCR
|
||||
func PrintJSONRPCErrorResult(c *core.WebContext, jsonRPCRequest *core.JSONRPCRequest, err *errs.Error) {
|
||||
c.SetResponseError(err)
|
||||
|
||||
errorMessage := err.Error()
|
||||
|
||||
if err.Code() == errs.ErrIncompleteOrIncorrectSubmission.Code() && len(err.BaseError) > 0 {
|
||||
validationErrors, ok := err.BaseError[0].(validator.ValidationErrors)
|
||||
|
||||
if ok {
|
||||
for _, err := range validationErrors {
|
||||
errorMessage = getValidationErrorText(err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var id any
|
||||
|
||||
if jsonRPCRequest != nil {
|
||||
@@ -97,27 +88,13 @@ func PrintJSONRPCErrorResult(c *core.WebContext, jsonRPCRequest *core.JSONRPCReq
|
||||
jsonRPCError = core.JSONRPCInvalidParamsError
|
||||
}
|
||||
|
||||
c.AbortWithStatusJSON(err.HttpStatusCode, core.NewJSONRPCErrorResponseWithCause(id, jsonRPCError, errorMessage))
|
||||
c.AbortWithStatusJSON(err.HttpStatusCode, core.NewJSONRPCErrorResponseWithCause(id, jsonRPCError, GetDisplayErrorMessage(err)))
|
||||
}
|
||||
|
||||
// PrintDataErrorResult writes error response in custom content type to current http context
|
||||
func PrintDataErrorResult(c *core.WebContext, contentType string, err *errs.Error) {
|
||||
c.SetResponseError(err)
|
||||
|
||||
errorMessage := err.Error()
|
||||
|
||||
if err.Code() == errs.ErrIncompleteOrIncorrectSubmission.Code() && len(err.BaseError) > 0 {
|
||||
validationErrors, ok := err.BaseError[0].(validator.ValidationErrors)
|
||||
|
||||
if ok {
|
||||
for _, err := range validationErrors {
|
||||
errorMessage = getValidationErrorText(err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Data(err.HttpStatusCode, contentType, []byte(errorMessage))
|
||||
c.Data(err.HttpStatusCode, contentType, []byte(GetDisplayErrorMessage(err)))
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
@@ -149,23 +126,10 @@ func WriteEventStreamJsonSuccessResult(c *core.WebContext, result any) {
|
||||
func WriteEventStreamJsonErrorResult(c *core.WebContext, originalErr *errs.Error) {
|
||||
c.SetResponseError(originalErr)
|
||||
|
||||
errorMessage := originalErr.Error()
|
||||
|
||||
if originalErr.Code() == errs.ErrIncompleteOrIncorrectSubmission.Code() && len(originalErr.BaseError) > 0 {
|
||||
validationErrors, ok := originalErr.BaseError[0].(validator.ValidationErrors)
|
||||
|
||||
if ok {
|
||||
for _, err := range validationErrors {
|
||||
errorMessage = getValidationErrorText(err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := core.O{
|
||||
"success": false,
|
||||
"errorCode": originalErr.Code(),
|
||||
"errorMessage": errorMessage,
|
||||
"errorMessage": GetDisplayErrorMessage(originalErr),
|
||||
"path": c.Request.URL.Path,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,47 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
type defaultTransport struct {
|
||||
defaultUserAgent string
|
||||
baseTransport http.RoundTripper
|
||||
}
|
||||
|
||||
func (t *defaultTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if len(req.Header.Values("User-Agent")) < 1 {
|
||||
req.Header.Set("User-Agent", t.defaultUserAgent)
|
||||
} else if req.Header.Get("User-Agent") == "" {
|
||||
req.Header.Del("User-Agent")
|
||||
}
|
||||
|
||||
return t.baseTransport.RoundTrip(req)
|
||||
}
|
||||
|
||||
// NewHttpClient creates and returns a new http client with specified settings
|
||||
func NewHttpClient(requestTimeout uint32, proxy string, skipTLSVerify bool, defaultUserAgent string) *http.Client {
|
||||
baseTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
SetProxyUrl(baseTransport, proxy)
|
||||
|
||||
if skipTLSVerify {
|
||||
baseTransport.TLSClientConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: &defaultTransport{
|
||||
defaultUserAgent: defaultUserAgent,
|
||||
baseTransport: baseTransport,
|
||||
},
|
||||
Timeout: time.Duration(requestTimeout) * time.Millisecond,
|
||||
}
|
||||
}
|
||||
|
||||
// SetProxyUrl sets proxy url to http transport according to specified proxy setting
|
||||
func SetProxyUrl(transport *http.Transport, proxy string) {
|
||||
if proxy == "none" {
|
||||
|
||||
Reference in New Issue
Block a user