From 53aa4ff3904845744aaae3e32e85feb30dab1c64 Mon Sep 17 00:00:00 2001 From: MaysWind Date: Mon, 30 Jun 2025 22:58:17 +0800 Subject: [PATCH] code refactor --- pkg/api/transactions.go | 154 +++---------------------- pkg/services/accounts.go | 53 +++++++++ pkg/services/transaction_categories.go | 54 +++++++++ pkg/services/transaction_tags.go | 28 +++++ 4 files changed, 148 insertions(+), 141 deletions(-) diff --git a/pkg/api/transactions.go b/pkg/api/transactions.go index 4f7d18d0..18d58899 100644 --- a/pkg/api/transactions.go +++ b/pkg/api/transactions.go @@ -70,14 +70,14 @@ func (a *TransactionsApi) TransactionCountHandler(c *core.WebContext) (any, *err uid := c.GetCurrentUid() - allAccountIds, err := a.getAccountOrSubAccountIds(c, transactionCountReq.AccountIds, uid) + allAccountIds, err := a.accounts.GetAccountOrSubAccountIds(c, transactionCountReq.AccountIds, uid) if err != nil { log.Warnf(c, "[transactions.TransactionCountHandler] get account error, because %s", err.Error()) return nil, errs.Or(err, errs.ErrOperationFailed) } - allCategoryIds, err := a.getCategoryOrSubCategoryIds(c, transactionCountReq.CategoryIds, uid) + allCategoryIds, err := a.transactionCategories.GetCategoryOrSubCategoryIds(c, transactionCountReq.CategoryIds, uid) if err != nil { log.Warnf(c, "[transactions.TransactionCountHandler] get transaction category error, because %s", err.Error()) @@ -88,7 +88,7 @@ func (a *TransactionsApi) TransactionCountHandler(c *core.WebContext) (any, *err noTags := transactionCountReq.TagIds == "none" if !noTags { - allTagIds, err = a.getTagIds(transactionCountReq.TagIds) + allTagIds, err = a.transactionTags.GetTagIds(transactionCountReq.TagIds) if err != nil { log.Warnf(c, "[transactions.TransactionCountHandler] get transaction tag ids error, because %s", err.Error()) @@ -138,14 +138,14 @@ func (a *TransactionsApi) TransactionListHandler(c *core.WebContext) (any, *errs return nil, errs.ErrUserNotFound } - allAccountIds, err := a.getAccountOrSubAccountIds(c, transactionListReq.AccountIds, uid) + allAccountIds, err := a.accounts.GetAccountOrSubAccountIds(c, transactionListReq.AccountIds, uid) if err != nil { log.Warnf(c, "[transactions.TransactionListHandler] get account error, because %s", err.Error()) return nil, errs.Or(err, errs.ErrOperationFailed) } - allCategoryIds, err := a.getCategoryOrSubCategoryIds(c, transactionListReq.CategoryIds, uid) + allCategoryIds, err := a.transactionCategories.GetCategoryOrSubCategoryIds(c, transactionListReq.CategoryIds, uid) if err != nil { log.Warnf(c, "[transactions.TransactionListHandler] get transaction category error, because %s", err.Error()) @@ -156,7 +156,7 @@ func (a *TransactionsApi) TransactionListHandler(c *core.WebContext) (any, *errs noTags := transactionListReq.TagIds == "none" if !noTags { - allTagIds, err = a.getTagIds(transactionListReq.TagIds) + allTagIds, err = a.transactionTags.GetTagIds(transactionListReq.TagIds) if err != nil { log.Warnf(c, "[transactions.TransactionListHandler] get transaction tag ids error, because %s", err.Error()) @@ -241,14 +241,14 @@ func (a *TransactionsApi) TransactionMonthListHandler(c *core.WebContext) (any, return nil, errs.ErrUserNotFound } - allAccountIds, err := a.getAccountOrSubAccountIds(c, transactionListReq.AccountIds, uid) + allAccountIds, err := a.accounts.GetAccountOrSubAccountIds(c, transactionListReq.AccountIds, uid) if err != nil { log.Warnf(c, "[transactions.TransactionMonthListHandler] get account error, because %s", err.Error()) return nil, errs.Or(err, errs.ErrOperationFailed) } - allCategoryIds, err := a.getCategoryOrSubCategoryIds(c, transactionListReq.CategoryIds, uid) + allCategoryIds, err := a.transactionCategories.GetCategoryOrSubCategoryIds(c, transactionListReq.CategoryIds, uid) if err != nil { log.Warnf(c, "[transactions.TransactionMonthListHandler] get transaction category error, because %s", err.Error()) @@ -259,7 +259,7 @@ func (a *TransactionsApi) TransactionMonthListHandler(c *core.WebContext) (any, noTags := transactionListReq.TagIds == "none" if !noTags { - allTagIds, err = a.getTagIds(transactionListReq.TagIds) + allTagIds, err = a.transactionTags.GetTagIds(transactionListReq.TagIds) if err != nil { log.Warnf(c, "[transactions.TransactionMonthListHandler] get transaction tag ids error, because %s", err.Error()) @@ -310,7 +310,7 @@ func (a *TransactionsApi) TransactionStatisticsHandler(c *core.WebContext) (any, noTags := statisticReq.TagIds == "none" if !noTags { - allTagIds, err = a.getTagIds(statisticReq.TagIds) + allTagIds, err = a.transactionTags.GetTagIds(statisticReq.TagIds) if err != nil { log.Warnf(c, "[transactions.TransactionStatisticsHandler] get transaction tag ids error, because %s", err.Error()) @@ -373,7 +373,7 @@ func (a *TransactionsApi) TransactionStatisticsTrendsHandler(c *core.WebContext) noTags := statisticTrendsReq.TagIds == "none" if !noTags { - allTagIds, err = a.getTagIds(statisticTrendsReq.TagIds) + allTagIds, err = a.transactionTags.GetTagIds(statisticTrendsReq.TagIds) if err != nil { log.Warnf(c, "[transactions.TransactionStatisticsTrendsHandler] get transaction tag ids error, because %s", err.Error()) @@ -617,7 +617,7 @@ func (a *TransactionsApi) TransactionGetHandler(c *core.WebContext) (any, *errs. } if !transactionGetReq.TrimTag { - tagMap, err = a.transactionTags.GetTagsByTagIds(c, uid, utils.ToUniqueInt64Slice(a.getTransactionTagIds(allTransactionTagIds))) + tagMap, err = a.transactionTags.GetTagsByTagIds(c, uid, utils.ToUniqueInt64Slice(a.transactionTags.GetTransactionTagIds(allTransactionTagIds))) if err != nil { log.Errorf(c, "[transactions.TransactionGetHandler] failed to get transactions tags for user \"uid:%d\", because %s", uid, err.Error()) @@ -1532,134 +1532,6 @@ func (a *TransactionsApi) filterTransactions(c *core.WebContext, uid int64, tran return finalTransactions } -func (a *TransactionsApi) getAccountOrSubAccountIds(c *core.WebContext, accountIds string, uid int64) ([]int64, error) { - if accountIds == "" || accountIds == "0" { - return nil, nil - } - - requestAccountIds, err := utils.StringArrayToInt64Array(strings.Split(accountIds, ",")) - - if err != nil { - return nil, errs.Or(err, errs.ErrAccountIdInvalid) - } - - var allAccountIds []int64 - - if len(requestAccountIds) > 0 { - allSubAccounts, err := a.accounts.GetSubAccountsByAccountIds(c, uid, requestAccountIds) - - if err != nil { - return nil, err - } - - accountIdsMap := make(map[int64]int32, len(requestAccountIds)) - - for i := 0; i < len(requestAccountIds); i++ { - accountIdsMap[requestAccountIds[i]] = 0 - } - - for i := 0; i < len(allSubAccounts); i++ { - subAccount := allSubAccounts[i] - - if refCount, exists := accountIdsMap[subAccount.ParentAccountId]; exists { - accountIdsMap[subAccount.ParentAccountId] = refCount + 1 - } else { - accountIdsMap[subAccount.ParentAccountId] = 1 - } - - if _, exists := accountIdsMap[subAccount.AccountId]; exists { - delete(accountIdsMap, subAccount.AccountId) - } - - allAccountIds = append(allAccountIds, subAccount.AccountId) - } - - for accountId, refCount := range accountIdsMap { - if refCount < 1 { - allAccountIds = append(allAccountIds, accountId) - } - } - } - - return allAccountIds, nil -} - -func (a *TransactionsApi) getCategoryOrSubCategoryIds(c *core.WebContext, categoryIds string, uid int64) ([]int64, error) { - if categoryIds == "" || categoryIds == "0" { - return nil, nil - } - - requestCategoryIds, err := utils.StringArrayToInt64Array(strings.Split(categoryIds, ",")) - - if err != nil { - return nil, errs.Or(err, errs.ErrTransactionCategoryIdInvalid) - } - - var allCategoryIds []int64 - - if len(requestCategoryIds) > 0 { - allSubCategories, err := a.transactionCategories.GetSubCategoriesByCategoryIds(c, uid, requestCategoryIds) - - if err != nil { - return nil, err - } - - categoryIdsMap := make(map[int64]int32, len(requestCategoryIds)) - - for i := 0; i < len(requestCategoryIds); i++ { - categoryIdsMap[requestCategoryIds[i]] = 0 - } - - for i := 0; i < len(allSubCategories); i++ { - subCategory := allSubCategories[i] - - if refCount, exists := categoryIdsMap[subCategory.ParentCategoryId]; exists { - categoryIdsMap[subCategory.ParentCategoryId] = refCount + 1 - } else { - categoryIdsMap[subCategory.ParentCategoryId] = 1 - } - - if _, exists := categoryIdsMap[subCategory.CategoryId]; exists { - delete(categoryIdsMap, subCategory.CategoryId) - } - - allCategoryIds = append(allCategoryIds, subCategory.CategoryId) - } - - for accountId, refCount := range categoryIdsMap { - if refCount < 1 { - allCategoryIds = append(allCategoryIds, accountId) - } - } - } - - return allCategoryIds, nil -} - -func (a *TransactionsApi) getTagIds(tagIds string) ([]int64, error) { - if tagIds == "" || tagIds == "0" { - return nil, nil - } - - requestTagIds, err := utils.StringArrayToInt64Array(strings.Split(tagIds, ",")) - - if err != nil { - return nil, errs.Or(err, errs.ErrTransactionTagIdInvalid) - } - - return requestTagIds, nil -} - -func (a *TransactionsApi) getTransactionTagIds(allTransactionTagIds map[int64][]int64) []int64 { - allTagIds := make([]int64, 0, len(allTransactionTagIds)) - - for _, tagIds := range allTransactionTagIds { - allTagIds = append(allTagIds, tagIds...) - } - - return allTagIds -} - func (a *TransactionsApi) getTransactionTagInfoResponses(tagIds []int64, allTransactionTags map[int64]*models.TransactionTag) []*models.TransactionTagInfoResponse { allTags := make([]*models.TransactionTagInfoResponse, 0, len(tagIds)) @@ -1729,7 +1601,7 @@ func (a *TransactionsApi) getTransactionResponseListResult(c *core.WebContext, u } if !trimTag { - tagMap, err = a.transactionTags.GetTagsByTagIds(c, uid, utils.ToUniqueInt64Slice(a.getTransactionTagIds(allTransactionTagIds))) + tagMap, err = a.transactionTags.GetTagsByTagIds(c, uid, utils.ToUniqueInt64Slice(a.transactionTags.GetTransactionTagIds(allTransactionTagIds))) if err != nil { log.Errorf(c, "[transactions.getTransactionResponseListResult] failed to get transactions tags for user \"uid:%d\", because %s", uid, err.Error()) diff --git a/pkg/services/accounts.go b/pkg/services/accounts.go index 48f5037f..e42e8549 100644 --- a/pkg/services/accounts.go +++ b/pkg/services/accounts.go @@ -825,3 +825,56 @@ func (s *AccountService) GetAccountNames(accounts []*models.Account) []string { return accountNames } + +// GetAccountOrSubAccountIds returns a list of account ids or sub-account ids according to given account ids +func (s *AccountService) GetAccountOrSubAccountIds(c *core.WebContext, accountIds string, uid int64) ([]int64, error) { + if accountIds == "" || accountIds == "0" { + return nil, nil + } + + requestAccountIds, err := utils.StringArrayToInt64Array(strings.Split(accountIds, ",")) + + if err != nil { + return nil, errs.Or(err, errs.ErrAccountIdInvalid) + } + + var allAccountIds []int64 + + if len(requestAccountIds) > 0 { + allSubAccounts, err := s.GetSubAccountsByAccountIds(c, uid, requestAccountIds) + + if err != nil { + return nil, err + } + + accountIdsMap := make(map[int64]int32, len(requestAccountIds)) + + for i := 0; i < len(requestAccountIds); i++ { + accountIdsMap[requestAccountIds[i]] = 0 + } + + for i := 0; i < len(allSubAccounts); i++ { + subAccount := allSubAccounts[i] + + if refCount, exists := accountIdsMap[subAccount.ParentAccountId]; exists { + accountIdsMap[subAccount.ParentAccountId] = refCount + 1 + } else { + accountIdsMap[subAccount.ParentAccountId] = 1 + } + + if _, exists := accountIdsMap[subAccount.AccountId]; exists { + delete(accountIdsMap, subAccount.AccountId) + } + + allAccountIds = append(allAccountIds, subAccount.AccountId) + } + + for accountId, refCount := range accountIdsMap { + if refCount < 1 { + allAccountIds = append(allAccountIds, accountId) + } + } + } + + return allAccountIds, nil +} diff --git a/pkg/services/transaction_categories.go b/pkg/services/transaction_categories.go index 2431808c..b1bc3074 100644 --- a/pkg/services/transaction_categories.go +++ b/pkg/services/transaction_categories.go @@ -10,6 +10,7 @@ import ( "github.com/mayswind/ezbookkeeping/pkg/datastore" "github.com/mayswind/ezbookkeeping/pkg/errs" "github.com/mayswind/ezbookkeeping/pkg/models" + "github.com/mayswind/ezbookkeeping/pkg/utils" "github.com/mayswind/ezbookkeeping/pkg/uuid" ) @@ -523,3 +524,56 @@ func (s *TransactionCategoryService) GetCategoryNames(categories []*models.Trans return categoryNames } + +// GetCategoryOrSubCategoryIds returns all category ids and sub-category ids according to given category ids +func (s *TransactionCategoryService) GetCategoryOrSubCategoryIds(c *core.WebContext, categoryIds string, uid int64) ([]int64, error) { + if categoryIds == "" || categoryIds == "0" { + return nil, nil + } + + requestCategoryIds, err := utils.StringArrayToInt64Array(strings.Split(categoryIds, ",")) + + if err != nil { + return nil, errs.Or(err, errs.ErrTransactionCategoryIdInvalid) + } + + var allCategoryIds []int64 + + if len(requestCategoryIds) > 0 { + allSubCategories, err := s.GetSubCategoriesByCategoryIds(c, uid, requestCategoryIds) + + if err != nil { + return nil, err + } + + categoryIdsMap := make(map[int64]int32, len(requestCategoryIds)) + + for i := 0; i < len(requestCategoryIds); i++ { + categoryIdsMap[requestCategoryIds[i]] = 0 + } + + for i := 0; i < len(allSubCategories); i++ { + subCategory := allSubCategories[i] + + if refCount, exists := categoryIdsMap[subCategory.ParentCategoryId]; exists { + categoryIdsMap[subCategory.ParentCategoryId] = refCount + 1 + } else { + categoryIdsMap[subCategory.ParentCategoryId] = 1 + } + + if _, exists := categoryIdsMap[subCategory.CategoryId]; exists { + delete(categoryIdsMap, subCategory.CategoryId) + } + + allCategoryIds = append(allCategoryIds, subCategory.CategoryId) + } + + for accountId, refCount := range categoryIdsMap { + if refCount < 1 { + allCategoryIds = append(allCategoryIds, accountId) + } + } + } + + return allCategoryIds, nil +} diff --git a/pkg/services/transaction_tags.go b/pkg/services/transaction_tags.go index 5527c1f7..30764baa 100644 --- a/pkg/services/transaction_tags.go +++ b/pkg/services/transaction_tags.go @@ -1,6 +1,7 @@ package services import ( + "strings" "time" "xorm.io/xorm" @@ -507,6 +508,7 @@ func (s *TransactionTagService) GetTagNames(tags []*models.TransactionTag) []str return tagNames } +// GetGroupedTransactionTagIds returns a map of transaction tag ids grouped by transaction id func (s *TransactionTagService) GetGroupedTransactionTagIds(tagIndexes []*models.TransactionTagIndex) map[int64][]int64 { allTransactionTagIds := make(map[int64][]int64) @@ -529,3 +531,29 @@ func (s *TransactionTagService) GetGroupedTransactionTagIds(tagIndexes []*models return allTransactionTagIds } + +// GetTagIds converts a comma-separated string of tag ids into a slice of int64 +func (s *TransactionTagService) GetTagIds(tagIds string) ([]int64, error) { + if tagIds == "" || tagIds == "0" { + return nil, nil + } + + requestTagIds, err := utils.StringArrayToInt64Array(strings.Split(tagIds, ",")) + + if err != nil { + return nil, errs.Or(err, errs.ErrTransactionTagIdInvalid) + } + + return requestTagIds, nil +} + +// GetTransactionTagIds returns a slice of all transaction tag ids from a map of transaction tag ids grouped by transaction id +func (s *TransactionTagService) GetTransactionTagIds(allTransactionTagIds map[int64][]int64) []int64 { + allTagIds := make([]int64, 0, len(allTransactionTagIds)) + + for _, tagIds := range allTransactionTagIds { + allTagIds = append(allTagIds, tagIds...) + } + + return allTagIds +}