code refactor

This commit is contained in:
MaysWind
2025-06-30 22:38:53 +08:00
parent b37cde5a8c
commit 3c100b2543
2 changed files with 70 additions and 23 deletions
+27 -12
View File
@@ -19,6 +19,21 @@ const (
TRANSACTION_TYPE_TRANSFER TransactionType = 4
)
// ToTransactionDbType returns the transaction db type for this enum
func (t TransactionType) ToTransactionDbType() (TransactionDbType, error) {
if t == TRANSACTION_TYPE_MODIFY_BALANCE {
return TRANSACTION_DB_TYPE_MODIFY_BALANCE, nil
} else if t == TRANSACTION_TYPE_EXPENSE {
return TRANSACTION_DB_TYPE_EXPENSE, nil
} else if t == TRANSACTION_TYPE_INCOME {
return TRANSACTION_DB_TYPE_INCOME, nil
} else if t == TRANSACTION_TYPE_TRANSFER {
return TRANSACTION_DB_TYPE_TRANSFER_OUT, nil
} else {
return 0, errs.ErrTransactionTypeInvalid
}
}
// TransactionDbType represents transaction type in database
type TransactionDbType byte
@@ -32,8 +47,8 @@ const (
)
// String returns a textual representation of the transaction types for db enum
func (s TransactionDbType) String() string {
switch s {
func (t TransactionDbType) String() string {
switch t {
case TRANSACTION_DB_TYPE_MODIFY_BALANCE:
return "Modify Balance"
case TRANSACTION_DB_TYPE_INCOME:
@@ -45,21 +60,21 @@ func (s TransactionDbType) String() string {
case TRANSACTION_DB_TYPE_TRANSFER_IN:
return "Transfer In"
default:
return fmt.Sprintf("Invalid(%d)", int(s))
return fmt.Sprintf("Invalid(%d)", int(t))
}
}
// ToTransactionType returns the transaction type for this db enum
func (s TransactionDbType) ToTransactionType() (TransactionType, error) {
if s == TRANSACTION_DB_TYPE_MODIFY_BALANCE {
func (t TransactionDbType) ToTransactionType() (TransactionType, error) {
if t == TRANSACTION_DB_TYPE_MODIFY_BALANCE {
return TRANSACTION_TYPE_MODIFY_BALANCE, nil
} else if s == TRANSACTION_DB_TYPE_EXPENSE {
} else if t == TRANSACTION_DB_TYPE_EXPENSE {
return TRANSACTION_TYPE_EXPENSE, nil
} else if s == TRANSACTION_DB_TYPE_INCOME {
} else if t == TRANSACTION_DB_TYPE_INCOME {
return TRANSACTION_TYPE_INCOME, nil
} else if s == TRANSACTION_DB_TYPE_TRANSFER_OUT {
} else if t == TRANSACTION_DB_TYPE_TRANSFER_OUT {
return TRANSACTION_TYPE_TRANSFER, nil
} else if s == TRANSACTION_DB_TYPE_TRANSFER_IN {
} else if t == TRANSACTION_DB_TYPE_TRANSFER_IN {
return TRANSACTION_TYPE_TRANSFER, nil
} else {
return 0, errs.ErrTransactionTypeInvalid
@@ -156,7 +171,7 @@ type TransactionImportProcessRequest struct {
// TransactionCountRequest represents transaction count request
type TransactionCountRequest struct {
Type TransactionDbType `form:"type" binding:"min=0,max=4"`
Type TransactionType `form:"type" binding:"min=0,max=4"`
CategoryIds string `form:"category_ids"`
AccountIds string `form:"account_ids"`
TagIds string `form:"tag_ids"`
@@ -169,7 +184,7 @@ type TransactionCountRequest struct {
// TransactionListByMaxTimeRequest represents all parameters of transaction listing by max time request
type TransactionListByMaxTimeRequest struct {
Type TransactionDbType `form:"type" binding:"min=0,max=4"`
Type TransactionType `form:"type" binding:"min=0,max=4"`
CategoryIds string `form:"category_ids"`
AccountIds string `form:"account_ids"`
TagIds string `form:"tag_ids"`
@@ -191,7 +206,7 @@ type TransactionListByMaxTimeRequest struct {
type TransactionListInMonthByPageRequest struct {
Year int32 `form:"year" binding:"required,min=1"`
Month int32 `form:"month" binding:"required,min=1"`
Type TransactionDbType `form:"type" binding:"min=0,max=4"`
Type TransactionType `form:"type" binding:"min=0,max=4"`
CategoryIds string `form:"category_ids"`
AccountIds string `form:"account_ids"`
TagIds string `form:"tag_ids"`
+43 -11
View File
@@ -80,11 +80,22 @@ func (s *TransactionService) GetAllTransactionsByMaxTime(c core.Context, uid int
}
// GetTransactionsByMaxTime returns transactions before given time
func (s *TransactionService) GetTransactionsByMaxTime(c core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, tagFilterType models.TransactionTagFilterType, amountFilter string, keyword string, page int32, count int32, needOneMoreItem bool, noDuplicated bool) ([]*models.Transaction, error) {
func (s *TransactionService) GetTransactionsByMaxTime(c core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, tagFilterType models.TransactionTagFilterType, amountFilter string, keyword string, page int32, count int32, needOneMoreItem bool, noDuplicated bool) ([]*models.Transaction, error) {
if uid <= 0 {
return nil, errs.ErrUserIdInvalid
}
var err error
var transactionDbType models.TransactionDbType = 0
if transactionType > 0 {
transactionDbType, err = transactionType.ToTransactionDbType()
if err != nil {
return nil, err
}
}
if page < 0 {
return nil, errs.ErrPageIndexInvalid
} else if page == 0 {
@@ -96,7 +107,6 @@ func (s *TransactionService) GetTransactionsByMaxTime(c core.Context, uid int64,
}
var transactions []*models.Transaction
var err error
actualCount := count
@@ -104,7 +114,7 @@ func (s *TransactionService) GetTransactionsByMaxTime(c core.Context, uid int64,
actualCount++
}
condition, conditionParams := s.buildTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, tagIds, amountFilter, keyword, noDuplicated)
condition, conditionParams := s.buildTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionDbType, categoryIds, accountIds, tagIds, amountFilter, keyword, noDuplicated)
sess := s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...)
sess = s.appendFilterTagIdsConditionToQuery(sess, uid, maxTransactionTime, minTransactionTime, tagIds, noTags, tagFilterType)
@@ -114,11 +124,22 @@ func (s *TransactionService) GetTransactionsByMaxTime(c core.Context, uid int64,
}
// GetTransactionsInMonthByPage returns all transactions in given year and month
func (s *TransactionService) GetTransactionsInMonthByPage(c core.Context, uid int64, year int32, month int32, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, tagFilterType models.TransactionTagFilterType, amountFilter string, keyword string) ([]*models.Transaction, error) {
func (s *TransactionService) GetTransactionsInMonthByPage(c core.Context, uid int64, year int32, month int32, transactionType models.TransactionType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, tagFilterType models.TransactionTagFilterType, amountFilter string, keyword string) ([]*models.Transaction, error) {
if uid <= 0 {
return nil, errs.ErrUserIdInvalid
}
var err error
var transactionDbType models.TransactionDbType = 0
if transactionType > 0 {
transactionDbType, err = transactionType.ToTransactionDbType()
if err != nil {
return nil, err
}
}
minTransactionTime, maxTransactionTime, err := utils.GetTransactionTimeRangeByYearMonth(year, month)
if err != nil {
@@ -127,7 +148,7 @@ func (s *TransactionService) GetTransactionsInMonthByPage(c core.Context, uid in
var transactions []*models.Transaction
condition, conditionParams := s.buildTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, tagIds, amountFilter, keyword, true)
condition, conditionParams := s.buildTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionDbType, categoryIds, accountIds, tagIds, amountFilter, keyword, true)
sess := s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...)
sess = s.appendFilterTagIdsConditionToQuery(sess, uid, maxTransactionTime, minTransactionTime, tagIds, noTags, tagFilterType)
@@ -176,12 +197,23 @@ func (s *TransactionService) GetAllTransactionCount(c core.Context, uid int64) (
}
// GetTransactionCount returns count of transactions
func (s *TransactionService) GetTransactionCount(c core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, tagFilterType models.TransactionTagFilterType, amountFilter string, keyword string) (int64, error) {
func (s *TransactionService) GetTransactionCount(c core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, tagFilterType models.TransactionTagFilterType, amountFilter string, keyword string) (int64, error) {
if uid <= 0 {
return 0, errs.ErrUserIdInvalid
}
condition, conditionParams := s.buildTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, tagIds, amountFilter, keyword, true)
var err error
var transactionDbType models.TransactionDbType = 0
if transactionType > 0 {
transactionDbType, err = transactionType.ToTransactionDbType()
if err != nil {
return 0, err
}
}
condition, conditionParams := s.buildTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionDbType, categoryIds, accountIds, tagIds, amountFilter, keyword, true)
sess := s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...)
sess = s.appendFilterTagIdsConditionToQuery(sess, uid, maxTransactionTime, minTransactionTime, tagIds, noTags, tagFilterType)
@@ -1868,7 +1900,7 @@ func (s *TransactionService) doCreateTransaction(c core.Context, database *datas
return err
}
func (s *TransactionService) buildTransactionQueryCondition(uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, amountFilter string, keyword string, noDuplicated bool) (string, []any) {
func (s *TransactionService) buildTransactionQueryCondition(uid int64, maxTransactionTime int64, minTransactionTime int64, transactionDbType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, amountFilter string, keyword string, noDuplicated bool) (string, []any) {
condition := "uid=? AND deleted=?"
conditionParams := make([]any, 0, 16)
conditionParams = append(conditionParams, uid)
@@ -1896,10 +1928,10 @@ func (s *TransactionService) buildTransactionQueryCondition(uid int64, maxTransa
accountIdConditionParams = append(accountIdConditionParams, accountIds[i])
}
if models.TRANSACTION_DB_TYPE_MODIFY_BALANCE <= transactionType && transactionType <= models.TRANSACTION_DB_TYPE_EXPENSE {
if models.TRANSACTION_DB_TYPE_MODIFY_BALANCE <= transactionDbType && transactionDbType <= models.TRANSACTION_DB_TYPE_EXPENSE {
condition = condition + " AND type=?"
conditionParams = append(conditionParams, transactionType)
} else if transactionType == models.TRANSACTION_DB_TYPE_TRANSFER_OUT || transactionType == models.TRANSACTION_DB_TYPE_TRANSFER_IN {
conditionParams = append(conditionParams, transactionDbType)
} else if transactionDbType == models.TRANSACTION_DB_TYPE_TRANSFER_OUT || transactionDbType == models.TRANSACTION_DB_TYPE_TRANSFER_IN {
if len(accountIds) == 0 {
condition = condition + " AND type=?"
conditionParams = append(conditionParams, models.TRANSACTION_DB_TYPE_TRANSFER_OUT)