diff --git a/pkg/api/accounts.go b/pkg/api/accounts.go index 6644b0b9..4832fa51 100644 --- a/pkg/api/accounts.go +++ b/pkg/api/accounts.go @@ -34,7 +34,7 @@ func (a *AccountsApi) AccountListHandler(c *core.Context) (interface{}, *errs.Er } uid := c.GetCurrentUid() - accounts, err := a.accounts.GetAllAccountsByUid(uid) + accounts, err := a.accounts.GetAllAccountsByUid(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[accounts.AccountListHandler] failed to get all accounts for user \"uid:%d\", because %s", uid, err.Error()) @@ -94,7 +94,7 @@ func (a *AccountsApi) AccountGetHandler(c *core.Context) (interface{}, *errs.Err } uid := c.GetCurrentUid() - accountAndSubAccounts, err := a.accounts.GetAccountAndSubAccountsByAccountId(uid, accountGetReq.Id) + accountAndSubAccounts, err := a.accounts.GetAccountAndSubAccountsByAccountId(c, uid, accountGetReq.Id) if err != nil { log.ErrorfWithRequestId(c, "[accounts.AccountGetHandler] failed to get account \"id:%d\" for user \"uid:%d\", because %s", accountGetReq.Id, uid, err.Error()) @@ -193,7 +193,7 @@ func (a *AccountsApi) AccountCreateHandler(c *core.Context) (interface{}, *errs. } uid := c.GetCurrentUid() - maxOrderId, err := a.accounts.GetMaxDisplayOrder(uid, accountCreateReq.Category) + maxOrderId, err := a.accounts.GetMaxDisplayOrder(c, uid, accountCreateReq.Category) if err != nil { log.ErrorfWithRequestId(c, "[accounts.AccountCreateHandler] failed to get max display order for user \"uid:%d\", because %s", uid, err.Error()) @@ -203,7 +203,7 @@ func (a *AccountsApi) AccountCreateHandler(c *core.Context) (interface{}, *errs. mainAccount := a.createNewAccountModel(uid, &accountCreateReq, maxOrderId+1) childrenAccounts := a.createSubAccountModels(uid, &accountCreateReq) - err = a.accounts.CreateAccounts(mainAccount, childrenAccounts, utcOffset) + err = a.accounts.CreateAccounts(c, mainAccount, childrenAccounts, utcOffset) if err != nil { log.ErrorfWithRequestId(c, "[accounts.AccountCreateHandler] failed to create account \"id:%d\" for user \"uid:%d\", because %s", mainAccount.AccountId, uid, err.Error()) @@ -236,7 +236,7 @@ func (a *AccountsApi) AccountModifyHandler(c *core.Context) (interface{}, *errs. } uid := c.GetCurrentUid() - accountAndSubAccounts, err := a.accounts.GetAccountAndSubAccountsByAccountId(uid, accountModifyReq.Id) + accountAndSubAccounts, err := a.accounts.GetAccountAndSubAccountsByAccountId(c, uid, accountModifyReq.Id) if err != nil { log.ErrorfWithRequestId(c, "[accounts.AccountModifyHandler] failed to get account \"id:%d\" for user \"uid:%d\", because %s", accountModifyReq.Id, uid, err.Error()) @@ -282,7 +282,7 @@ func (a *AccountsApi) AccountModifyHandler(c *core.Context) (interface{}, *errs. return nil, errs.ErrNothingWillBeUpdated } - err = a.accounts.ModifyAccounts(uid, toUpdateAccounts) + err = a.accounts.ModifyAccounts(c, uid, toUpdateAccounts) if err != nil { log.ErrorfWithRequestId(c, "[accounts.AccountModifyHandler] failed to update account \"id:%d\" for user \"uid:%d\", because %s", accountModifyReq.Id, uid, err.Error()) @@ -342,7 +342,7 @@ func (a *AccountsApi) AccountHideHandler(c *core.Context) (interface{}, *errs.Er } uid := c.GetCurrentUid() - err = a.accounts.HideAccount(uid, []int64{accountHideReq.Id}, accountHideReq.Hidden) + err = a.accounts.HideAccount(c, uid, []int64{accountHideReq.Id}, accountHideReq.Hidden) if err != nil { log.ErrorfWithRequestId(c, "[accounts.AccountHideHandler] failed to hide account \"id:%d\" for user \"uid:%d\", because %s", accountHideReq.Id, uid, err.Error()) @@ -377,7 +377,7 @@ func (a *AccountsApi) AccountMoveHandler(c *core.Context) (interface{}, *errs.Er accounts[i] = account } - err = a.accounts.ModifyAccountDisplayOrders(uid, accounts) + err = a.accounts.ModifyAccountDisplayOrders(c, uid, accounts) if err != nil { log.ErrorfWithRequestId(c, "[accounts.AccountMoveHandler] failed to move accounts for user \"uid:%d\", because %s", uid, err.Error()) @@ -399,7 +399,7 @@ func (a *AccountsApi) AccountDeleteHandler(c *core.Context) (interface{}, *errs. } uid := c.GetCurrentUid() - err = a.accounts.DeleteAccount(uid, accountDeleteReq.Id) + err = a.accounts.DeleteAccount(c, uid, accountDeleteReq.Id) if err != nil { log.ErrorfWithRequestId(c, "[accounts.AccountDeleteHandler] failed to delete account \"id:%d\" for user \"uid:%d\", because %s", accountDeleteReq.Id, uid, err.Error()) diff --git a/pkg/api/authorizations.go b/pkg/api/authorizations.go index f75a8df6..a562d95c 100644 --- a/pkg/api/authorizations.go +++ b/pkg/api/authorizations.go @@ -36,7 +36,7 @@ func (a *AuthorizationsApi) AuthorizeHandler(c *core.Context) (interface{}, *err return nil, errs.ErrLoginNameOrPasswordInvalid } - user, err := a.users.GetUserByUsernameOrEmailAndPassword(credential.LoginName, credential.Password) + user, err := a.users.GetUserByUsernameOrEmailAndPassword(c, credential.LoginName, credential.Password) if err != nil { log.WarnfWithRequestId(c, "[authorizations.AuthorizeHandler] login failed for user \"%s\", because %s", credential.LoginName, err.Error()) @@ -48,7 +48,7 @@ func (a *AuthorizationsApi) AuthorizeHandler(c *core.Context) (interface{}, *err return nil, errs.ErrUserIsDisabled } - err = a.users.UpdateUserLastLoginTime(user.Uid) + err = a.users.UpdateUserLastLoginTime(c, user.Uid) if err != nil { log.WarnfWithRequestId(c, "[authorizations.AuthorizeHandler] failed to update last login time for user \"uid:%d\", because %s", user.Uid, err.Error()) @@ -57,7 +57,7 @@ func (a *AuthorizationsApi) AuthorizeHandler(c *core.Context) (interface{}, *err twoFactorEnable := a.tokens.CurrentConfig().EnableTwoFactor if twoFactorEnable { - twoFactorEnable, err = a.twoFactorAuthorizations.ExistsTwoFactorSetting(user.Uid) + twoFactorEnable, err = a.twoFactorAuthorizations.ExistsTwoFactorSetting(c, user.Uid) if err != nil { log.ErrorfWithRequestId(c, "[authorizations.AuthorizeHandler] failed to check two factor setting for user \"uid:%d\", because %s", user.Uid, err.Error()) @@ -69,9 +69,9 @@ func (a *AuthorizationsApi) AuthorizeHandler(c *core.Context) (interface{}, *err var claims *core.UserTokenClaims if twoFactorEnable { - token, claims, err = a.tokens.CreateRequire2FAToken(user, c) + token, claims, err = a.tokens.CreateRequire2FAToken(c, user) } else { - token, claims, err = a.tokens.CreateToken(user, c) + token, claims, err = a.tokens.CreateToken(c, user) } if err != nil { @@ -102,7 +102,7 @@ func (a *AuthorizationsApi) TwoFactorAuthorizeHandler(c *core.Context) (interfac } uid := c.GetCurrentUid() - twoFactorSetting, err := a.twoFactorAuthorizations.GetUserTwoFactorSettingByUid(uid) + twoFactorSetting, err := a.twoFactorAuthorizations.GetUserTwoFactorSettingByUid(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[authorizations.TwoFactorAuthorizeHandler] failed to get two factor setting for user \"uid:%d\", because %s", uid, err.Error()) @@ -114,7 +114,7 @@ func (a *AuthorizationsApi) TwoFactorAuthorizeHandler(c *core.Context) (interfac return nil, errs.ErrPasscodeInvalid } - user, err := a.users.GetUserById(uid) + user, err := a.users.GetUserById(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[authorizations.TwoFactorAuthorizeHandler] failed to get user \"uid:%d\" info, because %s", user.Uid, err.Error()) @@ -122,13 +122,13 @@ func (a *AuthorizationsApi) TwoFactorAuthorizeHandler(c *core.Context) (interfac } oldTokenClaims := c.GetTokenClaims() - err = a.tokens.DeleteTokenByClaims(oldTokenClaims) + err = a.tokens.DeleteTokenByClaims(c, oldTokenClaims) if err != nil { log.WarnfWithRequestId(c, "[authorizations.TwoFactorAuthorizeHandler] failed to revoke temporary token \"utid:%s\" for user \"uid:%d\", because %s", oldTokenClaims.UserTokenId, user.Uid, err.Error()) } - token, claims, err := a.tokens.CreateToken(user, c) + token, claims, err := a.tokens.CreateToken(c, user) if err != nil { log.ErrorfWithRequestId(c, "[authorizations.TwoFactorAuthorizeHandler] failed to create token for user \"uid:%d\", because %s", user.Uid, err.Error()) @@ -155,7 +155,7 @@ func (a *AuthorizationsApi) TwoFactorAuthorizeByRecoveryCodeHandler(c *core.Cont } uid := c.GetCurrentUid() - enableTwoFactor, err := a.twoFactorAuthorizations.ExistsTwoFactorSetting(uid) + enableTwoFactor, err := a.twoFactorAuthorizations.ExistsTwoFactorSetting(c, uid) if err != nil { log.WarnfWithRequestId(c, "[authorizations.TwoFactorAuthorizeByRecoveryCodeHandler] failed to get two factor setting for user \"uid:%d\", because %s", uid, err.Error()) @@ -166,14 +166,14 @@ func (a *AuthorizationsApi) TwoFactorAuthorizeByRecoveryCodeHandler(c *core.Cont return nil, errs.ErrTwoFactorIsNotEnabled } - user, err := a.users.GetUserById(uid) + user, err := a.users.GetUserById(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[authorizations.TwoFactorAuthorizeByRecoveryCodeHandler] failed to get user \"uid:%d\" info, because %s", user.Uid, err.Error()) return nil, errs.ErrUserNotFound } - err = a.twoFactorAuthorizations.GetAndUseUserTwoFactorRecoveryCode(uid, credential.RecoveryCode, user.Salt) + err = a.twoFactorAuthorizations.GetAndUseUserTwoFactorRecoveryCode(c, uid, credential.RecoveryCode, user.Salt) if err != nil { log.WarnfWithRequestId(c, "[authorizations.TwoFactorAuthorizeByRecoveryCodeHandler] failed to get two factor recovery code for user \"uid:%d\", because %s", uid, err.Error()) @@ -181,13 +181,13 @@ func (a *AuthorizationsApi) TwoFactorAuthorizeByRecoveryCodeHandler(c *core.Cont } oldTokenClaims := c.GetTokenClaims() - err = a.tokens.DeleteTokenByClaims(oldTokenClaims) + err = a.tokens.DeleteTokenByClaims(c, oldTokenClaims) if err != nil { log.WarnfWithRequestId(c, "[authorizations.TwoFactorAuthorizeByRecoveryCodeHandler] failed to revoke temporary token \"utid:%s\" for user \"uid:%d\", because %s", oldTokenClaims.UserTokenId, user.Uid, err.Error()) } - token, claims, err := a.tokens.CreateToken(user, c) + token, claims, err := a.tokens.CreateToken(c, user) if err != nil { log.ErrorfWithRequestId(c, "[authorizations.TwoFactorAuthorizeByRecoveryCodeHandler] failed to create token for user \"uid:%d\", because %s", user.Uid, err.Error()) diff --git a/pkg/api/data_managements.go b/pkg/api/data_managements.go index f333a301..b8b42788 100644 --- a/pkg/api/data_managements.go +++ b/pkg/api/data_managements.go @@ -57,7 +57,7 @@ func (a *DataManagementsApi) ExportDataHandler(c *core.Context) ([]byte, string, } uid := c.GetCurrentUid() - user, err := a.users.GetUserById(uid) + user, err := a.users.GetUserById(c, uid) if err != nil { if !errs.IsCustomError(err) { @@ -67,28 +67,28 @@ func (a *DataManagementsApi) ExportDataHandler(c *core.Context) ([]byte, string, return nil, "", errs.ErrUserNotFound } - accounts, err := a.accounts.GetAllAccountsByUid(uid) + accounts, err := a.accounts.GetAllAccountsByUid(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[data_managements.ExportDataHandler] failed to get all accounts for user \"uid:%d\", because %s", uid, err.Error()) return nil, "", errs.ErrOperationFailed } - categories, err := a.categories.GetAllCategoriesByUid(uid, 0, -1) + categories, err := a.categories.GetAllCategoriesByUid(c, uid, 0, -1) if err != nil { log.ErrorfWithRequestId(c, "[data_managements.ExportDataHandler] failed to get categories for user \"uid:%d\", because %s", uid, err.Error()) return nil, "", errs.ErrOperationFailed } - tags, err := a.tags.GetAllTagsByUid(uid) + tags, err := a.tags.GetAllTagsByUid(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[data_managements.ExportDataHandler] failed to get tags for user \"uid:%d\", because %s", uid, err.Error()) return nil, "", errs.ErrOperationFailed } - tagIndexs, err := a.tags.GetAllTagIdsOfAllTransactions(uid) + tagIndexs, err := a.tags.GetAllTagIdsOfAllTransactions(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[data_managements.ExportDataHandler] failed to get tag index for user \"uid:%d\", because %s", uid, err.Error()) @@ -99,7 +99,7 @@ func (a *DataManagementsApi) ExportDataHandler(c *core.Context) ([]byte, string, categoryMap := a.categories.GetCategoryMapByList(categories) tagMap := a.tags.GetTagMapByList(tags) - allTransactions, err := a.transactions.GetAllTransactions(uid, pageCountForDataExport, true) + allTransactions, err := a.transactions.GetAllTransactions(c, uid, pageCountForDataExport, true) if err != nil { log.ErrorfWithRequestId(c, "[data_managements.ExportDataHandler] failed to all transactions user \"uid:%d\", because %s", uid, err.Error()) @@ -121,28 +121,28 @@ func (a *DataManagementsApi) ExportDataHandler(c *core.Context) ([]byte, string, // DataStatisticsHandler returns user data statistics func (a *DataManagementsApi) DataStatisticsHandler(c *core.Context) (interface{}, *errs.Error) { uid := c.GetCurrentUid() - totalAccountCount, err := a.accounts.GetTotalAccountCountByUid(uid) + totalAccountCount, err := a.accounts.GetTotalAccountCountByUid(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[data_managements.DataStatisticsHandler] failed to get total account count for user \"uid:%d\", because %s", uid, err.Error()) return nil, errs.ErrOperationFailed } - totalTransactionCategoryCount, err := a.categories.GetTotalCategoryCountByUid(uid) + totalTransactionCategoryCount, err := a.categories.GetTotalCategoryCountByUid(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[data_managements.DataStatisticsHandler] failed to get total transaction category count for user \"uid:%d\", because %s", uid, err.Error()) return nil, errs.ErrOperationFailed } - totalTransactionTagCount, err := a.tags.GetTotalTagCountByUid(uid) + totalTransactionTagCount, err := a.tags.GetTotalTagCountByUid(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[data_managements.DataStatisticsHandler] failed to get total transaction tag count for user \"uid:%d\", because %s", uid, err.Error()) return nil, errs.ErrOperationFailed } - totalTransactionCount, err := a.transactions.GetTotalTransactionCountByUid(uid) + totalTransactionCount, err := a.transactions.GetTotalTransactionCountByUid(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[data_managements.DataStatisticsHandler] failed to get total transaction count for user \"uid:%d\", because %s", uid, err.Error()) @@ -170,7 +170,7 @@ func (a *DataManagementsApi) ClearDataHandler(c *core.Context) (interface{}, *er } uid := c.GetCurrentUid() - user, err := a.users.GetUserById(uid) + user, err := a.users.GetUserById(c, uid) if err != nil { if !errs.IsCustomError(err) { @@ -184,21 +184,21 @@ func (a *DataManagementsApi) ClearDataHandler(c *core.Context) (interface{}, *er return nil, errs.ErrUserPasswordWrong } - err = a.transactions.DeleteAllTransactions(uid) + err = a.transactions.DeleteAllTransactions(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[data_managements.ClearDataHandler] failed to delete all transactions, because %s", err.Error()) return nil, errs.Or(err, errs.ErrOperationFailed) } - err = a.categories.DeleteAllCategories(uid) + err = a.categories.DeleteAllCategories(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[data_managements.ClearDataHandler] failed to delete all transaction categories, because %s", err.Error()) return nil, errs.Or(err, errs.ErrOperationFailed) } - err = a.tags.DeleteAllTags(uid) + err = a.tags.DeleteAllTags(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[data_managements.ClearDataHandler] failed to delete all transaction tags, because %s", err.Error()) diff --git a/pkg/api/forget_passwords.go b/pkg/api/forget_passwords.go index b93c33f2..2117b61f 100644 --- a/pkg/api/forget_passwords.go +++ b/pkg/api/forget_passwords.go @@ -36,7 +36,7 @@ func (a *ForgetPasswordsApi) UserForgetPasswordRequestHandler(c *core.Context) ( return nil, errs.ErrEmailIsEmptyOrInvalid } - user, err := a.users.GetUserByEmail(request.Email) + user, err := a.users.GetUserByEmail(c, request.Email) if err != nil { if !errs.IsCustomError(err) { @@ -51,14 +51,14 @@ func (a *ForgetPasswordsApi) UserForgetPasswordRequestHandler(c *core.Context) ( return nil, errs.ErrEmptyIsNotVerified } - token, _, err := a.tokens.CreatePasswordResetToken(user, c) + token, _, err := a.tokens.CreatePasswordResetToken(c, user) if err != nil { log.ErrorfWithRequestId(c, "[forget_passwords.UserForgetPasswordRequestHandler] failed to create token for user \"uid:%d\", because %s", user.Uid, err.Error()) return nil, errs.ErrTokenGenerating } - err = a.forgetPasswords.SendPasswordResetEmail(user, token, c.GetClientLocale()) + err = a.forgetPasswords.SendPasswordResetEmail(c, user, token, c.GetClientLocale()) if err != nil { log.WarnfWithRequestId(c, "[forget_passwords.UserForgetPasswordRequestHandler] cannot send email to \"%s\", because %s", user.Email, err.Error()) @@ -79,7 +79,7 @@ func (a *ForgetPasswordsApi) UserResetPasswordHandler(c *core.Context) (interfac } uid := c.GetCurrentUid() - user, err := a.users.GetUserById(uid) + user, err := a.users.GetUserById(c, uid) if err != nil { if !errs.IsCustomError(err) { @@ -96,7 +96,7 @@ func (a *ForgetPasswordsApi) UserResetPasswordHandler(c *core.Context) (interfac if a.users.IsPasswordEqualsUserPassword(request.Password, user) { oldTokenClaims := c.GetTokenClaims() - err = a.tokens.DeleteTokenByClaims(oldTokenClaims) + err = a.tokens.DeleteTokenByClaims(c, oldTokenClaims) if err != nil { log.WarnfWithRequestId(c, "[forget_passwords.UserResetPasswordHandler] failed to revoke password reset token \"utid:%s\" for user \"uid:%d\", because %s", oldTokenClaims.UserTokenId, user.Uid, err.Error()) @@ -111,7 +111,7 @@ func (a *ForgetPasswordsApi) UserResetPasswordHandler(c *core.Context) (interfac Password: request.Password, } - _, err = a.users.UpdateUser(userNew, false) + _, err = a.users.UpdateUser(c, userNew, false) if err != nil { log.ErrorfWithRequestId(c, "[forget_passwords.UserResetPasswordHandler] failed to update user \"uid:%d\", because %s", user.Uid, err.Error()) @@ -119,7 +119,7 @@ func (a *ForgetPasswordsApi) UserResetPasswordHandler(c *core.Context) (interfac } now := time.Now().Unix() - err = a.tokens.DeleteTokensBeforeTime(uid, now) + err = a.tokens.DeleteTokensBeforeTime(c, uid, now) if err == nil { log.InfofWithRequestId(c, "[forget_passwords.UserResetPasswordHandler] revoke old tokens before unix time \"%d\" for user \"uid:%d\"", now, user.Uid) diff --git a/pkg/api/tokens.go b/pkg/api/tokens.go index 73b3a2b0..fc294fa7 100644 --- a/pkg/api/tokens.go +++ b/pkg/api/tokens.go @@ -28,7 +28,7 @@ var ( // TokenListHandler returns available token list of current user func (a *TokensApi) TokenListHandler(c *core.Context) (interface{}, *errs.Error) { uid := c.GetCurrentUid() - tokens, err := a.tokens.GetAllUnexpiredNormalTokensByUid(uid) + tokens, err := a.tokens.GetAllUnexpiredNormalTokensByUid(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[tokens.TokenListHandler] failed to get all tokens for user \"uid:%d\", because %s", uid, err.Error()) @@ -82,7 +82,7 @@ func (a *TokensApi) TokenRevokeCurrentHandler(c *core.Context) (interface{}, *er } tokenId := a.tokens.GenerateTokenId(tokenRecord) - err = a.tokens.DeleteToken(tokenRecord) + err = a.tokens.DeleteToken(c, tokenRecord) if err != nil { log.ErrorfWithRequestId(c, "[token.TokenRevokeCurrentHandler] failed to revoke token \"id:%s\" for user \"uid:%d\", because %s", tokenId, claims.Uid, err.Error()) @@ -120,7 +120,7 @@ func (a *TokensApi) TokenRevokeHandler(c *core.Context) (interface{}, *errs.Erro return nil, errs.ErrInvalidTokenId } - err = a.tokens.DeleteToken(tokenRecord) + err = a.tokens.DeleteToken(c, tokenRecord) if err != nil { log.ErrorfWithRequestId(c, "[token.TokenRevokeHandler] failed to revoke token \"id:%s\" for user \"uid:%d\", because %s", tokenRevokeReq.TokenId, uid, err.Error()) @@ -134,7 +134,7 @@ func (a *TokensApi) TokenRevokeHandler(c *core.Context) (interface{}, *errs.Erro // TokenRevokeAllHandler revokes all tokens of current user except current token func (a *TokensApi) TokenRevokeAllHandler(c *core.Context) (interface{}, *errs.Error) { uid := c.GetCurrentUid() - tokens, err := a.tokens.GetAllTokensByUid(uid) + tokens, err := a.tokens.GetAllTokensByUid(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[tokens.TokenRevokeAllHandler] failed to get all tokens for user \"uid:%d\", because %s", uid, err.Error()) @@ -155,7 +155,7 @@ func (a *TokensApi) TokenRevokeAllHandler(c *core.Context) (interface{}, *errs.E tokens = append(tokens[:currentTokenIndex], tokens[currentTokenIndex+1:]...) - err = a.tokens.DeleteTokens(uid, tokens) + err = a.tokens.DeleteTokens(c, uid, tokens) if err != nil { log.ErrorfWithRequestId(c, "[token.TokenRevokeAllHandler] failed to revoke all tokens for user \"uid:%d\", because %s", uid, err.Error()) @@ -169,14 +169,14 @@ func (a *TokensApi) TokenRevokeAllHandler(c *core.Context) (interface{}, *errs.E // TokenRefreshHandler refresh current token of current user func (a *TokensApi) TokenRefreshHandler(c *core.Context) (interface{}, *errs.Error) { uid := c.GetCurrentUid() - user, err := a.users.GetUserById(uid) + user, err := a.users.GetUserById(c, uid) if err != nil { log.WarnfWithRequestId(c, "[token.TokenRefreshHandler] failed to get user \"uid:%d\" info, because %s", uid, err.Error()) return nil, errs.ErrUserNotFound } - token, claims, err := a.tokens.CreateToken(user, c) + token, claims, err := a.tokens.CreateToken(c, user) if err != nil { log.ErrorfWithRequestId(c, "[token.TokenRefreshHandler] failed to create token for user \"uid:%d\", because %s", user.Uid, err.Error()) diff --git a/pkg/api/transaction_categories.go b/pkg/api/transaction_categories.go index 1cd145b7..82c03a64 100644 --- a/pkg/api/transaction_categories.go +++ b/pkg/api/transaction_categories.go @@ -33,7 +33,7 @@ func (a *TransactionCategoriesApi) CategoryListHandler(c *core.Context) (interfa } uid := c.GetCurrentUid() - categories, err := a.categories.GetAllCategoriesByUid(uid, categoryListReq.Type, categoryListReq.ParentId) + categories, err := a.categories.GetAllCategoriesByUid(c, uid, categoryListReq.Type, categoryListReq.ParentId) if err != nil { log.ErrorfWithRequestId(c, "[transaction_categories.CategoryListHandler] failed to get categories for user \"uid:%d\", because %s", uid, err.Error()) @@ -54,7 +54,7 @@ func (a *TransactionCategoriesApi) CategoryGetHandler(c *core.Context) (interfac } uid := c.GetCurrentUid() - category, err := a.categories.GetCategoryByCategoryId(uid, categoryGetReq.Id) + category, err := a.categories.GetCategoryByCategoryId(c, uid, categoryGetReq.Id) if err != nil { log.ErrorfWithRequestId(c, "[transaction_categories.CategoryGetHandler] failed to get category \"id:%d\" for user \"uid:%d\", because %s", categoryGetReq.Id, uid, err.Error()) @@ -84,7 +84,7 @@ func (a *TransactionCategoriesApi) CategoryCreateHandler(c *core.Context) (inter uid := c.GetCurrentUid() if categoryCreateReq.ParentId > 0 { - parentCategory, err := a.categories.GetCategoryByCategoryId(uid, categoryCreateReq.ParentId) + parentCategory, err := a.categories.GetCategoryByCategoryId(c, uid, categoryCreateReq.ParentId) if err != nil { log.ErrorfWithRequestId(c, "[transaction_categories.CategoryCreateHandler] failed to get parent category \"id:%d\" for user \"uid:%d\", because %s", categoryCreateReq.ParentId, uid, err.Error()) @@ -105,9 +105,9 @@ func (a *TransactionCategoriesApi) CategoryCreateHandler(c *core.Context) (inter var maxOrderId int32 if categoryCreateReq.ParentId <= 0 { - maxOrderId, err = a.categories.GetMaxDisplayOrder(uid, categoryCreateReq.Type) + maxOrderId, err = a.categories.GetMaxDisplayOrder(c, uid, categoryCreateReq.Type) } else { - maxOrderId, err = a.categories.GetMaxSubCategoryDisplayOrder(uid, categoryCreateReq.Type, categoryCreateReq.ParentId) + maxOrderId, err = a.categories.GetMaxSubCategoryDisplayOrder(c, uid, categoryCreateReq.Type, categoryCreateReq.ParentId) } if err != nil { @@ -117,7 +117,7 @@ func (a *TransactionCategoriesApi) CategoryCreateHandler(c *core.Context) (inter category := a.createNewCategoryModel(uid, &categoryCreateReq, maxOrderId+1) - err = a.categories.CreateCategory(category) + err = a.categories.CreateCategory(c, category) if err != nil { log.ErrorfWithRequestId(c, "[transaction_categories.CategoryCreateHandler] failed to create category \"id:%d\" for user \"uid:%d\", because %s", category.CategoryId, uid, err.Error()) @@ -153,7 +153,7 @@ func (a *TransactionCategoriesApi) CategoryCreateBatchHandler(c *core.Context) ( var maxOrderId, exists = categoryTypeMaxOrderMap[categoryCreateReq.Type] if !exists { - maxOrderId, err = a.categories.GetMaxDisplayOrder(uid, categoryCreateReq.Type) + maxOrderId, err = a.categories.GetMaxDisplayOrder(c, uid, categoryCreateReq.Type) if err != nil { log.ErrorfWithRequestId(c, "[transaction_categories.CategoryCreateBatchHandler] failed to get max display order for user \"uid:%d\", because %s", uid, err.Error()) @@ -181,7 +181,7 @@ func (a *TransactionCategoriesApi) CategoryCreateBatchHandler(c *core.Context) ( totalCount++ } - categories, err := a.categories.CreateCategories(uid, categoriesMap) + categories, err := a.categories.CreateCategories(c, uid, categoriesMap) if err != nil { log.ErrorfWithRequestId(c, "[transaction_categories.CategoryCreateBatchHandler] failed to create categories for user \"uid:%d\", because %s", uid, err.Error()) @@ -204,7 +204,7 @@ func (a *TransactionCategoriesApi) CategoryModifyHandler(c *core.Context) (inter } uid := c.GetCurrentUid() - category, err := a.categories.GetCategoryByCategoryId(uid, categoryModifyReq.Id) + category, err := a.categories.GetCategoryByCategoryId(c, uid, categoryModifyReq.Id) if err != nil { log.ErrorfWithRequestId(c, "[transaction_categories.CategoryModifyHandler] failed to get category \"id:%d\" for user \"uid:%d\", because %s", categoryModifyReq.Id, uid, err.Error()) @@ -229,7 +229,7 @@ func (a *TransactionCategoriesApi) CategoryModifyHandler(c *core.Context) (inter return nil, errs.ErrNothingWillBeUpdated } - err = a.categories.ModifyCategory(newCategory) + err = a.categories.ModifyCategory(c, newCategory) if err != nil { log.ErrorfWithRequestId(c, "[transaction_categories.CategoryModifyHandler] failed to update category \"id:%d\" for user \"uid:%d\", because %s", categoryModifyReq.Id, uid, err.Error()) @@ -257,7 +257,7 @@ func (a *TransactionCategoriesApi) CategoryHideHandler(c *core.Context) (interfa } uid := c.GetCurrentUid() - err = a.categories.HideCategory(uid, []int64{categoryHideReq.Id}, categoryHideReq.Hidden) + err = a.categories.HideCategory(c, uid, []int64{categoryHideReq.Id}, categoryHideReq.Hidden) if err != nil { log.ErrorfWithRequestId(c, "[transaction_categories.CategoryHideHandler] failed to hide category \"id:%d\" for user \"uid:%d\", because %s", categoryHideReq.Id, uid, err.Error()) @@ -292,7 +292,7 @@ func (a *TransactionCategoriesApi) CategoryMoveHandler(c *core.Context) (interfa categories[i] = category } - err = a.categories.ModifyCategoryDisplayOrders(uid, categories) + err = a.categories.ModifyCategoryDisplayOrders(c, uid, categories) if err != nil { log.ErrorfWithRequestId(c, "[transaction_categories.CategoryMoveHandler] failed to move categories for user \"uid:%d\", because %s", uid, err.Error()) @@ -314,7 +314,7 @@ func (a *TransactionCategoriesApi) CategoryDeleteHandler(c *core.Context) (inter } uid := c.GetCurrentUid() - err = a.categories.DeleteCategory(uid, categoryDeleteReq.Id) + err = a.categories.DeleteCategory(c, uid, categoryDeleteReq.Id) if err != nil { log.ErrorfWithRequestId(c, "[transaction_categories.CategoryDeleteHandler] failed to delete category \"id:%d\" for user \"uid:%d\", because %s", categoryDeleteReq.Id, uid, err.Error()) diff --git a/pkg/api/transaction_tags.go b/pkg/api/transaction_tags.go index 10fa5b6d..ebe6c2d3 100644 --- a/pkg/api/transaction_tags.go +++ b/pkg/api/transaction_tags.go @@ -25,7 +25,7 @@ var ( // TagListHandler returns transaction tag list of current user func (a *TransactionTagsApi) TagListHandler(c *core.Context) (interface{}, *errs.Error) { uid := c.GetCurrentUid() - tags, err := a.tags.GetAllTagsByUid(uid) + tags, err := a.tags.GetAllTagsByUid(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[transaction_tags.TagListHandler] failed to get tags for user \"uid:%d\", because %s", uid, err.Error()) @@ -54,7 +54,7 @@ func (a *TransactionTagsApi) TagGetHandler(c *core.Context) (interface{}, *errs. } uid := c.GetCurrentUid() - tag, err := a.tags.GetTagByTagId(uid, tagGetReq.Id) + tag, err := a.tags.GetTagByTagId(c, uid, tagGetReq.Id) if err != nil { log.ErrorfWithRequestId(c, "[transaction_tags.TagGetHandler] failed to get tag \"id:%d\" for user \"uid:%d\", because %s", tagGetReq.Id, uid, err.Error()) @@ -78,7 +78,7 @@ func (a *TransactionTagsApi) TagCreateHandler(c *core.Context) (interface{}, *er uid := c.GetCurrentUid() - maxOrderId, err := a.tags.GetMaxDisplayOrder(uid) + maxOrderId, err := a.tags.GetMaxDisplayOrder(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[transaction_tags.TagCreateHandler] failed to get max display order for user \"uid:%d\", because %s", uid, err.Error()) @@ -87,7 +87,7 @@ func (a *TransactionTagsApi) TagCreateHandler(c *core.Context) (interface{}, *er tag := a.createNewTagModel(uid, &tagCreateReq, maxOrderId+1) - err = a.tags.CreateTag(tag) + err = a.tags.CreateTag(c, tag) if err != nil { log.ErrorfWithRequestId(c, "[transaction_tags.TagCreateHandler] failed to create tag \"id:%d\" for user \"uid:%d\", because %s", tag.TagId, uid, err.Error()) @@ -112,7 +112,7 @@ func (a *TransactionTagsApi) TagModifyHandler(c *core.Context) (interface{}, *er } uid := c.GetCurrentUid() - tag, err := a.tags.GetTagByTagId(uid, tagModifyReq.Id) + tag, err := a.tags.GetTagByTagId(c, uid, tagModifyReq.Id) if err != nil { log.ErrorfWithRequestId(c, "[transaction_tags.TagModifyHandler] failed to get tag \"id:%d\" for user \"uid:%d\", because %s", tagModifyReq.Id, uid, err.Error()) @@ -129,7 +129,7 @@ func (a *TransactionTagsApi) TagModifyHandler(c *core.Context) (interface{}, *er return nil, errs.ErrNothingWillBeUpdated } - err = a.tags.ModifyTag(newTag) + err = a.tags.ModifyTag(c, newTag) if err != nil { log.ErrorfWithRequestId(c, "[transaction_tags.TagModifyHandler] failed to update tag \"id:%d\" for user \"uid:%d\", because %s", tagModifyReq.Id, uid, err.Error()) @@ -155,7 +155,7 @@ func (a *TransactionTagsApi) TagHideHandler(c *core.Context) (interface{}, *errs } uid := c.GetCurrentUid() - err = a.tags.HideTag(uid, []int64{tagHideReq.Id}, tagHideReq.Hidden) + err = a.tags.HideTag(c, uid, []int64{tagHideReq.Id}, tagHideReq.Hidden) if err != nil { log.ErrorfWithRequestId(c, "[transaction_tags.CategoryHideHandler] failed to hide tag \"id:%d\" for user \"uid:%d\", because %s", tagHideReq.Id, uid, err.Error()) @@ -190,7 +190,7 @@ func (a *TransactionTagsApi) TagMoveHandler(c *core.Context) (interface{}, *errs tags[i] = tag } - err = a.tags.ModifyTagDisplayOrders(uid, tags) + err = a.tags.ModifyTagDisplayOrders(c, uid, tags) if err != nil { log.ErrorfWithRequestId(c, "[transaction_tags.CategoryMoveHandler] failed to move tags for user \"uid:%d\", because %s", uid, err.Error()) @@ -212,7 +212,7 @@ func (a *TransactionTagsApi) TagDeleteHandler(c *core.Context) (interface{}, *er } uid := c.GetCurrentUid() - err = a.tags.DeleteTag(uid, tagDeleteReq.Id) + err = a.tags.DeleteTag(c, uid, tagDeleteReq.Id) if err != nil { log.ErrorfWithRequestId(c, "[transaction_tags.TagDeleteHandler] failed to delete tag \"id:%d\" for user \"uid:%d\", because %s", tagDeleteReq.Id, uid, err.Error()) diff --git a/pkg/api/transactions.go b/pkg/api/transactions.go index b4b4cd15..e579fb2e 100644 --- a/pkg/api/transactions.go +++ b/pkg/api/transactions.go @@ -48,21 +48,21 @@ func (a *TransactionsApi) TransactionCountHandler(c *core.Context) (interface{}, uid := c.GetCurrentUid() - allAccountIds, err := a.getAccountOrSubAccountIds(transactionCountReq.AccountId, uid) + allAccountIds, err := a.getAccountOrSubAccountIds(c, transactionCountReq.AccountId, uid) if err != nil { log.WarnfWithRequestId(c, "[transactions.TransactionCountHandler] get account error, because %s", err.Error()) return nil, errs.Or(err, errs.ErrOperationFailed) } - allCategoryIds, err := a.getCategoryOrSubCategoryIds(transactionCountReq.CategoryId, uid) + allCategoryIds, err := a.getCategoryOrSubCategoryIds(c, transactionCountReq.CategoryId, uid) if err != nil { log.WarnfWithRequestId(c, "[transactions.TransactionCountHandler] get transaction category error, because %s", err.Error()) return nil, errs.Or(err, errs.ErrOperationFailed) } - totalCount, err := a.transactions.GetTransactionCount(uid, transactionCountReq.MaxTime, transactionCountReq.MinTime, transactionCountReq.Type, allCategoryIds, allAccountIds, transactionCountReq.Keyword) + totalCount, err := a.transactions.GetTransactionCount(c, uid, transactionCountReq.MaxTime, transactionCountReq.MinTime, transactionCountReq.Type, allCategoryIds, allAccountIds, transactionCountReq.Keyword) if err != nil { log.ErrorfWithRequestId(c, "[transactions.TransactionCountHandler] failed to get transaction count for user \"uid:%d\", because %s", uid, err.Error()) @@ -94,7 +94,7 @@ func (a *TransactionsApi) TransactionListHandler(c *core.Context) (interface{}, } uid := c.GetCurrentUid() - user, err := a.users.GetUserById(uid) + user, err := a.users.GetUserById(c, uid) if err != nil { if !errs.IsCustomError(err) { @@ -104,14 +104,14 @@ func (a *TransactionsApi) TransactionListHandler(c *core.Context) (interface{}, return nil, errs.ErrUserNotFound } - allAccountIds, err := a.getAccountOrSubAccountIds(transactionListReq.AccountId, uid) + allAccountIds, err := a.getAccountOrSubAccountIds(c, transactionListReq.AccountId, uid) if err != nil { log.WarnfWithRequestId(c, "[transactions.TransactionListHandler] get account error, because %s", err.Error()) return nil, errs.Or(err, errs.ErrOperationFailed) } - allCategoryIds, err := a.getCategoryOrSubCategoryIds(transactionListReq.CategoryId, uid) + allCategoryIds, err := a.getCategoryOrSubCategoryIds(c, transactionListReq.CategoryId, uid) if err != nil { log.WarnfWithRequestId(c, "[transactions.TransactionListHandler] get transaction category error, because %s", err.Error()) @@ -121,7 +121,7 @@ func (a *TransactionsApi) TransactionListHandler(c *core.Context) (interface{}, var totalCount int64 if transactionListReq.WithCount { - totalCount, err = a.transactions.GetTransactionCount(uid, transactionListReq.MaxTime, transactionListReq.MinTime, transactionListReq.Type, allCategoryIds, allAccountIds, transactionListReq.Keyword) + totalCount, err = a.transactions.GetTransactionCount(c, uid, transactionListReq.MaxTime, transactionListReq.MinTime, transactionListReq.Type, allCategoryIds, allAccountIds, transactionListReq.Keyword) if err != nil { log.ErrorfWithRequestId(c, "[transactions.TransactionListHandler] failed to get transaction count for user \"uid:%d\", because %s", uid, err.Error()) @@ -129,7 +129,7 @@ func (a *TransactionsApi) TransactionListHandler(c *core.Context) (interface{}, } } - transactions, err := a.transactions.GetTransactionsByMaxTime(uid, transactionListReq.MaxTime, transactionListReq.MinTime, transactionListReq.Type, allCategoryIds, allAccountIds, transactionListReq.Keyword, transactionListReq.Page, transactionListReq.Count, true, true) + transactions, err := a.transactions.GetTransactionsByMaxTime(c, uid, transactionListReq.MaxTime, transactionListReq.MinTime, transactionListReq.Type, allCategoryIds, allAccountIds, transactionListReq.Keyword, transactionListReq.Page, transactionListReq.Count, true, true) if err != nil { log.ErrorfWithRequestId(c, "[transactions.TransactionListHandler] failed to get transactions earlier than \"%d\" for user \"uid:%d\", because %s", transactionListReq.MaxTime, uid, err.Error()) @@ -185,7 +185,7 @@ func (a *TransactionsApi) TransactionMonthListHandler(c *core.Context) (interfac } uid := c.GetCurrentUid() - user, err := a.users.GetUserById(uid) + user, err := a.users.GetUserById(c, uid) if err != nil { if !errs.IsCustomError(err) { @@ -195,21 +195,21 @@ func (a *TransactionsApi) TransactionMonthListHandler(c *core.Context) (interfac return nil, errs.ErrUserNotFound } - allAccountIds, err := a.getAccountOrSubAccountIds(transactionListReq.AccountId, uid) + allAccountIds, err := a.getAccountOrSubAccountIds(c, transactionListReq.AccountId, uid) if err != nil { log.WarnfWithRequestId(c, "[transactions.TransactionMonthListHandler] get account error, because %s", err.Error()) return nil, errs.Or(err, errs.ErrOperationFailed) } - allCategoryIds, err := a.getCategoryOrSubCategoryIds(transactionListReq.CategoryId, uid) + allCategoryIds, err := a.getCategoryOrSubCategoryIds(c, transactionListReq.CategoryId, uid) if err != nil { log.WarnfWithRequestId(c, "[transactions.TransactionMonthListHandler] get transaction category error, because %s", err.Error()) return nil, errs.Or(err, errs.ErrOperationFailed) } - transactions, err := a.transactions.GetTransactionsInMonthByPage(uid, transactionListReq.Year, transactionListReq.Month, transactionListReq.Type, allCategoryIds, allAccountIds, transactionListReq.Keyword) + transactions, err := a.transactions.GetTransactionsInMonthByPage(c, uid, transactionListReq.Year, transactionListReq.Month, transactionListReq.Type, allCategoryIds, allAccountIds, transactionListReq.Keyword) if err != nil { log.ErrorfWithRequestId(c, "[transactions.TransactionMonthListHandler] failed to get transactions in month \"%d-%d\" for user \"uid:%d\", because %s", transactionListReq.Year, transactionListReq.Month, uid, err.Error()) @@ -242,7 +242,7 @@ func (a *TransactionsApi) TransactionStatisticsHandler(c *core.Context) (interfa } uid := c.GetCurrentUid() - totalAmounts, err := a.transactions.GetAccountsAndCategoriesTotalIncomeAndExpense(uid, statisticReq.StartTime, statisticReq.EndTime) + totalAmounts, err := a.transactions.GetAccountsAndCategoriesTotalIncomeAndExpense(c, uid, statisticReq.StartTime, statisticReq.EndTime) if err != nil { log.ErrorfWithRequestId(c, "[transactions.TransactionStatisticsHandler] failed to get accounts and categories total income and expense for user \"uid:%d\", because %s", uid, err.Error()) @@ -297,7 +297,7 @@ func (a *TransactionsApi) TransactionAmountsHandler(c *core.Context) (interface{ uid := c.GetCurrentUid() - accounts, err := a.accounts.GetAllAccountsByUid(uid) + accounts, err := a.accounts.GetAllAccountsByUid(c, uid) accountMap := a.accounts.GetAccountMapByList(accounts) if err != nil { @@ -310,7 +310,7 @@ func (a *TransactionsApi) TransactionAmountsHandler(c *core.Context) (interface{ for i := 0; i < len(requestItems); i++ { requestItem := requestItems[i] - incomeAmounts, expenseAmounts, err := a.transactions.GetAccountsTotalIncomeAndExpense(uid, requestItem.StartTime, requestItem.EndTime) + incomeAmounts, expenseAmounts, err := a.transactions.GetAccountsTotalIncomeAndExpense(c, uid, requestItem.StartTime, requestItem.EndTime) if err != nil { log.ErrorfWithRequestId(c, "[transactions.TransactionAmountsHandler] failed to get transaction amounts item for user \"uid:%d\", because %s", uid, err.Error()) @@ -407,7 +407,7 @@ func (a *TransactionsApi) TransactionMonthAmountsHandler(c *core.Context) (inter uid := c.GetCurrentUid() - accounts, err := a.accounts.GetAllAccountsByUid(uid) + accounts, err := a.accounts.GetAllAccountsByUid(c, uid) accountMap := a.accounts.GetAccountMapByList(accounts) if err != nil { @@ -415,7 +415,7 @@ func (a *TransactionsApi) TransactionMonthAmountsHandler(c *core.Context) (inter return nil, errs.Or(err, errs.ErrOperationFailed) } - totalAmounts, err := a.transactions.GetAccountsMonthTotalIncomeAndExpense(uid, startTime, endTime, pageCountForLoadTransactionAmounts) + totalAmounts, err := a.transactions.GetAccountsMonthTotalIncomeAndExpense(c, uid, startTime, endTime, pageCountForLoadTransactionAmounts) if err != nil { log.ErrorfWithRequestId(c, "[transactions.TransactionMonthAmountsHandler] failed to get accounts month total income and expense for user \"uid:%d\", because %s", uid, err.Error()) @@ -513,7 +513,7 @@ func (a *TransactionsApi) TransactionGetHandler(c *core.Context) (interface{}, * } uid := c.GetCurrentUid() - user, err := a.users.GetUserById(uid) + user, err := a.users.GetUserById(c, uid) if err != nil { if !errs.IsCustomError(err) { @@ -523,7 +523,7 @@ func (a *TransactionsApi) TransactionGetHandler(c *core.Context) (interface{}, * return nil, errs.ErrUserNotFound } - transaction, err := a.transactions.GetTransactionByTransactionId(uid, transactionGetReq.Id) + transaction, err := a.transactions.GetTransactionByTransactionId(c, uid, transactionGetReq.Id) if err != nil { log.ErrorfWithRequestId(c, "[transactions.TransactionGetHandler] failed to get transaction \"id:%d\" for user \"uid:%d\", because %s", transactionGetReq.Id, uid, err.Error()) @@ -542,7 +542,7 @@ func (a *TransactionsApi) TransactionGetHandler(c *core.Context) (interface{}, * accountIds = utils.ToUniqueInt64Slice(accountIds) } - accountMap, err := a.accounts.GetAccountsByAccountIds(uid, accountIds) + accountMap, err := a.accounts.GetAccountsByAccountIds(c, uid, accountIds) if _, exists := accountMap[transaction.AccountId]; !exists { log.WarnfWithRequestId(c, "[transactions.TransactionGetHandler] account of transaction \"id:%d\" does not exist for user \"uid:%d\"", transaction.TransactionId, uid) @@ -556,7 +556,7 @@ func (a *TransactionsApi) TransactionGetHandler(c *core.Context) (interface{}, * } } - allTransactionTagIds, err := a.transactionTags.GetAllTagIdsOfTransactions(uid, []int64{transaction.TransactionId}) + allTransactionTagIds, err := a.transactionTags.GetAllTagIdsOfTransactions(c, uid, []int64{transaction.TransactionId}) if err != nil { log.ErrorfWithRequestId(c, "[transactions.TransactionGetHandler] failed to get transactions tag ids for user \"uid:%d\", because %s", uid, err.Error()) @@ -567,7 +567,7 @@ func (a *TransactionsApi) TransactionGetHandler(c *core.Context) (interface{}, * var tagMap map[int64]*models.TransactionTag if !transactionGetReq.TrimCategory { - category, err = a.transactionCategories.GetCategoryByCategoryId(uid, transaction.CategoryId) + category, err = a.transactionCategories.GetCategoryByCategoryId(c, uid, transaction.CategoryId) if err != nil { log.ErrorfWithRequestId(c, "[transactions.TransactionGetHandler] failed to get transactions category for user \"uid:%d\", because %s", uid, err.Error()) @@ -576,7 +576,7 @@ func (a *TransactionsApi) TransactionGetHandler(c *core.Context) (interface{}, * } if !transactionGetReq.TrimTag { - tagMap, err = a.transactionTags.GetTagsByTagIds(uid, utils.ToUniqueInt64Slice(a.getTransactionTagIds(allTransactionTagIds))) + tagMap, err = a.transactionTags.GetTagsByTagIds(c, uid, utils.ToUniqueInt64Slice(a.getTransactionTagIds(allTransactionTagIds))) if err != nil { log.ErrorfWithRequestId(c, "[transactions.TransactionGetHandler] failed to get transactions tags for user \"uid:%d\", because %s", uid, err.Error()) @@ -652,7 +652,7 @@ func (a *TransactionsApi) TransactionCreateHandler(c *core.Context) (interface{} } uid := c.GetCurrentUid() - user, err := a.users.GetUserById(uid) + user, err := a.users.GetUserById(c, uid) if err != nil { if !errs.IsCustomError(err) { @@ -669,7 +669,7 @@ func (a *TransactionsApi) TransactionCreateHandler(c *core.Context) (interface{} return nil, errs.ErrCannotCreateTransactionWithThisTransactionTime } - err = a.transactions.CreateTransaction(transaction, tagIds) + err = a.transactions.CreateTransaction(c, transaction, tagIds) if err != nil { log.ErrorfWithRequestId(c, "[transactions.TransactionCreateHandler] failed to create transaction \"id:%d\" for user \"uid:%d\", because %s", transaction.TransactionId, uid, err.Error()) @@ -701,7 +701,7 @@ func (a *TransactionsApi) TransactionModifyHandler(c *core.Context) (interface{} } uid := c.GetCurrentUid() - user, err := a.users.GetUserById(uid) + user, err := a.users.GetUserById(c, uid) if err != nil { if !errs.IsCustomError(err) { @@ -711,7 +711,7 @@ func (a *TransactionsApi) TransactionModifyHandler(c *core.Context) (interface{} return nil, errs.ErrUserNotFound } - transaction, err := a.transactions.GetTransactionByTransactionId(uid, transactionModifyReq.Id) + transaction, err := a.transactions.GetTransactionByTransactionId(c, uid, transactionModifyReq.Id) if err != nil { log.ErrorfWithRequestId(c, "[transactions.TransactionModifyHandler] failed to get transaction \"id:%d\" for user \"uid:%d\", because %s", transactionModifyReq.Id, uid, err.Error()) @@ -723,7 +723,7 @@ func (a *TransactionsApi) TransactionModifyHandler(c *core.Context) (interface{} return nil, errs.ErrTransactionTypeInvalid } - allTransactionTagIds, err := a.transactionTags.GetAllTagIdsOfTransactions(uid, []int64{transaction.TransactionId}) + allTransactionTagIds, err := a.transactionTags.GetAllTagIdsOfTransactions(c, uid, []int64{transaction.TransactionId}) if err != nil { log.ErrorfWithRequestId(c, "[transactions.TransactionModifyHandler] failed to get transactions tag ids for user \"uid:%d\", because %s", uid, err.Error()) @@ -788,7 +788,7 @@ func (a *TransactionsApi) TransactionModifyHandler(c *core.Context) (interface{} return nil, errs.ErrCannotModifyTransactionWithThisTransactionTime } - err = a.transactions.ModifyTransaction(newTransaction, addTransactionTagIds, removeTransactionTagIds) + err = a.transactions.ModifyTransaction(c, newTransaction, addTransactionTagIds, removeTransactionTagIds) if err != nil { log.ErrorfWithRequestId(c, "[transactions.TransactionModifyHandler] failed to update transaction \"id:%d\" for user \"uid:%d\", because %s", transactionModifyReq.Id, uid, err.Error()) @@ -821,7 +821,7 @@ func (a *TransactionsApi) TransactionDeleteHandler(c *core.Context) (interface{} } uid := c.GetCurrentUid() - user, err := a.users.GetUserById(uid) + user, err := a.users.GetUserById(c, uid) if err != nil { if !errs.IsCustomError(err) { @@ -831,7 +831,7 @@ func (a *TransactionsApi) TransactionDeleteHandler(c *core.Context) (interface{} return nil, errs.ErrUserNotFound } - transaction, err := a.transactions.GetTransactionByTransactionId(uid, transactionDeleteReq.Id) + transaction, err := a.transactions.GetTransactionByTransactionId(c, uid, transactionDeleteReq.Id) if err != nil { log.ErrorfWithRequestId(c, "[transactions.TransactionDeleteHandler] failed to get transaction \"id:%d\" for user \"uid:%d\", because %s", transactionDeleteReq.Id, uid, err.Error()) @@ -849,7 +849,7 @@ func (a *TransactionsApi) TransactionDeleteHandler(c *core.Context) (interface{} return nil, errs.ErrCannotDeleteTransactionWithThisTransactionTime } - err = a.transactions.DeleteTransaction(uid, transactionDeleteReq.Id) + err = a.transactions.DeleteTransaction(c, uid, transactionDeleteReq.Id) if err != nil { log.ErrorfWithRequestId(c, "[transactions.TransactionDeleteHandler] failed to delete transaction \"id:%d\" for user \"uid:%d\", because %s", transactionDeleteReq.Id, uid, err.Error()) @@ -884,11 +884,11 @@ func (a *TransactionsApi) filterTransactions(c *core.Context, uid int64, transac return finalTransactions } -func (a *TransactionsApi) getAccountOrSubAccountIds(accountId int64, uid int64) ([]int64, error) { +func (a *TransactionsApi) getAccountOrSubAccountIds(c *core.Context, accountId int64, uid int64) ([]int64, error) { var allAccountIds []int64 if accountId > 0 { - allSubAccounts, err := a.accounts.GetSubAccountsByAccountId(uid, accountId) + allSubAccounts, err := a.accounts.GetSubAccountsByAccountId(c, uid, accountId) if err != nil { return nil, err @@ -906,11 +906,11 @@ func (a *TransactionsApi) getAccountOrSubAccountIds(accountId int64, uid int64) return allAccountIds, nil } -func (a *TransactionsApi) getCategoryOrSubCategoryIds(categoryId int64, uid int64) ([]int64, error) { +func (a *TransactionsApi) getCategoryOrSubCategoryIds(c *core.Context, categoryId int64, uid int64) ([]int64, error) { var allCategoryIds []int64 if categoryId > 0 { - allSubCategories, err := a.transactionCategories.GetAllCategoriesByUid(uid, 0, categoryId) + allSubCategories, err := a.transactionCategories.GetAllCategoriesByUid(c, uid, 0, categoryId) if err != nil { return nil, err @@ -977,7 +977,7 @@ func (a *TransactionsApi) getTransactionListResult(c *core.Context, user *models categoryIds = append(categoryIds, transactions[i].CategoryId) } - allAccounts, err := a.accounts.GetAccountsByAccountIds(uid, utils.ToUniqueInt64Slice(accountIds)) + allAccounts, err := a.accounts.GetAccountsByAccountIds(c, uid, utils.ToUniqueInt64Slice(accountIds)) if err != nil { log.ErrorfWithRequestId(c, "[transactions.getTransactionListResult] failed to get accounts for user \"uid:%d\", because %s", uid, err.Error()) @@ -986,7 +986,7 @@ func (a *TransactionsApi) getTransactionListResult(c *core.Context, user *models transactions = a.filterTransactions(c, uid, transactions, allAccounts) - allTransactionTagIds, err := a.transactionTags.GetAllTagIdsOfTransactions(uid, transactionIds) + allTransactionTagIds, err := a.transactionTags.GetAllTagIdsOfTransactions(c, uid, transactionIds) if err != nil { log.ErrorfWithRequestId(c, "[transactions.getTransactionListResult] failed to get transactions tag ids for user \"uid:%d\", because %s", uid, err.Error()) @@ -997,7 +997,7 @@ func (a *TransactionsApi) getTransactionListResult(c *core.Context, user *models var tagMap map[int64]*models.TransactionTag if !trimCategory { - categoryMap, err = a.transactionCategories.GetCategoriesByCategoryIds(uid, utils.ToUniqueInt64Slice(categoryIds)) + categoryMap, err = a.transactionCategories.GetCategoriesByCategoryIds(c, uid, utils.ToUniqueInt64Slice(categoryIds)) if err != nil { log.ErrorfWithRequestId(c, "[transactions.getTransactionListResult] failed to get transactions categories for user \"uid:%d\", because %s", uid, err.Error()) @@ -1006,7 +1006,7 @@ func (a *TransactionsApi) getTransactionListResult(c *core.Context, user *models } if !trimTag { - tagMap, err = a.transactionTags.GetTagsByTagIds(uid, utils.ToUniqueInt64Slice(a.getTransactionTagIds(allTransactionTagIds))) + tagMap, err = a.transactionTags.GetTagsByTagIds(c, uid, utils.ToUniqueInt64Slice(a.getTransactionTagIds(allTransactionTagIds))) if err != nil { log.ErrorfWithRequestId(c, "[transactions.getTransactionListResult] failed to get transactions tags for user \"uid:%d\", because %s", uid, err.Error()) diff --git a/pkg/api/twofactor_authorizations.go b/pkg/api/twofactor_authorizations.go index 9cdd71a9..3fa8e9c1 100644 --- a/pkg/api/twofactor_authorizations.go +++ b/pkg/api/twofactor_authorizations.go @@ -34,7 +34,7 @@ var ( // TwoFactorStatusHandler returns 2fa status of current user func (a *TwoFactorAuthorizationsApi) TwoFactorStatusHandler(c *core.Context) (interface{}, *errs.Error) { uid := c.GetCurrentUid() - twoFactorSetting, err := a.twoFactorAuthorizations.GetUserTwoFactorSettingByUid(uid) + twoFactorSetting, err := a.twoFactorAuthorizations.GetUserTwoFactorSettingByUid(c, uid) if err == errs.ErrTwoFactorIsNotEnabled { statusResp := &models.TwoFactorStatusResponse{ @@ -60,7 +60,7 @@ func (a *TwoFactorAuthorizationsApi) TwoFactorStatusHandler(c *core.Context) (in // TwoFactorEnableRequestHandler returns a new 2fa secret and qr code for current user to set 2fa and verify passcode next func (a *TwoFactorAuthorizationsApi) TwoFactorEnableRequestHandler(c *core.Context) (interface{}, *errs.Error) { uid := c.GetCurrentUid() - enabled, err := a.twoFactorAuthorizations.ExistsTwoFactorSetting(uid) + enabled, err := a.twoFactorAuthorizations.ExistsTwoFactorSetting(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[twofactor_authorizations.TwoFactorEnableRequestHandler] failed to check two factor setting, because %s", err.Error()) @@ -71,7 +71,7 @@ func (a *TwoFactorAuthorizationsApi) TwoFactorEnableRequestHandler(c *core.Conte return nil, errs.ErrTwoFactorAlreadyEnabled } - user, err := a.users.GetUserById(uid) + user, err := a.users.GetUserById(c, uid) if err != nil { if !errs.IsCustomError(err) { @@ -81,7 +81,7 @@ func (a *TwoFactorAuthorizationsApi) TwoFactorEnableRequestHandler(c *core.Conte return nil, errs.ErrUserNotFound } - key, err := a.twoFactorAuthorizations.GenerateTwoFactorSecret(user) + key, err := a.twoFactorAuthorizations.GenerateTwoFactorSecret(c, user) if err != nil { log.ErrorfWithRequestId(c, "[twofactor_authorizations.TwoFactorEnableRequestHandler] failed to generate two factor secret, because %s", err.Error()) @@ -120,7 +120,7 @@ func (a *TwoFactorAuthorizationsApi) TwoFactorEnableConfirmHandler(c *core.Conte } uid := c.GetCurrentUid() - exists, err := a.twoFactorAuthorizations.ExistsTwoFactorSetting(uid) + exists, err := a.twoFactorAuthorizations.ExistsTwoFactorSetting(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[twofactor_authorizations.TwoFactorEnableConfirmHandler] failed to check two factor setting, because %s", err.Error()) @@ -131,7 +131,7 @@ func (a *TwoFactorAuthorizationsApi) TwoFactorEnableConfirmHandler(c *core.Conte return nil, errs.ErrTwoFactorAlreadyEnabled } - user, err := a.users.GetUserById(uid) + user, err := a.users.GetUserById(c, uid) if err != nil { if !errs.IsCustomError(err) { @@ -158,14 +158,14 @@ func (a *TwoFactorAuthorizationsApi) TwoFactorEnableConfirmHandler(c *core.Conte return nil, errs.Or(err, errs.ErrOperationFailed) } - err = a.twoFactorAuthorizations.CreateTwoFactorRecoveryCodes(uid, recoveryCodes, user.Salt) + err = a.twoFactorAuthorizations.CreateTwoFactorRecoveryCodes(c, uid, recoveryCodes, user.Salt) if err != nil { log.ErrorfWithRequestId(c, "[twofactor_authorizations.TwoFactorEnableConfirmHandler] failed to create two factor recovery codes for user \"uid:%d\", because %s", uid, err.Error()) return nil, errs.Or(err, errs.ErrOperationFailed) } - err = a.twoFactorAuthorizations.CreateTwoFactorSetting(twoFactorSetting) + err = a.twoFactorAuthorizations.CreateTwoFactorSetting(c, twoFactorSetting) if err != nil { log.ErrorfWithRequestId(c, "[twofactor_authorizations.TwoFactorEnableConfirmHandler] failed to create two factor setting for user \"uid:%d\", because %s", uid, err.Error()) @@ -175,7 +175,7 @@ func (a *TwoFactorAuthorizationsApi) TwoFactorEnableConfirmHandler(c *core.Conte log.InfofWithRequestId(c, "[twofactor_authorizations.TwoFactorEnableConfirmHandler] user \"uid:%d\" has enabled two factor authorization", uid) now := time.Now().Unix() - err = a.tokens.DeleteTokensBeforeTime(uid, now) + err = a.tokens.DeleteTokensBeforeTime(c, uid, now) if err == nil { log.InfofWithRequestId(c, "[twofactor_authorizations.TwoFactorEnableConfirmHandler] revoke old tokens before unix time \"%d\" for user \"uid:%d\"", now, user.Uid) @@ -183,7 +183,7 @@ func (a *TwoFactorAuthorizationsApi) TwoFactorEnableConfirmHandler(c *core.Conte log.WarnfWithRequestId(c, "[twofactor_authorizations.TwoFactorEnableConfirmHandler] failed to revoke old tokens for user \"uid:%d\", because %s", user.Uid, err.Error()) } - token, claims, err := a.tokens.CreateToken(user, c) + token, claims, err := a.tokens.CreateToken(c, user) if err != nil { log.WarnfWithRequestId(c, "[twofactor_authorizations.TwoFactorEnableConfirmHandler] failed to create token for user \"uid:%d\", because %s", user.Uid, err.Error()) @@ -219,7 +219,7 @@ func (a *TwoFactorAuthorizationsApi) TwoFactorDisableHandler(c *core.Context) (i } uid := c.GetCurrentUid() - user, err := a.users.GetUserById(uid) + user, err := a.users.GetUserById(c, uid) if err != nil { if !errs.IsCustomError(err) { @@ -233,7 +233,7 @@ func (a *TwoFactorAuthorizationsApi) TwoFactorDisableHandler(c *core.Context) (i return nil, errs.ErrUserPasswordWrong } - enableTwoFactor, err := a.twoFactorAuthorizations.ExistsTwoFactorSetting(uid) + enableTwoFactor, err := a.twoFactorAuthorizations.ExistsTwoFactorSetting(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[twofactor_authorizations.TwoFactorDisableHandler] failed to check two factor setting, because %s", err.Error()) @@ -244,14 +244,14 @@ func (a *TwoFactorAuthorizationsApi) TwoFactorDisableHandler(c *core.Context) (i return nil, errs.ErrTwoFactorIsNotEnabled } - err = a.twoFactorAuthorizations.DeleteTwoFactorRecoveryCodes(uid) + err = a.twoFactorAuthorizations.DeleteTwoFactorRecoveryCodes(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[twofactor_authorizations.TwoFactorDisableHandler] failed to delete two factor recovery codes for user \"uid:%d\"", uid) return nil, errs.Or(err, errs.ErrOperationFailed) } - err = a.twoFactorAuthorizations.DeleteTwoFactorSetting(uid) + err = a.twoFactorAuthorizations.DeleteTwoFactorSetting(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[twofactor_authorizations.TwoFactorDisableHandler] failed to delete two factor setting for user \"uid:%d\"", uid) @@ -274,7 +274,7 @@ func (a *TwoFactorAuthorizationsApi) TwoFactorRecoveryCodeRegenerateHandler(c *c } uid := c.GetCurrentUid() - user, err := a.users.GetUserById(uid) + user, err := a.users.GetUserById(c, uid) if err != nil { if !errs.IsCustomError(err) { @@ -288,7 +288,7 @@ func (a *TwoFactorAuthorizationsApi) TwoFactorRecoveryCodeRegenerateHandler(c *c return nil, errs.ErrUserPasswordWrong } - enableTwoFactor, err := a.twoFactorAuthorizations.ExistsTwoFactorSetting(uid) + enableTwoFactor, err := a.twoFactorAuthorizations.ExistsTwoFactorSetting(c, uid) if err != nil { log.ErrorfWithRequestId(c, "[twofactor_authorizations.TwoFactorRecoveryCodeRegenerateHandler] failed to check two factor setting, because %s", err.Error()) @@ -306,7 +306,7 @@ func (a *TwoFactorAuthorizationsApi) TwoFactorRecoveryCodeRegenerateHandler(c *c return nil, errs.Or(err, errs.ErrOperationFailed) } - err = a.twoFactorAuthorizations.CreateTwoFactorRecoveryCodes(uid, recoveryCodes, user.Salt) + err = a.twoFactorAuthorizations.CreateTwoFactorRecoveryCodes(c, uid, recoveryCodes, user.Salt) if err != nil { log.ErrorfWithRequestId(c, "[twofactor_authorizations.TwoFactorRecoveryCodeRegenerateHandler] failed to create two factor recovery codes for user \"uid:%d\", because %s", uid, err.Error()) diff --git a/pkg/api/users.go b/pkg/api/users.go index 321aea84..34bcd772 100644 --- a/pkg/api/users.go +++ b/pkg/api/users.go @@ -63,7 +63,7 @@ func (a *UsersApi) UserRegisterHandler(c *core.Context) (interface{}, *errs.Erro TransactionEditScope: models.TRANSACTION_EDIT_SCOPE_ALL, } - err = a.users.CreateUser(user) + err = a.users.CreateUser(c, user) if err != nil { log.ErrorfWithRequestId(c, "[users.UserRegisterHandler] failed to create user \"%s\", because %s", user.Username, err.Error()) @@ -77,7 +77,7 @@ func (a *UsersApi) UserRegisterHandler(c *core.Context) (interface{}, *errs.Erro User: user.ToUserBasicInfo(), } - token, claims, err := a.tokens.CreateToken(user, c) + token, claims, err := a.tokens.CreateToken(c, user) if err != nil { log.WarnfWithRequestId(c, "[users.UserRegisterHandler] failed to create token for user \"uid:%d\", because %s", user.Uid, err.Error()) @@ -96,7 +96,7 @@ func (a *UsersApi) UserRegisterHandler(c *core.Context) (interface{}, *errs.Erro // UserProfileHandler returns user profile of current user func (a *UsersApi) UserProfileHandler(c *core.Context) (interface{}, *errs.Error) { uid := c.GetCurrentUid() - user, err := a.users.GetUserById(uid) + user, err := a.users.GetUserById(c, uid) if err != nil { if !errs.IsCustomError(err) { @@ -121,7 +121,7 @@ func (a *UsersApi) UserUpdateProfileHandler(c *core.Context) (interface{}, *errs } uid := c.GetCurrentUid() - user, err := a.users.GetUserById(uid) + user, err := a.users.GetUserById(c, uid) if err != nil { if !errs.IsCustomError(err) { @@ -164,7 +164,7 @@ func (a *UsersApi) UserUpdateProfileHandler(c *core.Context) (interface{}, *errs } if userUpdateReq.DefaultAccountId > 0 && userUpdateReq.DefaultAccountId != user.DefaultAccountId { - accounts, err := a.accounts.GetAccountsByAccountIds(uid, []int64{userUpdateReq.DefaultAccountId}) + accounts, err := a.accounts.GetAccountsByAccountIds(c, uid, []int64{userUpdateReq.DefaultAccountId}) if err != nil || len(accounts) < 1 { return nil, errs.Or(err, errs.ErrUserDefaultAccountIsInvalid) @@ -242,7 +242,7 @@ func (a *UsersApi) UserUpdateProfileHandler(c *core.Context) (interface{}, *errs return nil, errs.ErrNothingWillBeUpdated } - keyProfileUpdated, err := a.users.UpdateUser(userNew, modifyUserLanguage) + keyProfileUpdated, err := a.users.UpdateUser(c, userNew, modifyUserLanguage) if err != nil { log.ErrorfWithRequestId(c, "[users.UserUpdateProfileHandler] failed to update user \"uid:%d\", because %s", user.Uid, err.Error()) @@ -257,7 +257,7 @@ func (a *UsersApi) UserUpdateProfileHandler(c *core.Context) (interface{}, *errs if keyProfileUpdated { now := time.Now().Unix() - err = a.tokens.DeleteTokensBeforeTime(uid, now) + err = a.tokens.DeleteTokensBeforeTime(c, uid, now) if err == nil { log.InfofWithRequestId(c, "[users.UserUpdateProfileHandler] revoke old tokens before unix time \"%d\" for user \"uid:%d\"", now, user.Uid) @@ -265,7 +265,7 @@ func (a *UsersApi) UserUpdateProfileHandler(c *core.Context) (interface{}, *errs log.WarnfWithRequestId(c, "[users.UserUpdateProfileHandler] failed to revoke old tokens for user \"uid:%d\", because %s", user.Uid, err.Error()) } - token, claims, err := a.tokens.CreateToken(user, c) + token, claims, err := a.tokens.CreateToken(c, user) if err != nil { log.WarnfWithRequestId(c, "[users.UserUpdateProfileHandler] failed to create token for user \"uid:%d\", because %s", user.Uid, err.Error()) diff --git a/pkg/cli/user_data.go b/pkg/cli/user_data.go index c1eb333d..1fc3caf1 100644 --- a/pkg/cli/user_data.go +++ b/pkg/cli/user_data.go @@ -86,7 +86,7 @@ func (l *UserDataCli) AddNewUser(c *cli.Context, username string, email string, TransactionEditScope: models.TRANSACTION_EDIT_SCOPE_ALL, } - err := l.users.CreateUser(user) + err := l.users.CreateUser(nil, user) if err != nil { log.BootErrorf("[user_data.AddNewUser] failed to create user \"%s\", because %s", user.Username, err.Error()) @@ -105,7 +105,7 @@ func (l *UserDataCli) GetUserByUsername(c *cli.Context, username string) (*model return nil, errs.ErrUsernameIsEmpty } - user, err := l.users.GetUserByUsername(username) + user, err := l.users.GetUserByUsername(nil, username) if err != nil { log.BootErrorf("[user_data.GetUserByUsername] failed to get user by user name \"%s\", because %s", username, err.Error()) @@ -127,7 +127,7 @@ func (l *UserDataCli) ModifyUserPassword(c *cli.Context, username string, passwo return errs.ErrPasswordIsEmpty } - user, err := l.users.GetUserByUsername(username) + user, err := l.users.GetUserByUsername(nil, username) if err != nil { log.BootErrorf("[user_data.ModifyUserPassword] failed to get user by user name \"%s\", because %s", username, err.Error()) @@ -144,7 +144,7 @@ func (l *UserDataCli) ModifyUserPassword(c *cli.Context, username string, passwo Password: password, } - _, err = l.users.UpdateUser(userNew, false) + _, err = l.users.UpdateUser(nil, userNew, false) if err != nil { log.BootErrorf("[user_data.ModifyUserPassword] failed to update user \"%s\" password, because %s", user.Username, err.Error()) @@ -152,7 +152,7 @@ func (l *UserDataCli) ModifyUserPassword(c *cli.Context, username string, passwo } now := time.Now().Unix() - err = l.tokens.DeleteTokensBeforeTime(user.Uid, now) + err = l.tokens.DeleteTokensBeforeTime(nil, user.Uid, now) if err == nil { log.BootInfof("[user_data.ModifyUserPassword] revoke old tokens before unix time \"%d\" for user \"%s\"", now, user.Username) @@ -170,7 +170,7 @@ func (l *UserDataCli) SendPasswordResetMail(c *cli.Context, username string) err return errs.ErrUsernameIsEmpty } - user, err := l.users.GetUserByUsername(username) + user, err := l.users.GetUserByUsername(nil, username) if err != nil { log.BootErrorf("[user_data.SendPasswordResetMail] failed to get user by user name \"%s\", because %s", username, err.Error()) @@ -182,14 +182,14 @@ func (l *UserDataCli) SendPasswordResetMail(c *cli.Context, username string) err return errs.ErrEmptyIsNotVerified } - token, _, err := l.tokens.CreatePasswordResetToken(user, nil) + token, _, err := l.tokens.CreatePasswordResetToken(nil, user) if err != nil { log.BootErrorf("[user_data.SendPasswordResetMail] failed to create token for user \"uid:%d\", because %s", user.Uid, err.Error()) return err } - err = l.forgetPasswords.SendPasswordResetEmail(user, token, "") + err = l.forgetPasswords.SendPasswordResetEmail(nil, user, token, "") if err != nil { log.BootWarnf("[user_data.SendPasswordResetMail] cannot send email to \"%s\", because %s", user.Email, err.Error()) @@ -206,7 +206,7 @@ func (l *UserDataCli) EnableUser(c *cli.Context, username string) error { return errs.ErrUsernameIsEmpty } - err := l.users.EnableUser(username) + err := l.users.EnableUser(nil, username) if err != nil { log.BootErrorf("[user_data.EnableUser] failed to set user enabled by user name \"%s\", because %s", username, err.Error()) @@ -223,7 +223,7 @@ func (l *UserDataCli) DisableUser(c *cli.Context, username string) error { return errs.ErrUsernameIsEmpty } - err := l.users.DisableUser(username) + err := l.users.DisableUser(nil, username) if err != nil { log.BootErrorf("[user_data.DisableUser] failed to set user disabled by user name \"%s\", because %s", username, err.Error()) @@ -240,7 +240,7 @@ func (l *UserDataCli) SetUserEmailVerified(c *cli.Context, username string) erro return errs.ErrUsernameIsEmpty } - err := l.users.SetUserEmailVerified(username) + err := l.users.SetUserEmailVerified(nil, username) if err != nil { log.BootErrorf("[user_data.SetUserEmailVerified] failed to set user email address verified by user name \"%s\", because %s", username, err.Error()) @@ -257,7 +257,7 @@ func (l *UserDataCli) SetUserEmailUnverified(c *cli.Context, username string) er return errs.ErrUsernameIsEmpty } - err := l.users.SetUserEmailUnverified(username) + err := l.users.SetUserEmailUnverified(nil, username) if err != nil { log.BootErrorf("[user_data.SetUserEmailUnverified] failed to set user email address unverified by user name \"%s\", because %s", username, err.Error()) @@ -274,7 +274,7 @@ func (l *UserDataCli) DeleteUser(c *cli.Context, username string) error { return errs.ErrUsernameIsEmpty } - err := l.users.DeleteUser(username) + err := l.users.DeleteUser(nil, username) if err != nil { log.BootErrorf("[user_data.DeleteUser] failed to delete user by user name \"%s\", because %s", username, err.Error()) @@ -298,7 +298,7 @@ func (l *UserDataCli) ListUserTokens(c *cli.Context, username string) ([]*models return nil, err } - tokens, err := l.tokens.GetAllUnexpiredNormalTokensByUid(uid) + tokens, err := l.tokens.GetAllUnexpiredNormalTokensByUid(nil, uid) if err != nil { log.BootErrorf("[user_data.ListUserTokens] failed to get tokens of user \"%s\", because %s", username, err.Error()) @@ -323,7 +323,7 @@ func (l *UserDataCli) ClearUserTokens(c *cli.Context, username string) error { } now := time.Now().Unix() - err = l.tokens.DeleteTokensBeforeTime(uid, now) + err = l.tokens.DeleteTokensBeforeTime(nil, uid, now) if err != nil { log.BootErrorf("[user_data.ClearUserTokens] failed to delete tokens of user \"%s\", because %s", username, err.Error()) @@ -347,7 +347,7 @@ func (l *UserDataCli) DisableUserTwoFactorAuthorization(c *cli.Context, username return err } - enableTwoFactor, err := l.twoFactorAuthorizations.ExistsTwoFactorSetting(uid) + enableTwoFactor, err := l.twoFactorAuthorizations.ExistsTwoFactorSetting(nil, uid) if err != nil { log.BootErrorf("[user_data.DisableUserTwoFactorAuthorization] failed to check two factor setting, because %s", err.Error()) @@ -358,14 +358,14 @@ func (l *UserDataCli) DisableUserTwoFactorAuthorization(c *cli.Context, username return errs.ErrTwoFactorIsNotEnabled } - err = l.twoFactorAuthorizations.DeleteTwoFactorRecoveryCodes(uid) + err = l.twoFactorAuthorizations.DeleteTwoFactorRecoveryCodes(nil, uid) if err != nil { log.BootErrorf("[user_data.DisableUserTwoFactorAuthorization] failed to delete two factor recovery codes for user \"%s\"", username) return err } - err = l.twoFactorAuthorizations.DeleteTwoFactorSetting(uid) + err = l.twoFactorAuthorizations.DeleteTwoFactorSetting(nil, uid) if err != nil { log.BootErrorf("[user_data.DisableUserTwoFactorAuthorization] failed to delete two factor setting for user \"%s\"", username) @@ -404,7 +404,7 @@ func (l *UserDataCli) CheckTransactionAndAccount(c *cli.Context, username string } } - allTransactions, err := l.transactions.GetAllTransactions(uid, pageCountForGettingTransactions, false) + allTransactions, err := l.transactions.GetAllTransactions(nil, uid, pageCountForGettingTransactions, false) if err != nil { log.BootErrorf("[user_data.CheckTransactionAndAccount] failed to all transactions for user \"%s\", because %s", username, err.Error()) @@ -516,7 +516,7 @@ func (l *UserDataCli) ExportTransaction(c *cli.Context, username string) ([]byte return nil, err } - allTransactions, err := l.transactions.GetAllTransactions(uid, pageCountForDataExport, true) + allTransactions, err := l.transactions.GetAllTransactions(nil, uid, pageCountForDataExport, true) if err != nil { log.BootErrorf("[user_data.ExportTransaction] failed to all transactions for user \"%s\", because %s", username, err.Error()) @@ -550,7 +550,7 @@ func (l *UserDataCli) getUserEssentialData(uid int64, username string) (accountM return nil, nil, nil, nil, errs.ErrUserIdInvalid } - accounts, err := l.accounts.GetAllAccountsByUid(uid) + accounts, err := l.accounts.GetAllAccountsByUid(nil, uid) if err != nil { log.BootErrorf("[user_data.getUserEssentialData] failed to get accounts for user \"%s\", because %s", username, err.Error()) @@ -559,7 +559,7 @@ func (l *UserDataCli) getUserEssentialData(uid int64, username string) (accountM accountMap = l.accounts.GetAccountMapByList(accounts) - categories, err := l.categories.GetAllCategoriesByUid(uid, 0, -1) + categories, err := l.categories.GetAllCategoriesByUid(nil, uid, 0, -1) if err != nil { log.BootErrorf("[user_data.getUserEssentialData] failed to get categories for user \"%s\", because %s", username, err.Error()) @@ -568,7 +568,7 @@ func (l *UserDataCli) getUserEssentialData(uid int64, username string) (accountM categoryMap = l.categories.GetCategoryMapByList(categories) - tags, err := l.tags.GetAllTagsByUid(uid) + tags, err := l.tags.GetAllTagsByUid(nil, uid) if err != nil { log.BootErrorf("[user_data.getUserEssentialData] failed to get tags for user \"%s\", because %s", username, err.Error()) @@ -577,7 +577,7 @@ func (l *UserDataCli) getUserEssentialData(uid int64, username string) (accountM tagMap = l.tags.GetTagMapByList(tags) - tagIndexs, err = l.tags.GetAllTagIdsOfAllTransactions(uid) + tagIndexs, err = l.tags.GetAllTagIdsOfAllTransactions(nil, uid) if err != nil { log.BootErrorf("[user_data.getUserEssentialData] failed to get tag index for user \"%s\", because %s", username, err.Error()) diff --git a/pkg/datastore/database.go b/pkg/datastore/database.go index 82b7577a..f460dddb 100644 --- a/pkg/datastore/database.go +++ b/pkg/datastore/database.go @@ -1,15 +1,29 @@ package datastore -import "xorm.io/xorm" +import ( + "xorm.io/xorm" + + "github.com/mayswind/ezbookkeeping/pkg/core" +) // Database represents a database instance type Database struct { - *xorm.EngineGroup + engineGroup *xorm.EngineGroup +} + +// NewSession starts a new session with the specified context +func (db *Database) NewSession(c *core.Context) *xorm.Session { + return db.engineGroup.Context(NewXOrmContextAdapter(c)) } // DoTransaction runs a new database transaction -func (db *Database) DoTransaction(fn func(sess *xorm.Session) error) (err error) { - sess := db.NewSession() +func (db *Database) DoTransaction(c *core.Context, fn func(sess *xorm.Session) error) (err error) { + sess := db.engineGroup.NewSession() + + if c != nil { + sess.Context(NewXOrmContextAdapter(c)) + } + defer sess.Close() if err = sess.Begin(); err != nil { diff --git a/pkg/datastore/datastore.go b/pkg/datastore/datastore.go index f7f7e61b..1b3832ac 100644 --- a/pkg/datastore/datastore.go +++ b/pkg/datastore/datastore.go @@ -3,6 +3,7 @@ package datastore import ( "xorm.io/xorm" + "github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/errs" ) @@ -17,13 +18,13 @@ func (s *DataStore) Choose(key int64) *Database { } // Query returns a new database session in a specific database by sharding key -func (s *DataStore) Query(key int64) *xorm.Session { - return s.Choose(key).NewSession() +func (s *DataStore) Query(c *core.Context, key int64) *xorm.Session { + return s.Choose(key).NewSession(c) } // DoTransaction runs a new database transaction in a specific database by sharding key -func (s *DataStore) DoTransaction(key int64, fn func(sess *xorm.Session) error) (err error) { - return s.Choose(key).DoTransaction(fn) +func (s *DataStore) DoTransaction(key int64, c *core.Context, fn func(sess *xorm.Session) error) (err error) { + return s.Choose(key).DoTransaction(c, fn) } // SyncStructs updates database structs by database models @@ -31,7 +32,7 @@ func (s *DataStore) SyncStructs(beans ...interface{}) error { var err error for i := 0; i < len(s.databases); i++ { - err = s.databases[i].Sync2(beans...) + err = s.databases[i].engineGroup.Sync2(beans...) if err != nil { return err diff --git a/pkg/datastore/datastore_container.go b/pkg/datastore/datastore_container.go index acb02bd9..83904823 100644 --- a/pkg/datastore/datastore_container.go +++ b/pkg/datastore/datastore_container.go @@ -104,14 +104,14 @@ func initializeDatabase(dbConfig *settings.DatabaseConfig) (*Database, error) { engineGroup.SetConnMaxLifetime(time.Duration(dbConfig.ConnectionMaxLifeTime) * time.Second) return &Database{ - EngineGroup: engineGroup, + engineGroup: engineGroup, }, nil } func setDatabaseLogger(database *Database, config *settings.Config) { if config.EnableQueryLog { - database.SetLogger(NewXOrmLoggerAdapter(config.EnableQueryLog, config.LogLevel)) - database.ShowSQL(true) + database.engineGroup.SetLogger(NewXOrmLoggerAdapter(config.EnableQueryLog, config.LogLevel)) + database.engineGroup.ShowSQL(true) } } diff --git a/pkg/datastore/query_context.go b/pkg/datastore/query_context.go new file mode 100644 index 00000000..260cf296 --- /dev/null +++ b/pkg/datastore/query_context.go @@ -0,0 +1,50 @@ +package datastore + +import ( + "fmt" + "time" + + "xorm.io/xorm/log" + + "github.com/mayswind/ezbookkeeping/pkg/core" +) + +// XOrmContextAdapter represents the context adapter for xorm +type XOrmContextAdapter struct { + requestId string +} + +// Deadline does nothing +func (c *XOrmContextAdapter) Deadline() (deadline time.Time, ok bool) { + return +} + +// Done always returns nil +func (c *XOrmContextAdapter) Done() <-chan struct{} { + return nil +} + +// Err always returns nil +func (c *XOrmContextAdapter) Err() error { + return nil +} + +// Value returns the value associated with this context for key, or nil +// if no value is associated with key. +func (c *XOrmContextAdapter) Value(key any) any { + if key == log.SessionIDKey && c.requestId != "" { + return fmt.Sprintf("r=%s", c.requestId) + } + + return nil +} + +func NewXOrmContextAdapter(c *core.Context) *XOrmContextAdapter { + if c != nil { + return &XOrmContextAdapter{ + requestId: c.GetRequestId(), + } + } + + return &XOrmContextAdapter{} +} diff --git a/pkg/services/accounts.go b/pkg/services/accounts.go index 34f86d5d..fadb1942 100644 --- a/pkg/services/accounts.go +++ b/pkg/services/accounts.go @@ -5,6 +5,7 @@ import ( "xorm.io/xorm" + "github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/datastore" "github.com/mayswind/ezbookkeeping/pkg/errs" "github.com/mayswind/ezbookkeeping/pkg/models" @@ -31,30 +32,30 @@ var ( ) // GetTotalAccountCountByUid returns total account count of user -func (s *AccountService) GetTotalAccountCountByUid(uid int64) (int64, error) { +func (s *AccountService) GetTotalAccountCountByUid(c *core.Context, uid int64) (int64, error) { if uid <= 0 { return 0, errs.ErrUserIdInvalid } - count, err := s.UserDataDB(uid).Where("uid=? AND deleted=?", uid, false).Count(&models.Account{}) + count, err := s.UserDataDB(uid).NewSession(c).Where("uid=? AND deleted=?", uid, false).Count(&models.Account{}) return count, err } // GetAllAccountsByUid returns all account models of user -func (s *AccountService) GetAllAccountsByUid(uid int64) ([]*models.Account, error) { +func (s *AccountService) GetAllAccountsByUid(c *core.Context, uid int64) ([]*models.Account, error) { if uid <= 0 { return nil, errs.ErrUserIdInvalid } var accounts []*models.Account - err := s.UserDataDB(uid).Where("uid=? AND deleted=?", uid, false).OrderBy("parent_account_id asc, display_order asc").Find(&accounts) + err := s.UserDataDB(uid).NewSession(c).Where("uid=? AND deleted=?", uid, false).OrderBy("parent_account_id asc, display_order asc").Find(&accounts) return accounts, err } // GetAccountAndSubAccountsByAccountId returns account model and sub account models according to account id -func (s *AccountService) GetAccountAndSubAccountsByAccountId(uid int64, accountId int64) ([]*models.Account, error) { +func (s *AccountService) GetAccountAndSubAccountsByAccountId(c *core.Context, uid int64, accountId int64) ([]*models.Account, error) { if uid <= 0 { return nil, errs.ErrUserIdInvalid } @@ -64,13 +65,13 @@ func (s *AccountService) GetAccountAndSubAccountsByAccountId(uid int64, accountI } var accounts []*models.Account - err := s.UserDataDB(uid).Where("uid=? AND deleted=? AND (account_id=? OR parent_account_id=?)", uid, false, accountId, accountId).OrderBy("parent_account_id asc, display_order asc").Find(&accounts) + err := s.UserDataDB(uid).NewSession(c).Where("uid=? AND deleted=? AND (account_id=? OR parent_account_id=?)", uid, false, accountId, accountId).OrderBy("parent_account_id asc, display_order asc").Find(&accounts) return accounts, err } // GetSubAccountsByAccountId returns sub account models according to account id -func (s *AccountService) GetSubAccountsByAccountId(uid int64, accountId int64) ([]*models.Account, error) { +func (s *AccountService) GetSubAccountsByAccountId(c *core.Context, uid int64, accountId int64) ([]*models.Account, error) { if uid <= 0 { return nil, errs.ErrUserIdInvalid } @@ -80,13 +81,13 @@ func (s *AccountService) GetSubAccountsByAccountId(uid int64, accountId int64) ( } var accounts []*models.Account - err := s.UserDataDB(uid).Where("uid=? AND deleted=? AND parent_account_id=?", uid, false, accountId).OrderBy("display_order asc").Find(&accounts) + err := s.UserDataDB(uid).NewSession(c).Where("uid=? AND deleted=? AND parent_account_id=?", uid, false, accountId).OrderBy("display_order asc").Find(&accounts) return accounts, err } // GetAccountsByAccountIds returns account models according to account ids -func (s *AccountService) GetAccountsByAccountIds(uid int64, accountIds []int64) (map[int64]*models.Account, error) { +func (s *AccountService) GetAccountsByAccountIds(c *core.Context, uid int64, accountIds []int64) (map[int64]*models.Account, error) { if uid <= 0 { return nil, errs.ErrUserIdInvalid } @@ -96,7 +97,7 @@ func (s *AccountService) GetAccountsByAccountIds(uid int64, accountIds []int64) } var accounts []*models.Account - err := s.UserDataDB(uid).Where("uid=? AND deleted=?", uid, false).In("account_id", accountIds).Find(&accounts) + err := s.UserDataDB(uid).NewSession(c).Where("uid=? AND deleted=?", uid, false).In("account_id", accountIds).Find(&accounts) if err != nil { return nil, err @@ -107,13 +108,13 @@ func (s *AccountService) GetAccountsByAccountIds(uid int64, accountIds []int64) } // GetMaxDisplayOrder returns the max display order according to account category -func (s *AccountService) GetMaxDisplayOrder(uid int64, category models.AccountCategory) (int32, error) { +func (s *AccountService) GetMaxDisplayOrder(c *core.Context, uid int64, category models.AccountCategory) (int32, error) { if uid <= 0 { return 0, errs.ErrUserIdInvalid } account := &models.Account{} - has, err := s.UserDataDB(uid).Cols("uid", "deleted", "parent_account_id", "display_order").Where("uid=? AND deleted=? AND parent_account_id=? AND category=?", uid, false, models.LevelOneAccountParentId, category).OrderBy("display_order desc").Limit(1).Get(account) + has, err := s.UserDataDB(uid).NewSession(c).Cols("uid", "deleted", "parent_account_id", "display_order").Where("uid=? AND deleted=? AND parent_account_id=? AND category=?", uid, false, models.LevelOneAccountParentId, category).OrderBy("display_order desc").Limit(1).Get(account) if err != nil { return 0, err @@ -127,7 +128,7 @@ func (s *AccountService) GetMaxDisplayOrder(uid int64, category models.AccountCa } // GetMaxSubAccountDisplayOrder returns the max display order of sub account according to account category and parent account id -func (s *AccountService) GetMaxSubAccountDisplayOrder(uid int64, category models.AccountCategory, parentAccountId int64) (int32, error) { +func (s *AccountService) GetMaxSubAccountDisplayOrder(c *core.Context, uid int64, category models.AccountCategory, parentAccountId int64) (int32, error) { if uid <= 0 { return 0, errs.ErrUserIdInvalid } @@ -137,7 +138,7 @@ func (s *AccountService) GetMaxSubAccountDisplayOrder(uid int64, category models } account := &models.Account{} - has, err := s.UserDataDB(uid).Cols("uid", "deleted", "parent_account_id", "display_order").Where("uid=? AND deleted=? AND parent_account_id=? AND category=?", uid, false, parentAccountId, category).OrderBy("display_order desc").Limit(1).Get(account) + has, err := s.UserDataDB(uid).NewSession(c).Cols("uid", "deleted", "parent_account_id", "display_order").Where("uid=? AND deleted=? AND parent_account_id=? AND category=?", uid, false, parentAccountId, category).OrderBy("display_order desc").Limit(1).Get(account) if err != nil { return 0, err @@ -151,7 +152,7 @@ func (s *AccountService) GetMaxSubAccountDisplayOrder(uid int64, category models } // CreateAccounts saves a new account model to database -func (s *AccountService) CreateAccounts(mainAccount *models.Account, childrenAccounts []*models.Account, utcOffset int16) error { +func (s *AccountService) CreateAccounts(c *core.Context, mainAccount *models.Account, childrenAccounts []*models.Account, utcOffset int16) error { if mainAccount.Uid <= 0 { return errs.ErrUserIdInvalid } @@ -204,7 +205,7 @@ func (s *AccountService) CreateAccounts(mainAccount *models.Account, childrenAcc } } - return s.UserDataDB(mainAccount.Uid).DoTransaction(func(sess *xorm.Session) error { + return s.UserDataDB(mainAccount.Uid).DoTransaction(c, func(sess *xorm.Session) error { for i := 0; i < len(allAccounts); i++ { account := allAccounts[i] _, err := sess.Insert(account) @@ -228,7 +229,7 @@ func (s *AccountService) CreateAccounts(mainAccount *models.Account, childrenAcc } // ModifyAccounts saves an existed account model to database -func (s *AccountService) ModifyAccounts(uid int64, accounts []*models.Account) error { +func (s *AccountService) ModifyAccounts(c *core.Context, uid int64, accounts []*models.Account) error { if uid <= 0 { return errs.ErrUserIdInvalid } @@ -239,7 +240,7 @@ func (s *AccountService) ModifyAccounts(uid int64, accounts []*models.Account) e accounts[i].UpdatedUnixTime = now } - return s.UserDataDB(uid).DoTransaction(func(sess *xorm.Session) error { + return s.UserDataDB(uid).DoTransaction(c, func(sess *xorm.Session) error { for i := 0; i < len(accounts); i++ { account := accounts[i] updatedRows, err := sess.ID(account.AccountId).Cols("name", "category", "icon", "color", "comment", "hidden", "updated_unix_time").Where("uid=? AND deleted=?", uid, false).Update(account) @@ -256,7 +257,7 @@ func (s *AccountService) ModifyAccounts(uid int64, accounts []*models.Account) e } // HideAccount updates hidden field of given accounts -func (s *AccountService) HideAccount(uid int64, ids []int64, hidden bool) error { +func (s *AccountService) HideAccount(c *core.Context, uid int64, ids []int64, hidden bool) error { if uid <= 0 { return errs.ErrUserIdInvalid } @@ -268,7 +269,7 @@ func (s *AccountService) HideAccount(uid int64, ids []int64, hidden bool) error UpdatedUnixTime: now, } - return s.UserDataDB(uid).DoTransaction(func(sess *xorm.Session) error { + return s.UserDataDB(uid).DoTransaction(c, func(sess *xorm.Session) error { updatedRows, err := sess.Cols("hidden", "updated_unix_time").Where("uid=? AND deleted=?", uid, false).In("account_id", ids).Update(updateModel) if err != nil { @@ -282,7 +283,7 @@ func (s *AccountService) HideAccount(uid int64, ids []int64, hidden bool) error } // ModifyAccountDisplayOrders updates display order of given accounts -func (s *AccountService) ModifyAccountDisplayOrders(uid int64, accounts []*models.Account) error { +func (s *AccountService) ModifyAccountDisplayOrders(c *core.Context, uid int64, accounts []*models.Account) error { if uid <= 0 { return errs.ErrUserIdInvalid } @@ -291,7 +292,7 @@ func (s *AccountService) ModifyAccountDisplayOrders(uid int64, accounts []*model accounts[i].UpdatedUnixTime = time.Now().Unix() } - return s.UserDataDB(uid).DoTransaction(func(sess *xorm.Session) error { + return s.UserDataDB(uid).DoTransaction(c, func(sess *xorm.Session) error { for i := 0; i < len(accounts); i++ { account := accounts[i] updatedRows, err := sess.ID(account.AccountId).Cols("display_order", "updated_unix_time").Where("uid=? AND deleted=?", uid, false).Update(account) @@ -308,7 +309,7 @@ func (s *AccountService) ModifyAccountDisplayOrders(uid int64, accounts []*model } // DeleteAccount deletes an existed account from database -func (s *AccountService) DeleteAccount(uid int64, accountId int64) error { +func (s *AccountService) DeleteAccount(c *core.Context, uid int64, accountId int64) error { if uid <= 0 { return errs.ErrUserIdInvalid } @@ -321,7 +322,7 @@ func (s *AccountService) DeleteAccount(uid int64, accountId int64) error { DeletedUnixTime: now, } - return s.UserDataDB(uid).DoTransaction(func(sess *xorm.Session) error { + return s.UserDataDB(uid).DoTransaction(c, func(sess *xorm.Session) error { var accountAndSubAccounts []*models.Account err := sess.Where("uid=? AND deleted=? AND (account_id=? OR parent_account_id=?)", uid, false, accountId, accountId).Find(&accountAndSubAccounts) diff --git a/pkg/services/forget_passwords.go b/pkg/services/forget_passwords.go index ae6c38d2..ebe60fdd 100644 --- a/pkg/services/forget_passwords.go +++ b/pkg/services/forget_passwords.go @@ -5,6 +5,7 @@ import ( "fmt" "net/url" + "github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/errs" "github.com/mayswind/ezbookkeeping/pkg/locales" "github.com/mayswind/ezbookkeeping/pkg/mail" @@ -34,7 +35,7 @@ var ( ) // SendPasswordResetEmail sends password reset email according to specified parameters -func (s *ForgetPasswordService) SendPasswordResetEmail(user *models.User, passwordResetToken string, backupLocale string) error { +func (s *ForgetPasswordService) SendPasswordResetEmail(c *core.Context, user *models.User, passwordResetToken string, backupLocale string) error { if !s.CurrentConfig().EnableSMTP { return errs.ErrSMTPServerNotEnabled } diff --git a/pkg/services/tokens.go b/pkg/services/tokens.go index 94e4b837..87134a7e 100644 --- a/pkg/services/tokens.go +++ b/pkg/services/tokens.go @@ -38,19 +38,19 @@ var ( ) // GetAllTokensByUid returns all token models of given user -func (s *TokenService) GetAllTokensByUid(uid int64) ([]*models.TokenRecord, error) { +func (s *TokenService) GetAllTokensByUid(c *core.Context, uid int64) ([]*models.TokenRecord, error) { if uid <= 0 { return nil, errs.ErrUserIdInvalid } var tokenRecords []*models.TokenRecord - err := s.TokenDB(uid).Cols("uid", "user_token_id", "token_type", "user_agent", "created_unix_time", "expired_unix_time").Where("uid=?", uid).Find(&tokenRecords) + err := s.TokenDB(uid).NewSession(c).Cols("uid", "user_token_id", "token_type", "user_agent", "created_unix_time", "expired_unix_time").Where("uid=?", uid).Find(&tokenRecords) return tokenRecords, err } // GetAllUnexpiredNormalTokensByUid returns all available token models of given user -func (s *TokenService) GetAllUnexpiredNormalTokensByUid(uid int64) ([]*models.TokenRecord, error) { +func (s *TokenService) GetAllUnexpiredNormalTokensByUid(c *core.Context, uid int64) ([]*models.TokenRecord, error) { if uid <= 0 { return nil, errs.ErrUserIdInvalid } @@ -58,7 +58,7 @@ func (s *TokenService) GetAllUnexpiredNormalTokensByUid(uid int64) ([]*models.To now := time.Now().Unix() var tokenRecords []*models.TokenRecord - err := s.TokenDB(uid).Cols("uid", "user_token_id", "token_type", "user_agent", "created_unix_time", "expired_unix_time").Where("uid=? AND token_type=? AND expired_unix_time>?", uid, core.USER_TOKEN_TYPE_NORMAL, now).Find(&tokenRecords) + err := s.TokenDB(uid).NewSession(c).Cols("uid", "user_token_id", "token_type", "user_agent", "created_unix_time", "expired_unix_time").Where("uid=? AND token_type=? AND expired_unix_time>?", uid, core.USER_TOKEN_TYPE_NORMAL, now).Find(&tokenRecords) return tokenRecords, err } @@ -79,22 +79,22 @@ func (s *TokenService) ParseTokenByCookie(c *core.Context, tokenCookieName strin } // CreateToken generates a new normal token and saves to database -func (s *TokenService) CreateToken(user *models.User, ctx *core.Context) (string, *core.UserTokenClaims, error) { - return s.createToken(user, core.USER_TOKEN_TYPE_NORMAL, s.getUserAgent(ctx), s.CurrentConfig().TokenExpiredTimeDuration) +func (s *TokenService) CreateToken(c *core.Context, user *models.User) (string, *core.UserTokenClaims, error) { + return s.createToken(c, user, core.USER_TOKEN_TYPE_NORMAL, s.getUserAgent(c), s.CurrentConfig().TokenExpiredTimeDuration) } // CreateRequire2FAToken generates a new token requiring user to verify 2fa passcode and saves to database -func (s *TokenService) CreateRequire2FAToken(user *models.User, ctx *core.Context) (string, *core.UserTokenClaims, error) { - return s.createToken(user, core.USER_TOKEN_TYPE_REQUIRE_2FA, s.getUserAgent(ctx), s.CurrentConfig().TemporaryTokenExpiredTimeDuration) +func (s *TokenService) CreateRequire2FAToken(c *core.Context, user *models.User) (string, *core.UserTokenClaims, error) { + return s.createToken(c, user, core.USER_TOKEN_TYPE_REQUIRE_2FA, s.getUserAgent(c), s.CurrentConfig().TemporaryTokenExpiredTimeDuration) } // CreatePasswordResetToken generates a new password reset token and saves to database -func (s *TokenService) CreatePasswordResetToken(user *models.User, ctx *core.Context) (string, *core.UserTokenClaims, error) { - return s.createToken(user, core.USER_TOKEN_TYPE_PASSWORD_RESET, s.getUserAgent(ctx), s.CurrentConfig().PasswordResetTokenExpiredTimeDuration) +func (s *TokenService) CreatePasswordResetToken(c *core.Context, user *models.User) (string, *core.UserTokenClaims, error) { + return s.createToken(c, user, core.USER_TOKEN_TYPE_PASSWORD_RESET, s.getUserAgent(c), s.CurrentConfig().PasswordResetTokenExpiredTimeDuration) } // DeleteToken deletes given token from database -func (s *TokenService) DeleteToken(tokenRecord *models.TokenRecord) error { +func (s *TokenService) DeleteToken(c *core.Context, tokenRecord *models.TokenRecord) error { if tokenRecord.Uid <= 0 { return errs.ErrUserIdInvalid } @@ -103,7 +103,7 @@ func (s *TokenService) DeleteToken(tokenRecord *models.TokenRecord) error { return errs.ErrInvalidUserTokenId } - return s.TokenDB(tokenRecord.Uid).DoTransaction(func(sess *xorm.Session) error { + return s.TokenDB(tokenRecord.Uid).DoTransaction(c, func(sess *xorm.Session) error { deletedRows, err := sess.Where("uid=? AND user_token_id=? AND created_unix_time=?", tokenRecord.Uid, tokenRecord.UserTokenId, tokenRecord.CreatedUnixTime).Delete(&models.TokenRecord{}) if err != nil { @@ -117,12 +117,12 @@ func (s *TokenService) DeleteToken(tokenRecord *models.TokenRecord) error { } // DeleteTokens deletes given tokens from database -func (s *TokenService) DeleteTokens(uid int64, tokenRecords []*models.TokenRecord) error { +func (s *TokenService) DeleteTokens(c *core.Context, uid int64, tokenRecords []*models.TokenRecord) error { if uid <= 0 { return errs.ErrUserIdInvalid } - return s.TokenDB(uid).DoTransaction(func(sess *xorm.Session) error { + return s.TokenDB(uid).DoTransaction(c, func(sess *xorm.Session) error { for i := 0; i < len(tokenRecords); i++ { tokenRecord := tokenRecords[i] deletedRows, err := sess.Where("uid=? AND user_token_id=? AND created_unix_time=?", uid, tokenRecord.UserTokenId, tokenRecord.CreatedUnixTime).Delete(&models.TokenRecord{}) @@ -139,14 +139,14 @@ func (s *TokenService) DeleteTokens(uid int64, tokenRecords []*models.TokenRecor } // DeleteTokenByClaims deletes given token from database -func (s *TokenService) DeleteTokenByClaims(claims *core.UserTokenClaims) error { +func (s *TokenService) DeleteTokenByClaims(c *core.Context, claims *core.UserTokenClaims) error { userTokenId, err := utils.StringToInt64(claims.UserTokenId) if err != nil { return errs.ErrInvalidUserTokenId } - return s.DeleteToken(&models.TokenRecord{ + return s.DeleteToken(c, &models.TokenRecord{ Uid: claims.Uid, UserTokenId: userTokenId, CreatedUnixTime: claims.IssuedAt, @@ -154,12 +154,12 @@ func (s *TokenService) DeleteTokenByClaims(claims *core.UserTokenClaims) error { } // DeleteTokensBeforeTime deletes tokens that is created before specific time -func (s *TokenService) DeleteTokensBeforeTime(uid int64, expireTime int64) error { +func (s *TokenService) DeleteTokensBeforeTime(c *core.Context, uid int64, expireTime int64) error { if uid <= 0 { return errs.ErrUserIdInvalid } - return s.TokenDB(uid).DoTransaction(func(sess *xorm.Session) error { + return s.TokenDB(uid).DoTransaction(c, func(sess *xorm.Session) error { _, err := sess.Where("uid=? AND created_unix_time?", uid, false, 0).Limit(1).Exist(&models.Transaction{}) if err != nil { diff --git a/pkg/services/transaction_tags.go b/pkg/services/transaction_tags.go index 777fc1df..e74747fc 100644 --- a/pkg/services/transaction_tags.go +++ b/pkg/services/transaction_tags.go @@ -5,6 +5,7 @@ import ( "xorm.io/xorm" + "github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/datastore" "github.com/mayswind/ezbookkeeping/pkg/errs" "github.com/mayswind/ezbookkeeping/pkg/models" @@ -30,30 +31,30 @@ var ( ) // GetTotalTagCountByUid returns total tag count of user -func (s *TransactionTagService) GetTotalTagCountByUid(uid int64) (int64, error) { +func (s *TransactionTagService) GetTotalTagCountByUid(c *core.Context, uid int64) (int64, error) { if uid <= 0 { return 0, errs.ErrUserIdInvalid } - count, err := s.UserDataDB(uid).Where("uid=? AND deleted=?", uid, false).Count(&models.TransactionTag{}) + count, err := s.UserDataDB(uid).NewSession(c).Where("uid=? AND deleted=?", uid, false).Count(&models.TransactionTag{}) return count, err } // GetAllTagsByUid returns all transaction tag models of user -func (s *TransactionTagService) GetAllTagsByUid(uid int64) ([]*models.TransactionTag, error) { +func (s *TransactionTagService) GetAllTagsByUid(c *core.Context, uid int64) ([]*models.TransactionTag, error) { if uid <= 0 { return nil, errs.ErrUserIdInvalid } var tags []*models.TransactionTag - err := s.UserDataDB(uid).Where("uid=? AND deleted=?", uid, false).Find(&tags) + err := s.UserDataDB(uid).NewSession(c).Where("uid=? AND deleted=?", uid, false).Find(&tags) return tags, err } // GetTagByTagId returns a transaction tag model according to transaction tag id -func (s *TransactionTagService) GetTagByTagId(uid int64, tagId int64) (*models.TransactionTag, error) { +func (s *TransactionTagService) GetTagByTagId(c *core.Context, uid int64, tagId int64) (*models.TransactionTag, error) { if uid <= 0 { return nil, errs.ErrUserIdInvalid } @@ -63,7 +64,7 @@ func (s *TransactionTagService) GetTagByTagId(uid int64, tagId int64) (*models.T } tag := &models.TransactionTag{} - has, err := s.UserDataDB(uid).ID(tagId).Where("uid=? AND deleted=?", uid, false).Get(tag) + has, err := s.UserDataDB(uid).NewSession(c).ID(tagId).Where("uid=? AND deleted=?", uid, false).Get(tag) if err != nil { return nil, err @@ -75,7 +76,7 @@ func (s *TransactionTagService) GetTagByTagId(uid int64, tagId int64) (*models.T } // GetTagsByTagIds returns transaction tag models according to transaction tag ids -func (s *TransactionTagService) GetTagsByTagIds(uid int64, tagIds []int64) (map[int64]*models.TransactionTag, error) { +func (s *TransactionTagService) GetTagsByTagIds(c *core.Context, uid int64, tagIds []int64) (map[int64]*models.TransactionTag, error) { if uid <= 0 { return nil, errs.ErrUserIdInvalid } @@ -85,7 +86,7 @@ func (s *TransactionTagService) GetTagsByTagIds(uid int64, tagIds []int64) (map[ } var tags []*models.TransactionTag - err := s.UserDataDB(uid).Where("uid=? AND deleted=?", uid, false).In("tag_id", tagIds).Find(&tags) + err := s.UserDataDB(uid).NewSession(c).Where("uid=? AND deleted=?", uid, false).In("tag_id", tagIds).Find(&tags) if err != nil { return nil, err @@ -96,13 +97,13 @@ func (s *TransactionTagService) GetTagsByTagIds(uid int64, tagIds []int64) (map[ } // GetMaxDisplayOrder returns the max display order -func (s *TransactionTagService) GetMaxDisplayOrder(uid int64) (int32, error) { +func (s *TransactionTagService) GetMaxDisplayOrder(c *core.Context, uid int64) (int32, error) { if uid <= 0 { return 0, errs.ErrUserIdInvalid } tag := &models.TransactionTag{} - has, err := s.UserDataDB(uid).Cols("uid", "deleted", "display_order").Where("uid=? AND deleted=?", uid, false).OrderBy("display_order desc").Limit(1).Get(tag) + has, err := s.UserDataDB(uid).NewSession(c).Cols("uid", "deleted", "display_order").Where("uid=? AND deleted=?", uid, false).OrderBy("display_order desc").Limit(1).Get(tag) if err != nil { return 0, err @@ -116,13 +117,13 @@ func (s *TransactionTagService) GetMaxDisplayOrder(uid int64) (int32, error) { } // GetAllTagIdsOfAllTransactions returns all transaction tag ids -func (s *TransactionTagService) GetAllTagIdsOfAllTransactions(uid int64) (map[int64][]int64, error) { +func (s *TransactionTagService) GetAllTagIdsOfAllTransactions(c *core.Context, uid int64) (map[int64][]int64, error) { if uid <= 0 { return nil, errs.ErrUserIdInvalid } var tagIndexs []*models.TransactionTagIndex - err := s.UserDataDB(uid).Where("uid=? AND deleted=?", uid, false).Find(&tagIndexs) + err := s.UserDataDB(uid).NewSession(c).Where("uid=? AND deleted=?", uid, false).Find(&tagIndexs) allTransactionTagIds := s.getGroupedTransactionTagIds(tagIndexs) @@ -130,13 +131,13 @@ func (s *TransactionTagService) GetAllTagIdsOfAllTransactions(uid int64) (map[in } // GetAllTagIdsOfTransactions returns transaction tag ids for given transactions -func (s *TransactionTagService) GetAllTagIdsOfTransactions(uid int64, transactionIds []int64) (map[int64][]int64, error) { +func (s *TransactionTagService) GetAllTagIdsOfTransactions(c *core.Context, uid int64, transactionIds []int64) (map[int64][]int64, error) { if uid <= 0 { return nil, errs.ErrUserIdInvalid } var tagIndexs []*models.TransactionTagIndex - err := s.UserDataDB(uid).Where("uid=? AND deleted=?", uid, false).In("transaction_id", transactionIds).Find(&tagIndexs) + err := s.UserDataDB(uid).NewSession(c).Where("uid=? AND deleted=?", uid, false).In("transaction_id", transactionIds).Find(&tagIndexs) allTransactionTagIds := s.getGroupedTransactionTagIds(tagIndexs) @@ -144,12 +145,12 @@ func (s *TransactionTagService) GetAllTagIdsOfTransactions(uid int64, transactio } // CreateTag saves a new transaction tag model to database -func (s *TransactionTagService) CreateTag(tag *models.TransactionTag) error { +func (s *TransactionTagService) CreateTag(c *core.Context, tag *models.TransactionTag) error { if tag.Uid <= 0 { return errs.ErrUserIdInvalid } - exists, err := s.ExistsTagName(tag.Uid, tag.Name) + exists, err := s.ExistsTagName(c, tag.Uid, tag.Name) if err != nil { return err @@ -163,19 +164,19 @@ func (s *TransactionTagService) CreateTag(tag *models.TransactionTag) error { tag.CreatedUnixTime = time.Now().Unix() tag.UpdatedUnixTime = time.Now().Unix() - return s.UserDataDB(tag.Uid).DoTransaction(func(sess *xorm.Session) error { + return s.UserDataDB(tag.Uid).DoTransaction(c, func(sess *xorm.Session) error { _, err := sess.Insert(tag) return err }) } // ModifyTag saves an existed transaction tag model to database -func (s *TransactionTagService) ModifyTag(tag *models.TransactionTag) error { +func (s *TransactionTagService) ModifyTag(c *core.Context, tag *models.TransactionTag) error { if tag.Uid <= 0 { return errs.ErrUserIdInvalid } - exists, err := s.ExistsTagName(tag.Uid, tag.Name) + exists, err := s.ExistsTagName(c, tag.Uid, tag.Name) if err != nil { return err @@ -185,7 +186,7 @@ func (s *TransactionTagService) ModifyTag(tag *models.TransactionTag) error { tag.UpdatedUnixTime = time.Now().Unix() - return s.UserDataDB(tag.Uid).DoTransaction(func(sess *xorm.Session) error { + return s.UserDataDB(tag.Uid).DoTransaction(c, func(sess *xorm.Session) error { updatedRows, err := sess.ID(tag.TagId).Cols("name", "updated_unix_time").Where("uid=? AND deleted=?", tag.Uid, false).Update(tag) if err != nil { @@ -199,7 +200,7 @@ func (s *TransactionTagService) ModifyTag(tag *models.TransactionTag) error { } // HideTag updates hidden field of given transaction tags -func (s *TransactionTagService) HideTag(uid int64, ids []int64, hidden bool) error { +func (s *TransactionTagService) HideTag(c *core.Context, uid int64, ids []int64, hidden bool) error { if uid <= 0 { return errs.ErrUserIdInvalid } @@ -211,7 +212,7 @@ func (s *TransactionTagService) HideTag(uid int64, ids []int64, hidden bool) err UpdatedUnixTime: now, } - return s.UserDataDB(uid).DoTransaction(func(sess *xorm.Session) error { + return s.UserDataDB(uid).DoTransaction(c, func(sess *xorm.Session) error { updatedRows, err := sess.Cols("hidden", "updated_unix_time").Where("uid=? AND deleted=?", uid, false).In("tag_id", ids).Update(updateModel) if err != nil { @@ -225,7 +226,7 @@ func (s *TransactionTagService) HideTag(uid int64, ids []int64, hidden bool) err } // ModifyTagDisplayOrders updates display order of given transaction tags -func (s *TransactionTagService) ModifyTagDisplayOrders(uid int64, tags []*models.TransactionTag) error { +func (s *TransactionTagService) ModifyTagDisplayOrders(c *core.Context, uid int64, tags []*models.TransactionTag) error { if uid <= 0 { return errs.ErrUserIdInvalid } @@ -234,7 +235,7 @@ func (s *TransactionTagService) ModifyTagDisplayOrders(uid int64, tags []*models tags[i].UpdatedUnixTime = time.Now().Unix() } - return s.UserDataDB(uid).DoTransaction(func(sess *xorm.Session) error { + return s.UserDataDB(uid).DoTransaction(c, func(sess *xorm.Session) error { for i := 0; i < len(tags); i++ { tag := tags[i] updatedRows, err := sess.ID(tag.TagId).Cols("display_order", "updated_unix_time").Where("uid=? AND deleted=?", uid, false).Update(tag) @@ -251,7 +252,7 @@ func (s *TransactionTagService) ModifyTagDisplayOrders(uid int64, tags []*models } // DeleteTag deletes an existed transaction tag from database -func (s *TransactionTagService) DeleteTag(uid int64, tagId int64) error { +func (s *TransactionTagService) DeleteTag(c *core.Context, uid int64, tagId int64) error { if uid <= 0 { return errs.ErrUserIdInvalid } @@ -263,7 +264,7 @@ func (s *TransactionTagService) DeleteTag(uid int64, tagId int64) error { DeletedUnixTime: now, } - return s.UserDataDB(uid).DoTransaction(func(sess *xorm.Session) error { + return s.UserDataDB(uid).DoTransaction(c, func(sess *xorm.Session) error { exists, err := sess.Cols("uid", "tag_id").Where("uid=? AND deleted=? AND tag_id=?", uid, false, tagId).Limit(1).Exist(&models.TransactionTagIndex{}) if err != nil { @@ -285,7 +286,7 @@ func (s *TransactionTagService) DeleteTag(uid int64, tagId int64) error { } // DeleteAllTags deletes all existed transaction tags from database -func (s *TransactionTagService) DeleteAllTags(uid int64) error { +func (s *TransactionTagService) DeleteAllTags(c *core.Context, uid int64) error { if uid <= 0 { return errs.ErrUserIdInvalid } @@ -297,7 +298,7 @@ func (s *TransactionTagService) DeleteAllTags(uid int64) error { DeletedUnixTime: now, } - return s.UserDataDB(uid).DoTransaction(func(sess *xorm.Session) error { + return s.UserDataDB(uid).DoTransaction(c, func(sess *xorm.Session) error { exists, err := sess.Cols("uid", "deleted").Where("uid=? AND deleted=?", uid, false).Limit(1).Exist(&models.TransactionTagIndex{}) if err != nil { @@ -317,12 +318,12 @@ func (s *TransactionTagService) DeleteAllTags(uid int64) error { } // ExistsTagName returns whether the given tag name exists -func (s *TransactionTagService) ExistsTagName(uid int64, name string) (bool, error) { +func (s *TransactionTagService) ExistsTagName(c *core.Context, uid int64, name string) (bool, error) { if name == "" { return false, errs.ErrTransactionTagNameIsEmpty } - return s.UserDataDB(uid).Cols("name").Where("uid=? AND deleted=? AND name=?", uid, false, name).Exist(&models.TransactionTag{}) + return s.UserDataDB(uid).NewSession(c).Cols("name").Where("uid=? AND deleted=? AND name=?", uid, false, name).Exist(&models.TransactionTag{}) } // GetTagMapByList returns a transaction tag map by a list diff --git a/pkg/services/transactions.go b/pkg/services/transactions.go index 4ce364a5..e3b366f0 100644 --- a/pkg/services/transactions.go +++ b/pkg/services/transactions.go @@ -7,6 +7,7 @@ import ( "xorm.io/xorm" + "github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/datastore" "github.com/mayswind/ezbookkeeping/pkg/errs" "github.com/mayswind/ezbookkeeping/pkg/models" @@ -33,23 +34,23 @@ var ( ) // GetTotalTransactionCountByUid returns total transaction count of user -func (s *TransactionService) GetTotalTransactionCountByUid(uid int64) (int64, error) { +func (s *TransactionService) GetTotalTransactionCountByUid(c *core.Context, uid int64) (int64, error) { if uid <= 0 { return 0, errs.ErrUserIdInvalid } - count, err := s.UserDataDB(uid).Where("uid=? AND deleted=?", uid, false).Count(&models.Transaction{}) + count, err := s.UserDataDB(uid).NewSession(c).Where("uid=? AND deleted=?", uid, false).Count(&models.Transaction{}) return count, err } // GetAllTransactions returns all transactions -func (s *TransactionService) GetAllTransactions(uid int64, pageCount int32, noDuplicated bool) ([]*models.Transaction, error) { +func (s *TransactionService) GetAllTransactions(c *core.Context, uid int64, pageCount int32, noDuplicated bool) ([]*models.Transaction, error) { maxTransactionTime := utils.GetMaxTransactionTimeFromUnixTime(time.Now().Unix()) var allTransactions []*models.Transaction for maxTransactionTime > 0 { - transactions, err := s.GetAllTransactionsByMaxTime(uid, maxTransactionTime, pageCount, noDuplicated) + transactions, err := s.GetAllTransactionsByMaxTime(c, uid, maxTransactionTime, pageCount, noDuplicated) if err != nil { return nil, err @@ -69,12 +70,12 @@ func (s *TransactionService) GetAllTransactions(uid int64, pageCount int32, noDu } // GetAllTransactionsByMaxTime returns all transactions before given time -func (s *TransactionService) GetAllTransactionsByMaxTime(uid int64, maxTransactionTime int64, count int32, noDuplicated bool) ([]*models.Transaction, error) { - return s.GetTransactionsByMaxTime(uid, maxTransactionTime, 0, 0, nil, nil, "", 1, count, false, noDuplicated) +func (s *TransactionService) GetAllTransactionsByMaxTime(c *core.Context, uid int64, maxTransactionTime int64, count int32, noDuplicated bool) ([]*models.Transaction, error) { + return s.GetTransactionsByMaxTime(c, uid, maxTransactionTime, 0, 0, nil, nil, "", 1, count, false, noDuplicated) } // GetTransactionsByMaxTime returns transactions before given time -func (s *TransactionService) GetTransactionsByMaxTime(uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, keyword string, page int32, count int32, needOneMoreItem bool, noDuplicated bool) ([]*models.Transaction, error) { +func (s *TransactionService) GetTransactionsByMaxTime(c *core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, keyword string, page int32, count int32, needOneMoreItem bool, noDuplicated bool) ([]*models.Transaction, error) { if uid <= 0 { return nil, errs.ErrUserIdInvalid } @@ -99,13 +100,13 @@ func (s *TransactionService) GetTransactionsByMaxTime(uid int64, maxTransactionT } condition, conditionParams := s.getTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, keyword, noDuplicated) - err = s.UserDataDB(uid).Where(condition, conditionParams...).Limit(int(actualCount), int(count*(page-1))).OrderBy("transaction_time desc").Find(&transactions) + err = s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...).Limit(int(actualCount), int(count*(page-1))).OrderBy("transaction_time desc").Find(&transactions) return transactions, err } // GetTransactionsInMonthByPage returns all transactions in given year and month -func (s *TransactionService) GetTransactionsInMonthByPage(uid int64, year int32, month int32, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, keyword string) ([]*models.Transaction, error) { +func (s *TransactionService) GetTransactionsInMonthByPage(c *core.Context, uid int64, year int32, month int32, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, keyword string) ([]*models.Transaction, error) { if uid <= 0 { return nil, errs.ErrUserIdInvalid } @@ -125,7 +126,7 @@ func (s *TransactionService) GetTransactionsInMonthByPage(uid int64, year int32, var transactions []*models.Transaction condition, conditionParams := s.getTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, keyword, true) - err = s.UserDataDB(uid).Where(condition, conditionParams...).OrderBy("transaction_time desc").Find(&transactions) + err = s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...).OrderBy("transaction_time desc").Find(&transactions) transactionsInMonth := make([]*models.Transaction, 0, len(transactions)) @@ -143,7 +144,7 @@ func (s *TransactionService) GetTransactionsInMonthByPage(uid int64, year int32, } // GetTransactionByTransactionId returns a transaction model according to transaction id -func (s *TransactionService) GetTransactionByTransactionId(uid int64, transactionId int64) (*models.Transaction, error) { +func (s *TransactionService) GetTransactionByTransactionId(c *core.Context, uid int64, transactionId int64) (*models.Transaction, error) { if uid <= 0 { return nil, errs.ErrUserIdInvalid } @@ -153,7 +154,7 @@ func (s *TransactionService) GetTransactionByTransactionId(uid int64, transactio } transaction := &models.Transaction{} - has, err := s.UserDataDB(uid).ID(transactionId).Where("uid=? AND deleted=?", uid, false).Get(transaction) + has, err := s.UserDataDB(uid).NewSession(c).ID(transactionId).Where("uid=? AND deleted=?", uid, false).Get(transaction) if err != nil { return nil, err @@ -165,12 +166,12 @@ func (s *TransactionService) GetTransactionByTransactionId(uid int64, transactio } // GetAllTransactionCount returns total count of transactions -func (s *TransactionService) GetAllTransactionCount(uid int64) (int64, error) { - return s.GetTransactionCount(uid, 0, 0, 0, nil, nil, "") +func (s *TransactionService) GetAllTransactionCount(c *core.Context, uid int64) (int64, error) { + return s.GetTransactionCount(c, uid, 0, 0, 0, nil, nil, "") } // GetMonthTransactionCount returns total count of transactions in given year and month -func (s *TransactionService) GetMonthTransactionCount(uid int64, year int32, month int32, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, keyword string, utcOffset int16) (int64, error) { +func (s *TransactionService) GetMonthTransactionCount(c *core.Context, uid int64, year int32, month int32, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, keyword string, utcOffset int16) (int64, error) { if uid <= 0 { return 0, errs.ErrUserIdInvalid } @@ -186,21 +187,21 @@ func (s *TransactionService) GetMonthTransactionCount(uid int64, year int32, mon minTransactionTime := utils.GetMinTransactionTimeFromUnixTime(startTime.Unix()) maxTransactionTime := utils.GetMinTransactionTimeFromUnixTime(endTime.Unix()) - 1 - return s.GetTransactionCount(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, keyword) + return s.GetTransactionCount(c, uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, keyword) } // GetTransactionCount returns count of transactions -func (s *TransactionService) GetTransactionCount(uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, keyword string) (int64, error) { +func (s *TransactionService) GetTransactionCount(c *core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, keyword string) (int64, error) { if uid <= 0 { return 0, errs.ErrUserIdInvalid } condition, conditionParams := s.getTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, keyword, true) - return s.UserDataDB(uid).Where(condition, conditionParams...).Count(&models.Transaction{}) + return s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...).Count(&models.Transaction{}) } // CreateTransaction saves a new transaction to database -func (s *TransactionService) CreateTransaction(transaction *models.Transaction, tagIds []int64) error { +func (s *TransactionService) CreateTransaction(c *core.Context, transaction *models.Transaction, tagIds []int64) error { if transaction.Uid <= 0 { return errs.ErrUserIdInvalid } @@ -247,7 +248,7 @@ func (s *TransactionService) CreateTransaction(transaction *models.Transaction, } } - return s.UserDataDB(transaction.Uid).DoTransaction(func(sess *xorm.Session) error { + return s.UserDataDB(transaction.Uid).DoTransaction(c, func(sess *xorm.Session) error { // Get and verify source and destination account sourceAccount, destinationAccount, err := s.getAccountModels(sess, transaction) @@ -411,7 +412,7 @@ func (s *TransactionService) CreateTransaction(transaction *models.Transaction, } // ModifyTransaction saves an existed transaction to database -func (s *TransactionService) ModifyTransaction(transaction *models.Transaction, addTagIds []int64, removeTagIds []int64) error { +func (s *TransactionService) ModifyTransaction(c *core.Context, transaction *models.Transaction, addTagIds []int64, removeTagIds []int64) error { if transaction.Uid <= 0 { return errs.ErrUserIdInvalid } @@ -441,7 +442,7 @@ func (s *TransactionService) ModifyTransaction(transaction *models.Transaction, } } - err := s.UserDataDB(transaction.Uid).DoTransaction(func(sess *xorm.Session) error { + err := s.UserDataDB(transaction.Uid).DoTransaction(c, func(sess *xorm.Session) error { // Get and verify current transaction oldTransaction := &models.Transaction{} has, err := sess.ID(transaction.TransactionId).Where("uid=? AND deleted=?", transaction.Uid, false).Get(oldTransaction) @@ -782,7 +783,7 @@ func (s *TransactionService) ModifyTransaction(transaction *models.Transaction, } // DeleteTransaction deletes an existed transaction from database -func (s *TransactionService) DeleteTransaction(uid int64, transactionId int64) error { +func (s *TransactionService) DeleteTransaction(c *core.Context, uid int64, transactionId int64) error { if uid <= 0 { return errs.ErrUserIdInvalid } @@ -799,7 +800,7 @@ func (s *TransactionService) DeleteTransaction(uid int64, transactionId int64) e DeletedUnixTime: now, } - return s.UserDataDB(uid).DoTransaction(func(sess *xorm.Session) error { + return s.UserDataDB(uid).DoTransaction(c, func(sess *xorm.Session) error { // Get and verify current transaction oldTransaction := &models.Transaction{} has, err := sess.ID(transactionId).Where("uid=? AND deleted=?", uid, false).Get(oldTransaction) @@ -902,7 +903,7 @@ func (s *TransactionService) DeleteTransaction(uid int64, transactionId int64) e } // DeleteAllTransactions deletes all existed transactions from database -func (s *TransactionService) DeleteAllTransactions(uid int64) error { +func (s *TransactionService) DeleteAllTransactions(c *core.Context, uid int64) error { if uid <= 0 { return errs.ErrUserIdInvalid } @@ -925,7 +926,7 @@ func (s *TransactionService) DeleteAllTransactions(uid int64) error { DeletedUnixTime: now, } - return s.UserDataDB(uid).DoTransaction(func(sess *xorm.Session) error { + return s.UserDataDB(uid).DoTransaction(c, func(sess *xorm.Session) error { // Update all transaction to deleted _, err := sess.Cols("deleted", "deleted_unix_time").Where("uid=? AND deleted=?", uid, false).Update(updateModel) @@ -992,7 +993,7 @@ func (s *TransactionService) GetRelatedTransferTransaction(originalTransaction * } // GetAccountsTotalIncomeAndExpense returns the every accounts total income and expense amount by specific date range -func (s *TransactionService) GetAccountsTotalIncomeAndExpense(uid int64, startUnixTime int64, endUnixTime int64) (map[int64]int64, map[int64]int64, error) { +func (s *TransactionService) GetAccountsTotalIncomeAndExpense(c *core.Context, uid int64, startUnixTime int64, endUnixTime int64) (map[int64]int64, map[int64]int64, error) { if uid <= 0 { return nil, nil, errs.ErrUserIdInvalid } @@ -1001,7 +1002,7 @@ func (s *TransactionService) GetAccountsTotalIncomeAndExpense(uid int64, startUn endTransactionTime := utils.GetMaxTransactionTimeFromUnixTime(endUnixTime) var transactionTotalAmounts []*models.Transaction - err := s.UserDataDB(uid).Select("type, account_id, SUM(amount) as amount").Where("uid=? AND deleted=? AND (type=? OR type=?) AND transaction_time>=? AND transaction_time<=?", uid, false, models.TRANSACTION_DB_TYPE_INCOME, models.TRANSACTION_DB_TYPE_EXPENSE, startTransactionTime, endTransactionTime).GroupBy("type, account_id").Find(&transactionTotalAmounts) + err := s.UserDataDB(uid).NewSession(c).Select("type, account_id, SUM(amount) as amount").Where("uid=? AND deleted=? AND (type=? OR type=?) AND transaction_time>=? AND transaction_time<=?", uid, false, models.TRANSACTION_DB_TYPE_INCOME, models.TRANSACTION_DB_TYPE_EXPENSE, startTransactionTime, endTransactionTime).GroupBy("type, account_id").Find(&transactionTotalAmounts) if err != nil { return nil, nil, err @@ -1024,7 +1025,7 @@ func (s *TransactionService) GetAccountsTotalIncomeAndExpense(uid int64, startUn } // GetAccountsMonthTotalIncomeAndExpense returns the every accounts total income and expense amount in month by specific date range -func (s *TransactionService) GetAccountsMonthTotalIncomeAndExpense(uid int64, startUnixTime int64, endUnixTime int64, pageCount int) (map[string]models.TransactionAccountsAmount, error) { +func (s *TransactionService) GetAccountsMonthTotalIncomeAndExpense(c *core.Context, uid int64, startUnixTime int64, endUnixTime int64, pageCount int) (map[string]models.TransactionAccountsAmount, error) { if uid <= 0 { return nil, errs.ErrUserIdInvalid } @@ -1039,7 +1040,7 @@ func (s *TransactionService) GetAccountsMonthTotalIncomeAndExpense(uid int64, st for maxTransactionTime > 0 { var transactions []*models.Transaction - err := s.UserDataDB(uid).Select("uid, type, account_id, transaction_time, timezone_utc_offset, amount").Where("uid=? AND deleted=? AND (type=? OR type=?) AND transaction_time>=? AND transaction_time<=?", uid, false, models.TRANSACTION_DB_TYPE_INCOME, models.TRANSACTION_DB_TYPE_EXPENSE, minTransactionTime, maxTransactionTime).Limit(pageCount, 0).OrderBy("transaction_time desc").Find(&transactions) + err := s.UserDataDB(uid).NewSession(c).Select("uid, type, account_id, transaction_time, timezone_utc_offset, amount").Where("uid=? AND deleted=? AND (type=? OR type=?) AND transaction_time>=? AND transaction_time<=?", uid, false, models.TRANSACTION_DB_TYPE_INCOME, models.TRANSACTION_DB_TYPE_EXPENSE, minTransactionTime, maxTransactionTime).Limit(pageCount, 0).OrderBy("transaction_time desc").Find(&transactions) if err != nil { return nil, err @@ -1091,7 +1092,7 @@ func (s *TransactionService) GetAccountsMonthTotalIncomeAndExpense(uid int64, st } // GetAccountsAndCategoriesTotalIncomeAndExpense returns the every accounts and categories total income and expense amount by specific date range -func (s *TransactionService) GetAccountsAndCategoriesTotalIncomeAndExpense(uid int64, startUnixTime int64, endUnixTime int64) ([]*models.Transaction, error) { +func (s *TransactionService) GetAccountsAndCategoriesTotalIncomeAndExpense(c *core.Context, uid int64, startUnixTime int64, endUnixTime int64) ([]*models.Transaction, error) { if uid <= 0 { return nil, errs.ErrUserIdInvalid } @@ -1114,7 +1115,7 @@ func (s *TransactionService) GetAccountsAndCategoriesTotalIncomeAndExpense(uid i } var transactionTotalAmounts []*models.Transaction - err := s.UserDataDB(uid).Select("category_id, account_id, SUM(amount) as amount").Where(condition, conditionParams...).GroupBy("category_id, account_id").Find(&transactionTotalAmounts) + err := s.UserDataDB(uid).NewSession(c).Select("category_id, account_id, SUM(amount) as amount").Where(condition, conditionParams...).GroupBy("category_id, account_id").Find(&transactionTotalAmounts) if err != nil { return nil, err diff --git a/pkg/services/twofactor_authorizations.go b/pkg/services/twofactor_authorizations.go index e74797ba..822b1682 100644 --- a/pkg/services/twofactor_authorizations.go +++ b/pkg/services/twofactor_authorizations.go @@ -7,6 +7,7 @@ import ( "github.com/pquerna/otp/totp" "xorm.io/xorm" + "github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/datastore" "github.com/mayswind/ezbookkeeping/pkg/errs" "github.com/mayswind/ezbookkeeping/pkg/models" @@ -45,13 +46,13 @@ var ( ) // GetUserTwoFactorSettingByUid returns the 2fa setting model according to user uid -func (s *TwoFactorAuthorizationService) GetUserTwoFactorSettingByUid(uid int64) (*models.TwoFactor, error) { +func (s *TwoFactorAuthorizationService) GetUserTwoFactorSettingByUid(c *core.Context, uid int64) (*models.TwoFactor, error) { if uid <= 0 { return nil, errs.ErrUserIdInvalid } twoFactor := &models.TwoFactor{} - has, err := s.UserDB().Where("uid=?", uid).Get(twoFactor) + has, err := s.UserDB().NewSession(c).Where("uid=?", uid).Get(twoFactor) if err != nil { return nil, err @@ -69,7 +70,7 @@ func (s *TwoFactorAuthorizationService) GetUserTwoFactorSettingByUid(uid int64) } // GenerateTwoFactorSecret generates a new 2fa secret -func (s *TwoFactorAuthorizationService) GenerateTwoFactorSecret(user *models.User) (*otp.Key, error) { +func (s *TwoFactorAuthorizationService) GenerateTwoFactorSecret(c *core.Context, user *models.User) (*otp.Key, error) { if user == nil { return nil, errs.ErrUserNotFound } @@ -85,7 +86,7 @@ func (s *TwoFactorAuthorizationService) GenerateTwoFactorSecret(user *models.Use } // CreateTwoFactorSetting saves a new 2fa setting to database -func (s *TwoFactorAuthorizationService) CreateTwoFactorSetting(twoFactor *models.TwoFactor) error { +func (s *TwoFactorAuthorizationService) CreateTwoFactorSetting(c *core.Context, twoFactor *models.TwoFactor) error { if twoFactor.Uid <= 0 { return errs.ErrUserIdInvalid } @@ -99,19 +100,19 @@ func (s *TwoFactorAuthorizationService) CreateTwoFactorSetting(twoFactor *models twoFactor.CreatedUnixTime = time.Now().Unix() - return s.UserDB().DoTransaction(func(sess *xorm.Session) error { + return s.UserDB().DoTransaction(c, func(sess *xorm.Session) error { _, err := sess.Insert(twoFactor) return err }) } // DeleteTwoFactorSetting deletes an existed 2fa setting from database -func (s *TwoFactorAuthorizationService) DeleteTwoFactorSetting(uid int64) error { +func (s *TwoFactorAuthorizationService) DeleteTwoFactorSetting(c *core.Context, uid int64) error { if uid <= 0 { return errs.ErrUserIdInvalid } - return s.UserDB().DoTransaction(func(sess *xorm.Session) error { + return s.UserDB().DoTransaction(c, func(sess *xorm.Session) error { deletedRows, err := sess.Where("uid=?", uid).Delete(&models.TwoFactor{}) if err != nil { @@ -125,22 +126,22 @@ func (s *TwoFactorAuthorizationService) DeleteTwoFactorSetting(uid int64) error } // ExistsTwoFactorSetting returns whether the given user has existed 2fa setting -func (s *TwoFactorAuthorizationService) ExistsTwoFactorSetting(uid int64) (bool, error) { +func (s *TwoFactorAuthorizationService) ExistsTwoFactorSetting(c *core.Context, uid int64) (bool, error) { if uid <= 0 { return false, errs.ErrUserIdInvalid } - return s.UserDB().Cols("uid").Where("uid=?", uid).Exist(&models.TwoFactor{}) + return s.UserDB().NewSession(c).Cols("uid").Where("uid=?", uid).Exist(&models.TwoFactor{}) } // GetAndUseUserTwoFactorRecoveryCode checks whether the given 2fa recovery code exists and marks it used -func (s *TwoFactorAuthorizationService) GetAndUseUserTwoFactorRecoveryCode(uid int64, recoveryCode string, salt string) error { +func (s *TwoFactorAuthorizationService) GetAndUseUserTwoFactorRecoveryCode(c *core.Context, uid int64, recoveryCode string, salt string) error { if uid <= 0 { return errs.ErrUserIdInvalid } recoveryCode = utils.EncodePassword(recoveryCode, salt) - exists, err := s.UserDB().Cols("uid", "recovery_code").Where("uid=? AND recovery_code=? AND used=?", uid, recoveryCode, false).Exist(&models.TwoFactorRecoveryCode{}) + exists, err := s.UserDB().NewSession(c).Cols("uid", "recovery_code").Where("uid=? AND recovery_code=? AND used=?", uid, recoveryCode, false).Exist(&models.TwoFactorRecoveryCode{}) if err != nil { return err @@ -148,7 +149,7 @@ func (s *TwoFactorAuthorizationService) GetAndUseUserTwoFactorRecoveryCode(uid i return errs.ErrTwoFactorRecoveryCodeNotExist } - return s.UserDB().DoTransaction(func(sess *xorm.Session) error { + return s.UserDB().DoTransaction(c, func(sess *xorm.Session) error { _, err := sess.Cols("used", "used_unix_time").Where("uid=? AND recovery_code=?", uid, recoveryCode).Update(&models.TwoFactorRecoveryCode{Used: true, UsedUnixTime: time.Now().Unix()}) return err }) @@ -172,7 +173,7 @@ func (s *TwoFactorAuthorizationService) GenerateTwoFactorRecoveryCodes() ([]stri } // CreateTwoFactorRecoveryCodes saves new 2fa recovery codes to database -func (s *TwoFactorAuthorizationService) CreateTwoFactorRecoveryCodes(uid int64, recoveryCodes []string, salt string) error { +func (s *TwoFactorAuthorizationService) CreateTwoFactorRecoveryCodes(c *core.Context, uid int64, recoveryCodes []string, salt string) error { twoFactorRecoveryCodes := make([]*models.TwoFactorRecoveryCode, len(recoveryCodes)) for i := 0; i < len(recoveryCodes); i++ { @@ -184,7 +185,7 @@ func (s *TwoFactorAuthorizationService) CreateTwoFactorRecoveryCodes(uid int64, } } - return s.UserDB().DoTransaction(func(sess *xorm.Session) error { + return s.UserDB().DoTransaction(c, func(sess *xorm.Session) error { _, err := sess.Where("uid=?", uid).Delete(&models.TwoFactorRecoveryCode{}) if err != nil { @@ -205,12 +206,12 @@ func (s *TwoFactorAuthorizationService) CreateTwoFactorRecoveryCodes(uid int64, } // DeleteTwoFactorRecoveryCodes deletes existed 2fa recovery codes from database -func (s *TwoFactorAuthorizationService) DeleteTwoFactorRecoveryCodes(uid int64) error { +func (s *TwoFactorAuthorizationService) DeleteTwoFactorRecoveryCodes(c *core.Context, uid int64) error { if uid <= 0 { return errs.ErrUserIdInvalid } - return s.UserDB().DoTransaction(func(sess *xorm.Session) error { + return s.UserDB().DoTransaction(c, func(sess *xorm.Session) error { _, err := sess.Where("uid=?", uid).Delete(&models.TwoFactorRecoveryCode{}) return err }) diff --git a/pkg/services/users.go b/pkg/services/users.go index a4e759b0..50989f67 100644 --- a/pkg/services/users.go +++ b/pkg/services/users.go @@ -5,6 +5,7 @@ import ( "xorm.io/xorm" + "github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/datastore" "github.com/mayswind/ezbookkeeping/pkg/errs" "github.com/mayswind/ezbookkeeping/pkg/models" @@ -31,14 +32,14 @@ var ( ) // GetUserByUsernameOrEmailAndPassword returns the user model according to login name and password -func (s *UserService) GetUserByUsernameOrEmailAndPassword(loginname string, password string) (*models.User, error) { +func (s *UserService) GetUserByUsernameOrEmailAndPassword(c *core.Context, loginname string, password string) (*models.User, error) { var user *models.User var err error if utils.IsValidUsername(loginname) { - user, err = s.GetUserByUsername(loginname) + user, err = s.GetUserByUsername(c, loginname) } else if utils.IsValidEmail(loginname) { - user, err = s.GetUserByEmail(loginname) + user, err = s.GetUserByEmail(c, loginname) } else { err = errs.ErrLoginNameInvalid } @@ -55,13 +56,13 @@ func (s *UserService) GetUserByUsernameOrEmailAndPassword(loginname string, pass } // GetUserById returns the user model according to user uid -func (s *UserService) GetUserById(uid int64) (*models.User, error) { +func (s *UserService) GetUserById(c *core.Context, uid int64) (*models.User, error) { if uid <= 0 { return nil, errs.ErrUserIdInvalid } user := &models.User{} - has, err := s.UserDB().ID(uid).Where("deleted=?", false).Get(user) + has, err := s.UserDB().NewSession(c).ID(uid).Where("deleted=?", false).Get(user) if err != nil { return nil, err @@ -73,13 +74,13 @@ func (s *UserService) GetUserById(uid int64) (*models.User, error) { } // GetUserByUsername returns the user model according to user name -func (s *UserService) GetUserByUsername(username string) (*models.User, error) { +func (s *UserService) GetUserByUsername(c *core.Context, username string) (*models.User, error) { if username == "" { return nil, errs.ErrUsernameIsEmpty } user := &models.User{} - has, err := s.UserDB().Where("username=? AND deleted=?", username, false).Get(user) + has, err := s.UserDB().NewSession(c).Where("username=? AND deleted=?", username, false).Get(user) if err != nil { return nil, err @@ -91,13 +92,13 @@ func (s *UserService) GetUserByUsername(username string) (*models.User, error) { } // GetUserByEmail returns the user model according to user email -func (s *UserService) GetUserByEmail(email string) (*models.User, error) { +func (s *UserService) GetUserByEmail(c *core.Context, email string) (*models.User, error) { if email == "" { return nil, errs.ErrEmailIsEmpty } user := &models.User{} - has, err := s.UserDB().Where("email=? AND deleted=?", email, false).Get(user) + has, err := s.UserDB().NewSession(c).Where("email=? AND deleted=?", email, false).Get(user) if err != nil { return nil, err @@ -109,8 +110,8 @@ func (s *UserService) GetUserByEmail(email string) (*models.User, error) { } // CreateUser saves a new user model to database -func (s *UserService) CreateUser(user *models.User) error { - exists, err := s.ExistsUsername(user.Username) +func (s *UserService) CreateUser(c *core.Context, user *models.User) error { + exists, err := s.ExistsUsername(c, user.Username) if err != nil { return err @@ -118,7 +119,7 @@ func (s *UserService) CreateUser(user *models.User) error { return errs.ErrUsernameAlreadyExists } - exists, err = s.ExistsEmail(user.Email) + exists, err = s.ExistsEmail(c, user.Email) if err != nil { return err @@ -143,14 +144,14 @@ func (s *UserService) CreateUser(user *models.User) error { user.UpdatedUnixTime = time.Now().Unix() user.LastLoginUnixTime = time.Now().Unix() - return s.UserDB().DoTransaction(func(sess *xorm.Session) error { + return s.UserDB().DoTransaction(c, func(sess *xorm.Session) error { _, err := sess.Insert(user) return err }) } // UpdateUser saves an existed user model to database -func (s *UserService) UpdateUser(user *models.User, modifyUserLanguage bool) (keyProfileUpdated bool, err error) { +func (s *UserService) UpdateUser(c *core.Context, user *models.User, modifyUserLanguage bool) (keyProfileUpdated bool, err error) { if user.Uid <= 0 { return false, errs.ErrUserIdInvalid } @@ -161,7 +162,7 @@ func (s *UserService) UpdateUser(user *models.User, modifyUserLanguage bool) (ke keyProfileUpdated = false if user.Email != "" { - exists, err := s.ExistsEmail(user.Email) + exists, err := s.ExistsEmail(c, user.Email) if err != nil { return false, err @@ -225,7 +226,7 @@ func (s *UserService) UpdateUser(user *models.User, modifyUserLanguage bool) (ke user.UpdatedUnixTime = now updateCols = append(updateCols, "updated_unix_time") - err = s.UserDB().DoTransaction(func(sess *xorm.Session) error { + err = s.UserDB().DoTransaction(c, func(sess *xorm.Session) error { updatedRows, err := sess.ID(user.Uid).Cols(updateCols...).Where("deleted=?", false).Update(user) if err != nil { @@ -245,19 +246,19 @@ func (s *UserService) UpdateUser(user *models.User, modifyUserLanguage bool) (ke } // UpdateUserLastLoginTime updates the last login time field -func (s *UserService) UpdateUserLastLoginTime(uid int64) error { +func (s *UserService) UpdateUserLastLoginTime(c *core.Context, uid int64) error { if uid <= 0 { return errs.ErrUserIdInvalid } - return s.UserDB().DoTransaction(func(sess *xorm.Session) error { + return s.UserDB().DoTransaction(c, func(sess *xorm.Session) error { _, err := sess.ID(uid).Cols("last_login_unix_time").Where("deleted=?", false).Update(&models.User{LastLoginUnixTime: time.Now().Unix()}) return err }) } // EnableUser sets user enabled -func (s *UserService) EnableUser(username string) error { +func (s *UserService) EnableUser(c *core.Context, username string) error { if username == "" { return errs.ErrUsernameIsEmpty } @@ -269,7 +270,7 @@ func (s *UserService) EnableUser(username string) error { UpdatedUnixTime: now, } - updatedRows, err := s.UserDB().Cols("disabled", "updated_unix_time").Where("username=? AND deleted=?", username, false).Update(updateModel) + updatedRows, err := s.UserDB().NewSession(c).Cols("disabled", "updated_unix_time").Where("username=? AND deleted=?", username, false).Update(updateModel) if err != nil { return err @@ -280,7 +281,7 @@ func (s *UserService) EnableUser(username string) error { } // DisableUser sets user disabled -func (s *UserService) DisableUser(username string) error { +func (s *UserService) DisableUser(c *core.Context, username string) error { if username == "" { return errs.ErrUsernameIsEmpty } @@ -292,7 +293,7 @@ func (s *UserService) DisableUser(username string) error { UpdatedUnixTime: now, } - updatedRows, err := s.UserDB().Cols("disabled", "updated_unix_time").Where("username=? AND deleted=?", username, false).Update(updateModel) + updatedRows, err := s.UserDB().NewSession(c).Cols("disabled", "updated_unix_time").Where("username=? AND deleted=?", username, false).Update(updateModel) if err != nil { return err @@ -303,7 +304,7 @@ func (s *UserService) DisableUser(username string) error { } // SetUserEmailVerified sets user email address verified -func (s *UserService) SetUserEmailVerified(username string) error { +func (s *UserService) SetUserEmailVerified(c *core.Context, username string) error { if username == "" { return errs.ErrUsernameIsEmpty } @@ -315,7 +316,7 @@ func (s *UserService) SetUserEmailVerified(username string) error { UpdatedUnixTime: now, } - updatedRows, err := s.UserDB().Cols("email_verified", "updated_unix_time").Where("username=? AND deleted=?", username, false).Update(updateModel) + updatedRows, err := s.UserDB().NewSession(c).Cols("email_verified", "updated_unix_time").Where("username=? AND deleted=?", username, false).Update(updateModel) if err != nil { return err @@ -326,7 +327,7 @@ func (s *UserService) SetUserEmailVerified(username string) error { } // SetUserEmailUnverified sets user email address unverified -func (s *UserService) SetUserEmailUnverified(username string) error { +func (s *UserService) SetUserEmailUnverified(c *core.Context, username string) error { if username == "" { return errs.ErrUsernameIsEmpty } @@ -338,7 +339,7 @@ func (s *UserService) SetUserEmailUnverified(username string) error { UpdatedUnixTime: now, } - updatedRows, err := s.UserDB().Cols("email_verified", "updated_unix_time").Where("username=? AND deleted=?", username, false).Update(updateModel) + updatedRows, err := s.UserDB().NewSession(c).Cols("email_verified", "updated_unix_time").Where("username=? AND deleted=?", username, false).Update(updateModel) if err != nil { return err @@ -349,7 +350,7 @@ func (s *UserService) SetUserEmailUnverified(username string) error { } // DeleteUser deletes an existed user from database -func (s *UserService) DeleteUser(username string) error { +func (s *UserService) DeleteUser(c *core.Context, username string) error { if username == "" { return errs.ErrUsernameIsEmpty } @@ -361,7 +362,7 @@ func (s *UserService) DeleteUser(username string) error { DeletedUnixTime: now, } - deletedRows, err := s.UserDB().Cols("deleted", "deleted_unix_time").Where("username=? AND deleted=?", username, false).Update(updateModel) + deletedRows, err := s.UserDB().NewSession(c).Cols("deleted", "deleted_unix_time").Where("username=? AND deleted=?", username, false).Update(updateModel) if err != nil { return err @@ -372,21 +373,21 @@ func (s *UserService) DeleteUser(username string) error { } // ExistsUsername returns whether the given user name exists -func (s *UserService) ExistsUsername(username string) (bool, error) { +func (s *UserService) ExistsUsername(c *core.Context, username string) (bool, error) { if username == "" { return false, errs.ErrUsernameIsEmpty } - return s.UserDB().Cols("username").Where("username=? AND deleted=?", username, false).Exist(&models.User{}) + return s.UserDB().NewSession(c).Cols("username").Where("username=? AND deleted=?", username, false).Exist(&models.User{}) } // ExistsEmail returns whether the given user email exists -func (s *UserService) ExistsEmail(email string) (bool, error) { +func (s *UserService) ExistsEmail(c *core.Context, email string) (bool, error) { if email == "" { return false, errs.ErrEmailIsEmpty } - return s.UserDB().Cols("email").Where("email=? AND deleted=?", email, false).Exist(&models.User{}) + return s.UserDB().NewSession(c).Cols("email").Where("email=? AND deleted=?", email, false).Exist(&models.User{}) } // IsPasswordEqualsUserPassword returns whether the given password is correct