support OIDC authentication (#242)
This commit is contained in:
+14
-3
@@ -285,6 +285,9 @@ enable_forget_password = true
|
||||
# For "internal" authentication only, set to true to require email must be verified when use forget password
|
||||
forget_password_require_email_verify = false
|
||||
|
||||
# For "oauth2" authentication only, OAuth 2.0 provider, supports "oidc", "nextcloud", "gitea" and "github" currently
|
||||
oauth2_provider =
|
||||
|
||||
# For "oauth2" authentication only, OAuth 2.0 client ID
|
||||
oauth2_client_id =
|
||||
|
||||
@@ -297,9 +300,6 @@ oauth2_user_identifier = email
|
||||
# For "oauth2" authentication only, if the user returned by OAuth 2.0 is not registered, automatically create a new user (requires "enable_register" to be set to true)
|
||||
oauth2_auto_register = true
|
||||
|
||||
# For "oauth2" authentication only, OAuth 2.0 provider, supports "nextcloud", "gitea" and "github" currently
|
||||
oauth2_provider =
|
||||
|
||||
# For "oauth2" authentication only, OAuth 2.0 state expired seconds (60 - 4294967295), default is 300 (5 minutes)
|
||||
oauth2_state_expired_time = 300
|
||||
|
||||
@@ -313,6 +313,17 @@ oauth2_proxy = system
|
||||
# For "oauth2" authentication only, set to true to skip tls verification when request OAuth 2.0 api
|
||||
oauth2_skip_tls_verify = false
|
||||
|
||||
# For "oauth2" authentication and "oidc" OAuth 2.0 provider only, OIDC provider base url. Make sure the ".well-known" directory is available under this path. For example, if it's set to "https://auth.example.com/", the discovery URL should be "https://auth.example.com/.well-known/openid-configuration".
|
||||
oidc_provider_base_url =
|
||||
|
||||
# For "oauth2" authentication and "oidc" OAuth 2.0 provider only, set to true to replace the text in the "Log in with Connect ID" button with the below custom display name
|
||||
enable_oidc_display_name = false
|
||||
|
||||
# For "oauth2" authentication and "oidc" OAuth 2.0 provider only, the custom display name to replace the text in the "Log in with Connect ID" button, it supports multi-language configuration
|
||||
# Add an underscore and a language tag after the setting key to configure the display name in that language, the same below
|
||||
# For example, oidc_custom_display_name_zh_hans means the display name in Chinese (Simplified)
|
||||
oidc_custom_display_name =
|
||||
|
||||
# For "oauth2" authentication and "nextcloud" OAuth 2.0 provider only, Nextcloud base url, e.g. "https://cloud.example.org/"
|
||||
nextcloud_base_url =
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ go 1.25
|
||||
|
||||
require (
|
||||
github.com/boombuler/barcode v1.1.0
|
||||
github.com/coreos/go-oidc/v3 v3.16.0
|
||||
github.com/extrame/xls v0.0.2-0.20200426124601-4a6cf263071b
|
||||
github.com/gin-contrib/cache v1.4.1
|
||||
github.com/gin-contrib/gzip v1.2.3
|
||||
@@ -25,6 +26,7 @@ require (
|
||||
github.com/xuri/excelize/v2 v2.9.0
|
||||
golang.org/x/crypto v0.41.0
|
||||
golang.org/x/net v0.43.0
|
||||
golang.org/x/oauth2 v0.31.0
|
||||
golang.org/x/text v0.28.0
|
||||
gopkg.in/ini.v1 v1.67.0
|
||||
gopkg.in/mail.v2 v2.3.1
|
||||
@@ -52,6 +54,7 @@ require (
|
||||
github.com/gabriel-vasile/mimetype v1.4.9 // indirect
|
||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||
github.com/go-ini/ini v1.67.0 // indirect
|
||||
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
@@ -91,7 +94,6 @@ require (
|
||||
github.com/xuri/nfp v0.0.1 // indirect
|
||||
golang.org/x/arch v0.18.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 // indirect
|
||||
golang.org/x/oauth2 v0.31.0 // indirect
|
||||
golang.org/x/sys v0.35.0 // indirect
|
||||
google.golang.org/protobuf v1.36.6 // indirect
|
||||
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
|
||||
|
||||
@@ -27,6 +27,8 @@ github.com/chenzhuoyu/iasm v0.9.1/go.mod h1:Xjy2NpN3h7aUqeqM+woSuuvxmIe6+DDsiNLI
|
||||
github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4=
|
||||
github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
||||
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
|
||||
github.com/coreos/go-oidc/v3 v3.16.0 h1:qRQUCFstKpXwmEjDQTIbyY/5jF00+asXzSkmkoa/mow=
|
||||
github.com/coreos/go-oidc/v3 v3.16.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.5/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
@@ -54,6 +56,8 @@ github.com/go-co-op/gocron/v2 v2.16.5 h1:j228Jxk7bb9CF8LKR3gS+bK3rcjRUINjlVI+ZMp
|
||||
github.com/go-co-op/gocron/v2 v2.16.5/go.mod h1:zAfC/GFQ668qHxOVl/D68Jh5Ce7sDqX6TJnSQyRkRBc=
|
||||
github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A=
|
||||
github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8=
|
||||
github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs=
|
||||
github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
|
||||
+8
-5
@@ -114,6 +114,11 @@ func (a *ApiUsingDuplicateChecker) GetSubmissionRemark(checkerType duplicatechec
|
||||
return a.container.GetSubmissionRemark(checkerType, uid, identification)
|
||||
}
|
||||
|
||||
// SetSubmissionRemarkWithCustomExpiration saves the identification and remark by the current duplicate checker with custom expiration time
|
||||
func (a *ApiUsingDuplicateChecker) SetSubmissionRemarkWithCustomExpiration(checkerType duplicatechecker.DuplicateCheckerType, uid int64, identification string, remark string, expiration time.Duration) {
|
||||
a.container.SetSubmissionRemarkWithCustomExpiration(checkerType, uid, identification, remark, expiration)
|
||||
}
|
||||
|
||||
// SetSubmissionRemarkIfEnable saves the identification and remark by the current duplicate checker if the duplicate submission check is enabled
|
||||
func (a *ApiUsingDuplicateChecker) SetSubmissionRemarkIfEnable(checkerType duplicatechecker.DuplicateCheckerType, uid int64, identification string, remark string) {
|
||||
if a.CurrentConfig().EnableDuplicateSubmissionsCheck {
|
||||
@@ -121,11 +126,9 @@ func (a *ApiUsingDuplicateChecker) SetSubmissionRemarkIfEnable(checkerType dupli
|
||||
}
|
||||
}
|
||||
|
||||
// SetSubmissionRemarkWithCustomExpirationIfEnable saves the identification and remark by the current duplicate checker with custom expiration time if the duplicate submission check is enabled
|
||||
func (a *ApiUsingDuplicateChecker) SetSubmissionRemarkWithCustomExpirationIfEnable(checkerType duplicatechecker.DuplicateCheckerType, uid int64, identification string, remark string, expiration time.Duration) {
|
||||
if a.CurrentConfig().EnableDuplicateSubmissionsCheck {
|
||||
a.container.SetSubmissionRemarkWithCustomExpiration(checkerType, uid, identification, remark, expiration)
|
||||
}
|
||||
// RemoveSubmissionRemark removes the identification and remark by the current duplicate checker
|
||||
func (a *ApiUsingDuplicateChecker) RemoveSubmissionRemark(checkerType duplicatechecker.DuplicateCheckerType, uid int64, identification string) {
|
||||
a.container.RemoveSubmissionRemark(checkerType, uid, identification)
|
||||
}
|
||||
|
||||
// RemoveSubmissionRemarkIfEnable removes the identification and remark by the current duplicate checker if the duplicate submission check is enabled
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
const oauth2CallbackPageUrlSuccessFormat = "%sdesktop/#/oauth2_callback?platform=%s&provider=%s&token=%s"
|
||||
const oauth2CallbackPageUrlNeedVerifyFormat = "%sdesktop/#/oauth2_callback?platform=%s&provider=%s&userName=%s&token=%s"
|
||||
const oauth2CallbackPageUrlFailedFormat = "%sdesktop/#/oauth2_callback?errorCode=%d&errorMessage=%s"
|
||||
const oauth2CallbackPageUrlErrorMessageFormat = "%sdesktop/#/oauth2_callback?errorMessage=%s"
|
||||
|
||||
// OAuth2AuthenticationApi represents OAuth 2.0 authorization api
|
||||
type OAuth2AuthenticationApi struct {
|
||||
@@ -65,37 +66,31 @@ func (a *OAuth2AuthenticationApi) LoginHandler(c *core.WebContext) (string, *err
|
||||
return "", errs.ErrInvalidOAuth2LoginRequest
|
||||
}
|
||||
|
||||
state := fmt.Sprintf("%s|%s", oauth2LoginReq.Platform, oauth2LoginReq.ClientSessionId)
|
||||
remark := ""
|
||||
|
||||
if a.CurrentConfig().EnableDuplicateSubmissionsCheck {
|
||||
found := false
|
||||
found, remark = a.GetSubmissionRemark(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, oauth2LoginReq.ClientSessionId)
|
||||
found, remark := a.GetSubmissionRemark(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, oauth2LoginReq.ClientSessionId)
|
||||
|
||||
if found {
|
||||
log.Errorf(c, "[oauth2_authentications.LoginHandler] another oauth 2.0 state \"%s\" has been processing for client session id \"%s\"", remark, oauth2LoginReq.ClientSessionId)
|
||||
return "", errs.ErrRepeatedRequest
|
||||
}
|
||||
|
||||
randomString, err := utils.GetRandomNumberOrLowercaseLetter(32)
|
||||
verifier, err := utils.GetRandomNumberOrLowercaseLetter(64)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[oauth2_authentications.LoginHandler] failed to generate random string for oauth 2.0 state, because %s", err.Error())
|
||||
return "", errs.ErrSystemError
|
||||
}
|
||||
|
||||
remark = fmt.Sprintf("%s|%s|%s", oauth2LoginReq.Platform, oauth2LoginReq.ClientSessionId, randomString)
|
||||
state = fmt.Sprintf("%s|%s|%s", oauth2LoginReq.Platform, oauth2LoginReq.ClientSessionId, utils.MD5EncodeToString([]byte(remark)))
|
||||
}
|
||||
remark = fmt.Sprintf("%s|%s|%s", oauth2LoginReq.Platform, oauth2LoginReq.ClientSessionId, verifier)
|
||||
state := fmt.Sprintf("%s|%s|%s", oauth2LoginReq.Platform, oauth2LoginReq.ClientSessionId, utils.MD5EncodeToString([]byte(remark)))
|
||||
|
||||
redirectUrl, err := oauth2.GetOAuth2AuthUrl(c, state)
|
||||
redirectUrl, err := oauth2.GetOAuth2AuthUrl(c, state, verifier)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[oauth2_authentications.LoginHandler] failed to get oauth 2.0 auth url, because %s", err.Error())
|
||||
return "", errs.Or(err, errs.ErrSystemError)
|
||||
}
|
||||
|
||||
a.SetSubmissionRemarkWithCustomExpirationIfEnable(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, oauth2LoginReq.ClientSessionId, remark, a.CurrentConfig().OAuth2StateExpiredTimeDuration)
|
||||
a.SetSubmissionRemarkWithCustomExpiration(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, oauth2LoginReq.ClientSessionId, remark, a.CurrentConfig().OAuth2StateExpiredTimeDuration)
|
||||
|
||||
return redirectUrl, nil
|
||||
}
|
||||
@@ -115,6 +110,11 @@ func (a *OAuth2AuthenticationApi) CallbackHandler(c *core.WebContext) (string, *
|
||||
}
|
||||
|
||||
if oauth2CallbackReq.Code == "" {
|
||||
if oauth2CallbackReq.ErrorDescription != "" {
|
||||
log.Errorf(c, "[oauth2_authentications.CallbackHandler] oauth 2.0 provider returned error: %s, description: %s", oauth2CallbackReq.Error, oauth2CallbackReq.ErrorDescription)
|
||||
return a.redirectToErrorMessageCallbackPage(c, oauth2CallbackReq.ErrorDescription)
|
||||
}
|
||||
|
||||
return a.redirectToFailedCallbackPage(c, errs.ErrMissingOAuth2Code)
|
||||
}
|
||||
|
||||
@@ -134,7 +134,6 @@ func (a *OAuth2AuthenticationApi) CallbackHandler(c *core.WebContext) (string, *
|
||||
return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2LoginRequest)
|
||||
}
|
||||
|
||||
if a.CurrentConfig().EnableDuplicateSubmissionsCheck {
|
||||
found, remark := a.GetSubmissionRemark(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, clientSessionId)
|
||||
|
||||
if !found {
|
||||
@@ -149,7 +148,8 @@ func (a *OAuth2AuthenticationApi) CallbackHandler(c *core.WebContext) (string, *
|
||||
return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2State)
|
||||
}
|
||||
|
||||
expectedState := fmt.Sprintf("%s|%s|%s", platform, clientSessionId, remarkParts[2])
|
||||
verifier := remarkParts[2]
|
||||
expectedState := fmt.Sprintf("%s|%s|%s", platform, clientSessionId, verifier)
|
||||
expectedState = fmt.Sprintf("%s|%s|%s", platform, clientSessionId, utils.MD5EncodeToString([]byte(expectedState)))
|
||||
|
||||
if oauth2CallbackReq.State != expectedState {
|
||||
@@ -157,10 +157,9 @@ func (a *OAuth2AuthenticationApi) CallbackHandler(c *core.WebContext) (string, *
|
||||
return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2State)
|
||||
}
|
||||
|
||||
a.RemoveSubmissionRemarkIfEnable(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, clientSessionId)
|
||||
}
|
||||
a.RemoveSubmissionRemark(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, clientSessionId)
|
||||
|
||||
oauth2Token, err := oauth2.GetOAuth2Token(c, oauth2CallbackReq.Code)
|
||||
oauth2Token, err := oauth2.GetOAuth2Token(c, oauth2CallbackReq.Code, verifier)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[oauth2_authentications.CallbackHandler] failed to retrieve oauth 2.0 token, because %s", err.Error())
|
||||
@@ -347,3 +346,7 @@ func (a *OAuth2AuthenticationApi) redirectToVerifyCallbackPage(c *core.WebContex
|
||||
func (a *OAuth2AuthenticationApi) redirectToFailedCallbackPage(c *core.WebContext, err *errs.Error) (string, *errs.Error) {
|
||||
return fmt.Sprintf(oauth2CallbackPageUrlFailedFormat, a.CurrentConfig().RootUrl, err.Code(), url.QueryEscape(utils.GetDisplayErrorMessage(err))), nil
|
||||
}
|
||||
|
||||
func (a *OAuth2AuthenticationApi) redirectToErrorMessageCallbackPage(c *core.WebContext, message string) (string, *errs.Error) {
|
||||
return fmt.Sprintf(oauth2CallbackPageUrlErrorMessageFormat, a.CurrentConfig().RootUrl, url.QueryEscape(message)), nil
|
||||
}
|
||||
|
||||
@@ -47,6 +47,10 @@ func (a *ServerSettingsApi) ServerSettingsJavascriptHandler(c *core.WebContext)
|
||||
|
||||
a.appendStringSetting(builder, "op", config.OAuth2Provider)
|
||||
|
||||
if config.OAuth2Provider == settings.OAuth2ProviderOIDC && config.OAuth2OIDCCustomDisplayNameConfig.Enabled {
|
||||
a.appendMultiLanguageTipSetting(builder, "ocn", config.OAuth2OIDCCustomDisplayNameConfig)
|
||||
}
|
||||
|
||||
if config.EnableMCPServer {
|
||||
a.appendBooleanSetting(builder, "mcp", config.EnableMCPServer)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package oauth2
|
||||
package data
|
||||
|
||||
import "github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
|
||||
@@ -1,10 +1,18 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/gitea"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/github"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/nextcloud"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/oidc"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
||||
@@ -13,8 +21,7 @@ import (
|
||||
|
||||
// OAuth2Container contains the current OAuth 2.0 authentication provider
|
||||
type OAuth2Container struct {
|
||||
oauth2Config *oauth2.Config
|
||||
oauth2Provider OAuth2Provider
|
||||
current provider.OAuth2Provider
|
||||
oauth2HttpClient *http.Client
|
||||
externalUserAuthType core.UserExternalAuthType
|
||||
}
|
||||
@@ -34,24 +41,32 @@ func InitializeOAuth2Provider(config *settings.Config) error {
|
||||
return errs.ErrInvalidOAuth2Config
|
||||
}
|
||||
|
||||
var oauth2Provider OAuth2Provider
|
||||
var err error
|
||||
var oauth2Provider provider.OAuth2Provider
|
||||
var externalUserAuthType core.UserExternalAuthType
|
||||
redirectUrl := config.RootUrl + "oauth2/callback"
|
||||
|
||||
if config.OAuth2Provider == settings.OAuth2ProviderNextcloud {
|
||||
oauth2Provider = NewNextcloudOAuth2Provider(config.OAuth2NextcloudBaseUrl)
|
||||
if config.OAuth2Provider == settings.OAuth2ProviderOIDC {
|
||||
oauth2Provider, err = oidc.NewOIDCProvider(config, redirectUrl)
|
||||
externalUserAuthType = core.USER_EXTERNAL_AUTH_TYPE_OAUTH2_OIDC
|
||||
} else if config.OAuth2Provider == settings.OAuth2ProviderNextcloud {
|
||||
oauth2Provider, err = nextcloud.NewNextcloudOAuth2Provider(config, redirectUrl)
|
||||
externalUserAuthType = core.USER_EXTERNAL_AUTH_TYPE_OAUTH2_NEXTCLOUD
|
||||
} else if config.OAuth2Provider == settings.OAuth2ProviderGitea {
|
||||
oauth2Provider = NewGiteaOAuth2Provider(config.OAuth2GiteaBaseUrl)
|
||||
oauth2Provider, err = gitea.NewGiteaOAuth2Provider(config, redirectUrl)
|
||||
externalUserAuthType = core.USER_EXTERNAL_AUTH_TYPE_OAUTH2_GITEA
|
||||
} else if config.OAuth2Provider == settings.OAuth2ProviderGithub {
|
||||
oauth2Provider = NewGithubOAuth2Provider()
|
||||
oauth2Provider, err = github.NewGithubOAuth2Provider(config, redirectUrl)
|
||||
externalUserAuthType = core.USER_EXTERNAL_AUTH_TYPE_OAUTH2_GITHUB
|
||||
} else {
|
||||
return errs.ErrInvalidOAuth2Provider
|
||||
}
|
||||
|
||||
Container.oauth2Config = buildOAuth2Config(config, oauth2Provider)
|
||||
Container.oauth2Provider = oauth2Provider
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Container.current = oauth2Provider
|
||||
Container.oauth2HttpClient = utils.NewHttpClient(config.OAuth2RequestTimeout, config.OAuth2Proxy, config.OAuth2SkipTLSVerify, settings.GetUserAgent())
|
||||
Container.externalUserAuthType = externalUserAuthType
|
||||
|
||||
@@ -59,26 +74,29 @@ func InitializeOAuth2Provider(config *settings.Config) error {
|
||||
}
|
||||
|
||||
// GetOAuth2AuthUrl returns the OAuth 2.0 authentication url
|
||||
func GetOAuth2AuthUrl(c core.Context, state string) (string, error) {
|
||||
if Container.oauth2Config == nil {
|
||||
func GetOAuth2AuthUrl(c core.Context, state string, verifier string) (string, error) {
|
||||
if Container.current == nil {
|
||||
return "", errs.ErrOAuth2NotEnabled
|
||||
}
|
||||
|
||||
return Container.oauth2Config.AuthCodeURL(state), nil
|
||||
sha256Hash := sha256.New()
|
||||
sha256Hash.Write([]byte(verifier))
|
||||
challenge := base64.RawURLEncoding.EncodeToString(sha256Hash.Sum(nil))
|
||||
return Container.current.GetOAuth2AuthUrl(wrapOAuth2Context(c, Container.oauth2HttpClient), state, challenge)
|
||||
}
|
||||
|
||||
// GetOAuth2Token exchanges the authorization code for an OAuth 2.0 token
|
||||
func GetOAuth2Token(c core.Context, code string) (*oauth2.Token, error) {
|
||||
if Container.oauth2Config == nil || Container.oauth2HttpClient == nil {
|
||||
func GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error) {
|
||||
if Container.current == nil || Container.oauth2HttpClient == nil {
|
||||
return nil, errs.ErrOAuth2NotEnabled
|
||||
}
|
||||
|
||||
return Container.oauth2Config.Exchange(wrapOAuth2Context(c, Container.oauth2HttpClient), code)
|
||||
return Container.current.GetOAuth2Token(wrapOAuth2Context(c, Container.oauth2HttpClient), code, verifier)
|
||||
}
|
||||
|
||||
// GetOAuth2UserInfo retrieves the OAuth 2.0 user info using the provided OAuth 2.0 token
|
||||
func GetOAuth2UserInfo(c core.Context, token *oauth2.Token) (*OAuth2UserInfo, error) {
|
||||
if Container.oauth2Config == nil || Container.oauth2Provider == nil || Container.oauth2HttpClient == nil {
|
||||
func GetOAuth2UserInfo(c core.Context, token *oauth2.Token) (*data.OAuth2UserInfo, error) {
|
||||
if Container.current == nil || Container.oauth2HttpClient == nil {
|
||||
return nil, errs.ErrOAuth2NotEnabled
|
||||
}
|
||||
|
||||
@@ -86,26 +104,10 @@ func GetOAuth2UserInfo(c core.Context, token *oauth2.Token) (*OAuth2UserInfo, er
|
||||
return nil, errs.ErrInvalidOAuth2Token
|
||||
}
|
||||
|
||||
oauth2Client := oauth2.NewClient(wrapOAuth2Context(c, Container.oauth2HttpClient), oauth2.StaticTokenSource(token))
|
||||
return Container.oauth2Provider.GetUserInfo(c, oauth2Client)
|
||||
return Container.current.GetUserInfo(wrapOAuth2Context(c, Container.oauth2HttpClient), token)
|
||||
}
|
||||
|
||||
// GetExternalUserAuthType returns the external user auth type of the current OAuth 2.0 provider
|
||||
func GetExternalUserAuthType() core.UserExternalAuthType {
|
||||
return Container.externalUserAuthType
|
||||
}
|
||||
|
||||
func buildOAuth2Config(config *settings.Config, oauth2Provider OAuth2Provider) *oauth2.Config {
|
||||
redirectURL := config.RootUrl + "oauth2/callback"
|
||||
|
||||
return &oauth2.Config{
|
||||
ClientID: config.OAuth2ClientID,
|
||||
ClientSecret: config.OAuth2ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: oauth2Provider.GetAuthUrl(),
|
||||
TokenURL: oauth2Provider.GetTokenUrl(),
|
||||
},
|
||||
RedirectURL: redirectURL,
|
||||
Scopes: oauth2Provider.GetScopes(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
)
|
||||
|
||||
// OAuth2Provider defines the structure of OAuth 2.0 provider
|
||||
type OAuth2Provider interface {
|
||||
// GetAuthUrl returns the authentication url of the provider
|
||||
GetAuthUrl() string
|
||||
|
||||
// GetTokenUrl returns the token url of the provider
|
||||
GetTokenUrl() string
|
||||
|
||||
// GetUserInfo returns the user info
|
||||
GetUserInfo(c core.Context, oauth2Client *http.Client) (*OAuth2UserInfo, error)
|
||||
|
||||
// GetScopes returns the scopes required by the provider
|
||||
GetScopes() []string
|
||||
}
|
||||
+41
-15
@@ -1,17 +1,23 @@
|
||||
package oauth2
|
||||
package common
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/log"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
||||
)
|
||||
|
||||
// CommonOAuth2Provider represents common OAuth 2.0 provider
|
||||
type CommonOAuth2Provider struct {
|
||||
OAuth2Provider
|
||||
provider.OAuth2Provider
|
||||
oauth2Config *oauth2.Config
|
||||
dataSource CommonOAuth2DataSource
|
||||
}
|
||||
|
||||
@@ -30,26 +36,21 @@ type CommonOAuth2DataSource interface {
|
||||
GetScopes() []string
|
||||
|
||||
// ParseUserInfo returns the user info by parsing the response body
|
||||
ParseUserInfo(c core.Context, body []byte) (*OAuth2UserInfo, error)
|
||||
ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error)
|
||||
}
|
||||
|
||||
// GetAuthUrl returns the authentication url of the common OAuth 2.0 provider
|
||||
func (p *CommonOAuth2Provider) GetAuthUrl() string {
|
||||
return p.dataSource.GetAuthUrl()
|
||||
// GetOAuth2AuthUrl returns the authentication url of the common OAuth 2.0 provider
|
||||
func (p *CommonOAuth2Provider) GetOAuth2AuthUrl(c core.Context, state string, challenge string) (string, error) {
|
||||
return p.oauth2Config.AuthCodeURL(state), nil
|
||||
}
|
||||
|
||||
// GetTokenUrl returns the token url of the common OAuth 2.0 provider
|
||||
func (p *CommonOAuth2Provider) GetTokenUrl() string {
|
||||
return p.dataSource.GetTokenUrl()
|
||||
}
|
||||
|
||||
// GetScopes returns the scopes required by the common OAuth 2.0 provider
|
||||
func (p *CommonOAuth2Provider) GetScopes() []string {
|
||||
return p.dataSource.GetScopes()
|
||||
// GetOAuth2Token returns the OAuth 2.0 token of the common OAuth 2.0 provider
|
||||
func (p *CommonOAuth2Provider) GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error) {
|
||||
return p.oauth2Config.Exchange(c, code)
|
||||
}
|
||||
|
||||
// GetUserInfo returns the user info by the common OAuth 2.0 provider
|
||||
func (p *CommonOAuth2Provider) GetUserInfo(c core.Context, oauth2Client *http.Client) (*OAuth2UserInfo, error) {
|
||||
func (p *CommonOAuth2Provider) GetUserInfo(c core.Context, oauth2Token *oauth2.Token) (*data.OAuth2UserInfo, error) {
|
||||
req, err := p.dataSource.GetUserInfoRequest()
|
||||
|
||||
if err != nil {
|
||||
@@ -57,6 +58,7 @@ func (p *CommonOAuth2Provider) GetUserInfo(c core.Context, oauth2Client *http.Cl
|
||||
return nil, errs.ErrFailedToRequestRemoteApi
|
||||
}
|
||||
|
||||
oauth2Client := oauth2.NewClient(c, oauth2.StaticTokenSource(oauth2Token))
|
||||
resp, err := oauth2Client.Do(req)
|
||||
|
||||
if err != nil {
|
||||
@@ -76,3 +78,27 @@ func (p *CommonOAuth2Provider) GetUserInfo(c core.Context, oauth2Client *http.Cl
|
||||
|
||||
return p.dataSource.ParseUserInfo(c, body)
|
||||
}
|
||||
|
||||
// GetDataSource returns the data source of the common OAuth 2.0 provider
|
||||
func (p *CommonOAuth2Provider) GetDataSource() CommonOAuth2DataSource {
|
||||
return p.dataSource
|
||||
}
|
||||
|
||||
// NewCommonOAuth2Provider returns a new common OAuth 2.0 provider
|
||||
func NewCommonOAuth2Provider(config *settings.Config, redirectUrl string, dataSource CommonOAuth2DataSource) *CommonOAuth2Provider {
|
||||
oauth2Config := &oauth2.Config{
|
||||
ClientID: config.OAuth2ClientID,
|
||||
ClientSecret: config.OAuth2ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: dataSource.GetAuthUrl(),
|
||||
TokenURL: dataSource.GetTokenUrl(),
|
||||
},
|
||||
RedirectURL: redirectUrl,
|
||||
Scopes: dataSource.GetScopes(),
|
||||
}
|
||||
|
||||
return &CommonOAuth2Provider{
|
||||
oauth2Config: oauth2Config,
|
||||
dataSource: dataSource,
|
||||
}
|
||||
}
|
||||
+17
-9
@@ -1,12 +1,16 @@
|
||||
package oauth2
|
||||
package gitea
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/log"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
||||
)
|
||||
|
||||
type giteaUserInfoResponse struct {
|
||||
@@ -17,7 +21,7 @@ type giteaUserInfoResponse struct {
|
||||
|
||||
// GiteaOAuth2DataSource represents Gitea OAuth 2.0 data source
|
||||
type GiteaOAuth2DataSource struct {
|
||||
CommonOAuth2DataSource
|
||||
common.CommonOAuth2DataSource
|
||||
baseUrl string
|
||||
}
|
||||
|
||||
@@ -52,7 +56,7 @@ func (s *GiteaOAuth2DataSource) GetScopes() []string {
|
||||
}
|
||||
|
||||
// ParseUserInfo returns the user info by parsing the response body
|
||||
func (s *GiteaOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OAuth2UserInfo, error) {
|
||||
func (s *GiteaOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error) {
|
||||
userInfoResp := &giteaUserInfoResponse{}
|
||||
err := json.Unmarshal(body, &userInfoResp)
|
||||
|
||||
@@ -66,7 +70,7 @@ func (s *GiteaOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OAu
|
||||
return nil, errs.ErrCannotRetrieveUserInfo
|
||||
}
|
||||
|
||||
return &OAuth2UserInfo{
|
||||
return &data.OAuth2UserInfo{
|
||||
UserName: userInfoResp.Login,
|
||||
Email: userInfoResp.Email,
|
||||
NickName: userInfoResp.FullName,
|
||||
@@ -74,14 +78,18 @@ func (s *GiteaOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OAu
|
||||
}
|
||||
|
||||
// NewGiteaOAuth2Provider creates a new Gitea OAuth 2.0 provider instance
|
||||
func NewGiteaOAuth2Provider(baseUrl string) OAuth2Provider {
|
||||
func NewGiteaOAuth2Provider(config *settings.Config, redirectUrl string) (provider.OAuth2Provider, error) {
|
||||
if len(config.OAuth2GiteaBaseUrl) < 1 {
|
||||
return nil, errs.ErrInvalidOAuth2Config
|
||||
}
|
||||
|
||||
baseUrl := config.OAuth2GiteaBaseUrl
|
||||
|
||||
if baseUrl[len(baseUrl)-1] != '/' {
|
||||
baseUrl += "/"
|
||||
}
|
||||
|
||||
return &CommonOAuth2Provider{
|
||||
dataSource: &GiteaOAuth2DataSource{
|
||||
return common.NewCommonOAuth2Provider(config, redirectUrl, &GiteaOAuth2DataSource{
|
||||
baseUrl: baseUrl,
|
||||
},
|
||||
}
|
||||
}), nil
|
||||
}
|
||||
+18
-7
@@ -1,22 +1,33 @@
|
||||
package oauth2
|
||||
package gitea
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
||||
)
|
||||
|
||||
func TestNewGiteaOAuth2Provider(t *testing.T) {
|
||||
datasource := NewGiteaOAuth2Provider("https://example.com/")
|
||||
assert.Equal(t, "https://example.com/login/oauth/authorize", datasource.GetAuthUrl())
|
||||
assert.Equal(t, "https://example.com/login/oauth/access_token", datasource.GetTokenUrl())
|
||||
provider, err := NewGiteaOAuth2Provider(&settings.Config{
|
||||
OAuth2GiteaBaseUrl: "https://example.com/",
|
||||
}, "")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "https://example.com/login/oauth/authorize", provider.(*common.CommonOAuth2Provider).GetDataSource().GetAuthUrl())
|
||||
assert.Equal(t, "https://example.com/login/oauth/access_token", provider.(*common.CommonOAuth2Provider).GetDataSource().GetTokenUrl())
|
||||
|
||||
datasource = NewGiteaOAuth2Provider("https://example.com")
|
||||
assert.Equal(t, "https://example.com/login/oauth/authorize", datasource.GetAuthUrl())
|
||||
assert.Equal(t, "https://example.com/login/oauth/access_token", datasource.GetTokenUrl())
|
||||
provider, err = NewGiteaOAuth2Provider(&settings.Config{
|
||||
OAuth2GiteaBaseUrl: "https://example.com",
|
||||
}, "")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "https://example.com/login/oauth/authorize", provider.(*common.CommonOAuth2Provider).GetDataSource().GetAuthUrl())
|
||||
assert.Equal(t, "https://example.com/login/oauth/access_token", provider.(*common.CommonOAuth2Provider).GetDataSource().GetTokenUrl())
|
||||
|
||||
provider, err = NewGiteaOAuth2Provider(&settings.Config{}, "")
|
||||
assert.Equal(t, errs.ErrInvalidOAuth2Config, err)
|
||||
}
|
||||
|
||||
func TestGiteaOAuth2Datasource_GetUserInfoRequest(t *testing.T) {
|
||||
+10
-8
@@ -1,12 +1,16 @@
|
||||
package oauth2
|
||||
package github
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/log"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
||||
)
|
||||
|
||||
type githubUserProfileResponse struct {
|
||||
@@ -17,7 +21,7 @@ type githubUserProfileResponse struct {
|
||||
|
||||
// GithubOAuth2DataSource represents Github OAuth 2.0 data source
|
||||
type GithubOAuth2DataSource struct {
|
||||
CommonOAuth2DataSource
|
||||
common.CommonOAuth2DataSource
|
||||
}
|
||||
|
||||
// GetAuthUrl returns the authentication url of the Github data source
|
||||
@@ -51,7 +55,7 @@ func (p *GithubOAuth2DataSource) GetScopes() []string {
|
||||
}
|
||||
|
||||
// ParseUserInfo returns the user info by parsing the response body
|
||||
func (p *GithubOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OAuth2UserInfo, error) {
|
||||
func (p *GithubOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error) {
|
||||
userInfoResp := &githubUserProfileResponse{}
|
||||
err := json.Unmarshal(body, &userInfoResp)
|
||||
|
||||
@@ -65,7 +69,7 @@ func (p *GithubOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OA
|
||||
return nil, errs.ErrCannotRetrieveUserInfo
|
||||
}
|
||||
|
||||
return &OAuth2UserInfo{
|
||||
return &data.OAuth2UserInfo{
|
||||
UserName: userInfoResp.Login,
|
||||
Email: userInfoResp.Email,
|
||||
NickName: userInfoResp.Name,
|
||||
@@ -73,8 +77,6 @@ func (p *GithubOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OA
|
||||
}
|
||||
|
||||
// NewGithubOAuth2Provider creates a new Github OAuth 2.0 provider instance
|
||||
func NewGithubOAuth2Provider() OAuth2Provider {
|
||||
return &CommonOAuth2Provider{
|
||||
dataSource: &GithubOAuth2DataSource{},
|
||||
}
|
||||
func NewGithubOAuth2Provider(config *settings.Config, redirectUrl string) (provider.OAuth2Provider, error) {
|
||||
return common.NewCommonOAuth2Provider(config, redirectUrl, &GithubOAuth2DataSource{}), nil
|
||||
}
|
||||
+1
-1
@@ -1,4 +1,4 @@
|
||||
package oauth2
|
||||
package github
|
||||
|
||||
import (
|
||||
"testing"
|
||||
+17
-9
@@ -1,12 +1,16 @@
|
||||
package oauth2
|
||||
package nextcloud
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/log"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
||||
)
|
||||
|
||||
type nextcloudUserInfoResponse struct {
|
||||
@@ -25,7 +29,7 @@ type nextcloudUserInfoResponse struct {
|
||||
|
||||
// NextcloudOAuth2DataSource represents Nextcloud OAuth 2.0 data source
|
||||
type NextcloudOAuth2DataSource struct {
|
||||
CommonOAuth2DataSource
|
||||
common.CommonOAuth2DataSource
|
||||
baseUrl string
|
||||
}
|
||||
|
||||
@@ -61,7 +65,7 @@ func (s *NextcloudOAuth2DataSource) GetScopes() []string {
|
||||
}
|
||||
|
||||
// ParseUserInfo returns the user info by parsing the response body
|
||||
func (s *NextcloudOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OAuth2UserInfo, error) {
|
||||
func (s *NextcloudOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error) {
|
||||
userInfoResp := &nextcloudUserInfoResponse{}
|
||||
err := json.Unmarshal(body, &userInfoResp)
|
||||
|
||||
@@ -85,7 +89,7 @@ func (s *NextcloudOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (
|
||||
return nil, errs.ErrCannotRetrieveUserInfo
|
||||
}
|
||||
|
||||
return &OAuth2UserInfo{
|
||||
return &data.OAuth2UserInfo{
|
||||
UserName: userInfoResp.OCS.Data.ID,
|
||||
Email: userInfoResp.OCS.Data.Email,
|
||||
NickName: userInfoResp.OCS.Data.DisplayName,
|
||||
@@ -93,14 +97,18 @@ func (s *NextcloudOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (
|
||||
}
|
||||
|
||||
// NewNextcloudOAuth2Provider creates a new Nextcloud OAuth 2.0 provider instance
|
||||
func NewNextcloudOAuth2Provider(baseUrl string) OAuth2Provider {
|
||||
func NewNextcloudOAuth2Provider(config *settings.Config, redirectUrl string) (provider.OAuth2Provider, error) {
|
||||
if len(config.OAuth2NextcloudBaseUrl) < 1 {
|
||||
return nil, errs.ErrInvalidOAuth2Config
|
||||
}
|
||||
|
||||
baseUrl := config.OAuth2NextcloudBaseUrl
|
||||
|
||||
if baseUrl[len(baseUrl)-1] != '/' {
|
||||
baseUrl += "/"
|
||||
}
|
||||
|
||||
return &CommonOAuth2Provider{
|
||||
dataSource: &NextcloudOAuth2DataSource{
|
||||
return common.NewCommonOAuth2Provider(config, redirectUrl, &NextcloudOAuth2DataSource{
|
||||
baseUrl: baseUrl,
|
||||
},
|
||||
}
|
||||
}), nil
|
||||
}
|
||||
+18
-7
@@ -1,22 +1,33 @@
|
||||
package oauth2
|
||||
package nextcloud
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
||||
)
|
||||
|
||||
func TestNewNextcloudOAuth2Provider(t *testing.T) {
|
||||
datasource := NewNextcloudOAuth2Provider("https://example.com/")
|
||||
assert.Equal(t, "https://example.com/apps/oauth2/authorize", datasource.GetAuthUrl())
|
||||
assert.Equal(t, "https://example.com/apps/oauth2/api/v1/token", datasource.GetTokenUrl())
|
||||
provider, err := NewNextcloudOAuth2Provider(&settings.Config{
|
||||
OAuth2NextcloudBaseUrl: "https://example.com/",
|
||||
}, "")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "https://example.com/apps/oauth2/authorize", provider.(*common.CommonOAuth2Provider).GetDataSource().GetAuthUrl())
|
||||
assert.Equal(t, "https://example.com/apps/oauth2/api/v1/token", provider.(*common.CommonOAuth2Provider).GetDataSource().GetTokenUrl())
|
||||
|
||||
datasource = NewNextcloudOAuth2Provider("https://example.com/index.php")
|
||||
assert.Equal(t, "https://example.com/index.php/apps/oauth2/authorize", datasource.GetAuthUrl())
|
||||
assert.Equal(t, "https://example.com/index.php/apps/oauth2/api/v1/token", datasource.GetTokenUrl())
|
||||
provider, err = NewNextcloudOAuth2Provider(&settings.Config{
|
||||
OAuth2NextcloudBaseUrl: "https://example.com/index.php",
|
||||
}, "")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "https://example.com/index.php/apps/oauth2/authorize", provider.(*common.CommonOAuth2Provider).GetDataSource().GetAuthUrl())
|
||||
assert.Equal(t, "https://example.com/index.php/apps/oauth2/api/v1/token", provider.(*common.CommonOAuth2Provider).GetDataSource().GetTokenUrl())
|
||||
|
||||
provider, err = NewNextcloudOAuth2Provider(&settings.Config{}, "")
|
||||
assert.Equal(t, errs.ErrInvalidOAuth2Config, err)
|
||||
}
|
||||
|
||||
func TestNextcloudOAuth2Datasource_GetUserInfoRequest(t *testing.T) {
|
||||
@@ -0,0 +1,20 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
)
|
||||
|
||||
// OAuth2Provider defines the structure of OAuth 2.0 provider
|
||||
type OAuth2Provider interface {
|
||||
// GetOAuth2AuthUrl returns the authentication url of the provider
|
||||
GetOAuth2AuthUrl(c core.Context, state string, challenge string) (string, error)
|
||||
|
||||
// GetOAuth2Token returns the OAuth 2.0 token of the provider
|
||||
GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error)
|
||||
|
||||
// GetUserInfo returns the user info
|
||||
GetUserInfo(c core.Context, oauth2Token *oauth2.Token) (*data.OAuth2UserInfo, error)
|
||||
}
|
||||
@@ -0,0 +1,173 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/log"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
||||
)
|
||||
|
||||
// OIDCClaims represents OIDC claims
|
||||
type OIDCClaims struct {
|
||||
PreferredUserName string `json:"preferred_username"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
// OIDCProvider represents OIDC provider
|
||||
type OIDCProvider struct {
|
||||
provider.OAuth2Provider
|
||||
oidcBaseUrl string
|
||||
redirectUrl string
|
||||
oauth2ClientID string
|
||||
oauth2ClientSecret string
|
||||
oauth2Config *oauth2.Config
|
||||
oidcProvider *oidc.Provider
|
||||
oidcVerifier *oidc.IDTokenVerifier
|
||||
}
|
||||
|
||||
// GetOAuth2AuthUrl returns the authentication url of the OIDC provider
|
||||
func (p *OIDCProvider) GetOAuth2AuthUrl(c core.Context, state string, challenge string) (string, error) {
|
||||
oauth2Config, err := p.getOAuth2Config(c)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return oauth2Config.AuthCodeURL(state,
|
||||
oauth2.SetAuthURLParam("code_challenge", challenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256")), nil
|
||||
}
|
||||
|
||||
// GetOAuth2Token returns the OAuth 2.0 token of the OIDC provider
|
||||
func (p *OIDCProvider) GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error) {
|
||||
oauth2Config, err := p.getOAuth2Config(c)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return oauth2Config.Exchange(c, code, oauth2.SetAuthURLParam("code_verifier", verifier))
|
||||
}
|
||||
|
||||
// GetUserInfo returns the user info by the OIDC provider
|
||||
func (p *OIDCProvider) GetUserInfo(c core.Context, oauth2Token *oauth2.Token) (*data.OAuth2UserInfo, error) {
|
||||
_, err := p.getOAuth2Config(c)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
|
||||
if !ok {
|
||||
log.Errorf(c, "[oidc_provider.GetUserInfo] missing \"id_token\" field in oauth 2.0 token")
|
||||
return nil, errs.ErrInvalidOAuth2Token
|
||||
}
|
||||
|
||||
idToken, err := p.oidcVerifier.Verify(c, rawIDToken)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[oidc_provider.GetUserInfo] failed to verify \"id_token\" field in oauth 2.0 token, because %s", err.Error())
|
||||
return nil, errs.ErrInvalidOAuth2Token
|
||||
}
|
||||
|
||||
var claims OIDCClaims
|
||||
err = idToken.Claims(&claims)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[oidc_provider.GetUserInfo] failed to parse claims in oauth 2.0 token, because %s", err.Error())
|
||||
return nil, errs.ErrInvalidOAuth2Token
|
||||
}
|
||||
|
||||
userName := claims.PreferredUserName
|
||||
email := claims.Email
|
||||
nickName := claims.Name
|
||||
|
||||
if userName == "" || email == "" || nickName == "" {
|
||||
userInfo, err := p.oidcProvider.UserInfo(c, oauth2.StaticTokenSource(oauth2Token))
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[oidc_provider.GetUserInfo] failed to get user info, because %s", err.Error())
|
||||
return nil, errs.ErrCannotRetrieveUserInfo
|
||||
}
|
||||
|
||||
err = userInfo.Claims(&claims)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[oidc_provider.GetUserInfo] failed to parse user info, because %s", err.Error())
|
||||
return nil, errs.ErrCannotRetrieveUserInfo
|
||||
}
|
||||
|
||||
if userName == "" {
|
||||
userName = claims.PreferredUserName
|
||||
}
|
||||
|
||||
if email == "" {
|
||||
email = claims.Email
|
||||
}
|
||||
|
||||
if nickName == "" {
|
||||
nickName = claims.Name
|
||||
}
|
||||
}
|
||||
|
||||
return &data.OAuth2UserInfo{
|
||||
UserName: userName,
|
||||
Email: email,
|
||||
NickName: nickName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) getOAuth2Config(c core.Context) (*oauth2.Config, error) {
|
||||
if p.oauth2Config != nil {
|
||||
return p.oauth2Config, nil
|
||||
}
|
||||
|
||||
oidcProvider, err := oidc.NewProvider(c, p.oidcBaseUrl)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[oidc_provider.getOAuth2Config] failed to create oidc provider, because %s", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oidcVerifier := oidcProvider.Verifier(&oidc.Config{ClientID: p.oauth2ClientID})
|
||||
|
||||
oauth2Config := &oauth2.Config{
|
||||
ClientID: p.oauth2ClientID,
|
||||
ClientSecret: p.oauth2ClientSecret,
|
||||
Endpoint: oidcProvider.Endpoint(),
|
||||
RedirectURL: p.redirectUrl,
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
|
||||
p.oauth2Config = oauth2Config
|
||||
p.oidcProvider = oidcProvider
|
||||
p.oidcVerifier = oidcVerifier
|
||||
return oauth2Config, nil
|
||||
}
|
||||
|
||||
// NewOIDCProvider returns a new OIDC provider
|
||||
func NewOIDCProvider(config *settings.Config, redirectUrl string) (*OIDCProvider, error) {
|
||||
if len(config.OAuth2OIDCProviderBaseUrl) < 1 {
|
||||
return nil, errs.ErrInvalidOAuth2Config
|
||||
}
|
||||
|
||||
baseUrl := strings.TrimSuffix(config.OAuth2OIDCProviderBaseUrl, "/")
|
||||
|
||||
return &OIDCProvider{
|
||||
oidcBaseUrl: baseUrl,
|
||||
redirectUrl: redirectUrl,
|
||||
oauth2ClientID: config.OAuth2ClientID,
|
||||
oauth2ClientSecret: config.OAuth2ClientSecret,
|
||||
oauth2Config: nil,
|
||||
}, nil
|
||||
}
|
||||
@@ -5,6 +5,7 @@ type UserExternalAuthType string
|
||||
|
||||
// User External Auth Type
|
||||
const (
|
||||
USER_EXTERNAL_AUTH_TYPE_OAUTH2_OIDC UserExternalAuthType = "oidc"
|
||||
USER_EXTERNAL_AUTH_TYPE_OAUTH2_NEXTCLOUD UserExternalAuthType = "nextcloud"
|
||||
USER_EXTERNAL_AUTH_TYPE_OAUTH2_GITEA UserExternalAuthType = "gitea"
|
||||
USER_EXTERNAL_AUTH_TYPE_OAUTH2_GITHUB UserExternalAuthType = "github"
|
||||
@@ -13,7 +14,8 @@ const (
|
||||
// IsValid checks if the UserExternalAuthType is valid
|
||||
func (t UserExternalAuthType) IsValid() bool {
|
||||
switch t {
|
||||
case USER_EXTERNAL_AUTH_TYPE_OAUTH2_NEXTCLOUD,
|
||||
case USER_EXTERNAL_AUTH_TYPE_OAUTH2_OIDC,
|
||||
USER_EXTERNAL_AUTH_TYPE_OAUTH2_NEXTCLOUD,
|
||||
USER_EXTERNAL_AUTH_TYPE_OAUTH2_GITEA,
|
||||
USER_EXTERNAL_AUTH_TYPE_OAUTH2_GITHUB:
|
||||
return true
|
||||
|
||||
@@ -10,6 +10,8 @@ type OAuth2LoginRequest struct {
|
||||
type OAuth2CallbackRequest struct {
|
||||
State string `form:"state"`
|
||||
Code string `form:"code"`
|
||||
Error string `form:"error"`
|
||||
ErrorDescription string `form:"error_description"`
|
||||
}
|
||||
|
||||
// OAuth2CallbackLoginRequest represents all parameters of OAuth 2.0 callback login request
|
||||
|
||||
@@ -93,6 +93,7 @@ const (
|
||||
|
||||
// OAuth 2.0 provider types
|
||||
const (
|
||||
OAuth2ProviderOIDC string = "oidc"
|
||||
OAuth2ProviderNextcloud string = "nextcloud"
|
||||
OAuth2ProviderGitea string = "gitea"
|
||||
OAuth2ProviderGithub string = "github"
|
||||
@@ -375,6 +376,8 @@ type Config struct {
|
||||
OAuth2RequestTimeout uint32
|
||||
OAuth2Proxy string
|
||||
OAuth2SkipTLSVerify bool
|
||||
OAuth2OIDCProviderBaseUrl string
|
||||
OAuth2OIDCCustomDisplayNameConfig MultiLanguageContentConfig
|
||||
OAuth2NextcloudBaseUrl string
|
||||
OAuth2GiteaBaseUrl string
|
||||
|
||||
@@ -1003,6 +1006,8 @@ func loadAuthConfiguration(config *Config, configFile *ini.File, sectionName str
|
||||
|
||||
if oauth2Provider == "" {
|
||||
config.OAuth2Provider = ""
|
||||
} else if oauth2Provider == OAuth2ProviderOIDC {
|
||||
config.OAuth2Provider = OAuth2ProviderOIDC
|
||||
} else if oauth2Provider == OAuth2ProviderNextcloud {
|
||||
config.OAuth2Provider = OAuth2ProviderNextcloud
|
||||
} else if oauth2Provider == OAuth2ProviderGitea {
|
||||
@@ -1025,6 +1030,8 @@ func loadAuthConfiguration(config *Config, configFile *ini.File, sectionName str
|
||||
config.OAuth2RequestTimeout = getConfigItemUint32Value(configFile, sectionName, "oauth2_request_timeout", defaultOAuth2RequestTimeout)
|
||||
config.OAuth2SkipTLSVerify = getConfigItemBoolValue(configFile, sectionName, "oauth2_skip_tls_verify", false)
|
||||
|
||||
config.OAuth2OIDCProviderBaseUrl = getConfigItemStringValue(configFile, sectionName, "oidc_provider_base_url")
|
||||
config.OAuth2OIDCCustomDisplayNameConfig = getMultiLanguageContentConfig(configFile, sectionName, "enable_oidc_display_name", "oidc_custom_display_name")
|
||||
config.OAuth2NextcloudBaseUrl = getConfigItemStringValue(configFile, sectionName, "nextcloud_base_url")
|
||||
config.OAuth2GiteaBaseUrl = getConfigItemStringValue(configFile, sectionName, "gitea_base_url")
|
||||
|
||||
|
||||
@@ -9,3 +9,14 @@ export interface ErrorResponse {
|
||||
readonly errorMessage: string;
|
||||
readonly path: string;
|
||||
}
|
||||
|
||||
export function buildErrorResponse(errorCode: number, errorMessage: string): ErrorResponse {
|
||||
const errorResponse: ErrorResponse = {
|
||||
success: false,
|
||||
errorCode: errorCode,
|
||||
errorMessage: errorMessage,
|
||||
path: ''
|
||||
};
|
||||
|
||||
return errorResponse;
|
||||
}
|
||||
|
||||
@@ -24,9 +24,10 @@
|
||||
<v-card variant="flat" class="w-100 mt-0 px-4 pt-12" max-width="500">
|
||||
<v-card-text>
|
||||
<h4 class="text-h4 mb-2">{{ oauth2LoginDisplayName }}</h4>
|
||||
<p class="mb-0" v-if="!error && platform && token && !userName">{{ tt('Logging in...') }}</p>
|
||||
<p class="mb-0" v-else-if="!error && userName">{{ tt('format.misc.oauth2bindTip', { providerName: oauth2ProviderDisplayName, userName: userName }) }}</p>
|
||||
<p class="mb-0" v-if="!error && !errorMessage && platform && token && !userName">{{ tt('Logging in...') }}</p>
|
||||
<p class="mb-0" v-else-if="!error && !errorMessage && userName">{{ tt('format.misc.oauth2bindTip', { providerName: oauth2ProviderDisplayName, userName: userName }) }}</p>
|
||||
<p class="mb-0" v-else-if="error">{{ te({ error }) }}</p>
|
||||
<p class="mb-0" v-else-if="errorMessage">{{ errorMessage }}</p>
|
||||
<p class="mb-0" v-else>{{ tt('An error occurred') }}</p>
|
||||
</v-card-text>
|
||||
|
||||
@@ -106,7 +107,7 @@ import { useLoginPageBase } from '@/views/base/LoginPageBase.ts';
|
||||
import { useRootStore } from '@/stores/index.ts';
|
||||
|
||||
import { ThemeType } from '@/core/theme.ts';
|
||||
import { type ErrorResponse } from '@/core/api.ts';
|
||||
import { type ErrorResponse, buildErrorResponse } from '@/core/api.ts';
|
||||
import { APPLICATION_LOGO_PATH } from '@/consts/asset.ts';
|
||||
import { KnownErrorCode } from '@/consts/api.ts';
|
||||
|
||||
@@ -158,14 +159,7 @@ const oauth2LoginDisplayName = computed<string>(() => getLocalizedOAuth2LoginTex
|
||||
|
||||
const error = computed<ErrorResponse | undefined>(() => {
|
||||
if (props.errorCode && props.errorMessage) {
|
||||
const errorResponse: ErrorResponse = {
|
||||
success: false,
|
||||
errorCode: parseInt(props.errorCode),
|
||||
errorMessage: props.errorMessage,
|
||||
path: ''
|
||||
};
|
||||
|
||||
return errorResponse;
|
||||
return buildErrorResponse(parseInt(props.errorCode), props.errorMessage);
|
||||
} else {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
@@ -105,6 +105,12 @@
|
||||
"url": "https://golang.org/x/oauth2",
|
||||
"licenseUrl": "https://cs.opensource.google/go/x/oauth2/+/refs/tags/v0.31.0:LICENSE"
|
||||
},
|
||||
{
|
||||
"name": "go-oidc",
|
||||
"copyright": "Copyright 2014 CoreOS, Inc",
|
||||
"url": "https://github.com/coreos/go-oidc",
|
||||
"licenseUrl": "https://github.com/coreos/go-oidc/blob/v3.16.0/LICENSE"
|
||||
},
|
||||
{
|
||||
"name": "Gomail",
|
||||
"copyright": "Copyright (c) 2014 Alexandre Cesaro",
|
||||
|
||||
Reference in New Issue
Block a user