diff --git a/pkg/models/transaction.go b/pkg/models/transaction.go index c3d27d99..da04c6c4 100644 --- a/pkg/models/transaction.go +++ b/pkg/models/transaction.go @@ -19,6 +19,21 @@ const ( TRANSACTION_TYPE_TRANSFER TransactionType = 4 ) +// ToTransactionDbType returns the transaction db type for this enum +func (t TransactionType) ToTransactionDbType() (TransactionDbType, error) { + if t == TRANSACTION_TYPE_MODIFY_BALANCE { + return TRANSACTION_DB_TYPE_MODIFY_BALANCE, nil + } else if t == TRANSACTION_TYPE_EXPENSE { + return TRANSACTION_DB_TYPE_EXPENSE, nil + } else if t == TRANSACTION_TYPE_INCOME { + return TRANSACTION_DB_TYPE_INCOME, nil + } else if t == TRANSACTION_TYPE_TRANSFER { + return TRANSACTION_DB_TYPE_TRANSFER_OUT, nil + } else { + return 0, errs.ErrTransactionTypeInvalid + } +} + // TransactionDbType represents transaction type in database type TransactionDbType byte @@ -32,8 +47,8 @@ const ( ) // String returns a textual representation of the transaction types for db enum -func (s TransactionDbType) String() string { - switch s { +func (t TransactionDbType) String() string { + switch t { case TRANSACTION_DB_TYPE_MODIFY_BALANCE: return "Modify Balance" case TRANSACTION_DB_TYPE_INCOME: @@ -45,21 +60,21 @@ func (s TransactionDbType) String() string { case TRANSACTION_DB_TYPE_TRANSFER_IN: return "Transfer In" default: - return fmt.Sprintf("Invalid(%d)", int(s)) + return fmt.Sprintf("Invalid(%d)", int(t)) } } // ToTransactionType returns the transaction type for this db enum -func (s TransactionDbType) ToTransactionType() (TransactionType, error) { - if s == TRANSACTION_DB_TYPE_MODIFY_BALANCE { +func (t TransactionDbType) ToTransactionType() (TransactionType, error) { + if t == TRANSACTION_DB_TYPE_MODIFY_BALANCE { return TRANSACTION_TYPE_MODIFY_BALANCE, nil - } else if s == TRANSACTION_DB_TYPE_EXPENSE { + } else if t == TRANSACTION_DB_TYPE_EXPENSE { return TRANSACTION_TYPE_EXPENSE, nil - } else if s == TRANSACTION_DB_TYPE_INCOME { + } else if t == TRANSACTION_DB_TYPE_INCOME { return TRANSACTION_TYPE_INCOME, nil - } else if s == TRANSACTION_DB_TYPE_TRANSFER_OUT { + } else if t == TRANSACTION_DB_TYPE_TRANSFER_OUT { return TRANSACTION_TYPE_TRANSFER, nil - } else if s == TRANSACTION_DB_TYPE_TRANSFER_IN { + } else if t == TRANSACTION_DB_TYPE_TRANSFER_IN { return TRANSACTION_TYPE_TRANSFER, nil } else { return 0, errs.ErrTransactionTypeInvalid @@ -156,7 +171,7 @@ type TransactionImportProcessRequest struct { // TransactionCountRequest represents transaction count request type TransactionCountRequest struct { - Type TransactionDbType `form:"type" binding:"min=0,max=4"` + Type TransactionType `form:"type" binding:"min=0,max=4"` CategoryIds string `form:"category_ids"` AccountIds string `form:"account_ids"` TagIds string `form:"tag_ids"` @@ -169,7 +184,7 @@ type TransactionCountRequest struct { // TransactionListByMaxTimeRequest represents all parameters of transaction listing by max time request type TransactionListByMaxTimeRequest struct { - Type TransactionDbType `form:"type" binding:"min=0,max=4"` + Type TransactionType `form:"type" binding:"min=0,max=4"` CategoryIds string `form:"category_ids"` AccountIds string `form:"account_ids"` TagIds string `form:"tag_ids"` @@ -191,7 +206,7 @@ type TransactionListByMaxTimeRequest struct { type TransactionListInMonthByPageRequest struct { Year int32 `form:"year" binding:"required,min=1"` Month int32 `form:"month" binding:"required,min=1"` - Type TransactionDbType `form:"type" binding:"min=0,max=4"` + Type TransactionType `form:"type" binding:"min=0,max=4"` CategoryIds string `form:"category_ids"` AccountIds string `form:"account_ids"` TagIds string `form:"tag_ids"` diff --git a/pkg/services/transactions.go b/pkg/services/transactions.go index 06aaf338..696abbef 100644 --- a/pkg/services/transactions.go +++ b/pkg/services/transactions.go @@ -80,11 +80,22 @@ func (s *TransactionService) GetAllTransactionsByMaxTime(c core.Context, uid int } // GetTransactionsByMaxTime returns transactions before given time -func (s *TransactionService) GetTransactionsByMaxTime(c core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, tagFilterType models.TransactionTagFilterType, amountFilter string, keyword string, page int32, count int32, needOneMoreItem bool, noDuplicated bool) ([]*models.Transaction, error) { +func (s *TransactionService) GetTransactionsByMaxTime(c core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, tagFilterType models.TransactionTagFilterType, amountFilter string, keyword string, page int32, count int32, needOneMoreItem bool, noDuplicated bool) ([]*models.Transaction, error) { if uid <= 0 { return nil, errs.ErrUserIdInvalid } + var err error + var transactionDbType models.TransactionDbType = 0 + + if transactionType > 0 { + transactionDbType, err = transactionType.ToTransactionDbType() + + if err != nil { + return nil, err + } + } + if page < 0 { return nil, errs.ErrPageIndexInvalid } else if page == 0 { @@ -96,7 +107,6 @@ func (s *TransactionService) GetTransactionsByMaxTime(c core.Context, uid int64, } var transactions []*models.Transaction - var err error actualCount := count @@ -104,7 +114,7 @@ func (s *TransactionService) GetTransactionsByMaxTime(c core.Context, uid int64, actualCount++ } - condition, conditionParams := s.buildTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, tagIds, amountFilter, keyword, noDuplicated) + condition, conditionParams := s.buildTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionDbType, categoryIds, accountIds, tagIds, amountFilter, keyword, noDuplicated) sess := s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...) sess = s.appendFilterTagIdsConditionToQuery(sess, uid, maxTransactionTime, minTransactionTime, tagIds, noTags, tagFilterType) @@ -114,11 +124,22 @@ func (s *TransactionService) GetTransactionsByMaxTime(c core.Context, uid int64, } // GetTransactionsInMonthByPage returns all transactions in given year and month -func (s *TransactionService) GetTransactionsInMonthByPage(c core.Context, uid int64, year int32, month int32, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, tagFilterType models.TransactionTagFilterType, amountFilter string, keyword string) ([]*models.Transaction, error) { +func (s *TransactionService) GetTransactionsInMonthByPage(c core.Context, uid int64, year int32, month int32, transactionType models.TransactionType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, tagFilterType models.TransactionTagFilterType, amountFilter string, keyword string) ([]*models.Transaction, error) { if uid <= 0 { return nil, errs.ErrUserIdInvalid } + var err error + var transactionDbType models.TransactionDbType = 0 + + if transactionType > 0 { + transactionDbType, err = transactionType.ToTransactionDbType() + + if err != nil { + return nil, err + } + } + minTransactionTime, maxTransactionTime, err := utils.GetTransactionTimeRangeByYearMonth(year, month) if err != nil { @@ -127,7 +148,7 @@ func (s *TransactionService) GetTransactionsInMonthByPage(c core.Context, uid in var transactions []*models.Transaction - condition, conditionParams := s.buildTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, tagIds, amountFilter, keyword, true) + condition, conditionParams := s.buildTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionDbType, categoryIds, accountIds, tagIds, amountFilter, keyword, true) sess := s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...) sess = s.appendFilterTagIdsConditionToQuery(sess, uid, maxTransactionTime, minTransactionTime, tagIds, noTags, tagFilterType) @@ -176,12 +197,23 @@ func (s *TransactionService) GetAllTransactionCount(c core.Context, uid int64) ( } // GetTransactionCount returns count of transactions -func (s *TransactionService) GetTransactionCount(c core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, tagFilterType models.TransactionTagFilterType, amountFilter string, keyword string) (int64, error) { +func (s *TransactionService) GetTransactionCount(c core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, tagFilterType models.TransactionTagFilterType, amountFilter string, keyword string) (int64, error) { if uid <= 0 { return 0, errs.ErrUserIdInvalid } - condition, conditionParams := s.buildTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, tagIds, amountFilter, keyword, true) + var err error + var transactionDbType models.TransactionDbType = 0 + + if transactionType > 0 { + transactionDbType, err = transactionType.ToTransactionDbType() + + if err != nil { + return 0, err + } + } + + condition, conditionParams := s.buildTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionDbType, categoryIds, accountIds, tagIds, amountFilter, keyword, true) sess := s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...) sess = s.appendFilterTagIdsConditionToQuery(sess, uid, maxTransactionTime, minTransactionTime, tagIds, noTags, tagFilterType) @@ -1868,7 +1900,7 @@ func (s *TransactionService) doCreateTransaction(c core.Context, database *datas return err } -func (s *TransactionService) buildTransactionQueryCondition(uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, amountFilter string, keyword string, noDuplicated bool) (string, []any) { +func (s *TransactionService) buildTransactionQueryCondition(uid int64, maxTransactionTime int64, minTransactionTime int64, transactionDbType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, amountFilter string, keyword string, noDuplicated bool) (string, []any) { condition := "uid=? AND deleted=?" conditionParams := make([]any, 0, 16) conditionParams = append(conditionParams, uid) @@ -1896,10 +1928,10 @@ func (s *TransactionService) buildTransactionQueryCondition(uid int64, maxTransa accountIdConditionParams = append(accountIdConditionParams, accountIds[i]) } - if models.TRANSACTION_DB_TYPE_MODIFY_BALANCE <= transactionType && transactionType <= models.TRANSACTION_DB_TYPE_EXPENSE { + if models.TRANSACTION_DB_TYPE_MODIFY_BALANCE <= transactionDbType && transactionDbType <= models.TRANSACTION_DB_TYPE_EXPENSE { condition = condition + " AND type=?" - conditionParams = append(conditionParams, transactionType) - } else if transactionType == models.TRANSACTION_DB_TYPE_TRANSFER_OUT || transactionType == models.TRANSACTION_DB_TYPE_TRANSFER_IN { + conditionParams = append(conditionParams, transactionDbType) + } else if transactionDbType == models.TRANSACTION_DB_TYPE_TRANSFER_OUT || transactionDbType == models.TRANSACTION_DB_TYPE_TRANSFER_IN { if len(accountIds) == 0 { condition = condition + " AND type=?" conditionParams = append(conditionParams, models.TRANSACTION_DB_TYPE_TRANSFER_OUT)