code refactor

This commit is contained in:
MaysWind
2025-06-30 22:58:17 +08:00
parent 3c100b2543
commit 53aa4ff390
4 changed files with 148 additions and 141 deletions
+13 -141
View File
@@ -70,14 +70,14 @@ func (a *TransactionsApi) TransactionCountHandler(c *core.WebContext) (any, *err
uid := c.GetCurrentUid()
allAccountIds, err := a.getAccountOrSubAccountIds(c, transactionCountReq.AccountIds, uid)
allAccountIds, err := a.accounts.GetAccountOrSubAccountIds(c, transactionCountReq.AccountIds, uid)
if err != nil {
log.Warnf(c, "[transactions.TransactionCountHandler] get account error, because %s", err.Error())
return nil, errs.Or(err, errs.ErrOperationFailed)
}
allCategoryIds, err := a.getCategoryOrSubCategoryIds(c, transactionCountReq.CategoryIds, uid)
allCategoryIds, err := a.transactionCategories.GetCategoryOrSubCategoryIds(c, transactionCountReq.CategoryIds, uid)
if err != nil {
log.Warnf(c, "[transactions.TransactionCountHandler] get transaction category error, because %s", err.Error())
@@ -88,7 +88,7 @@ func (a *TransactionsApi) TransactionCountHandler(c *core.WebContext) (any, *err
noTags := transactionCountReq.TagIds == "none"
if !noTags {
allTagIds, err = a.getTagIds(transactionCountReq.TagIds)
allTagIds, err = a.transactionTags.GetTagIds(transactionCountReq.TagIds)
if err != nil {
log.Warnf(c, "[transactions.TransactionCountHandler] get transaction tag ids error, because %s", err.Error())
@@ -138,14 +138,14 @@ func (a *TransactionsApi) TransactionListHandler(c *core.WebContext) (any, *errs
return nil, errs.ErrUserNotFound
}
allAccountIds, err := a.getAccountOrSubAccountIds(c, transactionListReq.AccountIds, uid)
allAccountIds, err := a.accounts.GetAccountOrSubAccountIds(c, transactionListReq.AccountIds, uid)
if err != nil {
log.Warnf(c, "[transactions.TransactionListHandler] get account error, because %s", err.Error())
return nil, errs.Or(err, errs.ErrOperationFailed)
}
allCategoryIds, err := a.getCategoryOrSubCategoryIds(c, transactionListReq.CategoryIds, uid)
allCategoryIds, err := a.transactionCategories.GetCategoryOrSubCategoryIds(c, transactionListReq.CategoryIds, uid)
if err != nil {
log.Warnf(c, "[transactions.TransactionListHandler] get transaction category error, because %s", err.Error())
@@ -156,7 +156,7 @@ func (a *TransactionsApi) TransactionListHandler(c *core.WebContext) (any, *errs
noTags := transactionListReq.TagIds == "none"
if !noTags {
allTagIds, err = a.getTagIds(transactionListReq.TagIds)
allTagIds, err = a.transactionTags.GetTagIds(transactionListReq.TagIds)
if err != nil {
log.Warnf(c, "[transactions.TransactionListHandler] get transaction tag ids error, because %s", err.Error())
@@ -241,14 +241,14 @@ func (a *TransactionsApi) TransactionMonthListHandler(c *core.WebContext) (any,
return nil, errs.ErrUserNotFound
}
allAccountIds, err := a.getAccountOrSubAccountIds(c, transactionListReq.AccountIds, uid)
allAccountIds, err := a.accounts.GetAccountOrSubAccountIds(c, transactionListReq.AccountIds, uid)
if err != nil {
log.Warnf(c, "[transactions.TransactionMonthListHandler] get account error, because %s", err.Error())
return nil, errs.Or(err, errs.ErrOperationFailed)
}
allCategoryIds, err := a.getCategoryOrSubCategoryIds(c, transactionListReq.CategoryIds, uid)
allCategoryIds, err := a.transactionCategories.GetCategoryOrSubCategoryIds(c, transactionListReq.CategoryIds, uid)
if err != nil {
log.Warnf(c, "[transactions.TransactionMonthListHandler] get transaction category error, because %s", err.Error())
@@ -259,7 +259,7 @@ func (a *TransactionsApi) TransactionMonthListHandler(c *core.WebContext) (any,
noTags := transactionListReq.TagIds == "none"
if !noTags {
allTagIds, err = a.getTagIds(transactionListReq.TagIds)
allTagIds, err = a.transactionTags.GetTagIds(transactionListReq.TagIds)
if err != nil {
log.Warnf(c, "[transactions.TransactionMonthListHandler] get transaction tag ids error, because %s", err.Error())
@@ -310,7 +310,7 @@ func (a *TransactionsApi) TransactionStatisticsHandler(c *core.WebContext) (any,
noTags := statisticReq.TagIds == "none"
if !noTags {
allTagIds, err = a.getTagIds(statisticReq.TagIds)
allTagIds, err = a.transactionTags.GetTagIds(statisticReq.TagIds)
if err != nil {
log.Warnf(c, "[transactions.TransactionStatisticsHandler] get transaction tag ids error, because %s", err.Error())
@@ -373,7 +373,7 @@ func (a *TransactionsApi) TransactionStatisticsTrendsHandler(c *core.WebContext)
noTags := statisticTrendsReq.TagIds == "none"
if !noTags {
allTagIds, err = a.getTagIds(statisticTrendsReq.TagIds)
allTagIds, err = a.transactionTags.GetTagIds(statisticTrendsReq.TagIds)
if err != nil {
log.Warnf(c, "[transactions.TransactionStatisticsTrendsHandler] get transaction tag ids error, because %s", err.Error())
@@ -617,7 +617,7 @@ func (a *TransactionsApi) TransactionGetHandler(c *core.WebContext) (any, *errs.
}
if !transactionGetReq.TrimTag {
tagMap, err = a.transactionTags.GetTagsByTagIds(c, uid, utils.ToUniqueInt64Slice(a.getTransactionTagIds(allTransactionTagIds)))
tagMap, err = a.transactionTags.GetTagsByTagIds(c, uid, utils.ToUniqueInt64Slice(a.transactionTags.GetTransactionTagIds(allTransactionTagIds)))
if err != nil {
log.Errorf(c, "[transactions.TransactionGetHandler] failed to get transactions tags for user \"uid:%d\", because %s", uid, err.Error())
@@ -1532,134 +1532,6 @@ func (a *TransactionsApi) filterTransactions(c *core.WebContext, uid int64, tran
return finalTransactions
}
func (a *TransactionsApi) getAccountOrSubAccountIds(c *core.WebContext, accountIds string, uid int64) ([]int64, error) {
if accountIds == "" || accountIds == "0" {
return nil, nil
}
requestAccountIds, err := utils.StringArrayToInt64Array(strings.Split(accountIds, ","))
if err != nil {
return nil, errs.Or(err, errs.ErrAccountIdInvalid)
}
var allAccountIds []int64
if len(requestAccountIds) > 0 {
allSubAccounts, err := a.accounts.GetSubAccountsByAccountIds(c, uid, requestAccountIds)
if err != nil {
return nil, err
}
accountIdsMap := make(map[int64]int32, len(requestAccountIds))
for i := 0; i < len(requestAccountIds); i++ {
accountIdsMap[requestAccountIds[i]] = 0
}
for i := 0; i < len(allSubAccounts); i++ {
subAccount := allSubAccounts[i]
if refCount, exists := accountIdsMap[subAccount.ParentAccountId]; exists {
accountIdsMap[subAccount.ParentAccountId] = refCount + 1
} else {
accountIdsMap[subAccount.ParentAccountId] = 1
}
if _, exists := accountIdsMap[subAccount.AccountId]; exists {
delete(accountIdsMap, subAccount.AccountId)
}
allAccountIds = append(allAccountIds, subAccount.AccountId)
}
for accountId, refCount := range accountIdsMap {
if refCount < 1 {
allAccountIds = append(allAccountIds, accountId)
}
}
}
return allAccountIds, nil
}
func (a *TransactionsApi) getCategoryOrSubCategoryIds(c *core.WebContext, categoryIds string, uid int64) ([]int64, error) {
if categoryIds == "" || categoryIds == "0" {
return nil, nil
}
requestCategoryIds, err := utils.StringArrayToInt64Array(strings.Split(categoryIds, ","))
if err != nil {
return nil, errs.Or(err, errs.ErrTransactionCategoryIdInvalid)
}
var allCategoryIds []int64
if len(requestCategoryIds) > 0 {
allSubCategories, err := a.transactionCategories.GetSubCategoriesByCategoryIds(c, uid, requestCategoryIds)
if err != nil {
return nil, err
}
categoryIdsMap := make(map[int64]int32, len(requestCategoryIds))
for i := 0; i < len(requestCategoryIds); i++ {
categoryIdsMap[requestCategoryIds[i]] = 0
}
for i := 0; i < len(allSubCategories); i++ {
subCategory := allSubCategories[i]
if refCount, exists := categoryIdsMap[subCategory.ParentCategoryId]; exists {
categoryIdsMap[subCategory.ParentCategoryId] = refCount + 1
} else {
categoryIdsMap[subCategory.ParentCategoryId] = 1
}
if _, exists := categoryIdsMap[subCategory.CategoryId]; exists {
delete(categoryIdsMap, subCategory.CategoryId)
}
allCategoryIds = append(allCategoryIds, subCategory.CategoryId)
}
for accountId, refCount := range categoryIdsMap {
if refCount < 1 {
allCategoryIds = append(allCategoryIds, accountId)
}
}
}
return allCategoryIds, nil
}
func (a *TransactionsApi) getTagIds(tagIds string) ([]int64, error) {
if tagIds == "" || tagIds == "0" {
return nil, nil
}
requestTagIds, err := utils.StringArrayToInt64Array(strings.Split(tagIds, ","))
if err != nil {
return nil, errs.Or(err, errs.ErrTransactionTagIdInvalid)
}
return requestTagIds, nil
}
func (a *TransactionsApi) getTransactionTagIds(allTransactionTagIds map[int64][]int64) []int64 {
allTagIds := make([]int64, 0, len(allTransactionTagIds))
for _, tagIds := range allTransactionTagIds {
allTagIds = append(allTagIds, tagIds...)
}
return allTagIds
}
func (a *TransactionsApi) getTransactionTagInfoResponses(tagIds []int64, allTransactionTags map[int64]*models.TransactionTag) []*models.TransactionTagInfoResponse {
allTags := make([]*models.TransactionTagInfoResponse, 0, len(tagIds))
@@ -1729,7 +1601,7 @@ func (a *TransactionsApi) getTransactionResponseListResult(c *core.WebContext, u
}
if !trimTag {
tagMap, err = a.transactionTags.GetTagsByTagIds(c, uid, utils.ToUniqueInt64Slice(a.getTransactionTagIds(allTransactionTagIds)))
tagMap, err = a.transactionTags.GetTagsByTagIds(c, uid, utils.ToUniqueInt64Slice(a.transactionTags.GetTransactionTagIds(allTransactionTagIds)))
if err != nil {
log.Errorf(c, "[transactions.getTransactionResponseListResult] failed to get transactions tags for user \"uid:%d\", because %s", uid, err.Error())