code refactor

This commit is contained in:
MaysWind
2021-03-28 17:32:40 +08:00
parent d77d9f39d2
commit 841f275668
+34 -48
View File
@@ -43,23 +43,11 @@ func (a *TransactionsApi) TransactionCountHandler(c *core.Context) (interface{},
uid := c.GetCurrentUid()
var allCategoryIds []int64
allCategoryIds, err := a.getCategoryAndSubCategoryIds(transactionCountReq.CategoryId, uid)
if transactionCountReq.CategoryId > 0 {
allSubCategories, err := a.transactionCategories.GetAllCategoriesByUid(uid, 0, transactionCountReq.CategoryId)
if err != nil {
log.WarnfWithRequestId(c, "[transactions.TransactionCountHandler] get transaction category error, because %s", err.Error())
return nil, errs.ErrOperationFailed
}
if len(allSubCategories) > 0 {
for i := 0; i < len(allSubCategories); i++ {
allCategoryIds = append(allCategoryIds, allSubCategories[i].CategoryId)
}
} else {
allCategoryIds = append(allCategoryIds, transactionCountReq.CategoryId)
}
if err != nil {
log.WarnfWithRequestId(c, "[transactions.TransactionCountHandler] get transaction category error, because %s", err.Error())
return nil, errs.ErrOperationFailed
}
totalCount, err := a.transactions.GetTransactionCount(uid, transactionCountReq.MaxTime, transactionCountReq.MinTime, transactionCountReq.Type, allCategoryIds, transactionCountReq.AccountId, transactionCountReq.Keyword)
@@ -99,23 +87,11 @@ func (a *TransactionsApi) TransactionListHandler(c *core.Context) (interface{},
return nil, errs.ErrUserNotFound
}
var allCategoryIds []int64
allCategoryIds, err := a.getCategoryAndSubCategoryIds(transactionListReq.CategoryId, uid)
if transactionListReq.CategoryId > 0 {
allSubCategories, err := a.transactionCategories.GetAllCategoriesByUid(uid, 0, transactionListReq.CategoryId)
if err != nil {
log.WarnfWithRequestId(c, "[transactions.TransactionListHandler] get transaction category error, because %s", err.Error())
return nil, errs.ErrOperationFailed
}
if len(allSubCategories) > 0 {
for i := 0; i < len(allSubCategories); i++ {
allCategoryIds = append(allCategoryIds, allSubCategories[i].CategoryId)
}
} else {
allCategoryIds = append(allCategoryIds, transactionListReq.CategoryId)
}
if err != nil {
log.WarnfWithRequestId(c, "[transactions.TransactionListHandler] get transaction category error, because %s", err.Error())
return nil, errs.ErrOperationFailed
}
transactions, err := a.transactions.GetTransactionsByMaxTime(uid, transactionListReq.MaxTime, transactionListReq.MinTime, transactionListReq.Type, allCategoryIds, transactionListReq.AccountId, transactionListReq.Keyword, transactionListReq.Count+1, true)
@@ -264,23 +240,11 @@ func (a *TransactionsApi) TransactionMonthListHandler(c *core.Context) (interfac
return nil, errs.ErrUserNotFound
}
var allCategoryIds []int64
allCategoryIds, err := a.getCategoryAndSubCategoryIds(transactionListReq.CategoryId, uid)
if transactionListReq.CategoryId > 0 {
allSubCategories, err := a.transactionCategories.GetAllCategoriesByUid(uid, 0, transactionListReq.CategoryId)
if err != nil {
log.WarnfWithRequestId(c, "[transactions.TransactionMonthListHandler] get transaction category error, because %s", err.Error())
return nil, errs.ErrOperationFailed
}
if len(allSubCategories) > 0 {
for i := 0; i < len(allSubCategories); i++ {
allCategoryIds = append(allCategoryIds, allSubCategories[i].CategoryId)
}
} else {
allCategoryIds = append(allCategoryIds, transactionListReq.CategoryId)
}
if err != nil {
log.WarnfWithRequestId(c, "[transactions.TransactionMonthListHandler] get transaction category error, because %s", err.Error())
return nil, errs.ErrOperationFailed
}
transactions, err := a.transactions.GetTransactionsInMonthByPage(uid, transactionListReq.Year, transactionListReq.Month, transactionListReq.Type, allCategoryIds, transactionListReq.AccountId, transactionListReq.Keyword, transactionListReq.Page, transactionListReq.Count, utcOffset)
@@ -765,6 +729,28 @@ func (a *TransactionsApi) filterTransactions(c *core.Context, uid int64, transac
return finalTransactions
}
func (a *TransactionsApi) getCategoryAndSubCategoryIds(categoryId int64, uid int64) ([]int64, error) {
var allCategoryIds []int64
if categoryId > 0 {
allSubCategories, err := a.transactionCategories.GetAllCategoriesByUid(uid, 0, categoryId)
if err != nil {
return nil, err
}
if len(allSubCategories) > 0 {
for i := 0; i < len(allSubCategories); i++ {
allCategoryIds = append(allCategoryIds, allSubCategories[i].CategoryId)
}
} else {
allCategoryIds = append(allCategoryIds, categoryId)
}
}
return allCategoryIds, nil
}
func (a *TransactionsApi) getTransactionTagIds(allTransactionTagIds map[int64][]int64) []int64 {
allTagIds := make([]int64, 0, len(allTransactionTagIds))