add option to control whether PKCE is used in OAuth 2.0 authentication process

This commit is contained in:
MaysWind
2025-10-24 23:03:57 +08:00
parent beea6fe733
commit a17a2cc377
12 changed files with 61 additions and 51 deletions
+16 -7
View File
@@ -1,8 +1,6 @@
package oauth2
import (
"crypto/sha256"
"encoding/base64"
"net/http"
"golang.org/x/oauth2"
@@ -22,6 +20,7 @@ import (
// OAuth2Container contains the current OAuth 2.0 authentication provider
type OAuth2Container struct {
current provider.OAuth2Provider
usePKCE bool
oauth2HttpClient *http.Client
externalUserAuthType core.UserExternalAuthType
}
@@ -67,6 +66,7 @@ func InitializeOAuth2Provider(config *settings.Config) error {
}
Container.current = oauth2Provider
Container.usePKCE = config.OAuth2UsePKCE
Container.oauth2HttpClient = utils.NewHttpClient(config.OAuth2RequestTimeout, config.OAuth2Proxy, config.OAuth2SkipTLSVerify, settings.GetUserAgent())
Container.externalUserAuthType = externalUserAuthType
@@ -79,10 +79,13 @@ func GetOAuth2AuthUrl(c core.Context, state string, verifier string) (string, er
return "", errs.ErrOAuth2NotEnabled
}
sha256Hash := sha256.New()
sha256Hash.Write([]byte(verifier))
challenge := base64.RawURLEncoding.EncodeToString(sha256Hash.Sum(nil))
return Container.current.GetOAuth2AuthUrl(wrapOAuth2Context(c, Container.oauth2HttpClient), state, challenge)
var opts []oauth2.AuthCodeOption
if Container.usePKCE {
opts = append(opts, oauth2.S256ChallengeOption(verifier))
}
return Container.current.GetOAuth2AuthUrl(wrapOAuth2Context(c, Container.oauth2HttpClient), state, opts...)
}
// GetOAuth2Token exchanges the authorization code for an OAuth 2.0 token
@@ -91,7 +94,13 @@ func GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token
return nil, errs.ErrOAuth2NotEnabled
}
return Container.current.GetOAuth2Token(wrapOAuth2Context(c, Container.oauth2HttpClient), code, verifier)
var opts []oauth2.AuthCodeOption
if Container.usePKCE {
opts = append(opts, oauth2.VerifierOption(verifier))
}
return Container.current.GetOAuth2Token(wrapOAuth2Context(c, Container.oauth2HttpClient), code, opts...)
}
// GetOAuth2UserInfo retrieves the OAuth 2.0 user info using the provided OAuth 2.0 token
@@ -36,17 +36,17 @@ type CommonOAuth2DataSource interface {
GetScopes() []string
// ParseUserInfo returns the user info by parsing the response body
ParseUserInfo(c core.Context, body []byte, oauth2Client *http.Client) (*data.OAuth2UserInfo, error)
ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error)
}
// GetOAuth2AuthUrl returns the authentication url of the common OAuth 2.0 provider
func (p *CommonOAuth2Provider) GetOAuth2AuthUrl(c core.Context, state string, challenge string) (string, error) {
return p.oauth2Config.AuthCodeURL(state), nil
func (p *CommonOAuth2Provider) GetOAuth2AuthUrl(c core.Context, state string, opts ...oauth2.AuthCodeOption) (string, error) {
return p.oauth2Config.AuthCodeURL(state, opts...), nil
}
// GetOAuth2Token returns the OAuth 2.0 token of the common OAuth 2.0 provider
func (p *CommonOAuth2Provider) GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error) {
return p.oauth2Config.Exchange(c, code)
func (p *CommonOAuth2Provider) GetOAuth2Token(c core.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return p.oauth2Config.Exchange(c, code, opts...)
}
// GetUserInfo returns the user info by the common OAuth 2.0 provider
@@ -76,7 +76,7 @@ func (p *CommonOAuth2Provider) GetUserInfo(c core.Context, oauth2Token *oauth2.T
return nil, errs.ErrFailedToRequestRemoteApi
}
return p.dataSource.ParseUserInfo(c, body, oauth2Client)
return p.dataSource.ParseUserInfo(c, body)
}
// GetDataSource returns the data source of the common OAuth 2.0 provider
@@ -56,7 +56,7 @@ func (s *GiteaOAuth2DataSource) GetScopes() []string {
}
// ParseUserInfo returns the user info by parsing the response body
func (s *GiteaOAuth2DataSource) ParseUserInfo(c core.Context, body []byte, oauth2Client *http.Client) (*data.OAuth2UserInfo, error) {
func (s *GiteaOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error) {
userInfoResp := &giteaUserInfoResponse{}
err := json.Unmarshal(body, &userInfoResp)
@@ -1,7 +1,6 @@
package gitea
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
@@ -48,7 +47,7 @@ func TestGiteaOAuth2Datasource_ParseUserInfo_Success(t *testing.T) {
"full_name": "User",
"email": "user1@example.com"
}`
info, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent), &http.Client{})
info, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent))
assert.Nil(t, err)
assert.Equal(t, "user1", info.UserName)
@@ -58,7 +57,7 @@ func TestGiteaOAuth2Datasource_ParseUserInfo_Success(t *testing.T) {
func TestGiteaOAuth2Datasource_ParseUserInfo_InvalidJson(t *testing.T) {
datasource := &GiteaOAuth2DataSource{}
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte("invalid"), &http.Client{})
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte("invalid"))
assert.Equal(t, errs.ErrCannotRetrieveUserInfo, err)
}
@@ -66,7 +65,7 @@ func TestGiteaOAuth2Datasource_ParseUserInfo_InvalidJson(t *testing.T) {
func TestGiteaOAuth2Datasource_ParseUserInfo_EmptyLogin(t *testing.T) {
datasource := &GiteaOAuth2DataSource{}
responseContent := `{"login": ""}`
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent), &http.Client{})
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent))
assert.Equal(t, errs.ErrCannotRetrieveUserInfo, err)
}
@@ -41,13 +41,13 @@ type GithubOAuth2Provider struct {
}
// GetOAuth2AuthUrl returns the authentication url of the GitHub OAuth 2.0 provider
func (p *GithubOAuth2Provider) GetOAuth2AuthUrl(c core.Context, state string, challenge string) (string, error) {
return p.oauth2Config.AuthCodeURL(state), nil
func (p *GithubOAuth2Provider) GetOAuth2AuthUrl(c core.Context, state string, opts ...oauth2.AuthCodeOption) (string, error) {
return p.oauth2Config.AuthCodeURL(state, opts...), nil
}
// GetOAuth2Token returns the OAuth 2.0 token of the GitHub OAuth 2.0 provider
func (p *GithubOAuth2Provider) GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error) {
return p.oauth2Config.Exchange(c, code)
func (p *GithubOAuth2Provider) GetOAuth2Token(c core.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return p.oauth2Config.Exchange(c, code, opts...)
}
// GetUserInfo returns the user info by the Github OAuth 2.0 provider
@@ -56,7 +56,7 @@ func (p *GithubOAuth2Provider) GetUserInfo(c core.Context, oauth2Token *oauth2.T
req, err := p.buildAPIRequest(githubUserProfileApiUrl)
if err != nil {
log.Errorf(c, "[github_oauth2_datasource_test.GetUserInfo] failed to get user info request, because %s", err.Error())
log.Errorf(c, "[github_oauth2_provider.GetUserInfo] failed to get user info request, because %s", err.Error())
return nil, errs.ErrFailedToRequestRemoteApi
}
@@ -64,17 +64,17 @@ func (p *GithubOAuth2Provider) GetUserInfo(c core.Context, oauth2Token *oauth2.T
resp, err := oauth2Client.Do(req)
if err != nil {
log.Errorf(c, "[github_oauth2_datasource_test.GetUserInfo] failed to get user info response, because %s", err.Error())
log.Errorf(c, "[github_oauth2_provider.GetUserInfo] failed to get user info response, because %s", err.Error())
return nil, errs.ErrFailedToRequestRemoteApi
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
log.Debugf(c, "[github_oauth2_datasource_test.GetUserInfo] user profile response is %s", body)
log.Debugf(c, "[github_oauth2_provider.GetUserInfo] user profile response is %s", body)
if resp.StatusCode != 200 {
log.Errorf(c, "[github_oauth2_datasource_test.GetUserInfo] failed to get user info response, because response code is %d", resp.StatusCode)
log.Errorf(c, "[github_oauth2_provider.GetUserInfo] failed to get user info response, because response code is %d", resp.StatusCode)
return nil, errs.ErrFailedToRequestRemoteApi
}
@@ -88,24 +88,24 @@ func (p *GithubOAuth2Provider) GetUserInfo(c core.Context, oauth2Token *oauth2.T
req, err = p.buildAPIRequest(githubUserEmailApiUrl)
if err != nil {
log.Errorf(c, "[github_oauth2_datasource_test.GetUserInfo] failed to get user emails request, because %s", err.Error())
log.Errorf(c, "[github_oauth2_provider.GetUserInfo] failed to get user emails request, because %s", err.Error())
return nil, errs.ErrFailedToRequestRemoteApi
}
resp, err = oauth2Client.Do(req)
if err != nil {
log.Errorf(c, "[github_oauth2_datasource_test.GetUserInfo] failed to get user emails response, because %s", err.Error())
log.Errorf(c, "[github_oauth2_provider.GetUserInfo] failed to get user emails response, because %s", err.Error())
return nil, errs.ErrFailedToRequestRemoteApi
}
defer resp.Body.Close()
body, err = io.ReadAll(resp.Body)
log.Debugf(c, "[github_oauth2_datasource_test.GetUserInfo] user emails response is %s", body)
log.Debugf(c, "[github_oauth2_provider.GetUserInfo] user emails response is %s", body)
if resp.StatusCode != 200 {
log.Errorf(c, "[github_oauth2_datasource_test.GetUserInfo] failed to get user emails response, because response code is %d", resp.StatusCode)
log.Errorf(c, "[github_oauth2_provider.GetUserInfo] failed to get user emails response, because response code is %d", resp.StatusCode)
return nil, errs.ErrFailedToRequestRemoteApi
}
@@ -127,12 +127,12 @@ func (p *GithubOAuth2Provider) parseUserProfile(c core.Context, body []byte) (*g
err := json.Unmarshal(body, &userProfileResp)
if err != nil {
log.Warnf(c, "[github_oauth2_datasource.parseUserProfile] failed to parse user profile response body, because %s", err.Error())
log.Warnf(c, "[github_oauth2_provider.parseUserProfile] failed to parse user profile response body, because %s", err.Error())
return nil, errs.ErrCannotRetrieveUserInfo
}
if userProfileResp.Login == "" {
log.Warnf(c, "[github_oauth2_datasource.parseUserProfile] invalid user profile response body")
log.Warnf(c, "[github_oauth2_provider.parseUserProfile] invalid user profile response body")
return nil, errs.ErrCannotRetrieveUserInfo
}
@@ -144,7 +144,7 @@ func (p *GithubOAuth2Provider) parsePrimaryEmail(c core.Context, body []byte) (s
err := json.Unmarshal(body, &emailsResp)
if err != nil {
log.Warnf(c, "[github_oauth2_datasource.parsePrimaryEmail] failed to parse user emails response body, because %s", err.Error())
log.Warnf(c, "[github_oauth2_provider.parsePrimaryEmail] failed to parse user emails response body, because %s", err.Error())
return "", errs.ErrCannotRetrieveUserInfo
}
@@ -65,7 +65,7 @@ func (s *NextcloudOAuth2DataSource) GetScopes() []string {
}
// ParseUserInfo returns the user info by parsing the response body
func (s *NextcloudOAuth2DataSource) ParseUserInfo(c core.Context, body []byte, oauth2Client *http.Client) (*data.OAuth2UserInfo, error) {
func (s *NextcloudOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error) {
userInfoResp := &nextcloudUserInfoResponse{}
err := json.Unmarshal(body, &userInfoResp)
@@ -1,7 +1,6 @@
package nextcloud
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
@@ -57,7 +56,7 @@ func TestNextcloudOAuth2Datasource_ParseUserInfo_Success(t *testing.T) {
}
}
}`
info, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent), &http.Client{})
info, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent))
assert.Nil(t, err)
assert.Equal(t, "user1", info.UserName)
@@ -67,7 +66,7 @@ func TestNextcloudOAuth2Datasource_ParseUserInfo_Success(t *testing.T) {
func TestNextcloudOAuth2Datasource_ParseUserInfo_InvalidJson(t *testing.T) {
datasource := &NextcloudOAuth2DataSource{}
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte("invalid"), &http.Client{})
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte("invalid"))
assert.Equal(t, errs.ErrCannotRetrieveUserInfo, err)
}
@@ -75,7 +74,7 @@ func TestNextcloudOAuth2Datasource_ParseUserInfo_InvalidJson(t *testing.T) {
func TestNextcloudOAuth2Datasource_ParseUserInfo_MissingFields(t *testing.T) {
datasource := &NextcloudOAuth2DataSource{}
responseContent := `{"ocs": {}}`
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent), &http.Client{})
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent))
assert.Equal(t, errs.ErrCannotRetrieveUserInfo, err)
}
@@ -91,7 +90,7 @@ func TestNextcloudOAuth2Datasource_ParseUserInfo_Non200StatusCode(t *testing.T)
"data": {}
}
}`
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent), &http.Client{})
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent))
assert.Equal(t, errs.ErrCannotRetrieveUserInfo, err)
}
@@ -111,7 +110,7 @@ func TestNextcloudOAuth2Datasource_ParseUserInfo_EmptyID(t *testing.T) {
}
}
}`
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent), &http.Client{})
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent))
assert.Equal(t, errs.ErrCannotRetrieveUserInfo, err)
}
+2 -2
View File
@@ -10,10 +10,10 @@ import (
// OAuth2Provider defines the structure of OAuth 2.0 provider
type OAuth2Provider interface {
// GetOAuth2AuthUrl returns the authentication url of the provider
GetOAuth2AuthUrl(c core.Context, state string, challenge string) (string, error)
GetOAuth2AuthUrl(c core.Context, state string, opts ...oauth2.AuthCodeOption) (string, error)
// GetOAuth2Token returns the OAuth 2.0 token of the provider
GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error)
GetOAuth2Token(c core.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
// GetUserInfo returns the user info
GetUserInfo(c core.Context, oauth2Token *oauth2.Token) (*data.OAuth2UserInfo, error)
@@ -35,27 +35,25 @@ type OIDCProvider struct {
}
// GetOAuth2AuthUrl returns the authentication url of the OIDC provider
func (p *OIDCProvider) GetOAuth2AuthUrl(c core.Context, state string, challenge string) (string, error) {
func (p *OIDCProvider) GetOAuth2AuthUrl(c core.Context, state string, opts ...oauth2.AuthCodeOption) (string, error) {
oauth2Config, err := p.getOAuth2Config(c)
if err != nil {
return "", err
}
return oauth2Config.AuthCodeURL(state,
oauth2.SetAuthURLParam("code_challenge", challenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256")), nil
return oauth2Config.AuthCodeURL(state, opts...), nil
}
// GetOAuth2Token returns the OAuth 2.0 token of the OIDC provider
func (p *OIDCProvider) GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error) {
func (p *OIDCProvider) GetOAuth2Token(c core.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
oauth2Config, err := p.getOAuth2Config(c)
if err != nil {
return nil, err
}
return oauth2Config.Exchange(c, code, oauth2.SetAuthURLParam("code_verifier", verifier))
return oauth2Config.Exchange(c, code, opts...)
}
// GetUserInfo returns the user info by the OIDC provider
+2
View File
@@ -368,6 +368,7 @@ type Config struct {
ForgetPasswordRequireVerifyEmail bool
OAuth2ClientID string
OAuth2ClientSecret string
OAuth2UsePKCE bool
OAuth2UserIdentifier string
OAuth2AutoRegister bool
OAuth2Provider string
@@ -989,6 +990,7 @@ func loadAuthConfiguration(config *Config, configFile *ini.File, sectionName str
config.ForgetPasswordRequireVerifyEmail = getConfigItemBoolValue(configFile, sectionName, "forget_password_require_email_verify", false)
config.OAuth2ClientID = getConfigItemStringValue(configFile, sectionName, "oauth2_client_id")
config.OAuth2ClientSecret = getConfigItemStringValue(configFile, sectionName, "oauth2_client_secret")
config.OAuth2UsePKCE = getConfigItemBoolValue(configFile, sectionName, "oauth2_use_pkce", false)
oauth2UserIdentifier := getConfigItemStringValue(configFile, sectionName, "oauth2_user_identifier")