diff --git a/pkg/api/transactions.go b/pkg/api/transactions.go index 2efc26df..174fdce8 100644 --- a/pkg/api/transactions.go +++ b/pkg/api/transactions.go @@ -43,23 +43,11 @@ func (a *TransactionsApi) TransactionCountHandler(c *core.Context) (interface{}, uid := c.GetCurrentUid() - var allCategoryIds []int64 + allCategoryIds, err := a.getCategoryAndSubCategoryIds(transactionCountReq.CategoryId, uid) - if transactionCountReq.CategoryId > 0 { - allSubCategories, err := a.transactionCategories.GetAllCategoriesByUid(uid, 0, transactionCountReq.CategoryId) - - if err != nil { - log.WarnfWithRequestId(c, "[transactions.TransactionCountHandler] get transaction category error, because %s", err.Error()) - return nil, errs.ErrOperationFailed - } - - if len(allSubCategories) > 0 { - for i := 0; i < len(allSubCategories); i++ { - allCategoryIds = append(allCategoryIds, allSubCategories[i].CategoryId) - } - } else { - allCategoryIds = append(allCategoryIds, transactionCountReq.CategoryId) - } + if err != nil { + log.WarnfWithRequestId(c, "[transactions.TransactionCountHandler] get transaction category error, because %s", err.Error()) + return nil, errs.ErrOperationFailed } totalCount, err := a.transactions.GetTransactionCount(uid, transactionCountReq.MaxTime, transactionCountReq.MinTime, transactionCountReq.Type, allCategoryIds, transactionCountReq.AccountId, transactionCountReq.Keyword) @@ -99,23 +87,11 @@ func (a *TransactionsApi) TransactionListHandler(c *core.Context) (interface{}, return nil, errs.ErrUserNotFound } - var allCategoryIds []int64 + allCategoryIds, err := a.getCategoryAndSubCategoryIds(transactionListReq.CategoryId, uid) - if transactionListReq.CategoryId > 0 { - allSubCategories, err := a.transactionCategories.GetAllCategoriesByUid(uid, 0, transactionListReq.CategoryId) - - if err != nil { - log.WarnfWithRequestId(c, "[transactions.TransactionListHandler] get transaction category error, because %s", err.Error()) - return nil, errs.ErrOperationFailed - } - - if len(allSubCategories) > 0 { - for i := 0; i < len(allSubCategories); i++ { - allCategoryIds = append(allCategoryIds, allSubCategories[i].CategoryId) - } - } else { - allCategoryIds = append(allCategoryIds, transactionListReq.CategoryId) - } + if err != nil { + log.WarnfWithRequestId(c, "[transactions.TransactionListHandler] get transaction category error, because %s", err.Error()) + return nil, errs.ErrOperationFailed } transactions, err := a.transactions.GetTransactionsByMaxTime(uid, transactionListReq.MaxTime, transactionListReq.MinTime, transactionListReq.Type, allCategoryIds, transactionListReq.AccountId, transactionListReq.Keyword, transactionListReq.Count+1, true) @@ -264,23 +240,11 @@ func (a *TransactionsApi) TransactionMonthListHandler(c *core.Context) (interfac return nil, errs.ErrUserNotFound } - var allCategoryIds []int64 + allCategoryIds, err := a.getCategoryAndSubCategoryIds(transactionListReq.CategoryId, uid) - if transactionListReq.CategoryId > 0 { - allSubCategories, err := a.transactionCategories.GetAllCategoriesByUid(uid, 0, transactionListReq.CategoryId) - - if err != nil { - log.WarnfWithRequestId(c, "[transactions.TransactionMonthListHandler] get transaction category error, because %s", err.Error()) - return nil, errs.ErrOperationFailed - } - - if len(allSubCategories) > 0 { - for i := 0; i < len(allSubCategories); i++ { - allCategoryIds = append(allCategoryIds, allSubCategories[i].CategoryId) - } - } else { - allCategoryIds = append(allCategoryIds, transactionListReq.CategoryId) - } + if err != nil { + log.WarnfWithRequestId(c, "[transactions.TransactionMonthListHandler] get transaction category error, because %s", err.Error()) + return nil, errs.ErrOperationFailed } transactions, err := a.transactions.GetTransactionsInMonthByPage(uid, transactionListReq.Year, transactionListReq.Month, transactionListReq.Type, allCategoryIds, transactionListReq.AccountId, transactionListReq.Keyword, transactionListReq.Page, transactionListReq.Count, utcOffset) @@ -765,6 +729,28 @@ func (a *TransactionsApi) filterTransactions(c *core.Context, uid int64, transac return finalTransactions } +func (a *TransactionsApi) getCategoryAndSubCategoryIds(categoryId int64, uid int64) ([]int64, error) { + var allCategoryIds []int64 + + if categoryId > 0 { + allSubCategories, err := a.transactionCategories.GetAllCategoriesByUid(uid, 0, categoryId) + + if err != nil { + return nil, err + } + + if len(allSubCategories) > 0 { + for i := 0; i < len(allSubCategories); i++ { + allCategoryIds = append(allCategoryIds, allSubCategories[i].CategoryId) + } + } else { + allCategoryIds = append(allCategoryIds, categoryId) + } + } + + return allCategoryIds, nil +} + func (a *TransactionsApi) getTransactionTagIds(allTransactionTagIds map[int64][]int64) []int64 { allTagIds := make([]int64, 0, len(allTransactionTagIds))