add option to control whether PKCE is used in OAuth 2.0 authentication process

This commit is contained in:
MaysWind
2025-10-24 23:03:57 +08:00
parent beea6fe733
commit a17a2cc377
12 changed files with 61 additions and 51 deletions
+6 -3
View File
@@ -297,6 +297,9 @@ oauth2_client_secret =
# For "oauth2" authentication only, OAuth 2.0 provider user identifier claim name, supports "email" and "username", default is "email"
oauth2_user_identifier = email
# For "oauth2" authentication only, set to true to use PKCE
oauth2_use_pkce = false
# For "oauth2" authentication only, if the user returned by OAuth 2.0 is not registered, automatically create a new user (requires "enable_register" to be set to true)
oauth2_auto_register = true
@@ -316,11 +319,11 @@ oauth2_skip_tls_verify = false
# For "oauth2" authentication and "oidc" OAuth 2.0 provider only, OIDC provider base url. Make sure the ".well-known" directory is available under this path. For example, if it's set to "https://auth.example.com/", the discovery URL should be "https://auth.example.com/.well-known/openid-configuration".
oidc_provider_base_url =
# For "oauth2" authentication and "oidc" OAuth 2.0 provider only, set to true to replace the text in the "Log in with Connect ID" button with the below custom display name
# For "oauth2" authentication and "oidc" OAuth 2.0 provider only, set to true to replace the text "Connect ID" in the "Log in with Connect ID" button with the below custom provider name
enable_oidc_display_name = false
# For "oauth2" authentication and "oidc" OAuth 2.0 provider only, the custom display name to replace the text in the "Log in with Connect ID" button, it supports multi-language configuration
# Add an underscore and a language tag after the setting key to configure the display name in that language, the same below
# For "oauth2" authentication and "oidc" OAuth 2.0 provider only, the custom provider name to replace the text in the "Log in with Connect ID" button, it supports multi-language configuration
# Add an underscore and a language tag after the setting key to configure the display name in that language
# For example, oidc_custom_display_name_zh_hans means the display name in Chinese (Simplified)
oidc_custom_display_name =
+16 -7
View File
@@ -1,8 +1,6 @@
package oauth2
import (
"crypto/sha256"
"encoding/base64"
"net/http"
"golang.org/x/oauth2"
@@ -22,6 +20,7 @@ import (
// OAuth2Container contains the current OAuth 2.0 authentication provider
type OAuth2Container struct {
current provider.OAuth2Provider
usePKCE bool
oauth2HttpClient *http.Client
externalUserAuthType core.UserExternalAuthType
}
@@ -67,6 +66,7 @@ func InitializeOAuth2Provider(config *settings.Config) error {
}
Container.current = oauth2Provider
Container.usePKCE = config.OAuth2UsePKCE
Container.oauth2HttpClient = utils.NewHttpClient(config.OAuth2RequestTimeout, config.OAuth2Proxy, config.OAuth2SkipTLSVerify, settings.GetUserAgent())
Container.externalUserAuthType = externalUserAuthType
@@ -79,10 +79,13 @@ func GetOAuth2AuthUrl(c core.Context, state string, verifier string) (string, er
return "", errs.ErrOAuth2NotEnabled
}
sha256Hash := sha256.New()
sha256Hash.Write([]byte(verifier))
challenge := base64.RawURLEncoding.EncodeToString(sha256Hash.Sum(nil))
return Container.current.GetOAuth2AuthUrl(wrapOAuth2Context(c, Container.oauth2HttpClient), state, challenge)
var opts []oauth2.AuthCodeOption
if Container.usePKCE {
opts = append(opts, oauth2.S256ChallengeOption(verifier))
}
return Container.current.GetOAuth2AuthUrl(wrapOAuth2Context(c, Container.oauth2HttpClient), state, opts...)
}
// GetOAuth2Token exchanges the authorization code for an OAuth 2.0 token
@@ -91,7 +94,13 @@ func GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token
return nil, errs.ErrOAuth2NotEnabled
}
return Container.current.GetOAuth2Token(wrapOAuth2Context(c, Container.oauth2HttpClient), code, verifier)
var opts []oauth2.AuthCodeOption
if Container.usePKCE {
opts = append(opts, oauth2.VerifierOption(verifier))
}
return Container.current.GetOAuth2Token(wrapOAuth2Context(c, Container.oauth2HttpClient), code, opts...)
}
// GetOAuth2UserInfo retrieves the OAuth 2.0 user info using the provided OAuth 2.0 token
@@ -36,17 +36,17 @@ type CommonOAuth2DataSource interface {
GetScopes() []string
// ParseUserInfo returns the user info by parsing the response body
ParseUserInfo(c core.Context, body []byte, oauth2Client *http.Client) (*data.OAuth2UserInfo, error)
ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error)
}
// GetOAuth2AuthUrl returns the authentication url of the common OAuth 2.0 provider
func (p *CommonOAuth2Provider) GetOAuth2AuthUrl(c core.Context, state string, challenge string) (string, error) {
return p.oauth2Config.AuthCodeURL(state), nil
func (p *CommonOAuth2Provider) GetOAuth2AuthUrl(c core.Context, state string, opts ...oauth2.AuthCodeOption) (string, error) {
return p.oauth2Config.AuthCodeURL(state, opts...), nil
}
// GetOAuth2Token returns the OAuth 2.0 token of the common OAuth 2.0 provider
func (p *CommonOAuth2Provider) GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error) {
return p.oauth2Config.Exchange(c, code)
func (p *CommonOAuth2Provider) GetOAuth2Token(c core.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return p.oauth2Config.Exchange(c, code, opts...)
}
// GetUserInfo returns the user info by the common OAuth 2.0 provider
@@ -76,7 +76,7 @@ func (p *CommonOAuth2Provider) GetUserInfo(c core.Context, oauth2Token *oauth2.T
return nil, errs.ErrFailedToRequestRemoteApi
}
return p.dataSource.ParseUserInfo(c, body, oauth2Client)
return p.dataSource.ParseUserInfo(c, body)
}
// GetDataSource returns the data source of the common OAuth 2.0 provider
@@ -56,7 +56,7 @@ func (s *GiteaOAuth2DataSource) GetScopes() []string {
}
// ParseUserInfo returns the user info by parsing the response body
func (s *GiteaOAuth2DataSource) ParseUserInfo(c core.Context, body []byte, oauth2Client *http.Client) (*data.OAuth2UserInfo, error) {
func (s *GiteaOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error) {
userInfoResp := &giteaUserInfoResponse{}
err := json.Unmarshal(body, &userInfoResp)
@@ -1,7 +1,6 @@
package gitea
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
@@ -48,7 +47,7 @@ func TestGiteaOAuth2Datasource_ParseUserInfo_Success(t *testing.T) {
"full_name": "User",
"email": "user1@example.com"
}`
info, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent), &http.Client{})
info, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent))
assert.Nil(t, err)
assert.Equal(t, "user1", info.UserName)
@@ -58,7 +57,7 @@ func TestGiteaOAuth2Datasource_ParseUserInfo_Success(t *testing.T) {
func TestGiteaOAuth2Datasource_ParseUserInfo_InvalidJson(t *testing.T) {
datasource := &GiteaOAuth2DataSource{}
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte("invalid"), &http.Client{})
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte("invalid"))
assert.Equal(t, errs.ErrCannotRetrieveUserInfo, err)
}
@@ -66,7 +65,7 @@ func TestGiteaOAuth2Datasource_ParseUserInfo_InvalidJson(t *testing.T) {
func TestGiteaOAuth2Datasource_ParseUserInfo_EmptyLogin(t *testing.T) {
datasource := &GiteaOAuth2DataSource{}
responseContent := `{"login": ""}`
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent), &http.Client{})
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent))
assert.Equal(t, errs.ErrCannotRetrieveUserInfo, err)
}
@@ -41,13 +41,13 @@ type GithubOAuth2Provider struct {
}
// GetOAuth2AuthUrl returns the authentication url of the GitHub OAuth 2.0 provider
func (p *GithubOAuth2Provider) GetOAuth2AuthUrl(c core.Context, state string, challenge string) (string, error) {
return p.oauth2Config.AuthCodeURL(state), nil
func (p *GithubOAuth2Provider) GetOAuth2AuthUrl(c core.Context, state string, opts ...oauth2.AuthCodeOption) (string, error) {
return p.oauth2Config.AuthCodeURL(state, opts...), nil
}
// GetOAuth2Token returns the OAuth 2.0 token of the GitHub OAuth 2.0 provider
func (p *GithubOAuth2Provider) GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error) {
return p.oauth2Config.Exchange(c, code)
func (p *GithubOAuth2Provider) GetOAuth2Token(c core.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return p.oauth2Config.Exchange(c, code, opts...)
}
// GetUserInfo returns the user info by the Github OAuth 2.0 provider
@@ -56,7 +56,7 @@ func (p *GithubOAuth2Provider) GetUserInfo(c core.Context, oauth2Token *oauth2.T
req, err := p.buildAPIRequest(githubUserProfileApiUrl)
if err != nil {
log.Errorf(c, "[github_oauth2_datasource_test.GetUserInfo] failed to get user info request, because %s", err.Error())
log.Errorf(c, "[github_oauth2_provider.GetUserInfo] failed to get user info request, because %s", err.Error())
return nil, errs.ErrFailedToRequestRemoteApi
}
@@ -64,17 +64,17 @@ func (p *GithubOAuth2Provider) GetUserInfo(c core.Context, oauth2Token *oauth2.T
resp, err := oauth2Client.Do(req)
if err != nil {
log.Errorf(c, "[github_oauth2_datasource_test.GetUserInfo] failed to get user info response, because %s", err.Error())
log.Errorf(c, "[github_oauth2_provider.GetUserInfo] failed to get user info response, because %s", err.Error())
return nil, errs.ErrFailedToRequestRemoteApi
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
log.Debugf(c, "[github_oauth2_datasource_test.GetUserInfo] user profile response is %s", body)
log.Debugf(c, "[github_oauth2_provider.GetUserInfo] user profile response is %s", body)
if resp.StatusCode != 200 {
log.Errorf(c, "[github_oauth2_datasource_test.GetUserInfo] failed to get user info response, because response code is %d", resp.StatusCode)
log.Errorf(c, "[github_oauth2_provider.GetUserInfo] failed to get user info response, because response code is %d", resp.StatusCode)
return nil, errs.ErrFailedToRequestRemoteApi
}
@@ -88,24 +88,24 @@ func (p *GithubOAuth2Provider) GetUserInfo(c core.Context, oauth2Token *oauth2.T
req, err = p.buildAPIRequest(githubUserEmailApiUrl)
if err != nil {
log.Errorf(c, "[github_oauth2_datasource_test.GetUserInfo] failed to get user emails request, because %s", err.Error())
log.Errorf(c, "[github_oauth2_provider.GetUserInfo] failed to get user emails request, because %s", err.Error())
return nil, errs.ErrFailedToRequestRemoteApi
}
resp, err = oauth2Client.Do(req)
if err != nil {
log.Errorf(c, "[github_oauth2_datasource_test.GetUserInfo] failed to get user emails response, because %s", err.Error())
log.Errorf(c, "[github_oauth2_provider.GetUserInfo] failed to get user emails response, because %s", err.Error())
return nil, errs.ErrFailedToRequestRemoteApi
}
defer resp.Body.Close()
body, err = io.ReadAll(resp.Body)
log.Debugf(c, "[github_oauth2_datasource_test.GetUserInfo] user emails response is %s", body)
log.Debugf(c, "[github_oauth2_provider.GetUserInfo] user emails response is %s", body)
if resp.StatusCode != 200 {
log.Errorf(c, "[github_oauth2_datasource_test.GetUserInfo] failed to get user emails response, because response code is %d", resp.StatusCode)
log.Errorf(c, "[github_oauth2_provider.GetUserInfo] failed to get user emails response, because response code is %d", resp.StatusCode)
return nil, errs.ErrFailedToRequestRemoteApi
}
@@ -127,12 +127,12 @@ func (p *GithubOAuth2Provider) parseUserProfile(c core.Context, body []byte) (*g
err := json.Unmarshal(body, &userProfileResp)
if err != nil {
log.Warnf(c, "[github_oauth2_datasource.parseUserProfile] failed to parse user profile response body, because %s", err.Error())
log.Warnf(c, "[github_oauth2_provider.parseUserProfile] failed to parse user profile response body, because %s", err.Error())
return nil, errs.ErrCannotRetrieveUserInfo
}
if userProfileResp.Login == "" {
log.Warnf(c, "[github_oauth2_datasource.parseUserProfile] invalid user profile response body")
log.Warnf(c, "[github_oauth2_provider.parseUserProfile] invalid user profile response body")
return nil, errs.ErrCannotRetrieveUserInfo
}
@@ -144,7 +144,7 @@ func (p *GithubOAuth2Provider) parsePrimaryEmail(c core.Context, body []byte) (s
err := json.Unmarshal(body, &emailsResp)
if err != nil {
log.Warnf(c, "[github_oauth2_datasource.parsePrimaryEmail] failed to parse user emails response body, because %s", err.Error())
log.Warnf(c, "[github_oauth2_provider.parsePrimaryEmail] failed to parse user emails response body, because %s", err.Error())
return "", errs.ErrCannotRetrieveUserInfo
}
@@ -65,7 +65,7 @@ func (s *NextcloudOAuth2DataSource) GetScopes() []string {
}
// ParseUserInfo returns the user info by parsing the response body
func (s *NextcloudOAuth2DataSource) ParseUserInfo(c core.Context, body []byte, oauth2Client *http.Client) (*data.OAuth2UserInfo, error) {
func (s *NextcloudOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error) {
userInfoResp := &nextcloudUserInfoResponse{}
err := json.Unmarshal(body, &userInfoResp)
@@ -1,7 +1,6 @@
package nextcloud
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
@@ -57,7 +56,7 @@ func TestNextcloudOAuth2Datasource_ParseUserInfo_Success(t *testing.T) {
}
}
}`
info, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent), &http.Client{})
info, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent))
assert.Nil(t, err)
assert.Equal(t, "user1", info.UserName)
@@ -67,7 +66,7 @@ func TestNextcloudOAuth2Datasource_ParseUserInfo_Success(t *testing.T) {
func TestNextcloudOAuth2Datasource_ParseUserInfo_InvalidJson(t *testing.T) {
datasource := &NextcloudOAuth2DataSource{}
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte("invalid"), &http.Client{})
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte("invalid"))
assert.Equal(t, errs.ErrCannotRetrieveUserInfo, err)
}
@@ -75,7 +74,7 @@ func TestNextcloudOAuth2Datasource_ParseUserInfo_InvalidJson(t *testing.T) {
func TestNextcloudOAuth2Datasource_ParseUserInfo_MissingFields(t *testing.T) {
datasource := &NextcloudOAuth2DataSource{}
responseContent := `{"ocs": {}}`
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent), &http.Client{})
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent))
assert.Equal(t, errs.ErrCannotRetrieveUserInfo, err)
}
@@ -91,7 +90,7 @@ func TestNextcloudOAuth2Datasource_ParseUserInfo_Non200StatusCode(t *testing.T)
"data": {}
}
}`
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent), &http.Client{})
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent))
assert.Equal(t, errs.ErrCannotRetrieveUserInfo, err)
}
@@ -111,7 +110,7 @@ func TestNextcloudOAuth2Datasource_ParseUserInfo_EmptyID(t *testing.T) {
}
}
}`
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent), &http.Client{})
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent))
assert.Equal(t, errs.ErrCannotRetrieveUserInfo, err)
}
+2 -2
View File
@@ -10,10 +10,10 @@ import (
// OAuth2Provider defines the structure of OAuth 2.0 provider
type OAuth2Provider interface {
// GetOAuth2AuthUrl returns the authentication url of the provider
GetOAuth2AuthUrl(c core.Context, state string, challenge string) (string, error)
GetOAuth2AuthUrl(c core.Context, state string, opts ...oauth2.AuthCodeOption) (string, error)
// GetOAuth2Token returns the OAuth 2.0 token of the provider
GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error)
GetOAuth2Token(c core.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
// GetUserInfo returns the user info
GetUserInfo(c core.Context, oauth2Token *oauth2.Token) (*data.OAuth2UserInfo, error)
@@ -35,27 +35,25 @@ type OIDCProvider struct {
}
// GetOAuth2AuthUrl returns the authentication url of the OIDC provider
func (p *OIDCProvider) GetOAuth2AuthUrl(c core.Context, state string, challenge string) (string, error) {
func (p *OIDCProvider) GetOAuth2AuthUrl(c core.Context, state string, opts ...oauth2.AuthCodeOption) (string, error) {
oauth2Config, err := p.getOAuth2Config(c)
if err != nil {
return "", err
}
return oauth2Config.AuthCodeURL(state,
oauth2.SetAuthURLParam("code_challenge", challenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256")), nil
return oauth2Config.AuthCodeURL(state, opts...), nil
}
// GetOAuth2Token returns the OAuth 2.0 token of the OIDC provider
func (p *OIDCProvider) GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error) {
func (p *OIDCProvider) GetOAuth2Token(c core.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
oauth2Config, err := p.getOAuth2Config(c)
if err != nil {
return nil, err
}
return oauth2Config.Exchange(c, code, oauth2.SetAuthURLParam("code_verifier", verifier))
return oauth2Config.Exchange(c, code, opts...)
}
// GetUserInfo returns the user info by the OIDC provider
+2
View File
@@ -368,6 +368,7 @@ type Config struct {
ForgetPasswordRequireVerifyEmail bool
OAuth2ClientID string
OAuth2ClientSecret string
OAuth2UsePKCE bool
OAuth2UserIdentifier string
OAuth2AutoRegister bool
OAuth2Provider string
@@ -989,6 +990,7 @@ func loadAuthConfiguration(config *Config, configFile *ini.File, sectionName str
config.ForgetPasswordRequireVerifyEmail = getConfigItemBoolValue(configFile, sectionName, "forget_password_require_email_verify", false)
config.OAuth2ClientID = getConfigItemStringValue(configFile, sectionName, "oauth2_client_id")
config.OAuth2ClientSecret = getConfigItemStringValue(configFile, sectionName, "oauth2_client_secret")
config.OAuth2UsePKCE = getConfigItemBoolValue(configFile, sectionName, "oauth2_use_pkce", false)
oauth2UserIdentifier := getConfigItemStringValue(configFile, sectionName, "oauth2_user_identifier")