mirror of
https://github.com/mayswind/ezbookkeeping.git
synced 2026-05-19 09:14:27 +08:00
support OIDC authentication (#242)
This commit is contained in:
@@ -0,0 +1,104 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/log"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
||||
)
|
||||
|
||||
// CommonOAuth2Provider represents common OAuth 2.0 provider
|
||||
type CommonOAuth2Provider struct {
|
||||
provider.OAuth2Provider
|
||||
oauth2Config *oauth2.Config
|
||||
dataSource CommonOAuth2DataSource
|
||||
}
|
||||
|
||||
// CommonOAuth2DataSource defines the structure of OAuth 2.0 data source
|
||||
type CommonOAuth2DataSource interface {
|
||||
// GetAuthUrl returns the authentication url of the data source
|
||||
GetAuthUrl() string
|
||||
|
||||
// GetTokenUrl returns the token url of the data source
|
||||
GetTokenUrl() string
|
||||
|
||||
// GetUserInfoRequest returns the user info request of the data source
|
||||
GetUserInfoRequest() (*http.Request, error)
|
||||
|
||||
// GetScopes returns the scopes required by the data source
|
||||
GetScopes() []string
|
||||
|
||||
// ParseUserInfo returns the user info by parsing the response body
|
||||
ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error)
|
||||
}
|
||||
|
||||
// GetOAuth2AuthUrl returns the authentication url of the common OAuth 2.0 provider
|
||||
func (p *CommonOAuth2Provider) GetOAuth2AuthUrl(c core.Context, state string, challenge string) (string, error) {
|
||||
return p.oauth2Config.AuthCodeURL(state), nil
|
||||
}
|
||||
|
||||
// GetOAuth2Token returns the OAuth 2.0 token of the common OAuth 2.0 provider
|
||||
func (p *CommonOAuth2Provider) GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error) {
|
||||
return p.oauth2Config.Exchange(c, code)
|
||||
}
|
||||
|
||||
// GetUserInfo returns the user info by the common OAuth 2.0 provider
|
||||
func (p *CommonOAuth2Provider) GetUserInfo(c core.Context, oauth2Token *oauth2.Token) (*data.OAuth2UserInfo, error) {
|
||||
req, err := p.dataSource.GetUserInfoRequest()
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[common_oauth2_provider.GetUserInfo] failed to get user info request, because %s", err.Error())
|
||||
return nil, errs.ErrFailedToRequestRemoteApi
|
||||
}
|
||||
|
||||
oauth2Client := oauth2.NewClient(c, oauth2.StaticTokenSource(oauth2Token))
|
||||
resp, err := oauth2Client.Do(req)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[common_oauth2_provider.GetUserInfo] failed to get user info response, because %s", err.Error())
|
||||
return nil, errs.ErrFailedToRequestRemoteApi
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
|
||||
log.Debugf(c, "[common_oauth2_provider.GetUserInfo] response is %s", body)
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
log.Errorf(c, "[common_oauth2_provider.GetUserInfo] failed to get user info response, because response code is %d", resp.StatusCode)
|
||||
return nil, errs.ErrFailedToRequestRemoteApi
|
||||
}
|
||||
|
||||
return p.dataSource.ParseUserInfo(c, body)
|
||||
}
|
||||
|
||||
// GetDataSource returns the data source of the common OAuth 2.0 provider
|
||||
func (p *CommonOAuth2Provider) GetDataSource() CommonOAuth2DataSource {
|
||||
return p.dataSource
|
||||
}
|
||||
|
||||
// NewCommonOAuth2Provider returns a new common OAuth 2.0 provider
|
||||
func NewCommonOAuth2Provider(config *settings.Config, redirectUrl string, dataSource CommonOAuth2DataSource) *CommonOAuth2Provider {
|
||||
oauth2Config := &oauth2.Config{
|
||||
ClientID: config.OAuth2ClientID,
|
||||
ClientSecret: config.OAuth2ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: dataSource.GetAuthUrl(),
|
||||
TokenURL: dataSource.GetTokenUrl(),
|
||||
},
|
||||
RedirectURL: redirectUrl,
|
||||
Scopes: dataSource.GetScopes(),
|
||||
}
|
||||
|
||||
return &CommonOAuth2Provider{
|
||||
oauth2Config: oauth2Config,
|
||||
dataSource: dataSource,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package gitea
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/log"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
||||
)
|
||||
|
||||
type giteaUserInfoResponse struct {
|
||||
Login string `json:"login"`
|
||||
FullName string `json:"full_name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
// GiteaOAuth2DataSource represents Gitea OAuth 2.0 data source
|
||||
type GiteaOAuth2DataSource struct {
|
||||
common.CommonOAuth2DataSource
|
||||
baseUrl string
|
||||
}
|
||||
|
||||
// GetAuthUrl returns the authentication url of the Gitea data source
|
||||
func (s *GiteaOAuth2DataSource) GetAuthUrl() string {
|
||||
// Reference: https://docs.gitea.com/development/oauth2-provider
|
||||
return s.baseUrl + "login/oauth/authorize"
|
||||
}
|
||||
|
||||
// GetTokenUrl returns the token url of the Gitea data source
|
||||
func (s *GiteaOAuth2DataSource) GetTokenUrl() string {
|
||||
// Reference: https://docs.gitea.com/development/oauth2-provider
|
||||
return s.baseUrl + "login/oauth/access_token"
|
||||
}
|
||||
|
||||
// GetUserInfoRequest returns the user info request of the Gitea data source
|
||||
func (s *GiteaOAuth2DataSource) GetUserInfoRequest() (*http.Request, error) {
|
||||
// Reference: https://gitea.com/api/swagger#/user/userGetCurrent
|
||||
req, err := http.NewRequest("GET", s.baseUrl+"api/v1/user", nil)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// GetScopes returns the scopes required by the Gitea provider
|
||||
func (s *GiteaOAuth2DataSource) GetScopes() []string {
|
||||
return []string{"read:user"}
|
||||
}
|
||||
|
||||
// ParseUserInfo returns the user info by parsing the response body
|
||||
func (s *GiteaOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error) {
|
||||
userInfoResp := &giteaUserInfoResponse{}
|
||||
err := json.Unmarshal(body, &userInfoResp)
|
||||
|
||||
if err != nil {
|
||||
log.Warnf(c, "[gitea_oauth2_datasource.ParseUserInfo] failed to parse user profile response body, because %s", err.Error())
|
||||
return nil, errs.ErrCannotRetrieveUserInfo
|
||||
}
|
||||
|
||||
if userInfoResp.Login == "" {
|
||||
log.Warnf(c, "[gitea_oauth2_datasource.ParseUserInfo] invalid user profile response body")
|
||||
return nil, errs.ErrCannotRetrieveUserInfo
|
||||
}
|
||||
|
||||
return &data.OAuth2UserInfo{
|
||||
UserName: userInfoResp.Login,
|
||||
Email: userInfoResp.Email,
|
||||
NickName: userInfoResp.FullName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewGiteaOAuth2Provider creates a new Gitea OAuth 2.0 provider instance
|
||||
func NewGiteaOAuth2Provider(config *settings.Config, redirectUrl string) (provider.OAuth2Provider, error) {
|
||||
if len(config.OAuth2GiteaBaseUrl) < 1 {
|
||||
return nil, errs.ErrInvalidOAuth2Config
|
||||
}
|
||||
|
||||
baseUrl := config.OAuth2GiteaBaseUrl
|
||||
|
||||
if baseUrl[len(baseUrl)-1] != '/' {
|
||||
baseUrl += "/"
|
||||
}
|
||||
|
||||
return common.NewCommonOAuth2Provider(config, redirectUrl, &GiteaOAuth2DataSource{
|
||||
baseUrl: baseUrl,
|
||||
}), nil
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
package gitea
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
||||
)
|
||||
|
||||
func TestNewGiteaOAuth2Provider(t *testing.T) {
|
||||
provider, err := NewGiteaOAuth2Provider(&settings.Config{
|
||||
OAuth2GiteaBaseUrl: "https://example.com/",
|
||||
}, "")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "https://example.com/login/oauth/authorize", provider.(*common.CommonOAuth2Provider).GetDataSource().GetAuthUrl())
|
||||
assert.Equal(t, "https://example.com/login/oauth/access_token", provider.(*common.CommonOAuth2Provider).GetDataSource().GetTokenUrl())
|
||||
|
||||
provider, err = NewGiteaOAuth2Provider(&settings.Config{
|
||||
OAuth2GiteaBaseUrl: "https://example.com",
|
||||
}, "")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "https://example.com/login/oauth/authorize", provider.(*common.CommonOAuth2Provider).GetDataSource().GetAuthUrl())
|
||||
assert.Equal(t, "https://example.com/login/oauth/access_token", provider.(*common.CommonOAuth2Provider).GetDataSource().GetTokenUrl())
|
||||
|
||||
provider, err = NewGiteaOAuth2Provider(&settings.Config{}, "")
|
||||
assert.Equal(t, errs.ErrInvalidOAuth2Config, err)
|
||||
}
|
||||
|
||||
func TestGiteaOAuth2Datasource_GetUserInfoRequest(t *testing.T) {
|
||||
datasource := &GiteaOAuth2DataSource{baseUrl: "https://example.com/"}
|
||||
req, err := datasource.GetUserInfoRequest()
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "GET", req.Method)
|
||||
assert.Equal(t, "https://example.com/api/v1/user", req.URL.String())
|
||||
assert.Equal(t, "application/json", req.Header.Get("Accept"))
|
||||
}
|
||||
|
||||
func TestGiteaOAuth2Datasource_ParseUserInfo_Success(t *testing.T) {
|
||||
datasource := &GiteaOAuth2DataSource{}
|
||||
responseContent := `{
|
||||
"login": "user1",
|
||||
"full_name": "User",
|
||||
"email": "user1@example.com"
|
||||
}`
|
||||
info, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent))
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "user1", info.UserName)
|
||||
assert.Equal(t, "user1@example.com", info.Email)
|
||||
assert.Equal(t, "User", info.NickName)
|
||||
}
|
||||
|
||||
func TestGiteaOAuth2Datasource_ParseUserInfo_InvalidJson(t *testing.T) {
|
||||
datasource := &GiteaOAuth2DataSource{}
|
||||
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte("invalid"))
|
||||
|
||||
assert.Equal(t, errs.ErrCannotRetrieveUserInfo, err)
|
||||
}
|
||||
|
||||
func TestGiteaOAuth2Datasource_ParseUserInfo_EmptyLogin(t *testing.T) {
|
||||
datasource := &GiteaOAuth2DataSource{}
|
||||
responseContent := `{"login": ""}`
|
||||
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent))
|
||||
|
||||
assert.Equal(t, errs.ErrCannotRetrieveUserInfo, err)
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package github
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/log"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
||||
)
|
||||
|
||||
type githubUserProfileResponse struct {
|
||||
Login string `json:"login"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
// GithubOAuth2DataSource represents Github OAuth 2.0 data source
|
||||
type GithubOAuth2DataSource struct {
|
||||
common.CommonOAuth2DataSource
|
||||
}
|
||||
|
||||
// GetAuthUrl returns the authentication url of the Github data source
|
||||
func (s *GithubOAuth2DataSource) GetAuthUrl() string {
|
||||
// Reference: https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/authorizing-oauth-apps
|
||||
return "https://github.com/login/oauth/authorize"
|
||||
}
|
||||
|
||||
// GetTokenUrl returns the token url of the Github data source
|
||||
func (s *GithubOAuth2DataSource) GetTokenUrl() string {
|
||||
// Reference: https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/authorizing-oauth-apps
|
||||
return "https://github.com/login/oauth/access_token"
|
||||
}
|
||||
|
||||
// GetUserInfoRequest returns the user info request of the Github data source
|
||||
func (s *GithubOAuth2DataSource) GetUserInfoRequest() (*http.Request, error) {
|
||||
// Reference: https://docs.github.com/en/rest/users/users
|
||||
req, err := http.NewRequest("GET", "https://api.github.com/user", nil)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/vnd.github+json")
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// GetScopes returns the scopes required by the Github provider
|
||||
func (p *GithubOAuth2DataSource) GetScopes() []string {
|
||||
return []string{"read:user"}
|
||||
}
|
||||
|
||||
// ParseUserInfo returns the user info by parsing the response body
|
||||
func (p *GithubOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error) {
|
||||
userInfoResp := &githubUserProfileResponse{}
|
||||
err := json.Unmarshal(body, &userInfoResp)
|
||||
|
||||
if err != nil {
|
||||
log.Warnf(c, "[github_oauth2_datasource.ParseUserInfo] failed to parse user profile response body, because %s", err.Error())
|
||||
return nil, errs.ErrCannotRetrieveUserInfo
|
||||
}
|
||||
|
||||
if userInfoResp.Login == "" {
|
||||
log.Warnf(c, "[github_oauth2_datasource.ParseUserInfo] invalid user profile response body")
|
||||
return nil, errs.ErrCannotRetrieveUserInfo
|
||||
}
|
||||
|
||||
return &data.OAuth2UserInfo{
|
||||
UserName: userInfoResp.Login,
|
||||
Email: userInfoResp.Email,
|
||||
NickName: userInfoResp.Name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewGithubOAuth2Provider creates a new Github OAuth 2.0 provider instance
|
||||
func NewGithubOAuth2Provider(config *settings.Config, redirectUrl string) (provider.OAuth2Provider, error) {
|
||||
return common.NewCommonOAuth2Provider(config, redirectUrl, &GithubOAuth2DataSource{}), nil
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package github
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
)
|
||||
|
||||
func TestGithubOAuth2Datasource_GetUserInfoRequest(t *testing.T) {
|
||||
datasource := &GithubOAuth2DataSource{}
|
||||
req, err := datasource.GetUserInfoRequest()
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "GET", req.Method)
|
||||
assert.Equal(t, "https://api.github.com/user", req.URL.String())
|
||||
assert.Equal(t, "application/vnd.github+json", req.Header.Get("Accept"))
|
||||
}
|
||||
|
||||
func TestGithubOAuth2Datasource_ParseUserInfo_Success(t *testing.T) {
|
||||
datasource := &GithubOAuth2DataSource{}
|
||||
responseContent := `{
|
||||
"login": "octocat",
|
||||
"id": 1,
|
||||
"node_id": "MDQ6VXNlcjE=",
|
||||
"avatar_url": "https://github.com/images/error/octocat_happy.gif",
|
||||
"gravatar_id": "",
|
||||
"url": "https://api.github.com/users/octocat",
|
||||
"html_url": "https://github.com/octocat",
|
||||
"followers_url": "https://api.github.com/users/octocat/followers",
|
||||
"following_url": "https://api.github.com/users/octocat/following{/other_user}",
|
||||
"gists_url": "https://api.github.com/users/octocat/gists{/gist_id}",
|
||||
"starred_url": "https://api.github.com/users/octocat/starred{/owner}{/repo}",
|
||||
"subscriptions_url": "https://api.github.com/users/octocat/subscriptions",
|
||||
"organizations_url": "https://api.github.com/users/octocat/orgs",
|
||||
"repos_url": "https://api.github.com/users/octocat/repos",
|
||||
"events_url": "https://api.github.com/users/octocat/events{/privacy}",
|
||||
"received_events_url": "https://api.github.com/users/octocat/received_events",
|
||||
"type": "User",
|
||||
"site_admin": false,
|
||||
"name": "monalisa octocat",
|
||||
"company": "GitHub",
|
||||
"blog": "https://github.com/blog",
|
||||
"location": "San Francisco",
|
||||
"email": "octocat@github.com",
|
||||
"hireable": false,
|
||||
"bio": "There once was...",
|
||||
"twitter_username": "monatheoctocat",
|
||||
"public_repos": 2,
|
||||
"public_gists": 1,
|
||||
"followers": 20,
|
||||
"following": 0,
|
||||
"created_at": "2008-01-14T04:33:35Z",
|
||||
"updated_at": "2008-01-14T04:33:35Z",
|
||||
"private_gists": 81,
|
||||
"total_private_repos": 100,
|
||||
"owned_private_repos": 100,
|
||||
"disk_usage": 10000,
|
||||
"collaborators": 8,
|
||||
"two_factor_authentication": true,
|
||||
"plan": {
|
||||
"name": "Medium",
|
||||
"space": 400,
|
||||
"private_repos": 20,
|
||||
"collaborators": 0
|
||||
}
|
||||
}`
|
||||
info, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent))
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "octocat", info.UserName)
|
||||
assert.Equal(t, "octocat@github.com", info.Email)
|
||||
assert.Equal(t, "monalisa octocat", info.NickName)
|
||||
}
|
||||
|
||||
func TestGithubOAuth2Datasource_ParseUserInfo_InvalidJson(t *testing.T) {
|
||||
datasource := &GithubOAuth2DataSource{}
|
||||
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte("invalid"))
|
||||
|
||||
assert.Equal(t, errs.ErrCannotRetrieveUserInfo, err)
|
||||
}
|
||||
|
||||
func TestGithubOAuth2Datasource_ParseUserInfo_EmptyLogin(t *testing.T) {
|
||||
datasource := &GithubOAuth2DataSource{}
|
||||
responseContent := `{"login": ""}`
|
||||
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent))
|
||||
|
||||
assert.Equal(t, errs.ErrCannotRetrieveUserInfo, err)
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
package nextcloud
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/log"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
||||
)
|
||||
|
||||
type nextcloudUserInfoResponse struct {
|
||||
OCS *struct {
|
||||
Meta *struct {
|
||||
Status string `json:"status"`
|
||||
StatusCode int `json:"statuscode"`
|
||||
} `json:"meta"`
|
||||
Data *struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
DisplayName string `json:"display-name"`
|
||||
} `json:"data"`
|
||||
} `json:"ocs"`
|
||||
}
|
||||
|
||||
// NextcloudOAuth2DataSource represents Nextcloud OAuth 2.0 data source
|
||||
type NextcloudOAuth2DataSource struct {
|
||||
common.CommonOAuth2DataSource
|
||||
baseUrl string
|
||||
}
|
||||
|
||||
// GetAuthUrl returns the authentication url of the Nextcloud data source
|
||||
func (s *NextcloudOAuth2DataSource) GetAuthUrl() string {
|
||||
// Reference: https://docs.nextcloud.com/server/stable/developer_manual/_static/openapi.html#/operations/oauth2-login_redirector-authorize
|
||||
return s.baseUrl + "apps/oauth2/authorize"
|
||||
}
|
||||
|
||||
// GetTokenUrl returns the token url of the Nextcloud data source
|
||||
func (s *NextcloudOAuth2DataSource) GetTokenUrl() string {
|
||||
// Reference: https://docs.nextcloud.com/server/stable/developer_manual/_static/openapi.html#/operations/oauth2-oauth_api-get-token
|
||||
return s.baseUrl + "apps/oauth2/api/v1/token"
|
||||
}
|
||||
|
||||
// GetUserInfoRequest returns the user info request of the Nextcloud data source
|
||||
func (s *NextcloudOAuth2DataSource) GetUserInfoRequest() (*http.Request, error) {
|
||||
// Reference: https://docs.nextcloud.com/server/stable/developer_manual/_static/openapi.html#/operations/provisioning_api-users-get-current-user
|
||||
req, err := http.NewRequest("GET", s.baseUrl+"ocs/v2.php/cloud/user", nil)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("OCS-APIRequest", "true")
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// GetScopes returns the scopes required by the Nextcloud provider
|
||||
func (s *NextcloudOAuth2DataSource) GetScopes() []string {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// ParseUserInfo returns the user info by parsing the response body
|
||||
func (s *NextcloudOAuth2DataSource) ParseUserInfo(c core.Context, body []byte) (*data.OAuth2UserInfo, error) {
|
||||
userInfoResp := &nextcloudUserInfoResponse{}
|
||||
err := json.Unmarshal(body, &userInfoResp)
|
||||
|
||||
if err != nil {
|
||||
log.Warnf(c, "[nextcloud_oauth2_datasource.ParseUserInfo] failed to parse user info response body, because %s", err.Error())
|
||||
return nil, errs.ErrCannotRetrieveUserInfo
|
||||
}
|
||||
|
||||
if userInfoResp.OCS == nil || userInfoResp.OCS.Meta == nil || userInfoResp.OCS.Data == nil {
|
||||
log.Warnf(c, "[nextcloud_oauth2_datasource.ParseUserInfo] invalid user info response body")
|
||||
return nil, errs.ErrCannotRetrieveUserInfo
|
||||
}
|
||||
|
||||
if userInfoResp.OCS.Meta.StatusCode != 200 {
|
||||
log.Warnf(c, "[nextcloud_oauth2_datasource.ParseUserInfo] user info response status code is %d", userInfoResp.OCS.Meta.StatusCode)
|
||||
return nil, errs.ErrCannotRetrieveUserInfo
|
||||
}
|
||||
|
||||
if userInfoResp.OCS.Data.ID == "" {
|
||||
log.Warnf(c, "[nextcloud_oauth2_datasource.ParseUserInfo] user info id is empty")
|
||||
return nil, errs.ErrCannotRetrieveUserInfo
|
||||
}
|
||||
|
||||
return &data.OAuth2UserInfo{
|
||||
UserName: userInfoResp.OCS.Data.ID,
|
||||
Email: userInfoResp.OCS.Data.Email,
|
||||
NickName: userInfoResp.OCS.Data.DisplayName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewNextcloudOAuth2Provider creates a new Nextcloud OAuth 2.0 provider instance
|
||||
func NewNextcloudOAuth2Provider(config *settings.Config, redirectUrl string) (provider.OAuth2Provider, error) {
|
||||
if len(config.OAuth2NextcloudBaseUrl) < 1 {
|
||||
return nil, errs.ErrInvalidOAuth2Config
|
||||
}
|
||||
|
||||
baseUrl := config.OAuth2NextcloudBaseUrl
|
||||
|
||||
if baseUrl[len(baseUrl)-1] != '/' {
|
||||
baseUrl += "/"
|
||||
}
|
||||
|
||||
return common.NewCommonOAuth2Provider(config, redirectUrl, &NextcloudOAuth2DataSource{
|
||||
baseUrl: baseUrl,
|
||||
}), nil
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
package nextcloud
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider/common"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
||||
)
|
||||
|
||||
func TestNewNextcloudOAuth2Provider(t *testing.T) {
|
||||
provider, err := NewNextcloudOAuth2Provider(&settings.Config{
|
||||
OAuth2NextcloudBaseUrl: "https://example.com/",
|
||||
}, "")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "https://example.com/apps/oauth2/authorize", provider.(*common.CommonOAuth2Provider).GetDataSource().GetAuthUrl())
|
||||
assert.Equal(t, "https://example.com/apps/oauth2/api/v1/token", provider.(*common.CommonOAuth2Provider).GetDataSource().GetTokenUrl())
|
||||
|
||||
provider, err = NewNextcloudOAuth2Provider(&settings.Config{
|
||||
OAuth2NextcloudBaseUrl: "https://example.com/index.php",
|
||||
}, "")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "https://example.com/index.php/apps/oauth2/authorize", provider.(*common.CommonOAuth2Provider).GetDataSource().GetAuthUrl())
|
||||
assert.Equal(t, "https://example.com/index.php/apps/oauth2/api/v1/token", provider.(*common.CommonOAuth2Provider).GetDataSource().GetTokenUrl())
|
||||
|
||||
provider, err = NewNextcloudOAuth2Provider(&settings.Config{}, "")
|
||||
assert.Equal(t, errs.ErrInvalidOAuth2Config, err)
|
||||
}
|
||||
|
||||
func TestNextcloudOAuth2Datasource_GetUserInfoRequest(t *testing.T) {
|
||||
datasource := &NextcloudOAuth2DataSource{baseUrl: "https://example.com/"}
|
||||
req, err := datasource.GetUserInfoRequest()
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "GET", req.Method)
|
||||
assert.Equal(t, "https://example.com/ocs/v2.php/cloud/user", req.URL.String())
|
||||
assert.Equal(t, "application/json", req.Header.Get("Accept"))
|
||||
assert.Equal(t, "true", req.Header.Get("OCS-APIRequest"))
|
||||
}
|
||||
|
||||
func TestNextcloudOAuth2Datasource_ParseUserInfo_Success(t *testing.T) {
|
||||
datasource := &NextcloudOAuth2DataSource{}
|
||||
responseContent := `{
|
||||
"ocs": {
|
||||
"meta": {
|
||||
"status": "ok",
|
||||
"statuscode": 200
|
||||
},
|
||||
"data": {
|
||||
"id": "user1",
|
||||
"email": "user1@example.com",
|
||||
"display-name": "User"
|
||||
}
|
||||
}
|
||||
}`
|
||||
info, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent))
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "user1", info.UserName)
|
||||
assert.Equal(t, "user1@example.com", info.Email)
|
||||
assert.Equal(t, "User", info.NickName)
|
||||
}
|
||||
|
||||
func TestNextcloudOAuth2Datasource_ParseUserInfo_InvalidJson(t *testing.T) {
|
||||
datasource := &NextcloudOAuth2DataSource{}
|
||||
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte("invalid"))
|
||||
|
||||
assert.Equal(t, errs.ErrCannotRetrieveUserInfo, err)
|
||||
}
|
||||
|
||||
func TestNextcloudOAuth2Datasource_ParseUserInfo_MissingFields(t *testing.T) {
|
||||
datasource := &NextcloudOAuth2DataSource{}
|
||||
responseContent := `{"ocs": {}}`
|
||||
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent))
|
||||
|
||||
assert.Equal(t, errs.ErrCannotRetrieveUserInfo, err)
|
||||
}
|
||||
|
||||
func TestNextcloudOAuth2Datasource_ParseUserInfo_Non200StatusCode(t *testing.T) {
|
||||
datasource := &NextcloudOAuth2DataSource{}
|
||||
responseContent := `{
|
||||
"ocs": {
|
||||
"meta": {
|
||||
"status": "error",
|
||||
"statuscode": 400
|
||||
},
|
||||
"data": {}
|
||||
}
|
||||
}`
|
||||
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent))
|
||||
|
||||
assert.Equal(t, errs.ErrCannotRetrieveUserInfo, err)
|
||||
}
|
||||
|
||||
func TestNextcloudOAuth2Datasource_ParseUserInfo_EmptyID(t *testing.T) {
|
||||
datasource := &NextcloudOAuth2DataSource{}
|
||||
responseContent := `{
|
||||
"ocs": {
|
||||
"meta": {
|
||||
"status": "ok",
|
||||
"statuscode": 200
|
||||
},
|
||||
"data": {
|
||||
"id": "",
|
||||
"email": "user1@example.com",
|
||||
"display-name": "User One"
|
||||
}
|
||||
}
|
||||
}`
|
||||
_, err := datasource.ParseUserInfo(core.NewNullContext(), []byte(responseContent))
|
||||
|
||||
assert.Equal(t, errs.ErrCannotRetrieveUserInfo, err)
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
)
|
||||
|
||||
// OAuth2Provider defines the structure of OAuth 2.0 provider
|
||||
type OAuth2Provider interface {
|
||||
// GetOAuth2AuthUrl returns the authentication url of the provider
|
||||
GetOAuth2AuthUrl(c core.Context, state string, challenge string) (string, error)
|
||||
|
||||
// GetOAuth2Token returns the OAuth 2.0 token of the provider
|
||||
GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error)
|
||||
|
||||
// GetUserInfo returns the user info
|
||||
GetUserInfo(c core.Context, oauth2Token *oauth2.Token) (*data.OAuth2UserInfo, error)
|
||||
}
|
||||
@@ -0,0 +1,173 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/data"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/auth/oauth2/provider"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/log"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
||||
)
|
||||
|
||||
// OIDCClaims represents OIDC claims
|
||||
type OIDCClaims struct {
|
||||
PreferredUserName string `json:"preferred_username"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
// OIDCProvider represents OIDC provider
|
||||
type OIDCProvider struct {
|
||||
provider.OAuth2Provider
|
||||
oidcBaseUrl string
|
||||
redirectUrl string
|
||||
oauth2ClientID string
|
||||
oauth2ClientSecret string
|
||||
oauth2Config *oauth2.Config
|
||||
oidcProvider *oidc.Provider
|
||||
oidcVerifier *oidc.IDTokenVerifier
|
||||
}
|
||||
|
||||
// GetOAuth2AuthUrl returns the authentication url of the OIDC provider
|
||||
func (p *OIDCProvider) GetOAuth2AuthUrl(c core.Context, state string, challenge string) (string, error) {
|
||||
oauth2Config, err := p.getOAuth2Config(c)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return oauth2Config.AuthCodeURL(state,
|
||||
oauth2.SetAuthURLParam("code_challenge", challenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256")), nil
|
||||
}
|
||||
|
||||
// GetOAuth2Token returns the OAuth 2.0 token of the OIDC provider
|
||||
func (p *OIDCProvider) GetOAuth2Token(c core.Context, code string, verifier string) (*oauth2.Token, error) {
|
||||
oauth2Config, err := p.getOAuth2Config(c)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return oauth2Config.Exchange(c, code, oauth2.SetAuthURLParam("code_verifier", verifier))
|
||||
}
|
||||
|
||||
// GetUserInfo returns the user info by the OIDC provider
|
||||
func (p *OIDCProvider) GetUserInfo(c core.Context, oauth2Token *oauth2.Token) (*data.OAuth2UserInfo, error) {
|
||||
_, err := p.getOAuth2Config(c)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
|
||||
if !ok {
|
||||
log.Errorf(c, "[oidc_provider.GetUserInfo] missing \"id_token\" field in oauth 2.0 token")
|
||||
return nil, errs.ErrInvalidOAuth2Token
|
||||
}
|
||||
|
||||
idToken, err := p.oidcVerifier.Verify(c, rawIDToken)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[oidc_provider.GetUserInfo] failed to verify \"id_token\" field in oauth 2.0 token, because %s", err.Error())
|
||||
return nil, errs.ErrInvalidOAuth2Token
|
||||
}
|
||||
|
||||
var claims OIDCClaims
|
||||
err = idToken.Claims(&claims)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[oidc_provider.GetUserInfo] failed to parse claims in oauth 2.0 token, because %s", err.Error())
|
||||
return nil, errs.ErrInvalidOAuth2Token
|
||||
}
|
||||
|
||||
userName := claims.PreferredUserName
|
||||
email := claims.Email
|
||||
nickName := claims.Name
|
||||
|
||||
if userName == "" || email == "" || nickName == "" {
|
||||
userInfo, err := p.oidcProvider.UserInfo(c, oauth2.StaticTokenSource(oauth2Token))
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[oidc_provider.GetUserInfo] failed to get user info, because %s", err.Error())
|
||||
return nil, errs.ErrCannotRetrieveUserInfo
|
||||
}
|
||||
|
||||
err = userInfo.Claims(&claims)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[oidc_provider.GetUserInfo] failed to parse user info, because %s", err.Error())
|
||||
return nil, errs.ErrCannotRetrieveUserInfo
|
||||
}
|
||||
|
||||
if userName == "" {
|
||||
userName = claims.PreferredUserName
|
||||
}
|
||||
|
||||
if email == "" {
|
||||
email = claims.Email
|
||||
}
|
||||
|
||||
if nickName == "" {
|
||||
nickName = claims.Name
|
||||
}
|
||||
}
|
||||
|
||||
return &data.OAuth2UserInfo{
|
||||
UserName: userName,
|
||||
Email: email,
|
||||
NickName: nickName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) getOAuth2Config(c core.Context) (*oauth2.Config, error) {
|
||||
if p.oauth2Config != nil {
|
||||
return p.oauth2Config, nil
|
||||
}
|
||||
|
||||
oidcProvider, err := oidc.NewProvider(c, p.oidcBaseUrl)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[oidc_provider.getOAuth2Config] failed to create oidc provider, because %s", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oidcVerifier := oidcProvider.Verifier(&oidc.Config{ClientID: p.oauth2ClientID})
|
||||
|
||||
oauth2Config := &oauth2.Config{
|
||||
ClientID: p.oauth2ClientID,
|
||||
ClientSecret: p.oauth2ClientSecret,
|
||||
Endpoint: oidcProvider.Endpoint(),
|
||||
RedirectURL: p.redirectUrl,
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
|
||||
p.oauth2Config = oauth2Config
|
||||
p.oidcProvider = oidcProvider
|
||||
p.oidcVerifier = oidcVerifier
|
||||
return oauth2Config, nil
|
||||
}
|
||||
|
||||
// NewOIDCProvider returns a new OIDC provider
|
||||
func NewOIDCProvider(config *settings.Config, redirectUrl string) (*OIDCProvider, error) {
|
||||
if len(config.OAuth2OIDCProviderBaseUrl) < 1 {
|
||||
return nil, errs.ErrInvalidOAuth2Config
|
||||
}
|
||||
|
||||
baseUrl := strings.TrimSuffix(config.OAuth2OIDCProviderBaseUrl, "/")
|
||||
|
||||
return &OIDCProvider{
|
||||
oidcBaseUrl: baseUrl,
|
||||
redirectUrl: redirectUrl,
|
||||
oauth2ClientID: config.OAuth2ClientID,
|
||||
oauth2ClientSecret: config.OAuth2ClientSecret,
|
||||
oauth2Config: nil,
|
||||
}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user