diff --git a/conf/ezbookkeeping.ini b/conf/ezbookkeeping.ini index b4765f69..889b19d0 100644 --- a/conf/ezbookkeeping.ini +++ b/conf/ezbookkeeping.ini @@ -285,6 +285,9 @@ enable_forget_password = true # For "internal" authentication only, set to true to require email must be verified when use forget password forget_password_require_email_verify = false +# For "oauth2" authentication only, OAuth 2.0 provider, supports "oidc", "nextcloud", "gitea" and "github" currently +oauth2_provider = + # For "oauth2" authentication only, OAuth 2.0 client ID oauth2_client_id = @@ -297,9 +300,6 @@ oauth2_user_identifier = email # For "oauth2" authentication only, if the user returned by OAuth 2.0 is not registered, automatically create a new user (requires "enable_register" to be set to true) oauth2_auto_register = true -# For "oauth2" authentication only, OAuth 2.0 provider, supports "nextcloud", "gitea" and "github" currently -oauth2_provider = - # For "oauth2" authentication only, OAuth 2.0 state expired seconds (60 - 4294967295), default is 300 (5 minutes) oauth2_state_expired_time = 300 @@ -313,6 +313,17 @@ oauth2_proxy = system # For "oauth2" authentication only, set to true to skip tls verification when request OAuth 2.0 api oauth2_skip_tls_verify = false +# For "oauth2" authentication and "oidc" OAuth 2.0 provider only, OIDC provider base url. Make sure the ".well-known" directory is available under this path. For example, if it's set to "https://auth.example.com/", the discovery URL should be "https://auth.example.com/.well-known/openid-configuration". +oidc_provider_base_url = + +# For "oauth2" authentication and "oidc" OAuth 2.0 provider only, set to true to replace the text in the "Log in with Connect ID" button with the below custom display name +enable_oidc_display_name = false + +# For "oauth2" authentication and "oidc" OAuth 2.0 provider only, the custom display name to replace the text in the "Log in with Connect ID" button, it supports multi-language configuration +# Add an underscore and a language tag after the setting key to configure the display name in that language, the same below +# For example, oidc_custom_display_name_zh_hans means the display name in Chinese (Simplified) +oidc_custom_display_name = + # For "oauth2" authentication and "nextcloud" OAuth 2.0 provider only, Nextcloud base url, e.g. "https://cloud.example.org/" nextcloud_base_url = diff --git a/go.mod b/go.mod index 0cc4070c..9ffbdd06 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.25 require ( github.com/boombuler/barcode v1.1.0 + github.com/coreos/go-oidc/v3 v3.16.0 github.com/extrame/xls v0.0.2-0.20200426124601-4a6cf263071b github.com/gin-contrib/cache v1.4.1 github.com/gin-contrib/gzip v1.2.3 @@ -25,6 +26,7 @@ require ( github.com/xuri/excelize/v2 v2.9.0 golang.org/x/crypto v0.41.0 golang.org/x/net v0.43.0 + golang.org/x/oauth2 v0.31.0 golang.org/x/text v0.28.0 gopkg.in/ini.v1 v1.67.0 gopkg.in/mail.v2 v2.3.1 @@ -52,6 +54,7 @@ require ( github.com/gabriel-vasile/mimetype v1.4.9 // indirect github.com/gin-contrib/sse v1.1.0 // indirect github.com/go-ini/ini v1.67.0 // indirect + github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/goccy/go-json v0.10.5 // indirect @@ -91,7 +94,6 @@ require ( github.com/xuri/nfp v0.0.1 // indirect golang.org/x/arch v0.18.0 // indirect golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 // indirect - golang.org/x/oauth2 v0.31.0 // indirect golang.org/x/sys v0.35.0 // indirect google.golang.org/protobuf v1.36.6 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect diff --git a/go.sum b/go.sum index 655b60a7..4d55fd75 100644 --- a/go.sum +++ b/go.sum @@ -27,6 +27,8 @@ github.com/chenzhuoyu/iasm v0.9.1/go.mod h1:Xjy2NpN3h7aUqeqM+woSuuvxmIe6+DDsiNLI github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4= github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= +github.com/coreos/go-oidc/v3 v3.16.0 h1:qRQUCFstKpXwmEjDQTIbyY/5jF00+asXzSkmkoa/mow= +github.com/coreos/go-oidc/v3 v3.16.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8= github.com/cpuguy83/go-md2man/v2 v2.0.5/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -54,6 +56,8 @@ github.com/go-co-op/gocron/v2 v2.16.5 h1:j228Jxk7bb9CF8LKR3gS+bK3rcjRUINjlVI+ZMp github.com/go-co-op/gocron/v2 v2.16.5/go.mod h1:zAfC/GFQ668qHxOVl/D68Jh5Ce7sDqX6TJnSQyRkRBc= github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= +github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= +github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= diff --git a/pkg/api/base.go b/pkg/api/base.go index 3cf49f1b..914fdc9c 100644 --- a/pkg/api/base.go +++ b/pkg/api/base.go @@ -114,6 +114,11 @@ func (a *ApiUsingDuplicateChecker) GetSubmissionRemark(checkerType duplicatechec return a.container.GetSubmissionRemark(checkerType, uid, identification) } +// SetSubmissionRemarkWithCustomExpiration saves the identification and remark by the current duplicate checker with custom expiration time +func (a *ApiUsingDuplicateChecker) SetSubmissionRemarkWithCustomExpiration(checkerType duplicatechecker.DuplicateCheckerType, uid int64, identification string, remark string, expiration time.Duration) { + a.container.SetSubmissionRemarkWithCustomExpiration(checkerType, uid, identification, remark, expiration) +} + // SetSubmissionRemarkIfEnable saves the identification and remark by the current duplicate checker if the duplicate submission check is enabled func (a *ApiUsingDuplicateChecker) SetSubmissionRemarkIfEnable(checkerType duplicatechecker.DuplicateCheckerType, uid int64, identification string, remark string) { if a.CurrentConfig().EnableDuplicateSubmissionsCheck { @@ -121,11 +126,9 @@ func (a *ApiUsingDuplicateChecker) SetSubmissionRemarkIfEnable(checkerType dupli } } -// SetSubmissionRemarkWithCustomExpirationIfEnable saves the identification and remark by the current duplicate checker with custom expiration time if the duplicate submission check is enabled -func (a *ApiUsingDuplicateChecker) SetSubmissionRemarkWithCustomExpirationIfEnable(checkerType duplicatechecker.DuplicateCheckerType, uid int64, identification string, remark string, expiration time.Duration) { - if a.CurrentConfig().EnableDuplicateSubmissionsCheck { - a.container.SetSubmissionRemarkWithCustomExpiration(checkerType, uid, identification, remark, expiration) - } +// RemoveSubmissionRemark removes the identification and remark by the current duplicate checker +func (a *ApiUsingDuplicateChecker) RemoveSubmissionRemark(checkerType duplicatechecker.DuplicateCheckerType, uid int64, identification string) { + a.container.RemoveSubmissionRemark(checkerType, uid, identification) } // RemoveSubmissionRemarkIfEnable removes the identification and remark by the current duplicate checker if the duplicate submission check is enabled diff --git a/pkg/api/oauth2_authentications.go b/pkg/api/oauth2_authentications.go index c4f10a0e..065b045d 100644 --- a/pkg/api/oauth2_authentications.go +++ b/pkg/api/oauth2_authentications.go @@ -23,6 +23,7 @@ import ( const oauth2CallbackPageUrlSuccessFormat = "%sdesktop/#/oauth2_callback?platform=%s&provider=%s&token=%s" const oauth2CallbackPageUrlNeedVerifyFormat = "%sdesktop/#/oauth2_callback?platform=%s&provider=%s&userName=%s&token=%s" const oauth2CallbackPageUrlFailedFormat = "%sdesktop/#/oauth2_callback?errorCode=%d&errorMessage=%s" +const oauth2CallbackPageUrlErrorMessageFormat = "%sdesktop/#/oauth2_callback?errorMessage=%s" // OAuth2AuthenticationApi represents OAuth 2.0 authorization api type OAuth2AuthenticationApi struct { @@ -65,37 +66,31 @@ func (a *OAuth2AuthenticationApi) LoginHandler(c *core.WebContext) (string, *err return "", errs.ErrInvalidOAuth2LoginRequest } - state := fmt.Sprintf("%s|%s", oauth2LoginReq.Platform, oauth2LoginReq.ClientSessionId) - remark := "" + found, remark := a.GetSubmissionRemark(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, oauth2LoginReq.ClientSessionId) - if a.CurrentConfig().EnableDuplicateSubmissionsCheck { - found := false - found, remark = a.GetSubmissionRemark(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, oauth2LoginReq.ClientSessionId) - - if found { - log.Errorf(c, "[oauth2_authentications.LoginHandler] another oauth 2.0 state \"%s\" has been processing for client session id \"%s\"", remark, oauth2LoginReq.ClientSessionId) - return "", errs.ErrRepeatedRequest - } - - randomString, err := utils.GetRandomNumberOrLowercaseLetter(32) - - if err != nil { - log.Errorf(c, "[oauth2_authentications.LoginHandler] failed to generate random string for oauth 2.0 state, because %s", err.Error()) - return "", errs.ErrSystemError - } - - remark = fmt.Sprintf("%s|%s|%s", oauth2LoginReq.Platform, oauth2LoginReq.ClientSessionId, randomString) - state = fmt.Sprintf("%s|%s|%s", oauth2LoginReq.Platform, oauth2LoginReq.ClientSessionId, utils.MD5EncodeToString([]byte(remark))) + if found { + log.Errorf(c, "[oauth2_authentications.LoginHandler] another oauth 2.0 state \"%s\" has been processing for client session id \"%s\"", remark, oauth2LoginReq.ClientSessionId) + return "", errs.ErrRepeatedRequest } - redirectUrl, err := oauth2.GetOAuth2AuthUrl(c, state) + verifier, err := utils.GetRandomNumberOrLowercaseLetter(64) + + if err != nil { + log.Errorf(c, "[oauth2_authentications.LoginHandler] failed to generate random string for oauth 2.0 state, because %s", err.Error()) + return "", errs.ErrSystemError + } + + remark = fmt.Sprintf("%s|%s|%s", oauth2LoginReq.Platform, oauth2LoginReq.ClientSessionId, verifier) + state := fmt.Sprintf("%s|%s|%s", oauth2LoginReq.Platform, oauth2LoginReq.ClientSessionId, utils.MD5EncodeToString([]byte(remark))) + + redirectUrl, err := oauth2.GetOAuth2AuthUrl(c, state, verifier) if err != nil { log.Errorf(c, "[oauth2_authentications.LoginHandler] failed to get oauth 2.0 auth url, because %s", err.Error()) return "", errs.Or(err, errs.ErrSystemError) } - a.SetSubmissionRemarkWithCustomExpirationIfEnable(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, oauth2LoginReq.ClientSessionId, remark, a.CurrentConfig().OAuth2StateExpiredTimeDuration) + a.SetSubmissionRemarkWithCustomExpiration(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, oauth2LoginReq.ClientSessionId, remark, a.CurrentConfig().OAuth2StateExpiredTimeDuration) return redirectUrl, nil } @@ -115,6 +110,11 @@ func (a *OAuth2AuthenticationApi) CallbackHandler(c *core.WebContext) (string, * } if oauth2CallbackReq.Code == "" { + if oauth2CallbackReq.ErrorDescription != "" { + log.Errorf(c, "[oauth2_authentications.CallbackHandler] oauth 2.0 provider returned error: %s, description: %s", oauth2CallbackReq.Error, oauth2CallbackReq.ErrorDescription) + return a.redirectToErrorMessageCallbackPage(c, oauth2CallbackReq.ErrorDescription) + } + return a.redirectToFailedCallbackPage(c, errs.ErrMissingOAuth2Code) } @@ -134,33 +134,32 @@ func (a *OAuth2AuthenticationApi) CallbackHandler(c *core.WebContext) (string, * return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2LoginRequest) } - if a.CurrentConfig().EnableDuplicateSubmissionsCheck { - found, remark := a.GetSubmissionRemark(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, clientSessionId) + found, remark := a.GetSubmissionRemark(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, clientSessionId) - if !found { - log.Errorf(c, "[oauth2_authentications.CallbackHandler] cannot find oauth 2.0 state in duplicate checker for client session id \"%s\"", clientSessionId) - return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2Callback) - } - - remarkParts := strings.Split(remark, "|") - - if len(remarkParts) != 3 || remarkParts[0] != platform || remarkParts[1] != clientSessionId { - log.Errorf(c, "[oauth2_authentications.CallbackHandler] invalid oauth 2.0 state \"%s\" in duplicate checker for client session id \"%s\"", remark, clientSessionId) - return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2State) - } - - expectedState := fmt.Sprintf("%s|%s|%s", platform, clientSessionId, remarkParts[2]) - expectedState = fmt.Sprintf("%s|%s|%s", platform, clientSessionId, utils.MD5EncodeToString([]byte(expectedState))) - - if oauth2CallbackReq.State != expectedState { - log.Errorf(c, "[oauth2_authentications.CallbackHandler] mismatched random string in oauth 2.0 state, expected \"%s\", got \"%s\"", expectedState, oauth2CallbackReq.State) - return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2State) - } - - a.RemoveSubmissionRemarkIfEnable(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, clientSessionId) + if !found { + log.Errorf(c, "[oauth2_authentications.CallbackHandler] cannot find oauth 2.0 state in duplicate checker for client session id \"%s\"", clientSessionId) + return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2Callback) } - oauth2Token, err := oauth2.GetOAuth2Token(c, oauth2CallbackReq.Code) + remarkParts := strings.Split(remark, "|") + + if len(remarkParts) != 3 || remarkParts[0] != platform || remarkParts[1] != clientSessionId { + log.Errorf(c, "[oauth2_authentications.CallbackHandler] invalid oauth 2.0 state \"%s\" in duplicate checker for client session id \"%s\"", remark, clientSessionId) + return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2State) + } + + verifier := remarkParts[2] + expectedState := fmt.Sprintf("%s|%s|%s", platform, clientSessionId, verifier) + expectedState = fmt.Sprintf("%s|%s|%s", platform, clientSessionId, utils.MD5EncodeToString([]byte(expectedState))) + + if oauth2CallbackReq.State != expectedState { + log.Errorf(c, "[oauth2_authentications.CallbackHandler] mismatched random string in oauth 2.0 state, expected \"%s\", got \"%s\"", expectedState, oauth2CallbackReq.State) + return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2State) + } + + a.RemoveSubmissionRemark(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, clientSessionId) + + oauth2Token, err := oauth2.GetOAuth2Token(c, oauth2CallbackReq.Code, verifier) if err != nil { log.Errorf(c, "[oauth2_authentications.CallbackHandler] failed to retrieve oauth 2.0 token, because %s", err.Error()) @@ -347,3 +346,7 @@ func (a *OAuth2AuthenticationApi) redirectToVerifyCallbackPage(c *core.WebContex func (a *OAuth2AuthenticationApi) redirectToFailedCallbackPage(c *core.WebContext, err *errs.Error) (string, *errs.Error) { return fmt.Sprintf(oauth2CallbackPageUrlFailedFormat, a.CurrentConfig().RootUrl, err.Code(), url.QueryEscape(utils.GetDisplayErrorMessage(err))), nil } + +func (a *OAuth2AuthenticationApi) redirectToErrorMessageCallbackPage(c *core.WebContext, message string) (string, *errs.Error) { + return fmt.Sprintf(oauth2CallbackPageUrlErrorMessageFormat, a.CurrentConfig().RootUrl, url.QueryEscape(message)), nil +} diff --git a/pkg/api/server_settings.go b/pkg/api/server_settings.go index 9722d27d..60e76cbb 100644 --- a/pkg/api/server_settings.go +++ b/pkg/api/server_settings.go @@ -47,6 +47,10 @@ func (a *ServerSettingsApi) ServerSettingsJavascriptHandler(c *core.WebContext) a.appendStringSetting(builder, "op", config.OAuth2Provider) + if config.OAuth2Provider == settings.OAuth2ProviderOIDC && config.OAuth2OIDCCustomDisplayNameConfig.Enabled { + a.appendMultiLanguageTipSetting(builder, "ocn", config.OAuth2OIDCCustomDisplayNameConfig) + } + if config.EnableMCPServer { a.appendBooleanSetting(builder, "mcp", config.EnableMCPServer) } diff --git a/pkg/auth/oauth2/oauth2_user_info.go b/pkg/auth/oauth2/data/oauth2_user_info.go similarity index 95% rename from pkg/auth/oauth2/oauth2_user_info.go rename to pkg/auth/oauth2/data/oauth2_user_info.go index 450dbad1..b2bb15d7 100644 --- a/pkg/auth/oauth2/oauth2_user_info.go +++ b/pkg/auth/oauth2/data/oauth2_user_info.go @@ -1,4 +1,4 @@ -package oauth2 +package data import "github.com/mayswind/ezbookkeeping/pkg/core" diff --git a/pkg/auth/oauth2/oauth2_authentication.go b/pkg/auth/oauth2/oauth2_authentication.go index 94457b75..94fd4e6b 100644 --- a/pkg/auth/oauth2/oauth2_authentication.go +++ b/pkg/auth/oauth2/oauth2_authentication.go @@ -1,10 +1,18 @@ package oauth2 import ( + "crypto/sha256" + "encoding/base64" "net/http" "golang.org/x/oauth2" + "github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data" + "github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider" + "github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/gitea" + "github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/github" + "github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/nextcloud" + "github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/oidc" "github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/errs" "github.com/mayswind/ezbookkeeping/pkg/settings" @@ -13,8 +21,7 @@ import ( // OAuth2Container contains the current OAuth 2.0 authentication provider type OAuth2Container struct { - oauth2Config *oauth2.Config - oauth2Provider OAuth2Provider + current provider.OAuth2Provider oauth2HttpClient *http.Client externalUserAuthType core.UserExternalAuthType } @@ -34,24 +41,32 @@ func InitializeOAuth2Provider(config *settings.Config) error { return errs.ErrInvalidOAuth2Config } - var oauth2Provider OAuth2Provider + var err error + var oauth2Provider provider.OAuth2Provider var externalUserAuthType core.UserExternalAuthType + redirectUrl := config.RootUrl + "oauth2/callback" - if config.OAuth2Provider == settings.OAuth2ProviderNextcloud { - oauth2Provider = NewNextcloudOAuth2Provider(config.OAuth2NextcloudBaseUrl) + if config.OAuth2Provider == settings.OAuth2ProviderOIDC { + oauth2Provider, err = oidc.NewOIDCProvider(config, redirectUrl) + externalUserAuthType = core.USER_EXTERNAL_AUTH_TYPE_OAUTH2_OIDC + } else if config.OAuth2Provider == settings.OAuth2ProviderNextcloud { + oauth2Provider, err = nextcloud.NewNextcloudOAuth2Provider(config, redirectUrl) externalUserAuthType = core.USER_EXTERNAL_AUTH_TYPE_OAUTH2_NEXTCLOUD } else if config.OAuth2Provider == settings.OAuth2ProviderGitea { - oauth2Provider = NewGiteaOAuth2Provider(config.OAuth2GiteaBaseUrl) + oauth2Provider, err = gitea.NewGiteaOAuth2Provider(config, redirectUrl) externalUserAuthType = core.USER_EXTERNAL_AUTH_TYPE_OAUTH2_GITEA } else if config.OAuth2Provider == settings.OAuth2ProviderGithub { - oauth2Provider = NewGithubOAuth2Provider() + oauth2Provider, err = github.NewGithubOAuth2Provider(config, redirectUrl) externalUserAuthType = core.USER_EXTERNAL_AUTH_TYPE_OAUTH2_GITHUB } else { return errs.ErrInvalidOAuth2Provider } - Container.oauth2Config = buildOAuth2Config(config, oauth2Provider) - Container.oauth2Provider = oauth2Provider + if err != nil { + return err + } + + Container.current = oauth2Provider Container.oauth2HttpClient = utils.NewHttpClient(config.OAuth2RequestTimeout, config.OAuth2Proxy, config.OAuth2SkipTLSVerify, settings.GetUserAgent()) Container.externalUserAuthType = externalUserAuthType @@ -59,26 +74,29 @@ func InitializeOAuth2Provider(config *settings.Config) error { } // GetOAuth2AuthUrl returns the OAuth 2.0 authentication url -func GetOAuth2AuthUrl(c core.Context, state string) (string, error) { - if Container.oauth2Config == nil { +func GetOAuth2AuthUrl(c core.Context, state string, verifier string) (string, error) { + if Container.current == nil { return "", errs.ErrOAuth2NotEnabled } - return Container.oauth2Config.AuthCodeURL(state), nil + sha256Hash := sha256.New() + sha256Hash.Write([]byte(verifier)) + challenge := base64.RawURLEncoding.EncodeToString(sha256Hash.Sum(nil)) + return Container.current.GetOAuth2AuthUrl(wrapOAuth2Context(c, Container.oauth2HttpClient), state, challenge) } // GetOAuth2Token exchanges the authorization code for an OAuth 2.0 token -func GetOAuth2Token(c core.Context, code string) (*oauth2.Token, error) { - if Container.oauth2Config == nil || Container.oauth2HttpClient == nil { +func GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error) { + if Container.current == nil || Container.oauth2HttpClient == nil { return nil, errs.ErrOAuth2NotEnabled } - return Container.oauth2Config.Exchange(wrapOAuth2Context(c, Container.oauth2HttpClient), code) + return Container.current.GetOAuth2Token(wrapOAuth2Context(c, Container.oauth2HttpClient), code, verifier) } // GetOAuth2UserInfo retrieves the OAuth 2.0 user info using the provided OAuth 2.0 token -func GetOAuth2UserInfo(c core.Context, token *oauth2.Token) (*OAuth2UserInfo, error) { - if Container.oauth2Config == nil || Container.oauth2Provider == nil || Container.oauth2HttpClient == nil { +func GetOAuth2UserInfo(c core.Context, token *oauth2.Token) (*data.OAuth2UserInfo, error) { + if Container.current == nil || Container.oauth2HttpClient == nil { return nil, errs.ErrOAuth2NotEnabled } @@ -86,26 +104,10 @@ func GetOAuth2UserInfo(c core.Context, token *oauth2.Token) (*OAuth2UserInfo, er return nil, errs.ErrInvalidOAuth2Token } - oauth2Client := oauth2.NewClient(wrapOAuth2Context(c, Container.oauth2HttpClient), oauth2.StaticTokenSource(token)) - return Container.oauth2Provider.GetUserInfo(c, oauth2Client) + return Container.current.GetUserInfo(wrapOAuth2Context(c, Container.oauth2HttpClient), token) } // GetExternalUserAuthType returns the external user auth type of the current OAuth 2.0 provider func GetExternalUserAuthType() core.UserExternalAuthType { return Container.externalUserAuthType } - -func buildOAuth2Config(config *settings.Config, oauth2Provider OAuth2Provider) *oauth2.Config { - redirectURL := config.RootUrl + "oauth2/callback" - - return &oauth2.Config{ - ClientID: config.OAuth2ClientID, - ClientSecret: config.OAuth2ClientSecret, - Endpoint: oauth2.Endpoint{ - AuthURL: oauth2Provider.GetAuthUrl(), - TokenURL: oauth2Provider.GetTokenUrl(), - }, - RedirectURL: redirectURL, - Scopes: oauth2Provider.GetScopes(), - } -} diff --git a/pkg/auth/oauth2/oauth2_provider.go b/pkg/auth/oauth2/oauth2_provider.go deleted file mode 100644 index 55e72dc2..00000000 --- a/pkg/auth/oauth2/oauth2_provider.go +++ /dev/null @@ -1,22 +0,0 @@ -package oauth2 - -import ( - "net/http" - - "github.com/mayswind/ezbookkeeping/pkg/core" -) - -// OAuth2Provider defines the structure of OAuth 2.0 provider -type OAuth2Provider interface { - // GetAuthUrl returns the authentication url of the provider - GetAuthUrl() string - - // GetTokenUrl returns the token url of the provider - GetTokenUrl() string - - // GetUserInfo returns the user info - GetUserInfo(c core.Context, oauth2Client *http.Client) (*OAuth2UserInfo, error) - - // GetScopes returns the scopes required by the provider - GetScopes() []string -} diff --git a/pkg/auth/oauth2/common_oauth2_provider.go b/pkg/auth/oauth2/provider/common/common_oauth2_provider.go similarity index 51% rename from pkg/auth/oauth2/common_oauth2_provider.go rename to pkg/auth/oauth2/provider/common/common_oauth2_provider.go index 3fbc1136..aaf331ff 100644 --- a/pkg/auth/oauth2/common_oauth2_provider.go +++ b/pkg/auth/oauth2/provider/common/common_oauth2_provider.go @@ -1,18 +1,24 @@ -package oauth2 +package common import ( "io" "net/http" + "golang.org/x/oauth2" + + "github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data" + "github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider" "github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/errs" "github.com/mayswind/ezbookkeeping/pkg/log" + "github.com/mayswind/ezbookkeeping/pkg/settings" ) // CommonOAuth2Provider represents common OAuth 2.0 provider type CommonOAuth2Provider struct { - OAuth2Provider - dataSource CommonOAuth2DataSource + provider.OAuth2Provider + oauth2Config *oauth2.Config + dataSource CommonOAuth2DataSource } // CommonOAuth2DataSource defines the structure of OAuth 2.0 data source @@ -30,26 +36,21 @@ type CommonOAuth2DataSource interface { GetScopes() []string // ParseUserInfo returns the user info by parsing the response body - ParseUserInfo(c core.Context, body []byte) (*OAuth2UserInfo, error) + ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error) } -// GetAuthUrl returns the authentication url of the common OAuth 2.0 provider -func (p *CommonOAuth2Provider) GetAuthUrl() string { - return p.dataSource.GetAuthUrl() +// GetOAuth2AuthUrl returns the authentication url of the common OAuth 2.0 provider +func (p *CommonOAuth2Provider) GetOAuth2AuthUrl(c core.Context, state string, challenge string) (string, error) { + return p.oauth2Config.AuthCodeURL(state), nil } -// GetTokenUrl returns the token url of the common OAuth 2.0 provider -func (p *CommonOAuth2Provider) GetTokenUrl() string { - return p.dataSource.GetTokenUrl() -} - -// GetScopes returns the scopes required by the common OAuth 2.0 provider -func (p *CommonOAuth2Provider) GetScopes() []string { - return p.dataSource.GetScopes() +// GetOAuth2Token returns the OAuth 2.0 token of the common OAuth 2.0 provider +func (p *CommonOAuth2Provider) GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error) { + return p.oauth2Config.Exchange(c, code) } // GetUserInfo returns the user info by the common OAuth 2.0 provider -func (p *CommonOAuth2Provider) GetUserInfo(c core.Context, oauth2Client *http.Client) (*OAuth2UserInfo, error) { +func (p *CommonOAuth2Provider) GetUserInfo(c core.Context, oauth2Token *oauth2.Token) (*data.OAuth2UserInfo, error) { req, err := p.dataSource.GetUserInfoRequest() if err != nil { @@ -57,6 +58,7 @@ func (p *CommonOAuth2Provider) GetUserInfo(c core.Context, oauth2Client *http.Cl return nil, errs.ErrFailedToRequestRemoteApi } + oauth2Client := oauth2.NewClient(c, oauth2.StaticTokenSource(oauth2Token)) resp, err := oauth2Client.Do(req) if err != nil { @@ -76,3 +78,27 @@ func (p *CommonOAuth2Provider) GetUserInfo(c core.Context, oauth2Client *http.Cl return p.dataSource.ParseUserInfo(c, body) } + +// GetDataSource returns the data source of the common OAuth 2.0 provider +func (p *CommonOAuth2Provider) GetDataSource() CommonOAuth2DataSource { + return p.dataSource +} + +// NewCommonOAuth2Provider returns a new common OAuth 2.0 provider +func NewCommonOAuth2Provider(config *settings.Config, redirectUrl string, dataSource CommonOAuth2DataSource) *CommonOAuth2Provider { + oauth2Config := &oauth2.Config{ + ClientID: config.OAuth2ClientID, + ClientSecret: config.OAuth2ClientSecret, + Endpoint: oauth2.Endpoint{ + AuthURL: dataSource.GetAuthUrl(), + TokenURL: dataSource.GetTokenUrl(), + }, + RedirectURL: redirectUrl, + Scopes: dataSource.GetScopes(), + } + + return &CommonOAuth2Provider{ + oauth2Config: oauth2Config, + dataSource: dataSource, + } +} diff --git a/pkg/auth/oauth2/gitea_oauth2_datasource.go b/pkg/auth/oauth2/provider/gitea/gitea_oauth2_datasource.go similarity index 76% rename from pkg/auth/oauth2/gitea_oauth2_datasource.go rename to pkg/auth/oauth2/provider/gitea/gitea_oauth2_datasource.go index 256eb708..42268908 100644 --- a/pkg/auth/oauth2/gitea_oauth2_datasource.go +++ b/pkg/auth/oauth2/provider/gitea/gitea_oauth2_datasource.go @@ -1,12 +1,16 @@ -package oauth2 +package gitea import ( "encoding/json" "net/http" + "github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data" + "github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider" + "github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common" "github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/errs" "github.com/mayswind/ezbookkeeping/pkg/log" + "github.com/mayswind/ezbookkeeping/pkg/settings" ) type giteaUserInfoResponse struct { @@ -17,7 +21,7 @@ type giteaUserInfoResponse struct { // GiteaOAuth2DataSource represents Gitea OAuth 2.0 data source type GiteaOAuth2DataSource struct { - CommonOAuth2DataSource + common.CommonOAuth2DataSource baseUrl string } @@ -52,7 +56,7 @@ func (s *GiteaOAuth2DataSource) GetScopes() []string { } // ParseUserInfo returns the user info by parsing the response body -func (s *GiteaOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OAuth2UserInfo, error) { +func (s *GiteaOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error) { userInfoResp := &giteaUserInfoResponse{} err := json.Unmarshal(body, &userInfoResp) @@ -66,7 +70,7 @@ func (s *GiteaOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OAu return nil, errs.ErrCannotRetrieveUserInfo } - return &OAuth2UserInfo{ + return &data.OAuth2UserInfo{ UserName: userInfoResp.Login, Email: userInfoResp.Email, NickName: userInfoResp.FullName, @@ -74,14 +78,18 @@ func (s *GiteaOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OAu } // NewGiteaOAuth2Provider creates a new Gitea OAuth 2.0 provider instance -func NewGiteaOAuth2Provider(baseUrl string) OAuth2Provider { +func NewGiteaOAuth2Provider(config *settings.Config, redirectUrl string) (provider.OAuth2Provider, error) { + if len(config.OAuth2GiteaBaseUrl) < 1 { + return nil, errs.ErrInvalidOAuth2Config + } + + baseUrl := config.OAuth2GiteaBaseUrl + if baseUrl[len(baseUrl)-1] != '/' { baseUrl += "/" } - return &CommonOAuth2Provider{ - dataSource: &GiteaOAuth2DataSource{ - baseUrl: baseUrl, - }, - } + return common.NewCommonOAuth2Provider(config, redirectUrl, &GiteaOAuth2DataSource{ + baseUrl: baseUrl, + }), nil } diff --git a/pkg/auth/oauth2/gitea_oauth2_datasource_test.go b/pkg/auth/oauth2/provider/gitea/gitea_oauth2_datasource_test.go similarity index 64% rename from pkg/auth/oauth2/gitea_oauth2_datasource_test.go rename to pkg/auth/oauth2/provider/gitea/gitea_oauth2_datasource_test.go index b45fc3f0..e46060a9 100644 --- a/pkg/auth/oauth2/gitea_oauth2_datasource_test.go +++ b/pkg/auth/oauth2/provider/gitea/gitea_oauth2_datasource_test.go @@ -1,22 +1,33 @@ -package oauth2 +package gitea import ( "testing" "github.com/stretchr/testify/assert" + "github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common" "github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/errs" + "github.com/mayswind/ezbookkeeping/pkg/settings" ) func TestNewGiteaOAuth2Provider(t *testing.T) { - datasource := NewGiteaOAuth2Provider("https://example.com/") - assert.Equal(t, "https://example.com/login/oauth/authorize", datasource.GetAuthUrl()) - assert.Equal(t, "https://example.com/login/oauth/access_token", datasource.GetTokenUrl()) + provider, err := NewGiteaOAuth2Provider(&settings.Config{ + OAuth2GiteaBaseUrl: "https://example.com/", + }, "") + assert.Nil(t, err) + assert.Equal(t, "https://example.com/login/oauth/authorize", provider.(*common.CommonOAuth2Provider).GetDataSource().GetAuthUrl()) + assert.Equal(t, "https://example.com/login/oauth/access_token", provider.(*common.CommonOAuth2Provider).GetDataSource().GetTokenUrl()) - datasource = NewGiteaOAuth2Provider("https://example.com") - assert.Equal(t, "https://example.com/login/oauth/authorize", datasource.GetAuthUrl()) - assert.Equal(t, "https://example.com/login/oauth/access_token", datasource.GetTokenUrl()) + provider, err = NewGiteaOAuth2Provider(&settings.Config{ + OAuth2GiteaBaseUrl: "https://example.com", + }, "") + assert.Nil(t, err) + assert.Equal(t, "https://example.com/login/oauth/authorize", provider.(*common.CommonOAuth2Provider).GetDataSource().GetAuthUrl()) + assert.Equal(t, "https://example.com/login/oauth/access_token", provider.(*common.CommonOAuth2Provider).GetDataSource().GetTokenUrl()) + + provider, err = NewGiteaOAuth2Provider(&settings.Config{}, "") + assert.Equal(t, errs.ErrInvalidOAuth2Config, err) } func TestGiteaOAuth2Datasource_GetUserInfoRequest(t *testing.T) { diff --git a/pkg/auth/oauth2/github_oauth2_datasource.go b/pkg/auth/oauth2/provider/github/github_oauth2_datasource.go similarity index 80% rename from pkg/auth/oauth2/github_oauth2_datasource.go rename to pkg/auth/oauth2/provider/github/github_oauth2_datasource.go index dce88418..d16e8026 100644 --- a/pkg/auth/oauth2/github_oauth2_datasource.go +++ b/pkg/auth/oauth2/provider/github/github_oauth2_datasource.go @@ -1,12 +1,16 @@ -package oauth2 +package github import ( "encoding/json" "net/http" + "github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data" + "github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider" + "github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common" "github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/errs" "github.com/mayswind/ezbookkeeping/pkg/log" + "github.com/mayswind/ezbookkeeping/pkg/settings" ) type githubUserProfileResponse struct { @@ -17,7 +21,7 @@ type githubUserProfileResponse struct { // GithubOAuth2DataSource represents Github OAuth 2.0 data source type GithubOAuth2DataSource struct { - CommonOAuth2DataSource + common.CommonOAuth2DataSource } // GetAuthUrl returns the authentication url of the Github data source @@ -51,7 +55,7 @@ func (p *GithubOAuth2DataSource) GetScopes() []string { } // ParseUserInfo returns the user info by parsing the response body -func (p *GithubOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OAuth2UserInfo, error) { +func (p *GithubOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error) { userInfoResp := &githubUserProfileResponse{} err := json.Unmarshal(body, &userInfoResp) @@ -65,7 +69,7 @@ func (p *GithubOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OA return nil, errs.ErrCannotRetrieveUserInfo } - return &OAuth2UserInfo{ + return &data.OAuth2UserInfo{ UserName: userInfoResp.Login, Email: userInfoResp.Email, NickName: userInfoResp.Name, @@ -73,8 +77,6 @@ func (p *GithubOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OA } // NewGithubOAuth2Provider creates a new Github OAuth 2.0 provider instance -func NewGithubOAuth2Provider() OAuth2Provider { - return &CommonOAuth2Provider{ - dataSource: &GithubOAuth2DataSource{}, - } +func NewGithubOAuth2Provider(config *settings.Config, redirectUrl string) (provider.OAuth2Provider, error) { + return common.NewCommonOAuth2Provider(config, redirectUrl, &GithubOAuth2DataSource{}), nil } diff --git a/pkg/auth/oauth2/github_oauth2_datasource_test.go b/pkg/auth/oauth2/provider/github/github_oauth2_datasource_test.go similarity index 99% rename from pkg/auth/oauth2/github_oauth2_datasource_test.go rename to pkg/auth/oauth2/provider/github/github_oauth2_datasource_test.go index 7192ad66..6b962b79 100644 --- a/pkg/auth/oauth2/github_oauth2_datasource_test.go +++ b/pkg/auth/oauth2/provider/github/github_oauth2_datasource_test.go @@ -1,4 +1,4 @@ -package oauth2 +package github import ( "testing" diff --git a/pkg/auth/oauth2/nextcloud_oauth2_datasource.go b/pkg/auth/oauth2/provider/nextcloud/nextcloud_oauth2_datasource.go similarity index 81% rename from pkg/auth/oauth2/nextcloud_oauth2_datasource.go rename to pkg/auth/oauth2/provider/nextcloud/nextcloud_oauth2_datasource.go index 73110129..6924835b 100644 --- a/pkg/auth/oauth2/nextcloud_oauth2_datasource.go +++ b/pkg/auth/oauth2/provider/nextcloud/nextcloud_oauth2_datasource.go @@ -1,12 +1,16 @@ -package oauth2 +package nextcloud import ( "encoding/json" "net/http" + "github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data" + "github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider" + "github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common" "github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/errs" "github.com/mayswind/ezbookkeeping/pkg/log" + "github.com/mayswind/ezbookkeeping/pkg/settings" ) type nextcloudUserInfoResponse struct { @@ -25,7 +29,7 @@ type nextcloudUserInfoResponse struct { // NextcloudOAuth2DataSource represents Nextcloud OAuth 2.0 data source type NextcloudOAuth2DataSource struct { - CommonOAuth2DataSource + common.CommonOAuth2DataSource baseUrl string } @@ -61,7 +65,7 @@ func (s *NextcloudOAuth2DataSource) GetScopes() []string { } // ParseUserInfo returns the user info by parsing the response body -func (s *NextcloudOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OAuth2UserInfo, error) { +func (s *NextcloudOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error) { userInfoResp := &nextcloudUserInfoResponse{} err := json.Unmarshal(body, &userInfoResp) @@ -85,7 +89,7 @@ func (s *NextcloudOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) ( return nil, errs.ErrCannotRetrieveUserInfo } - return &OAuth2UserInfo{ + return &data.OAuth2UserInfo{ UserName: userInfoResp.OCS.Data.ID, Email: userInfoResp.OCS.Data.Email, NickName: userInfoResp.OCS.Data.DisplayName, @@ -93,14 +97,18 @@ func (s *NextcloudOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) ( } // NewNextcloudOAuth2Provider creates a new Nextcloud OAuth 2.0 provider instance -func NewNextcloudOAuth2Provider(baseUrl string) OAuth2Provider { +func NewNextcloudOAuth2Provider(config *settings.Config, redirectUrl string) (provider.OAuth2Provider, error) { + if len(config.OAuth2NextcloudBaseUrl) < 1 { + return nil, errs.ErrInvalidOAuth2Config + } + + baseUrl := config.OAuth2NextcloudBaseUrl + if baseUrl[len(baseUrl)-1] != '/' { baseUrl += "/" } - return &CommonOAuth2Provider{ - dataSource: &NextcloudOAuth2DataSource{ - baseUrl: baseUrl, - }, - } + return common.NewCommonOAuth2Provider(config, redirectUrl, &NextcloudOAuth2DataSource{ + baseUrl: baseUrl, + }), nil } diff --git a/pkg/auth/oauth2/nextcloud_oauth2_datasource_test.go b/pkg/auth/oauth2/provider/nextcloud/nextcloud_oauth2_datasource_test.go similarity index 75% rename from pkg/auth/oauth2/nextcloud_oauth2_datasource_test.go rename to pkg/auth/oauth2/provider/nextcloud/nextcloud_oauth2_datasource_test.go index 84112f10..17ae6485 100644 --- a/pkg/auth/oauth2/nextcloud_oauth2_datasource_test.go +++ b/pkg/auth/oauth2/provider/nextcloud/nextcloud_oauth2_datasource_test.go @@ -1,22 +1,33 @@ -package oauth2 +package nextcloud import ( "testing" "github.com/stretchr/testify/assert" + "github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common" "github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/errs" + "github.com/mayswind/ezbookkeeping/pkg/settings" ) func TestNewNextcloudOAuth2Provider(t *testing.T) { - datasource := NewNextcloudOAuth2Provider("https://example.com/") - assert.Equal(t, "https://example.com/apps/oauth2/authorize", datasource.GetAuthUrl()) - assert.Equal(t, "https://example.com/apps/oauth2/api/v1/token", datasource.GetTokenUrl()) + provider, err := NewNextcloudOAuth2Provider(&settings.Config{ + OAuth2NextcloudBaseUrl: "https://example.com/", + }, "") + assert.Nil(t, err) + assert.Equal(t, "https://example.com/apps/oauth2/authorize", provider.(*common.CommonOAuth2Provider).GetDataSource().GetAuthUrl()) + assert.Equal(t, "https://example.com/apps/oauth2/api/v1/token", provider.(*common.CommonOAuth2Provider).GetDataSource().GetTokenUrl()) - datasource = NewNextcloudOAuth2Provider("https://example.com/index.php") - assert.Equal(t, "https://example.com/index.php/apps/oauth2/authorize", datasource.GetAuthUrl()) - assert.Equal(t, "https://example.com/index.php/apps/oauth2/api/v1/token", datasource.GetTokenUrl()) + provider, err = NewNextcloudOAuth2Provider(&settings.Config{ + OAuth2NextcloudBaseUrl: "https://example.com/index.php", + }, "") + assert.Nil(t, err) + assert.Equal(t, "https://example.com/index.php/apps/oauth2/authorize", provider.(*common.CommonOAuth2Provider).GetDataSource().GetAuthUrl()) + assert.Equal(t, "https://example.com/index.php/apps/oauth2/api/v1/token", provider.(*common.CommonOAuth2Provider).GetDataSource().GetTokenUrl()) + + provider, err = NewNextcloudOAuth2Provider(&settings.Config{}, "") + assert.Equal(t, errs.ErrInvalidOAuth2Config, err) } func TestNextcloudOAuth2Datasource_GetUserInfoRequest(t *testing.T) { diff --git a/pkg/auth/oauth2/provider/oauth2_provider.go b/pkg/auth/oauth2/provider/oauth2_provider.go new file mode 100644 index 00000000..85e470a4 --- /dev/null +++ b/pkg/auth/oauth2/provider/oauth2_provider.go @@ -0,0 +1,20 @@ +package provider + +import ( + "golang.org/x/oauth2" + + "github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data" + "github.com/mayswind/ezbookkeeping/pkg/core" +) + +// OAuth2Provider defines the structure of OAuth 2.0 provider +type OAuth2Provider interface { + // GetOAuth2AuthUrl returns the authentication url of the provider + GetOAuth2AuthUrl(c core.Context, state string, challenge string) (string, error) + + // GetOAuth2Token returns the OAuth 2.0 token of the provider + GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error) + + // GetUserInfo returns the user info + GetUserInfo(c core.Context, oauth2Token *oauth2.Token) (*data.OAuth2UserInfo, error) +} diff --git a/pkg/auth/oauth2/provider/oidc/oidc_provider.go b/pkg/auth/oauth2/provider/oidc/oidc_provider.go new file mode 100644 index 00000000..4f841a49 --- /dev/null +++ b/pkg/auth/oauth2/provider/oidc/oidc_provider.go @@ -0,0 +1,173 @@ +package oidc + +import ( + "strings" + + "golang.org/x/oauth2" + + "github.com/coreos/go-oidc/v3/oidc" + + "github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data" + "github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider" + "github.com/mayswind/ezbookkeeping/pkg/core" + "github.com/mayswind/ezbookkeeping/pkg/errs" + "github.com/mayswind/ezbookkeeping/pkg/log" + "github.com/mayswind/ezbookkeeping/pkg/settings" +) + +// OIDCClaims represents OIDC claims +type OIDCClaims struct { + PreferredUserName string `json:"preferred_username"` + Name string `json:"name"` + Email string `json:"email"` +} + +// OIDCProvider represents OIDC provider +type OIDCProvider struct { + provider.OAuth2Provider + oidcBaseUrl string + redirectUrl string + oauth2ClientID string + oauth2ClientSecret string + oauth2Config *oauth2.Config + oidcProvider *oidc.Provider + oidcVerifier *oidc.IDTokenVerifier +} + +// GetOAuth2AuthUrl returns the authentication url of the OIDC provider +func (p *OIDCProvider) GetOAuth2AuthUrl(c core.Context, state string, challenge string) (string, error) { + oauth2Config, err := p.getOAuth2Config(c) + + if err != nil { + return "", err + } + + return oauth2Config.AuthCodeURL(state, + oauth2.SetAuthURLParam("code_challenge", challenge), + oauth2.SetAuthURLParam("code_challenge_method", "S256")), nil +} + +// GetOAuth2Token returns the OAuth 2.0 token of the OIDC provider +func (p *OIDCProvider) GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error) { + oauth2Config, err := p.getOAuth2Config(c) + + if err != nil { + return nil, err + } + + return oauth2Config.Exchange(c, code, oauth2.SetAuthURLParam("code_verifier", verifier)) +} + +// GetUserInfo returns the user info by the OIDC provider +func (p *OIDCProvider) GetUserInfo(c core.Context, oauth2Token *oauth2.Token) (*data.OAuth2UserInfo, error) { + _, err := p.getOAuth2Config(c) + + if err != nil { + return nil, err + } + + rawIDToken, ok := oauth2Token.Extra("id_token").(string) + + if !ok { + log.Errorf(c, "[oidc_provider.GetUserInfo] missing \"id_token\" field in oauth 2.0 token") + return nil, errs.ErrInvalidOAuth2Token + } + + idToken, err := p.oidcVerifier.Verify(c, rawIDToken) + + if err != nil { + log.Errorf(c, "[oidc_provider.GetUserInfo] failed to verify \"id_token\" field in oauth 2.0 token, because %s", err.Error()) + return nil, errs.ErrInvalidOAuth2Token + } + + var claims OIDCClaims + err = idToken.Claims(&claims) + + if err != nil { + log.Errorf(c, "[oidc_provider.GetUserInfo] failed to parse claims in oauth 2.0 token, because %s", err.Error()) + return nil, errs.ErrInvalidOAuth2Token + } + + userName := claims.PreferredUserName + email := claims.Email + nickName := claims.Name + + if userName == "" || email == "" || nickName == "" { + userInfo, err := p.oidcProvider.UserInfo(c, oauth2.StaticTokenSource(oauth2Token)) + + if err != nil { + log.Errorf(c, "[oidc_provider.GetUserInfo] failed to get user info, because %s", err.Error()) + return nil, errs.ErrCannotRetrieveUserInfo + } + + err = userInfo.Claims(&claims) + + if err != nil { + log.Errorf(c, "[oidc_provider.GetUserInfo] failed to parse user info, because %s", err.Error()) + return nil, errs.ErrCannotRetrieveUserInfo + } + + if userName == "" { + userName = claims.PreferredUserName + } + + if email == "" { + email = claims.Email + } + + if nickName == "" { + nickName = claims.Name + } + } + + return &data.OAuth2UserInfo{ + UserName: userName, + Email: email, + NickName: nickName, + }, nil +} + +func (p *OIDCProvider) getOAuth2Config(c core.Context) (*oauth2.Config, error) { + if p.oauth2Config != nil { + return p.oauth2Config, nil + } + + oidcProvider, err := oidc.NewProvider(c, p.oidcBaseUrl) + + if err != nil { + log.Errorf(c, "[oidc_provider.getOAuth2Config] failed to create oidc provider, because %s", err.Error()) + return nil, err + } + + oidcVerifier := oidcProvider.Verifier(&oidc.Config{ClientID: p.oauth2ClientID}) + + oauth2Config := &oauth2.Config{ + ClientID: p.oauth2ClientID, + ClientSecret: p.oauth2ClientSecret, + Endpoint: oidcProvider.Endpoint(), + RedirectURL: p.redirectUrl, + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + } + + p.oauth2Config = oauth2Config + p.oidcProvider = oidcProvider + p.oidcVerifier = oidcVerifier + return oauth2Config, nil +} + +// NewOIDCProvider returns a new OIDC provider +func NewOIDCProvider(config *settings.Config, redirectUrl string) (*OIDCProvider, error) { + if len(config.OAuth2OIDCProviderBaseUrl) < 1 { + return nil, errs.ErrInvalidOAuth2Config + } + + baseUrl := strings.TrimSuffix(config.OAuth2OIDCProviderBaseUrl, "/") + + return &OIDCProvider{ + oidcBaseUrl: baseUrl, + redirectUrl: redirectUrl, + oauth2ClientID: config.OAuth2ClientID, + oauth2ClientSecret: config.OAuth2ClientSecret, + oauth2Config: nil, + }, nil +} diff --git a/pkg/core/user_external_auth_type.go b/pkg/core/user_external_auth_type.go index 76500f49..9c7057d9 100644 --- a/pkg/core/user_external_auth_type.go +++ b/pkg/core/user_external_auth_type.go @@ -5,6 +5,7 @@ type UserExternalAuthType string // User External Auth Type const ( + USER_EXTERNAL_AUTH_TYPE_OAUTH2_OIDC UserExternalAuthType = "oidc" USER_EXTERNAL_AUTH_TYPE_OAUTH2_NEXTCLOUD UserExternalAuthType = "nextcloud" USER_EXTERNAL_AUTH_TYPE_OAUTH2_GITEA UserExternalAuthType = "gitea" USER_EXTERNAL_AUTH_TYPE_OAUTH2_GITHUB UserExternalAuthType = "github" @@ -13,7 +14,8 @@ const ( // IsValid checks if the UserExternalAuthType is valid func (t UserExternalAuthType) IsValid() bool { switch t { - case USER_EXTERNAL_AUTH_TYPE_OAUTH2_NEXTCLOUD, + case USER_EXTERNAL_AUTH_TYPE_OAUTH2_OIDC, + USER_EXTERNAL_AUTH_TYPE_OAUTH2_NEXTCLOUD, USER_EXTERNAL_AUTH_TYPE_OAUTH2_GITEA, USER_EXTERNAL_AUTH_TYPE_OAUTH2_GITHUB: return true diff --git a/pkg/models/oauth2.go b/pkg/models/oauth2.go index ef2c1173..6954bc36 100644 --- a/pkg/models/oauth2.go +++ b/pkg/models/oauth2.go @@ -8,8 +8,10 @@ type OAuth2LoginRequest struct { // OAuth2CallbackRequest represents all parameters of OAuth 2.0 callback request type OAuth2CallbackRequest struct { - State string `form:"state"` - Code string `form:"code"` + State string `form:"state"` + Code string `form:"code"` + Error string `form:"error"` + ErrorDescription string `form:"error_description"` } // OAuth2CallbackLoginRequest represents all parameters of OAuth 2.0 callback login request diff --git a/pkg/settings/setting.go b/pkg/settings/setting.go index fc42fe70..2c6f873a 100644 --- a/pkg/settings/setting.go +++ b/pkg/settings/setting.go @@ -93,6 +93,7 @@ const ( // OAuth 2.0 provider types const ( + OAuth2ProviderOIDC string = "oidc" OAuth2ProviderNextcloud string = "nextcloud" OAuth2ProviderGitea string = "gitea" OAuth2ProviderGithub string = "github" @@ -360,23 +361,25 @@ type Config struct { MaxFailuresPerUserPerMinute uint32 // Auth - EnableInternalAuth bool - EnableOAuth2Login bool - EnableTwoFactor bool - EnableUserForgetPassword bool - ForgetPasswordRequireVerifyEmail bool - OAuth2ClientID string - OAuth2ClientSecret string - OAuth2UserIdentifier string - OAuth2AutoRegister bool - OAuth2Provider string - OAuth2StateExpiredTime uint32 - OAuth2StateExpiredTimeDuration time.Duration - OAuth2RequestTimeout uint32 - OAuth2Proxy string - OAuth2SkipTLSVerify bool - OAuth2NextcloudBaseUrl string - OAuth2GiteaBaseUrl string + EnableInternalAuth bool + EnableOAuth2Login bool + EnableTwoFactor bool + EnableUserForgetPassword bool + ForgetPasswordRequireVerifyEmail bool + OAuth2ClientID string + OAuth2ClientSecret string + OAuth2UserIdentifier string + OAuth2AutoRegister bool + OAuth2Provider string + OAuth2StateExpiredTime uint32 + OAuth2StateExpiredTimeDuration time.Duration + OAuth2RequestTimeout uint32 + OAuth2Proxy string + OAuth2SkipTLSVerify bool + OAuth2OIDCProviderBaseUrl string + OAuth2OIDCCustomDisplayNameConfig MultiLanguageContentConfig + OAuth2NextcloudBaseUrl string + OAuth2GiteaBaseUrl string // User EnableUserRegister bool @@ -1003,6 +1006,8 @@ func loadAuthConfiguration(config *Config, configFile *ini.File, sectionName str if oauth2Provider == "" { config.OAuth2Provider = "" + } else if oauth2Provider == OAuth2ProviderOIDC { + config.OAuth2Provider = OAuth2ProviderOIDC } else if oauth2Provider == OAuth2ProviderNextcloud { config.OAuth2Provider = OAuth2ProviderNextcloud } else if oauth2Provider == OAuth2ProviderGitea { @@ -1025,6 +1030,8 @@ func loadAuthConfiguration(config *Config, configFile *ini.File, sectionName str config.OAuth2RequestTimeout = getConfigItemUint32Value(configFile, sectionName, "oauth2_request_timeout", defaultOAuth2RequestTimeout) config.OAuth2SkipTLSVerify = getConfigItemBoolValue(configFile, sectionName, "oauth2_skip_tls_verify", false) + config.OAuth2OIDCProviderBaseUrl = getConfigItemStringValue(configFile, sectionName, "oidc_provider_base_url") + config.OAuth2OIDCCustomDisplayNameConfig = getMultiLanguageContentConfig(configFile, sectionName, "enable_oidc_display_name", "oidc_custom_display_name") config.OAuth2NextcloudBaseUrl = getConfigItemStringValue(configFile, sectionName, "nextcloud_base_url") config.OAuth2GiteaBaseUrl = getConfigItemStringValue(configFile, sectionName, "gitea_base_url") diff --git a/src/core/api.ts b/src/core/api.ts index 81c6727c..9087f721 100644 --- a/src/core/api.ts +++ b/src/core/api.ts @@ -9,3 +9,14 @@ export interface ErrorResponse { readonly errorMessage: string; readonly path: string; } + +export function buildErrorResponse(errorCode: number, errorMessage: string): ErrorResponse { + const errorResponse: ErrorResponse = { + success: false, + errorCode: errorCode, + errorMessage: errorMessage, + path: '' + }; + + return errorResponse; +} diff --git a/src/views/desktop/OAuth2CallbackPage.vue b/src/views/desktop/OAuth2CallbackPage.vue index 6b9bba50..cef380e3 100644 --- a/src/views/desktop/OAuth2CallbackPage.vue +++ b/src/views/desktop/OAuth2CallbackPage.vue @@ -24,9 +24,10 @@

{{ oauth2LoginDisplayName }}

-

{{ tt('Logging in...') }}

-

{{ tt('format.misc.oauth2bindTip', { providerName: oauth2ProviderDisplayName, userName: userName }) }}

+

{{ tt('Logging in...') }}

+

{{ tt('format.misc.oauth2bindTip', { providerName: oauth2ProviderDisplayName, userName: userName }) }}

{{ te({ error }) }}

+

{{ errorMessage }}

{{ tt('An error occurred') }}

@@ -106,7 +107,7 @@ import { useLoginPageBase } from '@/views/base/LoginPageBase.ts'; import { useRootStore } from '@/stores/index.ts'; import { ThemeType } from '@/core/theme.ts'; -import { type ErrorResponse } from '@/core/api.ts'; +import { type ErrorResponse, buildErrorResponse } from '@/core/api.ts'; import { APPLICATION_LOGO_PATH } from '@/consts/asset.ts'; import { KnownErrorCode } from '@/consts/api.ts'; @@ -158,14 +159,7 @@ const oauth2LoginDisplayName = computed(() => getLocalizedOAuth2LoginTex const error = computed(() => { if (props.errorCode && props.errorMessage) { - const errorResponse: ErrorResponse = { - success: false, - errorCode: parseInt(props.errorCode), - errorMessage: props.errorMessage, - path: '' - }; - - return errorResponse; + return buildErrorResponse(parseInt(props.errorCode), props.errorMessage); } else { return undefined; } diff --git a/third-party-dependencies.json b/third-party-dependencies.json index 2f3f4ec9..2f249a2a 100644 --- a/third-party-dependencies.json +++ b/third-party-dependencies.json @@ -105,6 +105,12 @@ "url": "https://golang.org/x/oauth2", "licenseUrl": "https://cs.opensource.google/go/x/oauth2/+/refs/tags/v0.31.0:LICENSE" }, + { + "name": "go-oidc", + "copyright": "Copyright 2014 CoreOS, Inc", + "url": "https://github.com/coreos/go-oidc", + "licenseUrl": "https://github.com/coreos/go-oidc/blob/v3.16.0/LICENSE" + }, { "name": "Gomail", "copyright": "Copyright (c) 2014 Alexandre Cesaro",