support OIDC authentication (#242)

This commit is contained in:
MaysWind
2025-10-24 01:44:55 +08:00
parent d3ab2b94b7
commit 85b05f9e7e
24 changed files with 490 additions and 202 deletions
+8 -5
View File
@@ -114,6 +114,11 @@ func (a *ApiUsingDuplicateChecker) GetSubmissionRemark(checkerType duplicatechec
return a.container.GetSubmissionRemark(checkerType, uid, identification)
}
// SetSubmissionRemarkWithCustomExpiration saves the identification and remark by the current duplicate checker with custom expiration time
func (a *ApiUsingDuplicateChecker) SetSubmissionRemarkWithCustomExpiration(checkerType duplicatechecker.DuplicateCheckerType, uid int64, identification string, remark string, expiration time.Duration) {
a.container.SetSubmissionRemarkWithCustomExpiration(checkerType, uid, identification, remark, expiration)
}
// SetSubmissionRemarkIfEnable saves the identification and remark by the current duplicate checker if the duplicate submission check is enabled
func (a *ApiUsingDuplicateChecker) SetSubmissionRemarkIfEnable(checkerType duplicatechecker.DuplicateCheckerType, uid int64, identification string, remark string) {
if a.CurrentConfig().EnableDuplicateSubmissionsCheck {
@@ -121,11 +126,9 @@ func (a *ApiUsingDuplicateChecker) SetSubmissionRemarkIfEnable(checkerType dupli
}
}
// SetSubmissionRemarkWithCustomExpirationIfEnable saves the identification and remark by the current duplicate checker with custom expiration time if the duplicate submission check is enabled
func (a *ApiUsingDuplicateChecker) SetSubmissionRemarkWithCustomExpirationIfEnable(checkerType duplicatechecker.DuplicateCheckerType, uid int64, identification string, remark string, expiration time.Duration) {
if a.CurrentConfig().EnableDuplicateSubmissionsCheck {
a.container.SetSubmissionRemarkWithCustomExpiration(checkerType, uid, identification, remark, expiration)
}
// RemoveSubmissionRemark removes the identification and remark by the current duplicate checker
func (a *ApiUsingDuplicateChecker) RemoveSubmissionRemark(checkerType duplicatechecker.DuplicateCheckerType, uid int64, identification string) {
a.container.RemoveSubmissionRemark(checkerType, uid, identification)
}
// RemoveSubmissionRemarkIfEnable removes the identification and remark by the current duplicate checker if the duplicate submission check is enabled
+49 -46
View File
@@ -23,6 +23,7 @@ import (
const oauth2CallbackPageUrlSuccessFormat = "%sdesktop/#/oauth2_callback?platform=%s&provider=%s&token=%s"
const oauth2CallbackPageUrlNeedVerifyFormat = "%sdesktop/#/oauth2_callback?platform=%s&provider=%s&userName=%s&token=%s"
const oauth2CallbackPageUrlFailedFormat = "%sdesktop/#/oauth2_callback?errorCode=%d&errorMessage=%s"
const oauth2CallbackPageUrlErrorMessageFormat = "%sdesktop/#/oauth2_callback?errorMessage=%s"
// OAuth2AuthenticationApi represents OAuth 2.0 authorization api
type OAuth2AuthenticationApi struct {
@@ -65,37 +66,31 @@ func (a *OAuth2AuthenticationApi) LoginHandler(c *core.WebContext) (string, *err
return "", errs.ErrInvalidOAuth2LoginRequest
}
state := fmt.Sprintf("%s|%s", oauth2LoginReq.Platform, oauth2LoginReq.ClientSessionId)
remark := ""
found, remark := a.GetSubmissionRemark(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, oauth2LoginReq.ClientSessionId)
if a.CurrentConfig().EnableDuplicateSubmissionsCheck {
found := false
found, remark = a.GetSubmissionRemark(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, oauth2LoginReq.ClientSessionId)
if found {
log.Errorf(c, "[oauth2_authentications.LoginHandler] another oauth 2.0 state \"%s\" has been processing for client session id \"%s\"", remark, oauth2LoginReq.ClientSessionId)
return "", errs.ErrRepeatedRequest
}
randomString, err := utils.GetRandomNumberOrLowercaseLetter(32)
if err != nil {
log.Errorf(c, "[oauth2_authentications.LoginHandler] failed to generate random string for oauth 2.0 state, because %s", err.Error())
return "", errs.ErrSystemError
}
remark = fmt.Sprintf("%s|%s|%s", oauth2LoginReq.Platform, oauth2LoginReq.ClientSessionId, randomString)
state = fmt.Sprintf("%s|%s|%s", oauth2LoginReq.Platform, oauth2LoginReq.ClientSessionId, utils.MD5EncodeToString([]byte(remark)))
if found {
log.Errorf(c, "[oauth2_authentications.LoginHandler] another oauth 2.0 state \"%s\" has been processing for client session id \"%s\"", remark, oauth2LoginReq.ClientSessionId)
return "", errs.ErrRepeatedRequest
}
redirectUrl, err := oauth2.GetOAuth2AuthUrl(c, state)
verifier, err := utils.GetRandomNumberOrLowercaseLetter(64)
if err != nil {
log.Errorf(c, "[oauth2_authentications.LoginHandler] failed to generate random string for oauth 2.0 state, because %s", err.Error())
return "", errs.ErrSystemError
}
remark = fmt.Sprintf("%s|%s|%s", oauth2LoginReq.Platform, oauth2LoginReq.ClientSessionId, verifier)
state := fmt.Sprintf("%s|%s|%s", oauth2LoginReq.Platform, oauth2LoginReq.ClientSessionId, utils.MD5EncodeToString([]byte(remark)))
redirectUrl, err := oauth2.GetOAuth2AuthUrl(c, state, verifier)
if err != nil {
log.Errorf(c, "[oauth2_authentications.LoginHandler] failed to get oauth 2.0 auth url, because %s", err.Error())
return "", errs.Or(err, errs.ErrSystemError)
}
a.SetSubmissionRemarkWithCustomExpirationIfEnable(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, oauth2LoginReq.ClientSessionId, remark, a.CurrentConfig().OAuth2StateExpiredTimeDuration)
a.SetSubmissionRemarkWithCustomExpiration(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, oauth2LoginReq.ClientSessionId, remark, a.CurrentConfig().OAuth2StateExpiredTimeDuration)
return redirectUrl, nil
}
@@ -115,6 +110,11 @@ func (a *OAuth2AuthenticationApi) CallbackHandler(c *core.WebContext) (string, *
}
if oauth2CallbackReq.Code == "" {
if oauth2CallbackReq.ErrorDescription != "" {
log.Errorf(c, "[oauth2_authentications.CallbackHandler] oauth 2.0 provider returned error: %s, description: %s", oauth2CallbackReq.Error, oauth2CallbackReq.ErrorDescription)
return a.redirectToErrorMessageCallbackPage(c, oauth2CallbackReq.ErrorDescription)
}
return a.redirectToFailedCallbackPage(c, errs.ErrMissingOAuth2Code)
}
@@ -134,33 +134,32 @@ func (a *OAuth2AuthenticationApi) CallbackHandler(c *core.WebContext) (string, *
return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2LoginRequest)
}
if a.CurrentConfig().EnableDuplicateSubmissionsCheck {
found, remark := a.GetSubmissionRemark(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, clientSessionId)
found, remark := a.GetSubmissionRemark(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, clientSessionId)
if !found {
log.Errorf(c, "[oauth2_authentications.CallbackHandler] cannot find oauth 2.0 state in duplicate checker for client session id \"%s\"", clientSessionId)
return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2Callback)
}
remarkParts := strings.Split(remark, "|")
if len(remarkParts) != 3 || remarkParts[0] != platform || remarkParts[1] != clientSessionId {
log.Errorf(c, "[oauth2_authentications.CallbackHandler] invalid oauth 2.0 state \"%s\" in duplicate checker for client session id \"%s\"", remark, clientSessionId)
return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2State)
}
expectedState := fmt.Sprintf("%s|%s|%s", platform, clientSessionId, remarkParts[2])
expectedState = fmt.Sprintf("%s|%s|%s", platform, clientSessionId, utils.MD5EncodeToString([]byte(expectedState)))
if oauth2CallbackReq.State != expectedState {
log.Errorf(c, "[oauth2_authentications.CallbackHandler] mismatched random string in oauth 2.0 state, expected \"%s\", got \"%s\"", expectedState, oauth2CallbackReq.State)
return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2State)
}
a.RemoveSubmissionRemarkIfEnable(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, clientSessionId)
if !found {
log.Errorf(c, "[oauth2_authentications.CallbackHandler] cannot find oauth 2.0 state in duplicate checker for client session id \"%s\"", clientSessionId)
return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2Callback)
}
oauth2Token, err := oauth2.GetOAuth2Token(c, oauth2CallbackReq.Code)
remarkParts := strings.Split(remark, "|")
if len(remarkParts) != 3 || remarkParts[0] != platform || remarkParts[1] != clientSessionId {
log.Errorf(c, "[oauth2_authentications.CallbackHandler] invalid oauth 2.0 state \"%s\" in duplicate checker for client session id \"%s\"", remark, clientSessionId)
return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2State)
}
verifier := remarkParts[2]
expectedState := fmt.Sprintf("%s|%s|%s", platform, clientSessionId, verifier)
expectedState = fmt.Sprintf("%s|%s|%s", platform, clientSessionId, utils.MD5EncodeToString([]byte(expectedState)))
if oauth2CallbackReq.State != expectedState {
log.Errorf(c, "[oauth2_authentications.CallbackHandler] mismatched random string in oauth 2.0 state, expected \"%s\", got \"%s\"", expectedState, oauth2CallbackReq.State)
return a.redirectToFailedCallbackPage(c, errs.ErrInvalidOAuth2State)
}
a.RemoveSubmissionRemark(duplicatechecker.DUPLICATE_CHECKER_TYPE_OAUTH2_REDIRECT, 0, clientSessionId)
oauth2Token, err := oauth2.GetOAuth2Token(c, oauth2CallbackReq.Code, verifier)
if err != nil {
log.Errorf(c, "[oauth2_authentications.CallbackHandler] failed to retrieve oauth 2.0 token, because %s", err.Error())
@@ -347,3 +346,7 @@ func (a *OAuth2AuthenticationApi) redirectToVerifyCallbackPage(c *core.WebContex
func (a *OAuth2AuthenticationApi) redirectToFailedCallbackPage(c *core.WebContext, err *errs.Error) (string, *errs.Error) {
return fmt.Sprintf(oauth2CallbackPageUrlFailedFormat, a.CurrentConfig().RootUrl, err.Code(), url.QueryEscape(utils.GetDisplayErrorMessage(err))), nil
}
func (a *OAuth2AuthenticationApi) redirectToErrorMessageCallbackPage(c *core.WebContext, message string) (string, *errs.Error) {
return fmt.Sprintf(oauth2CallbackPageUrlErrorMessageFormat, a.CurrentConfig().RootUrl, url.QueryEscape(message)), nil
}
+4
View File
@@ -47,6 +47,10 @@ func (a *ServerSettingsApi) ServerSettingsJavascriptHandler(c *core.WebContext)
a.appendStringSetting(builder, "op", config.OAuth2Provider)
if config.OAuth2Provider == settings.OAuth2ProviderOIDC && config.OAuth2OIDCCustomDisplayNameConfig.Enabled {
a.appendMultiLanguageTipSetting(builder, "ocn", config.OAuth2OIDCCustomDisplayNameConfig)
}
if config.EnableMCPServer {
a.appendBooleanSetting(builder, "mcp", config.EnableMCPServer)
}
@@ -1,4 +1,4 @@
package oauth2
package data
import "github.com/mayswind/ezbookkeeping/pkg/core"
+36 -34
View File
@@ -1,10 +1,18 @@
package oauth2
import (
"crypto/sha256"
"encoding/base64"
"net/http"
"golang.org/x/oauth2"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/gitea"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/github"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/nextcloud"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/oidc"
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/settings"
@@ -13,8 +21,7 @@ import (
// OAuth2Container contains the current OAuth 2.0 authentication provider
type OAuth2Container struct {
oauth2Config *oauth2.Config
oauth2Provider OAuth2Provider
current provider.OAuth2Provider
oauth2HttpClient *http.Client
externalUserAuthType core.UserExternalAuthType
}
@@ -34,24 +41,32 @@ func InitializeOAuth2Provider(config *settings.Config) error {
return errs.ErrInvalidOAuth2Config
}
var oauth2Provider OAuth2Provider
var err error
var oauth2Provider provider.OAuth2Provider
var externalUserAuthType core.UserExternalAuthType
redirectUrl := config.RootUrl + "oauth2/callback"
if config.OAuth2Provider == settings.OAuth2ProviderNextcloud {
oauth2Provider = NewNextcloudOAuth2Provider(config.OAuth2NextcloudBaseUrl)
if config.OAuth2Provider == settings.OAuth2ProviderOIDC {
oauth2Provider, err = oidc.NewOIDCProvider(config, redirectUrl)
externalUserAuthType = core.USER_EXTERNAL_AUTH_TYPE_OAUTH2_OIDC
} else if config.OAuth2Provider == settings.OAuth2ProviderNextcloud {
oauth2Provider, err = nextcloud.NewNextcloudOAuth2Provider(config, redirectUrl)
externalUserAuthType = core.USER_EXTERNAL_AUTH_TYPE_OAUTH2_NEXTCLOUD
} else if config.OAuth2Provider == settings.OAuth2ProviderGitea {
oauth2Provider = NewGiteaOAuth2Provider(config.OAuth2GiteaBaseUrl)
oauth2Provider, err = gitea.NewGiteaOAuth2Provider(config, redirectUrl)
externalUserAuthType = core.USER_EXTERNAL_AUTH_TYPE_OAUTH2_GITEA
} else if config.OAuth2Provider == settings.OAuth2ProviderGithub {
oauth2Provider = NewGithubOAuth2Provider()
oauth2Provider, err = github.NewGithubOAuth2Provider(config, redirectUrl)
externalUserAuthType = core.USER_EXTERNAL_AUTH_TYPE_OAUTH2_GITHUB
} else {
return errs.ErrInvalidOAuth2Provider
}
Container.oauth2Config = buildOAuth2Config(config, oauth2Provider)
Container.oauth2Provider = oauth2Provider
if err != nil {
return err
}
Container.current = oauth2Provider
Container.oauth2HttpClient = utils.NewHttpClient(config.OAuth2RequestTimeout, config.OAuth2Proxy, config.OAuth2SkipTLSVerify, settings.GetUserAgent())
Container.externalUserAuthType = externalUserAuthType
@@ -59,26 +74,29 @@ func InitializeOAuth2Provider(config *settings.Config) error {
}
// GetOAuth2AuthUrl returns the OAuth 2.0 authentication url
func GetOAuth2AuthUrl(c core.Context, state string) (string, error) {
if Container.oauth2Config == nil {
func GetOAuth2AuthUrl(c core.Context, state string, verifier string) (string, error) {
if Container.current == nil {
return "", errs.ErrOAuth2NotEnabled
}
return Container.oauth2Config.AuthCodeURL(state), nil
sha256Hash := sha256.New()
sha256Hash.Write([]byte(verifier))
challenge := base64.RawURLEncoding.EncodeToString(sha256Hash.Sum(nil))
return Container.current.GetOAuth2AuthUrl(wrapOAuth2Context(c, Container.oauth2HttpClient), state, challenge)
}
// GetOAuth2Token exchanges the authorization code for an OAuth 2.0 token
func GetOAuth2Token(c core.Context, code string) (*oauth2.Token, error) {
if Container.oauth2Config == nil || Container.oauth2HttpClient == nil {
func GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error) {
if Container.current == nil || Container.oauth2HttpClient == nil {
return nil, errs.ErrOAuth2NotEnabled
}
return Container.oauth2Config.Exchange(wrapOAuth2Context(c, Container.oauth2HttpClient), code)
return Container.current.GetOAuth2Token(wrapOAuth2Context(c, Container.oauth2HttpClient), code, verifier)
}
// GetOAuth2UserInfo retrieves the OAuth 2.0 user info using the provided OAuth 2.0 token
func GetOAuth2UserInfo(c core.Context, token *oauth2.Token) (*OAuth2UserInfo, error) {
if Container.oauth2Config == nil || Container.oauth2Provider == nil || Container.oauth2HttpClient == nil {
func GetOAuth2UserInfo(c core.Context, token *oauth2.Token) (*data.OAuth2UserInfo, error) {
if Container.current == nil || Container.oauth2HttpClient == nil {
return nil, errs.ErrOAuth2NotEnabled
}
@@ -86,26 +104,10 @@ func GetOAuth2UserInfo(c core.Context, token *oauth2.Token) (*OAuth2UserInfo, er
return nil, errs.ErrInvalidOAuth2Token
}
oauth2Client := oauth2.NewClient(wrapOAuth2Context(c, Container.oauth2HttpClient), oauth2.StaticTokenSource(token))
return Container.oauth2Provider.GetUserInfo(c, oauth2Client)
return Container.current.GetUserInfo(wrapOAuth2Context(c, Container.oauth2HttpClient), token)
}
// GetExternalUserAuthType returns the external user auth type of the current OAuth 2.0 provider
func GetExternalUserAuthType() core.UserExternalAuthType {
return Container.externalUserAuthType
}
func buildOAuth2Config(config *settings.Config, oauth2Provider OAuth2Provider) *oauth2.Config {
redirectURL := config.RootUrl + "oauth2/callback"
return &oauth2.Config{
ClientID: config.OAuth2ClientID,
ClientSecret: config.OAuth2ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: oauth2Provider.GetAuthUrl(),
TokenURL: oauth2Provider.GetTokenUrl(),
},
RedirectURL: redirectURL,
Scopes: oauth2Provider.GetScopes(),
}
}
-22
View File
@@ -1,22 +0,0 @@
package oauth2
import (
"net/http"
"github.com/mayswind/ezbookkeeping/pkg/core"
)
// OAuth2Provider defines the structure of OAuth 2.0 provider
type OAuth2Provider interface {
// GetAuthUrl returns the authentication url of the provider
GetAuthUrl() string
// GetTokenUrl returns the token url of the provider
GetTokenUrl() string
// GetUserInfo returns the user info
GetUserInfo(c core.Context, oauth2Client *http.Client) (*OAuth2UserInfo, error)
// GetScopes returns the scopes required by the provider
GetScopes() []string
}
@@ -1,18 +1,24 @@
package oauth2
package common
import (
"io"
"net/http"
"golang.org/x/oauth2"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider"
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/log"
"github.com/mayswind/ezbookkeeping/pkg/settings"
)
// CommonOAuth2Provider represents common OAuth 2.0 provider
type CommonOAuth2Provider struct {
OAuth2Provider
dataSource CommonOAuth2DataSource
provider.OAuth2Provider
oauth2Config *oauth2.Config
dataSource CommonOAuth2DataSource
}
// CommonOAuth2DataSource defines the structure of OAuth 2.0 data source
@@ -30,26 +36,21 @@ type CommonOAuth2DataSource interface {
GetScopes() []string
// ParseUserInfo returns the user info by parsing the response body
ParseUserInfo(c core.Context, body []byte) (*OAuth2UserInfo, error)
ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error)
}
// GetAuthUrl returns the authentication url of the common OAuth 2.0 provider
func (p *CommonOAuth2Provider) GetAuthUrl() string {
return p.dataSource.GetAuthUrl()
// GetOAuth2AuthUrl returns the authentication url of the common OAuth 2.0 provider
func (p *CommonOAuth2Provider) GetOAuth2AuthUrl(c core.Context, state string, challenge string) (string, error) {
return p.oauth2Config.AuthCodeURL(state), nil
}
// GetTokenUrl returns the token url of the common OAuth 2.0 provider
func (p *CommonOAuth2Provider) GetTokenUrl() string {
return p.dataSource.GetTokenUrl()
}
// GetScopes returns the scopes required by the common OAuth 2.0 provider
func (p *CommonOAuth2Provider) GetScopes() []string {
return p.dataSource.GetScopes()
// GetOAuth2Token returns the OAuth 2.0 token of the common OAuth 2.0 provider
func (p *CommonOAuth2Provider) GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error) {
return p.oauth2Config.Exchange(c, code)
}
// GetUserInfo returns the user info by the common OAuth 2.0 provider
func (p *CommonOAuth2Provider) GetUserInfo(c core.Context, oauth2Client *http.Client) (*OAuth2UserInfo, error) {
func (p *CommonOAuth2Provider) GetUserInfo(c core.Context, oauth2Token *oauth2.Token) (*data.OAuth2UserInfo, error) {
req, err := p.dataSource.GetUserInfoRequest()
if err != nil {
@@ -57,6 +58,7 @@ func (p *CommonOAuth2Provider) GetUserInfo(c core.Context, oauth2Client *http.Cl
return nil, errs.ErrFailedToRequestRemoteApi
}
oauth2Client := oauth2.NewClient(c, oauth2.StaticTokenSource(oauth2Token))
resp, err := oauth2Client.Do(req)
if err != nil {
@@ -76,3 +78,27 @@ func (p *CommonOAuth2Provider) GetUserInfo(c core.Context, oauth2Client *http.Cl
return p.dataSource.ParseUserInfo(c, body)
}
// GetDataSource returns the data source of the common OAuth 2.0 provider
func (p *CommonOAuth2Provider) GetDataSource() CommonOAuth2DataSource {
return p.dataSource
}
// NewCommonOAuth2Provider returns a new common OAuth 2.0 provider
func NewCommonOAuth2Provider(config *settings.Config, redirectUrl string, dataSource CommonOAuth2DataSource) *CommonOAuth2Provider {
oauth2Config := &oauth2.Config{
ClientID: config.OAuth2ClientID,
ClientSecret: config.OAuth2ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: dataSource.GetAuthUrl(),
TokenURL: dataSource.GetTokenUrl(),
},
RedirectURL: redirectUrl,
Scopes: dataSource.GetScopes(),
}
return &CommonOAuth2Provider{
oauth2Config: oauth2Config,
dataSource: dataSource,
}
}
@@ -1,12 +1,16 @@
package oauth2
package gitea
import (
"encoding/json"
"net/http"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common"
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/log"
"github.com/mayswind/ezbookkeeping/pkg/settings"
)
type giteaUserInfoResponse struct {
@@ -17,7 +21,7 @@ type giteaUserInfoResponse struct {
// GiteaOAuth2DataSource represents Gitea OAuth 2.0 data source
type GiteaOAuth2DataSource struct {
CommonOAuth2DataSource
common.CommonOAuth2DataSource
baseUrl string
}
@@ -52,7 +56,7 @@ func (s *GiteaOAuth2DataSource) GetScopes() []string {
}
// ParseUserInfo returns the user info by parsing the response body
func (s *GiteaOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OAuth2UserInfo, error) {
func (s *GiteaOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error) {
userInfoResp := &giteaUserInfoResponse{}
err := json.Unmarshal(body, &userInfoResp)
@@ -66,7 +70,7 @@ func (s *GiteaOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OAu
return nil, errs.ErrCannotRetrieveUserInfo
}
return &OAuth2UserInfo{
return &data.OAuth2UserInfo{
UserName: userInfoResp.Login,
Email: userInfoResp.Email,
NickName: userInfoResp.FullName,
@@ -74,14 +78,18 @@ func (s *GiteaOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OAu
}
// NewGiteaOAuth2Provider creates a new Gitea OAuth 2.0 provider instance
func NewGiteaOAuth2Provider(baseUrl string) OAuth2Provider {
func NewGiteaOAuth2Provider(config *settings.Config, redirectUrl string) (provider.OAuth2Provider, error) {
if len(config.OAuth2GiteaBaseUrl) < 1 {
return nil, errs.ErrInvalidOAuth2Config
}
baseUrl := config.OAuth2GiteaBaseUrl
if baseUrl[len(baseUrl)-1] != '/' {
baseUrl += "/"
}
return &CommonOAuth2Provider{
dataSource: &GiteaOAuth2DataSource{
baseUrl: baseUrl,
},
}
return common.NewCommonOAuth2Provider(config, redirectUrl, &GiteaOAuth2DataSource{
baseUrl: baseUrl,
}), nil
}
@@ -1,22 +1,33 @@
package oauth2
package gitea
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common"
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/settings"
)
func TestNewGiteaOAuth2Provider(t *testing.T) {
datasource := NewGiteaOAuth2Provider("https://example.com/")
assert.Equal(t, "https://example.com/login/oauth/authorize", datasource.GetAuthUrl())
assert.Equal(t, "https://example.com/login/oauth/access_token", datasource.GetTokenUrl())
provider, err := NewGiteaOAuth2Provider(&settings.Config{
OAuth2GiteaBaseUrl: "https://example.com/",
}, "")
assert.Nil(t, err)
assert.Equal(t, "https://example.com/login/oauth/authorize", provider.(*common.CommonOAuth2Provider).GetDataSource().GetAuthUrl())
assert.Equal(t, "https://example.com/login/oauth/access_token", provider.(*common.CommonOAuth2Provider).GetDataSource().GetTokenUrl())
datasource = NewGiteaOAuth2Provider("https://example.com")
assert.Equal(t, "https://example.com/login/oauth/authorize", datasource.GetAuthUrl())
assert.Equal(t, "https://example.com/login/oauth/access_token", datasource.GetTokenUrl())
provider, err = NewGiteaOAuth2Provider(&settings.Config{
OAuth2GiteaBaseUrl: "https://example.com",
}, "")
assert.Nil(t, err)
assert.Equal(t, "https://example.com/login/oauth/authorize", provider.(*common.CommonOAuth2Provider).GetDataSource().GetAuthUrl())
assert.Equal(t, "https://example.com/login/oauth/access_token", provider.(*common.CommonOAuth2Provider).GetDataSource().GetTokenUrl())
provider, err = NewGiteaOAuth2Provider(&settings.Config{}, "")
assert.Equal(t, errs.ErrInvalidOAuth2Config, err)
}
func TestGiteaOAuth2Datasource_GetUserInfoRequest(t *testing.T) {
@@ -1,12 +1,16 @@
package oauth2
package github
import (
"encoding/json"
"net/http"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common"
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/log"
"github.com/mayswind/ezbookkeeping/pkg/settings"
)
type githubUserProfileResponse struct {
@@ -17,7 +21,7 @@ type githubUserProfileResponse struct {
// GithubOAuth2DataSource represents Github OAuth 2.0 data source
type GithubOAuth2DataSource struct {
CommonOAuth2DataSource
common.CommonOAuth2DataSource
}
// GetAuthUrl returns the authentication url of the Github data source
@@ -51,7 +55,7 @@ func (p *GithubOAuth2DataSource) GetScopes() []string {
}
// ParseUserInfo returns the user info by parsing the response body
func (p *GithubOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OAuth2UserInfo, error) {
func (p *GithubOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error) {
userInfoResp := &githubUserProfileResponse{}
err := json.Unmarshal(body, &userInfoResp)
@@ -65,7 +69,7 @@ func (p *GithubOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OA
return nil, errs.ErrCannotRetrieveUserInfo
}
return &OAuth2UserInfo{
return &data.OAuth2UserInfo{
UserName: userInfoResp.Login,
Email: userInfoResp.Email,
NickName: userInfoResp.Name,
@@ -73,8 +77,6 @@ func (p *GithubOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OA
}
// NewGithubOAuth2Provider creates a new Github OAuth 2.0 provider instance
func NewGithubOAuth2Provider() OAuth2Provider {
return &CommonOAuth2Provider{
dataSource: &GithubOAuth2DataSource{},
}
func NewGithubOAuth2Provider(config *settings.Config, redirectUrl string) (provider.OAuth2Provider, error) {
return common.NewCommonOAuth2Provider(config, redirectUrl, &GithubOAuth2DataSource{}), nil
}
@@ -1,4 +1,4 @@
package oauth2
package github
import (
"testing"
@@ -1,12 +1,16 @@
package oauth2
package nextcloud
import (
"encoding/json"
"net/http"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common"
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/log"
"github.com/mayswind/ezbookkeeping/pkg/settings"
)
type nextcloudUserInfoResponse struct {
@@ -25,7 +29,7 @@ type nextcloudUserInfoResponse struct {
// NextcloudOAuth2DataSource represents Nextcloud OAuth 2.0 data source
type NextcloudOAuth2DataSource struct {
CommonOAuth2DataSource
common.CommonOAuth2DataSource
baseUrl string
}
@@ -61,7 +65,7 @@ func (s *NextcloudOAuth2DataSource) GetScopes() []string {
}
// ParseUserInfo returns the user info by parsing the response body
func (s *NextcloudOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OAuth2UserInfo, error) {
func (s *NextcloudOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error) {
userInfoResp := &nextcloudUserInfoResponse{}
err := json.Unmarshal(body, &userInfoResp)
@@ -85,7 +89,7 @@ func (s *NextcloudOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (
return nil, errs.ErrCannotRetrieveUserInfo
}
return &OAuth2UserInfo{
return &data.OAuth2UserInfo{
UserName: userInfoResp.OCS.Data.ID,
Email: userInfoResp.OCS.Data.Email,
NickName: userInfoResp.OCS.Data.DisplayName,
@@ -93,14 +97,18 @@ func (s *NextcloudOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (
}
// NewNextcloudOAuth2Provider creates a new Nextcloud OAuth 2.0 provider instance
func NewNextcloudOAuth2Provider(baseUrl string) OAuth2Provider {
func NewNextcloudOAuth2Provider(config *settings.Config, redirectUrl string) (provider.OAuth2Provider, error) {
if len(config.OAuth2NextcloudBaseUrl) < 1 {
return nil, errs.ErrInvalidOAuth2Config
}
baseUrl := config.OAuth2NextcloudBaseUrl
if baseUrl[len(baseUrl)-1] != '/' {
baseUrl += "/"
}
return &CommonOAuth2Provider{
dataSource: &NextcloudOAuth2DataSource{
baseUrl: baseUrl,
},
}
return common.NewCommonOAuth2Provider(config, redirectUrl, &NextcloudOAuth2DataSource{
baseUrl: baseUrl,
}), nil
}
@@ -1,22 +1,33 @@
package oauth2
package nextcloud
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common"
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/settings"
)
func TestNewNextcloudOAuth2Provider(t *testing.T) {
datasource := NewNextcloudOAuth2Provider("https://example.com/")
assert.Equal(t, "https://example.com/apps/oauth2/authorize", datasource.GetAuthUrl())
assert.Equal(t, "https://example.com/apps/oauth2/api/v1/token", datasource.GetTokenUrl())
provider, err := NewNextcloudOAuth2Provider(&settings.Config{
OAuth2NextcloudBaseUrl: "https://example.com/",
}, "")
assert.Nil(t, err)
assert.Equal(t, "https://example.com/apps/oauth2/authorize", provider.(*common.CommonOAuth2Provider).GetDataSource().GetAuthUrl())
assert.Equal(t, "https://example.com/apps/oauth2/api/v1/token", provider.(*common.CommonOAuth2Provider).GetDataSource().GetTokenUrl())
datasource = NewNextcloudOAuth2Provider("https://example.com/index.php")
assert.Equal(t, "https://example.com/index.php/apps/oauth2/authorize", datasource.GetAuthUrl())
assert.Equal(t, "https://example.com/index.php/apps/oauth2/api/v1/token", datasource.GetTokenUrl())
provider, err = NewNextcloudOAuth2Provider(&settings.Config{
OAuth2NextcloudBaseUrl: "https://example.com/index.php",
}, "")
assert.Nil(t, err)
assert.Equal(t, "https://example.com/index.php/apps/oauth2/authorize", provider.(*common.CommonOAuth2Provider).GetDataSource().GetAuthUrl())
assert.Equal(t, "https://example.com/index.php/apps/oauth2/api/v1/token", provider.(*common.CommonOAuth2Provider).GetDataSource().GetTokenUrl())
provider, err = NewNextcloudOAuth2Provider(&settings.Config{}, "")
assert.Equal(t, errs.ErrInvalidOAuth2Config, err)
}
func TestNextcloudOAuth2Datasource_GetUserInfoRequest(t *testing.T) {
@@ -0,0 +1,20 @@
package provider
import (
"golang.org/x/oauth2"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
"github.com/mayswind/ezbookkeeping/pkg/core"
)
// OAuth2Provider defines the structure of OAuth 2.0 provider
type OAuth2Provider interface {
// GetOAuth2AuthUrl returns the authentication url of the provider
GetOAuth2AuthUrl(c core.Context, state string, challenge string) (string, error)
// GetOAuth2Token returns the OAuth 2.0 token of the provider
GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error)
// GetUserInfo returns the user info
GetUserInfo(c core.Context, oauth2Token *oauth2.Token) (*data.OAuth2UserInfo, error)
}
@@ -0,0 +1,173 @@
package oidc
import (
"strings"
"golang.org/x/oauth2"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider"
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/log"
"github.com/mayswind/ezbookkeeping/pkg/settings"
)
// OIDCClaims represents OIDC claims
type OIDCClaims struct {
PreferredUserName string `json:"preferred_username"`
Name string `json:"name"`
Email string `json:"email"`
}
// OIDCProvider represents OIDC provider
type OIDCProvider struct {
provider.OAuth2Provider
oidcBaseUrl string
redirectUrl string
oauth2ClientID string
oauth2ClientSecret string
oauth2Config *oauth2.Config
oidcProvider *oidc.Provider
oidcVerifier *oidc.IDTokenVerifier
}
// GetOAuth2AuthUrl returns the authentication url of the OIDC provider
func (p *OIDCProvider) GetOAuth2AuthUrl(c core.Context, state string, challenge string) (string, error) {
oauth2Config, err := p.getOAuth2Config(c)
if err != nil {
return "", err
}
return oauth2Config.AuthCodeURL(state,
oauth2.SetAuthURLParam("code_challenge", challenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256")), nil
}
// GetOAuth2Token returns the OAuth 2.0 token of the OIDC provider
func (p *OIDCProvider) GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error) {
oauth2Config, err := p.getOAuth2Config(c)
if err != nil {
return nil, err
}
return oauth2Config.Exchange(c, code, oauth2.SetAuthURLParam("code_verifier", verifier))
}
// GetUserInfo returns the user info by the OIDC provider
func (p *OIDCProvider) GetUserInfo(c core.Context, oauth2Token *oauth2.Token) (*data.OAuth2UserInfo, error) {
_, err := p.getOAuth2Config(c)
if err != nil {
return nil, err
}
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
log.Errorf(c, "[oidc_provider.GetUserInfo] missing \"id_token\" field in oauth 2.0 token")
return nil, errs.ErrInvalidOAuth2Token
}
idToken, err := p.oidcVerifier.Verify(c, rawIDToken)
if err != nil {
log.Errorf(c, "[oidc_provider.GetUserInfo] failed to verify \"id_token\" field in oauth 2.0 token, because %s", err.Error())
return nil, errs.ErrInvalidOAuth2Token
}
var claims OIDCClaims
err = idToken.Claims(&claims)
if err != nil {
log.Errorf(c, "[oidc_provider.GetUserInfo] failed to parse claims in oauth 2.0 token, because %s", err.Error())
return nil, errs.ErrInvalidOAuth2Token
}
userName := claims.PreferredUserName
email := claims.Email
nickName := claims.Name
if userName == "" || email == "" || nickName == "" {
userInfo, err := p.oidcProvider.UserInfo(c, oauth2.StaticTokenSource(oauth2Token))
if err != nil {
log.Errorf(c, "[oidc_provider.GetUserInfo] failed to get user info, because %s", err.Error())
return nil, errs.ErrCannotRetrieveUserInfo
}
err = userInfo.Claims(&claims)
if err != nil {
log.Errorf(c, "[oidc_provider.GetUserInfo] failed to parse user info, because %s", err.Error())
return nil, errs.ErrCannotRetrieveUserInfo
}
if userName == "" {
userName = claims.PreferredUserName
}
if email == "" {
email = claims.Email
}
if nickName == "" {
nickName = claims.Name
}
}
return &data.OAuth2UserInfo{
UserName: userName,
Email: email,
NickName: nickName,
}, nil
}
func (p *OIDCProvider) getOAuth2Config(c core.Context) (*oauth2.Config, error) {
if p.oauth2Config != nil {
return p.oauth2Config, nil
}
oidcProvider, err := oidc.NewProvider(c, p.oidcBaseUrl)
if err != nil {
log.Errorf(c, "[oidc_provider.getOAuth2Config] failed to create oidc provider, because %s", err.Error())
return nil, err
}
oidcVerifier := oidcProvider.Verifier(&oidc.Config{ClientID: p.oauth2ClientID})
oauth2Config := &oauth2.Config{
ClientID: p.oauth2ClientID,
ClientSecret: p.oauth2ClientSecret,
Endpoint: oidcProvider.Endpoint(),
RedirectURL: p.redirectUrl,
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
p.oauth2Config = oauth2Config
p.oidcProvider = oidcProvider
p.oidcVerifier = oidcVerifier
return oauth2Config, nil
}
// NewOIDCProvider returns a new OIDC provider
func NewOIDCProvider(config *settings.Config, redirectUrl string) (*OIDCProvider, error) {
if len(config.OAuth2OIDCProviderBaseUrl) < 1 {
return nil, errs.ErrInvalidOAuth2Config
}
baseUrl := strings.TrimSuffix(config.OAuth2OIDCProviderBaseUrl, "/")
return &OIDCProvider{
oidcBaseUrl: baseUrl,
redirectUrl: redirectUrl,
oauth2ClientID: config.OAuth2ClientID,
oauth2ClientSecret: config.OAuth2ClientSecret,
oauth2Config: nil,
}, nil
}
+3 -1
View File
@@ -5,6 +5,7 @@ type UserExternalAuthType string
// User External Auth Type
const (
USER_EXTERNAL_AUTH_TYPE_OAUTH2_OIDC UserExternalAuthType = "oidc"
USER_EXTERNAL_AUTH_TYPE_OAUTH2_NEXTCLOUD UserExternalAuthType = "nextcloud"
USER_EXTERNAL_AUTH_TYPE_OAUTH2_GITEA UserExternalAuthType = "gitea"
USER_EXTERNAL_AUTH_TYPE_OAUTH2_GITHUB UserExternalAuthType = "github"
@@ -13,7 +14,8 @@ const (
// IsValid checks if the UserExternalAuthType is valid
func (t UserExternalAuthType) IsValid() bool {
switch t {
case USER_EXTERNAL_AUTH_TYPE_OAUTH2_NEXTCLOUD,
case USER_EXTERNAL_AUTH_TYPE_OAUTH2_OIDC,
USER_EXTERNAL_AUTH_TYPE_OAUTH2_NEXTCLOUD,
USER_EXTERNAL_AUTH_TYPE_OAUTH2_GITEA,
USER_EXTERNAL_AUTH_TYPE_OAUTH2_GITHUB:
return true
+4 -2
View File
@@ -8,8 +8,10 @@ type OAuth2LoginRequest struct {
// OAuth2CallbackRequest represents all parameters of OAuth 2.0 callback request
type OAuth2CallbackRequest struct {
State string `form:"state"`
Code string `form:"code"`
State string `form:"state"`
Code string `form:"code"`
Error string `form:"error"`
ErrorDescription string `form:"error_description"`
}
// OAuth2CallbackLoginRequest represents all parameters of OAuth 2.0 callback login request
+24 -17
View File
@@ -93,6 +93,7 @@ const (
// OAuth 2.0 provider types
const (
OAuth2ProviderOIDC string = "oidc"
OAuth2ProviderNextcloud string = "nextcloud"
OAuth2ProviderGitea string = "gitea"
OAuth2ProviderGithub string = "github"
@@ -360,23 +361,25 @@ type Config struct {
MaxFailuresPerUserPerMinute uint32
// Auth
EnableInternalAuth bool
EnableOAuth2Login bool
EnableTwoFactor bool
EnableUserForgetPassword bool
ForgetPasswordRequireVerifyEmail bool
OAuth2ClientID string
OAuth2ClientSecret string
OAuth2UserIdentifier string
OAuth2AutoRegister bool
OAuth2Provider string
OAuth2StateExpiredTime uint32
OAuth2StateExpiredTimeDuration time.Duration
OAuth2RequestTimeout uint32
OAuth2Proxy string
OAuth2SkipTLSVerify bool
OAuth2NextcloudBaseUrl string
OAuth2GiteaBaseUrl string
EnableInternalAuth bool
EnableOAuth2Login bool
EnableTwoFactor bool
EnableUserForgetPassword bool
ForgetPasswordRequireVerifyEmail bool
OAuth2ClientID string
OAuth2ClientSecret string
OAuth2UserIdentifier string
OAuth2AutoRegister bool
OAuth2Provider string
OAuth2StateExpiredTime uint32
OAuth2StateExpiredTimeDuration time.Duration
OAuth2RequestTimeout uint32
OAuth2Proxy string
OAuth2SkipTLSVerify bool
OAuth2OIDCProviderBaseUrl string
OAuth2OIDCCustomDisplayNameConfig MultiLanguageContentConfig
OAuth2NextcloudBaseUrl string
OAuth2GiteaBaseUrl string
// User
EnableUserRegister bool
@@ -1003,6 +1006,8 @@ func loadAuthConfiguration(config *Config, configFile *ini.File, sectionName str
if oauth2Provider == "" {
config.OAuth2Provider = ""
} else if oauth2Provider == OAuth2ProviderOIDC {
config.OAuth2Provider = OAuth2ProviderOIDC
} else if oauth2Provider == OAuth2ProviderNextcloud {
config.OAuth2Provider = OAuth2ProviderNextcloud
} else if oauth2Provider == OAuth2ProviderGitea {
@@ -1025,6 +1030,8 @@ func loadAuthConfiguration(config *Config, configFile *ini.File, sectionName str
config.OAuth2RequestTimeout = getConfigItemUint32Value(configFile, sectionName, "oauth2_request_timeout", defaultOAuth2RequestTimeout)
config.OAuth2SkipTLSVerify = getConfigItemBoolValue(configFile, sectionName, "oauth2_skip_tls_verify", false)
config.OAuth2OIDCProviderBaseUrl = getConfigItemStringValue(configFile, sectionName, "oidc_provider_base_url")
config.OAuth2OIDCCustomDisplayNameConfig = getMultiLanguageContentConfig(configFile, sectionName, "enable_oidc_display_name", "oidc_custom_display_name")
config.OAuth2NextcloudBaseUrl = getConfigItemStringValue(configFile, sectionName, "nextcloud_base_url")
config.OAuth2GiteaBaseUrl = getConfigItemStringValue(configFile, sectionName, "gitea_base_url")