support OIDC authentication (#242)

This commit is contained in:
MaysWind
2025-10-24 01:44:55 +08:00
parent d3ab2b94b7
commit 85b05f9e7e
24 changed files with 490 additions and 202 deletions
@@ -1,4 +1,4 @@
package oauth2
package data
import "github.com/mayswind/ezbookkeeping/pkg/core"
+36 -34
View File
@@ -1,10 +1,18 @@
package oauth2
import (
"crypto/sha256"
"encoding/base64"
"net/http"
"golang.org/x/oauth2"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/gitea"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/github"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/nextcloud"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/oidc"
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/settings"
@@ -13,8 +21,7 @@ import (
// OAuth2Container contains the current OAuth 2.0 authentication provider
type OAuth2Container struct {
oauth2Config *oauth2.Config
oauth2Provider OAuth2Provider
current provider.OAuth2Provider
oauth2HttpClient *http.Client
externalUserAuthType core.UserExternalAuthType
}
@@ -34,24 +41,32 @@ func InitializeOAuth2Provider(config *settings.Config) error {
return errs.ErrInvalidOAuth2Config
}
var oauth2Provider OAuth2Provider
var err error
var oauth2Provider provider.OAuth2Provider
var externalUserAuthType core.UserExternalAuthType
redirectUrl := config.RootUrl + "oauth2/callback"
if config.OAuth2Provider == settings.OAuth2ProviderNextcloud {
oauth2Provider = NewNextcloudOAuth2Provider(config.OAuth2NextcloudBaseUrl)
if config.OAuth2Provider == settings.OAuth2ProviderOIDC {
oauth2Provider, err = oidc.NewOIDCProvider(config, redirectUrl)
externalUserAuthType = core.USER_EXTERNAL_AUTH_TYPE_OAUTH2_OIDC
} else if config.OAuth2Provider == settings.OAuth2ProviderNextcloud {
oauth2Provider, err = nextcloud.NewNextcloudOAuth2Provider(config, redirectUrl)
externalUserAuthType = core.USER_EXTERNAL_AUTH_TYPE_OAUTH2_NEXTCLOUD
} else if config.OAuth2Provider == settings.OAuth2ProviderGitea {
oauth2Provider = NewGiteaOAuth2Provider(config.OAuth2GiteaBaseUrl)
oauth2Provider, err = gitea.NewGiteaOAuth2Provider(config, redirectUrl)
externalUserAuthType = core.USER_EXTERNAL_AUTH_TYPE_OAUTH2_GITEA
} else if config.OAuth2Provider == settings.OAuth2ProviderGithub {
oauth2Provider = NewGithubOAuth2Provider()
oauth2Provider, err = github.NewGithubOAuth2Provider(config, redirectUrl)
externalUserAuthType = core.USER_EXTERNAL_AUTH_TYPE_OAUTH2_GITHUB
} else {
return errs.ErrInvalidOAuth2Provider
}
Container.oauth2Config = buildOAuth2Config(config, oauth2Provider)
Container.oauth2Provider = oauth2Provider
if err != nil {
return err
}
Container.current = oauth2Provider
Container.oauth2HttpClient = utils.NewHttpClient(config.OAuth2RequestTimeout, config.OAuth2Proxy, config.OAuth2SkipTLSVerify, settings.GetUserAgent())
Container.externalUserAuthType = externalUserAuthType
@@ -59,26 +74,29 @@ func InitializeOAuth2Provider(config *settings.Config) error {
}
// GetOAuth2AuthUrl returns the OAuth 2.0 authentication url
func GetOAuth2AuthUrl(c core.Context, state string) (string, error) {
if Container.oauth2Config == nil {
func GetOAuth2AuthUrl(c core.Context, state string, verifier string) (string, error) {
if Container.current == nil {
return "", errs.ErrOAuth2NotEnabled
}
return Container.oauth2Config.AuthCodeURL(state), nil
sha256Hash := sha256.New()
sha256Hash.Write([]byte(verifier))
challenge := base64.RawURLEncoding.EncodeToString(sha256Hash.Sum(nil))
return Container.current.GetOAuth2AuthUrl(wrapOAuth2Context(c, Container.oauth2HttpClient), state, challenge)
}
// GetOAuth2Token exchanges the authorization code for an OAuth 2.0 token
func GetOAuth2Token(c core.Context, code string) (*oauth2.Token, error) {
if Container.oauth2Config == nil || Container.oauth2HttpClient == nil {
func GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error) {
if Container.current == nil || Container.oauth2HttpClient == nil {
return nil, errs.ErrOAuth2NotEnabled
}
return Container.oauth2Config.Exchange(wrapOAuth2Context(c, Container.oauth2HttpClient), code)
return Container.current.GetOAuth2Token(wrapOAuth2Context(c, Container.oauth2HttpClient), code, verifier)
}
// GetOAuth2UserInfo retrieves the OAuth 2.0 user info using the provided OAuth 2.0 token
func GetOAuth2UserInfo(c core.Context, token *oauth2.Token) (*OAuth2UserInfo, error) {
if Container.oauth2Config == nil || Container.oauth2Provider == nil || Container.oauth2HttpClient == nil {
func GetOAuth2UserInfo(c core.Context, token *oauth2.Token) (*data.OAuth2UserInfo, error) {
if Container.current == nil || Container.oauth2HttpClient == nil {
return nil, errs.ErrOAuth2NotEnabled
}
@@ -86,26 +104,10 @@ func GetOAuth2UserInfo(c core.Context, token *oauth2.Token) (*OAuth2UserInfo, er
return nil, errs.ErrInvalidOAuth2Token
}
oauth2Client := oauth2.NewClient(wrapOAuth2Context(c, Container.oauth2HttpClient), oauth2.StaticTokenSource(token))
return Container.oauth2Provider.GetUserInfo(c, oauth2Client)
return Container.current.GetUserInfo(wrapOAuth2Context(c, Container.oauth2HttpClient), token)
}
// GetExternalUserAuthType returns the external user auth type of the current OAuth 2.0 provider
func GetExternalUserAuthType() core.UserExternalAuthType {
return Container.externalUserAuthType
}
func buildOAuth2Config(config *settings.Config, oauth2Provider OAuth2Provider) *oauth2.Config {
redirectURL := config.RootUrl + "oauth2/callback"
return &oauth2.Config{
ClientID: config.OAuth2ClientID,
ClientSecret: config.OAuth2ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: oauth2Provider.GetAuthUrl(),
TokenURL: oauth2Provider.GetTokenUrl(),
},
RedirectURL: redirectURL,
Scopes: oauth2Provider.GetScopes(),
}
}
-22
View File
@@ -1,22 +0,0 @@
package oauth2
import (
"net/http"
"github.com/mayswind/ezbookkeeping/pkg/core"
)
// OAuth2Provider defines the structure of OAuth 2.0 provider
type OAuth2Provider interface {
// GetAuthUrl returns the authentication url of the provider
GetAuthUrl() string
// GetTokenUrl returns the token url of the provider
GetTokenUrl() string
// GetUserInfo returns the user info
GetUserInfo(c core.Context, oauth2Client *http.Client) (*OAuth2UserInfo, error)
// GetScopes returns the scopes required by the provider
GetScopes() []string
}
@@ -1,18 +1,24 @@
package oauth2
package common
import (
"io"
"net/http"
"golang.org/x/oauth2"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider"
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/log"
"github.com/mayswind/ezbookkeeping/pkg/settings"
)
// CommonOAuth2Provider represents common OAuth 2.0 provider
type CommonOAuth2Provider struct {
OAuth2Provider
dataSource CommonOAuth2DataSource
provider.OAuth2Provider
oauth2Config *oauth2.Config
dataSource CommonOAuth2DataSource
}
// CommonOAuth2DataSource defines the structure of OAuth 2.0 data source
@@ -30,26 +36,21 @@ type CommonOAuth2DataSource interface {
GetScopes() []string
// ParseUserInfo returns the user info by parsing the response body
ParseUserInfo(c core.Context, body []byte) (*OAuth2UserInfo, error)
ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error)
}
// GetAuthUrl returns the authentication url of the common OAuth 2.0 provider
func (p *CommonOAuth2Provider) GetAuthUrl() string {
return p.dataSource.GetAuthUrl()
// GetOAuth2AuthUrl returns the authentication url of the common OAuth 2.0 provider
func (p *CommonOAuth2Provider) GetOAuth2AuthUrl(c core.Context, state string, challenge string) (string, error) {
return p.oauth2Config.AuthCodeURL(state), nil
}
// GetTokenUrl returns the token url of the common OAuth 2.0 provider
func (p *CommonOAuth2Provider) GetTokenUrl() string {
return p.dataSource.GetTokenUrl()
}
// GetScopes returns the scopes required by the common OAuth 2.0 provider
func (p *CommonOAuth2Provider) GetScopes() []string {
return p.dataSource.GetScopes()
// GetOAuth2Token returns the OAuth 2.0 token of the common OAuth 2.0 provider
func (p *CommonOAuth2Provider) GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error) {
return p.oauth2Config.Exchange(c, code)
}
// GetUserInfo returns the user info by the common OAuth 2.0 provider
func (p *CommonOAuth2Provider) GetUserInfo(c core.Context, oauth2Client *http.Client) (*OAuth2UserInfo, error) {
func (p *CommonOAuth2Provider) GetUserInfo(c core.Context, oauth2Token *oauth2.Token) (*data.OAuth2UserInfo, error) {
req, err := p.dataSource.GetUserInfoRequest()
if err != nil {
@@ -57,6 +58,7 @@ func (p *CommonOAuth2Provider) GetUserInfo(c core.Context, oauth2Client *http.Cl
return nil, errs.ErrFailedToRequestRemoteApi
}
oauth2Client := oauth2.NewClient(c, oauth2.StaticTokenSource(oauth2Token))
resp, err := oauth2Client.Do(req)
if err != nil {
@@ -76,3 +78,27 @@ func (p *CommonOAuth2Provider) GetUserInfo(c core.Context, oauth2Client *http.Cl
return p.dataSource.ParseUserInfo(c, body)
}
// GetDataSource returns the data source of the common OAuth 2.0 provider
func (p *CommonOAuth2Provider) GetDataSource() CommonOAuth2DataSource {
return p.dataSource
}
// NewCommonOAuth2Provider returns a new common OAuth 2.0 provider
func NewCommonOAuth2Provider(config *settings.Config, redirectUrl string, dataSource CommonOAuth2DataSource) *CommonOAuth2Provider {
oauth2Config := &oauth2.Config{
ClientID: config.OAuth2ClientID,
ClientSecret: config.OAuth2ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: dataSource.GetAuthUrl(),
TokenURL: dataSource.GetTokenUrl(),
},
RedirectURL: redirectUrl,
Scopes: dataSource.GetScopes(),
}
return &CommonOAuth2Provider{
oauth2Config: oauth2Config,
dataSource: dataSource,
}
}
@@ -1,12 +1,16 @@
package oauth2
package gitea
import (
"encoding/json"
"net/http"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common"
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/log"
"github.com/mayswind/ezbookkeeping/pkg/settings"
)
type giteaUserInfoResponse struct {
@@ -17,7 +21,7 @@ type giteaUserInfoResponse struct {
// GiteaOAuth2DataSource represents Gitea OAuth 2.0 data source
type GiteaOAuth2DataSource struct {
CommonOAuth2DataSource
common.CommonOAuth2DataSource
baseUrl string
}
@@ -52,7 +56,7 @@ func (s *GiteaOAuth2DataSource) GetScopes() []string {
}
// ParseUserInfo returns the user info by parsing the response body
func (s *GiteaOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OAuth2UserInfo, error) {
func (s *GiteaOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error) {
userInfoResp := &giteaUserInfoResponse{}
err := json.Unmarshal(body, &userInfoResp)
@@ -66,7 +70,7 @@ func (s *GiteaOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OAu
return nil, errs.ErrCannotRetrieveUserInfo
}
return &OAuth2UserInfo{
return &data.OAuth2UserInfo{
UserName: userInfoResp.Login,
Email: userInfoResp.Email,
NickName: userInfoResp.FullName,
@@ -74,14 +78,18 @@ func (s *GiteaOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OAu
}
// NewGiteaOAuth2Provider creates a new Gitea OAuth 2.0 provider instance
func NewGiteaOAuth2Provider(baseUrl string) OAuth2Provider {
func NewGiteaOAuth2Provider(config *settings.Config, redirectUrl string) (provider.OAuth2Provider, error) {
if len(config.OAuth2GiteaBaseUrl) < 1 {
return nil, errs.ErrInvalidOAuth2Config
}
baseUrl := config.OAuth2GiteaBaseUrl
if baseUrl[len(baseUrl)-1] != '/' {
baseUrl += "/"
}
return &CommonOAuth2Provider{
dataSource: &GiteaOAuth2DataSource{
baseUrl: baseUrl,
},
}
return common.NewCommonOAuth2Provider(config, redirectUrl, &GiteaOAuth2DataSource{
baseUrl: baseUrl,
}), nil
}
@@ -1,22 +1,33 @@
package oauth2
package gitea
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common"
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/settings"
)
func TestNewGiteaOAuth2Provider(t *testing.T) {
datasource := NewGiteaOAuth2Provider("https://example.com/")
assert.Equal(t, "https://example.com/login/oauth/authorize", datasource.GetAuthUrl())
assert.Equal(t, "https://example.com/login/oauth/access_token", datasource.GetTokenUrl())
provider, err := NewGiteaOAuth2Provider(&settings.Config{
OAuth2GiteaBaseUrl: "https://example.com/",
}, "")
assert.Nil(t, err)
assert.Equal(t, "https://example.com/login/oauth/authorize", provider.(*common.CommonOAuth2Provider).GetDataSource().GetAuthUrl())
assert.Equal(t, "https://example.com/login/oauth/access_token", provider.(*common.CommonOAuth2Provider).GetDataSource().GetTokenUrl())
datasource = NewGiteaOAuth2Provider("https://example.com")
assert.Equal(t, "https://example.com/login/oauth/authorize", datasource.GetAuthUrl())
assert.Equal(t, "https://example.com/login/oauth/access_token", datasource.GetTokenUrl())
provider, err = NewGiteaOAuth2Provider(&settings.Config{
OAuth2GiteaBaseUrl: "https://example.com",
}, "")
assert.Nil(t, err)
assert.Equal(t, "https://example.com/login/oauth/authorize", provider.(*common.CommonOAuth2Provider).GetDataSource().GetAuthUrl())
assert.Equal(t, "https://example.com/login/oauth/access_token", provider.(*common.CommonOAuth2Provider).GetDataSource().GetTokenUrl())
provider, err = NewGiteaOAuth2Provider(&settings.Config{}, "")
assert.Equal(t, errs.ErrInvalidOAuth2Config, err)
}
func TestGiteaOAuth2Datasource_GetUserInfoRequest(t *testing.T) {
@@ -1,12 +1,16 @@
package oauth2
package github
import (
"encoding/json"
"net/http"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common"
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/log"
"github.com/mayswind/ezbookkeeping/pkg/settings"
)
type githubUserProfileResponse struct {
@@ -17,7 +21,7 @@ type githubUserProfileResponse struct {
// GithubOAuth2DataSource represents Github OAuth 2.0 data source
type GithubOAuth2DataSource struct {
CommonOAuth2DataSource
common.CommonOAuth2DataSource
}
// GetAuthUrl returns the authentication url of the Github data source
@@ -51,7 +55,7 @@ func (p *GithubOAuth2DataSource) GetScopes() []string {
}
// ParseUserInfo returns the user info by parsing the response body
func (p *GithubOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OAuth2UserInfo, error) {
func (p *GithubOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error) {
userInfoResp := &githubUserProfileResponse{}
err := json.Unmarshal(body, &userInfoResp)
@@ -65,7 +69,7 @@ func (p *GithubOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OA
return nil, errs.ErrCannotRetrieveUserInfo
}
return &OAuth2UserInfo{
return &data.OAuth2UserInfo{
UserName: userInfoResp.Login,
Email: userInfoResp.Email,
NickName: userInfoResp.Name,
@@ -73,8 +77,6 @@ func (p *GithubOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OA
}
// NewGithubOAuth2Provider creates a new Github OAuth 2.0 provider instance
func NewGithubOAuth2Provider() OAuth2Provider {
return &CommonOAuth2Provider{
dataSource: &GithubOAuth2DataSource{},
}
func NewGithubOAuth2Provider(config *settings.Config, redirectUrl string) (provider.OAuth2Provider, error) {
return common.NewCommonOAuth2Provider(config, redirectUrl, &GithubOAuth2DataSource{}), nil
}
@@ -1,4 +1,4 @@
package oauth2
package github
import (
"testing"
@@ -1,12 +1,16 @@
package oauth2
package nextcloud
import (
"encoding/json"
"net/http"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common"
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/log"
"github.com/mayswind/ezbookkeeping/pkg/settings"
)
type nextcloudUserInfoResponse struct {
@@ -25,7 +29,7 @@ type nextcloudUserInfoResponse struct {
// NextcloudOAuth2DataSource represents Nextcloud OAuth 2.0 data source
type NextcloudOAuth2DataSource struct {
CommonOAuth2DataSource
common.CommonOAuth2DataSource
baseUrl string
}
@@ -61,7 +65,7 @@ func (s *NextcloudOAuth2DataSource) GetScopes() []string {
}
// ParseUserInfo returns the user info by parsing the response body
func (s *NextcloudOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*OAuth2UserInfo, error) {
func (s *NextcloudOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error) {
userInfoResp := &nextcloudUserInfoResponse{}
err := json.Unmarshal(body, &userInfoResp)
@@ -85,7 +89,7 @@ func (s *NextcloudOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (
return nil, errs.ErrCannotRetrieveUserInfo
}
return &OAuth2UserInfo{
return &data.OAuth2UserInfo{
UserName: userInfoResp.OCS.Data.ID,
Email: userInfoResp.OCS.Data.Email,
NickName: userInfoResp.OCS.Data.DisplayName,
@@ -93,14 +97,18 @@ func (s *NextcloudOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (
}
// NewNextcloudOAuth2Provider creates a new Nextcloud OAuth 2.0 provider instance
func NewNextcloudOAuth2Provider(baseUrl string) OAuth2Provider {
func NewNextcloudOAuth2Provider(config *settings.Config, redirectUrl string) (provider.OAuth2Provider, error) {
if len(config.OAuth2NextcloudBaseUrl) < 1 {
return nil, errs.ErrInvalidOAuth2Config
}
baseUrl := config.OAuth2NextcloudBaseUrl
if baseUrl[len(baseUrl)-1] != '/' {
baseUrl += "/"
}
return &CommonOAuth2Provider{
dataSource: &NextcloudOAuth2DataSource{
baseUrl: baseUrl,
},
}
return common.NewCommonOAuth2Provider(config, redirectUrl, &NextcloudOAuth2DataSource{
baseUrl: baseUrl,
}), nil
}
@@ -1,22 +1,33 @@
package oauth2
package nextcloud
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common"
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/settings"
)
func TestNewNextcloudOAuth2Provider(t *testing.T) {
datasource := NewNextcloudOAuth2Provider("https://example.com/")
assert.Equal(t, "https://example.com/apps/oauth2/authorize", datasource.GetAuthUrl())
assert.Equal(t, "https://example.com/apps/oauth2/api/v1/token", datasource.GetTokenUrl())
provider, err := NewNextcloudOAuth2Provider(&settings.Config{
OAuth2NextcloudBaseUrl: "https://example.com/",
}, "")
assert.Nil(t, err)
assert.Equal(t, "https://example.com/apps/oauth2/authorize", provider.(*common.CommonOAuth2Provider).GetDataSource().GetAuthUrl())
assert.Equal(t, "https://example.com/apps/oauth2/api/v1/token", provider.(*common.CommonOAuth2Provider).GetDataSource().GetTokenUrl())
datasource = NewNextcloudOAuth2Provider("https://example.com/index.php")
assert.Equal(t, "https://example.com/index.php/apps/oauth2/authorize", datasource.GetAuthUrl())
assert.Equal(t, "https://example.com/index.php/apps/oauth2/api/v1/token", datasource.GetTokenUrl())
provider, err = NewNextcloudOAuth2Provider(&settings.Config{
OAuth2NextcloudBaseUrl: "https://example.com/index.php",
}, "")
assert.Nil(t, err)
assert.Equal(t, "https://example.com/index.php/apps/oauth2/authorize", provider.(*common.CommonOAuth2Provider).GetDataSource().GetAuthUrl())
assert.Equal(t, "https://example.com/index.php/apps/oauth2/api/v1/token", provider.(*common.CommonOAuth2Provider).GetDataSource().GetTokenUrl())
provider, err = NewNextcloudOAuth2Provider(&settings.Config{}, "")
assert.Equal(t, errs.ErrInvalidOAuth2Config, err)
}
func TestNextcloudOAuth2Datasource_GetUserInfoRequest(t *testing.T) {
@@ -0,0 +1,20 @@
package provider
import (
"golang.org/x/oauth2"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
"github.com/mayswind/ezbookkeeping/pkg/core"
)
// OAuth2Provider defines the structure of OAuth 2.0 provider
type OAuth2Provider interface {
// GetOAuth2AuthUrl returns the authentication url of the provider
GetOAuth2AuthUrl(c core.Context, state string, challenge string) (string, error)
// GetOAuth2Token returns the OAuth 2.0 token of the provider
GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error)
// GetUserInfo returns the user info
GetUserInfo(c core.Context, oauth2Token *oauth2.Token) (*data.OAuth2UserInfo, error)
}
@@ -0,0 +1,173 @@
package oidc
import (
"strings"
"golang.org/x/oauth2"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider"
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/log"
"github.com/mayswind/ezbookkeeping/pkg/settings"
)
// OIDCClaims represents OIDC claims
type OIDCClaims struct {
PreferredUserName string `json:"preferred_username"`
Name string `json:"name"`
Email string `json:"email"`
}
// OIDCProvider represents OIDC provider
type OIDCProvider struct {
provider.OAuth2Provider
oidcBaseUrl string
redirectUrl string
oauth2ClientID string
oauth2ClientSecret string
oauth2Config *oauth2.Config
oidcProvider *oidc.Provider
oidcVerifier *oidc.IDTokenVerifier
}
// GetOAuth2AuthUrl returns the authentication url of the OIDC provider
func (p *OIDCProvider) GetOAuth2AuthUrl(c core.Context, state string, challenge string) (string, error) {
oauth2Config, err := p.getOAuth2Config(c)
if err != nil {
return "", err
}
return oauth2Config.AuthCodeURL(state,
oauth2.SetAuthURLParam("code_challenge", challenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256")), nil
}
// GetOAuth2Token returns the OAuth 2.0 token of the OIDC provider
func (p *OIDCProvider) GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error) {
oauth2Config, err := p.getOAuth2Config(c)
if err != nil {
return nil, err
}
return oauth2Config.Exchange(c, code, oauth2.SetAuthURLParam("code_verifier", verifier))
}
// GetUserInfo returns the user info by the OIDC provider
func (p *OIDCProvider) GetUserInfo(c core.Context, oauth2Token *oauth2.Token) (*data.OAuth2UserInfo, error) {
_, err := p.getOAuth2Config(c)
if err != nil {
return nil, err
}
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
log.Errorf(c, "[oidc_provider.GetUserInfo] missing \"id_token\" field in oauth 2.0 token")
return nil, errs.ErrInvalidOAuth2Token
}
idToken, err := p.oidcVerifier.Verify(c, rawIDToken)
if err != nil {
log.Errorf(c, "[oidc_provider.GetUserInfo] failed to verify \"id_token\" field in oauth 2.0 token, because %s", err.Error())
return nil, errs.ErrInvalidOAuth2Token
}
var claims OIDCClaims
err = idToken.Claims(&claims)
if err != nil {
log.Errorf(c, "[oidc_provider.GetUserInfo] failed to parse claims in oauth 2.0 token, because %s", err.Error())
return nil, errs.ErrInvalidOAuth2Token
}
userName := claims.PreferredUserName
email := claims.Email
nickName := claims.Name
if userName == "" || email == "" || nickName == "" {
userInfo, err := p.oidcProvider.UserInfo(c, oauth2.StaticTokenSource(oauth2Token))
if err != nil {
log.Errorf(c, "[oidc_provider.GetUserInfo] failed to get user info, because %s", err.Error())
return nil, errs.ErrCannotRetrieveUserInfo
}
err = userInfo.Claims(&claims)
if err != nil {
log.Errorf(c, "[oidc_provider.GetUserInfo] failed to parse user info, because %s", err.Error())
return nil, errs.ErrCannotRetrieveUserInfo
}
if userName == "" {
userName = claims.PreferredUserName
}
if email == "" {
email = claims.Email
}
if nickName == "" {
nickName = claims.Name
}
}
return &data.OAuth2UserInfo{
UserName: userName,
Email: email,
NickName: nickName,
}, nil
}
func (p *OIDCProvider) getOAuth2Config(c core.Context) (*oauth2.Config, error) {
if p.oauth2Config != nil {
return p.oauth2Config, nil
}
oidcProvider, err := oidc.NewProvider(c, p.oidcBaseUrl)
if err != nil {
log.Errorf(c, "[oidc_provider.getOAuth2Config] failed to create oidc provider, because %s", err.Error())
return nil, err
}
oidcVerifier := oidcProvider.Verifier(&oidc.Config{ClientID: p.oauth2ClientID})
oauth2Config := &oauth2.Config{
ClientID: p.oauth2ClientID,
ClientSecret: p.oauth2ClientSecret,
Endpoint: oidcProvider.Endpoint(),
RedirectURL: p.redirectUrl,
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
p.oauth2Config = oauth2Config
p.oidcProvider = oidcProvider
p.oidcVerifier = oidcVerifier
return oauth2Config, nil
}
// NewOIDCProvider returns a new OIDC provider
func NewOIDCProvider(config *settings.Config, redirectUrl string) (*OIDCProvider, error) {
if len(config.OAuth2OIDCProviderBaseUrl) < 1 {
return nil, errs.ErrInvalidOAuth2Config
}
baseUrl := strings.TrimSuffix(config.OAuth2OIDCProviderBaseUrl, "/")
return &OIDCProvider{
oidcBaseUrl: baseUrl,
redirectUrl: redirectUrl,
oauth2ClientID: config.OAuth2ClientID,
oauth2ClientSecret: config.OAuth2ClientSecret,
oauth2Config: nil,
}, nil
}