code refactor

This commit is contained in:
MaysWind
2025-06-30 22:38:53 +08:00
parent b37cde5a8c
commit 3c100b2543
2 changed files with 70 additions and 23 deletions
+27 -12
View File
@@ -19,6 +19,21 @@ const (
TRANSACTION_TYPE_TRANSFER TransactionType = 4 TRANSACTION_TYPE_TRANSFER TransactionType = 4
) )
// ToTransactionDbType returns the transaction db type for this enum
func (t TransactionType) ToTransactionDbType() (TransactionDbType, error) {
if t == TRANSACTION_TYPE_MODIFY_BALANCE {
return TRANSACTION_DB_TYPE_MODIFY_BALANCE, nil
} else if t == TRANSACTION_TYPE_EXPENSE {
return TRANSACTION_DB_TYPE_EXPENSE, nil
} else if t == TRANSACTION_TYPE_INCOME {
return TRANSACTION_DB_TYPE_INCOME, nil
} else if t == TRANSACTION_TYPE_TRANSFER {
return TRANSACTION_DB_TYPE_TRANSFER_OUT, nil
} else {
return 0, errs.ErrTransactionTypeInvalid
}
}
// TransactionDbType represents transaction type in database // TransactionDbType represents transaction type in database
type TransactionDbType byte type TransactionDbType byte
@@ -32,8 +47,8 @@ const (
) )
// String returns a textual representation of the transaction types for db enum // String returns a textual representation of the transaction types for db enum
func (s TransactionDbType) String() string { func (t TransactionDbType) String() string {
switch s { switch t {
case TRANSACTION_DB_TYPE_MODIFY_BALANCE: case TRANSACTION_DB_TYPE_MODIFY_BALANCE:
return "Modify Balance" return "Modify Balance"
case TRANSACTION_DB_TYPE_INCOME: case TRANSACTION_DB_TYPE_INCOME:
@@ -45,21 +60,21 @@ func (s TransactionDbType) String() string {
case TRANSACTION_DB_TYPE_TRANSFER_IN: case TRANSACTION_DB_TYPE_TRANSFER_IN:
return "Transfer In" return "Transfer In"
default: default:
return fmt.Sprintf("Invalid(%d)", int(s)) return fmt.Sprintf("Invalid(%d)", int(t))
} }
} }
// ToTransactionType returns the transaction type for this db enum // ToTransactionType returns the transaction type for this db enum
func (s TransactionDbType) ToTransactionType() (TransactionType, error) { func (t TransactionDbType) ToTransactionType() (TransactionType, error) {
if s == TRANSACTION_DB_TYPE_MODIFY_BALANCE { if t == TRANSACTION_DB_TYPE_MODIFY_BALANCE {
return TRANSACTION_TYPE_MODIFY_BALANCE, nil return TRANSACTION_TYPE_MODIFY_BALANCE, nil
} else if s == TRANSACTION_DB_TYPE_EXPENSE { } else if t == TRANSACTION_DB_TYPE_EXPENSE {
return TRANSACTION_TYPE_EXPENSE, nil return TRANSACTION_TYPE_EXPENSE, nil
} else if s == TRANSACTION_DB_TYPE_INCOME { } else if t == TRANSACTION_DB_TYPE_INCOME {
return TRANSACTION_TYPE_INCOME, nil return TRANSACTION_TYPE_INCOME, nil
} else if s == TRANSACTION_DB_TYPE_TRANSFER_OUT { } else if t == TRANSACTION_DB_TYPE_TRANSFER_OUT {
return TRANSACTION_TYPE_TRANSFER, nil return TRANSACTION_TYPE_TRANSFER, nil
} else if s == TRANSACTION_DB_TYPE_TRANSFER_IN { } else if t == TRANSACTION_DB_TYPE_TRANSFER_IN {
return TRANSACTION_TYPE_TRANSFER, nil return TRANSACTION_TYPE_TRANSFER, nil
} else { } else {
return 0, errs.ErrTransactionTypeInvalid return 0, errs.ErrTransactionTypeInvalid
@@ -156,7 +171,7 @@ type TransactionImportProcessRequest struct {
// TransactionCountRequest represents transaction count request // TransactionCountRequest represents transaction count request
type TransactionCountRequest struct { type TransactionCountRequest struct {
Type TransactionDbType `form:"type" binding:"min=0,max=4"` Type TransactionType `form:"type" binding:"min=0,max=4"`
CategoryIds string `form:"category_ids"` CategoryIds string `form:"category_ids"`
AccountIds string `form:"account_ids"` AccountIds string `form:"account_ids"`
TagIds string `form:"tag_ids"` TagIds string `form:"tag_ids"`
@@ -169,7 +184,7 @@ type TransactionCountRequest struct {
// TransactionListByMaxTimeRequest represents all parameters of transaction listing by max time request // TransactionListByMaxTimeRequest represents all parameters of transaction listing by max time request
type TransactionListByMaxTimeRequest struct { type TransactionListByMaxTimeRequest struct {
Type TransactionDbType `form:"type" binding:"min=0,max=4"` Type TransactionType `form:"type" binding:"min=0,max=4"`
CategoryIds string `form:"category_ids"` CategoryIds string `form:"category_ids"`
AccountIds string `form:"account_ids"` AccountIds string `form:"account_ids"`
TagIds string `form:"tag_ids"` TagIds string `form:"tag_ids"`
@@ -191,7 +206,7 @@ type TransactionListByMaxTimeRequest struct {
type TransactionListInMonthByPageRequest struct { type TransactionListInMonthByPageRequest struct {
Year int32 `form:"year" binding:"required,min=1"` Year int32 `form:"year" binding:"required,min=1"`
Month int32 `form:"month" binding:"required,min=1"` Month int32 `form:"month" binding:"required,min=1"`
Type TransactionDbType `form:"type" binding:"min=0,max=4"` Type TransactionType `form:"type" binding:"min=0,max=4"`
CategoryIds string `form:"category_ids"` CategoryIds string `form:"category_ids"`
AccountIds string `form:"account_ids"` AccountIds string `form:"account_ids"`
TagIds string `form:"tag_ids"` TagIds string `form:"tag_ids"`
+43 -11
View File
@@ -80,11 +80,22 @@ func (s *TransactionService) GetAllTransactionsByMaxTime(c core.Context, uid int
} }
// GetTransactionsByMaxTime returns transactions before given time // GetTransactionsByMaxTime returns transactions before given time
func (s *TransactionService) GetTransactionsByMaxTime(c core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, tagFilterType models.TransactionTagFilterType, amountFilter string, keyword string, page int32, count int32, needOneMoreItem bool, noDuplicated bool) ([]*models.Transaction, error) { func (s *TransactionService) GetTransactionsByMaxTime(c core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, tagFilterType models.TransactionTagFilterType, amountFilter string, keyword string, page int32, count int32, needOneMoreItem bool, noDuplicated bool) ([]*models.Transaction, error) {
if uid <= 0 { if uid <= 0 {
return nil, errs.ErrUserIdInvalid return nil, errs.ErrUserIdInvalid
} }
var err error
var transactionDbType models.TransactionDbType = 0
if transactionType > 0 {
transactionDbType, err = transactionType.ToTransactionDbType()
if err != nil {
return nil, err
}
}
if page < 0 { if page < 0 {
return nil, errs.ErrPageIndexInvalid return nil, errs.ErrPageIndexInvalid
} else if page == 0 { } else if page == 0 {
@@ -96,7 +107,6 @@ func (s *TransactionService) GetTransactionsByMaxTime(c core.Context, uid int64,
} }
var transactions []*models.Transaction var transactions []*models.Transaction
var err error
actualCount := count actualCount := count
@@ -104,7 +114,7 @@ func (s *TransactionService) GetTransactionsByMaxTime(c core.Context, uid int64,
actualCount++ actualCount++
} }
condition, conditionParams := s.buildTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, tagIds, amountFilter, keyword, noDuplicated) condition, conditionParams := s.buildTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionDbType, categoryIds, accountIds, tagIds, amountFilter, keyword, noDuplicated)
sess := s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...) sess := s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...)
sess = s.appendFilterTagIdsConditionToQuery(sess, uid, maxTransactionTime, minTransactionTime, tagIds, noTags, tagFilterType) sess = s.appendFilterTagIdsConditionToQuery(sess, uid, maxTransactionTime, minTransactionTime, tagIds, noTags, tagFilterType)
@@ -114,11 +124,22 @@ func (s *TransactionService) GetTransactionsByMaxTime(c core.Context, uid int64,
} }
// GetTransactionsInMonthByPage returns all transactions in given year and month // GetTransactionsInMonthByPage returns all transactions in given year and month
func (s *TransactionService) GetTransactionsInMonthByPage(c core.Context, uid int64, year int32, month int32, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, tagFilterType models.TransactionTagFilterType, amountFilter string, keyword string) ([]*models.Transaction, error) { func (s *TransactionService) GetTransactionsInMonthByPage(c core.Context, uid int64, year int32, month int32, transactionType models.TransactionType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, tagFilterType models.TransactionTagFilterType, amountFilter string, keyword string) ([]*models.Transaction, error) {
if uid <= 0 { if uid <= 0 {
return nil, errs.ErrUserIdInvalid return nil, errs.ErrUserIdInvalid
} }
var err error
var transactionDbType models.TransactionDbType = 0
if transactionType > 0 {
transactionDbType, err = transactionType.ToTransactionDbType()
if err != nil {
return nil, err
}
}
minTransactionTime, maxTransactionTime, err := utils.GetTransactionTimeRangeByYearMonth(year, month) minTransactionTime, maxTransactionTime, err := utils.GetTransactionTimeRangeByYearMonth(year, month)
if err != nil { if err != nil {
@@ -127,7 +148,7 @@ func (s *TransactionService) GetTransactionsInMonthByPage(c core.Context, uid in
var transactions []*models.Transaction var transactions []*models.Transaction
condition, conditionParams := s.buildTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, tagIds, amountFilter, keyword, true) condition, conditionParams := s.buildTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionDbType, categoryIds, accountIds, tagIds, amountFilter, keyword, true)
sess := s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...) sess := s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...)
sess = s.appendFilterTagIdsConditionToQuery(sess, uid, maxTransactionTime, minTransactionTime, tagIds, noTags, tagFilterType) sess = s.appendFilterTagIdsConditionToQuery(sess, uid, maxTransactionTime, minTransactionTime, tagIds, noTags, tagFilterType)
@@ -176,12 +197,23 @@ func (s *TransactionService) GetAllTransactionCount(c core.Context, uid int64) (
} }
// GetTransactionCount returns count of transactions // GetTransactionCount returns count of transactions
func (s *TransactionService) GetTransactionCount(c core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, tagFilterType models.TransactionTagFilterType, amountFilter string, keyword string) (int64, error) { func (s *TransactionService) GetTransactionCount(c core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, tagFilterType models.TransactionTagFilterType, amountFilter string, keyword string) (int64, error) {
if uid <= 0 { if uid <= 0 {
return 0, errs.ErrUserIdInvalid return 0, errs.ErrUserIdInvalid
} }
condition, conditionParams := s.buildTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, tagIds, amountFilter, keyword, true) var err error
var transactionDbType models.TransactionDbType = 0
if transactionType > 0 {
transactionDbType, err = transactionType.ToTransactionDbType()
if err != nil {
return 0, err
}
}
condition, conditionParams := s.buildTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionDbType, categoryIds, accountIds, tagIds, amountFilter, keyword, true)
sess := s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...) sess := s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...)
sess = s.appendFilterTagIdsConditionToQuery(sess, uid, maxTransactionTime, minTransactionTime, tagIds, noTags, tagFilterType) sess = s.appendFilterTagIdsConditionToQuery(sess, uid, maxTransactionTime, minTransactionTime, tagIds, noTags, tagFilterType)
@@ -1868,7 +1900,7 @@ func (s *TransactionService) doCreateTransaction(c core.Context, database *datas
return err return err
} }
func (s *TransactionService) buildTransactionQueryCondition(uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, amountFilter string, keyword string, noDuplicated bool) (string, []any) { func (s *TransactionService) buildTransactionQueryCondition(uid int64, maxTransactionTime int64, minTransactionTime int64, transactionDbType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, amountFilter string, keyword string, noDuplicated bool) (string, []any) {
condition := "uid=? AND deleted=?" condition := "uid=? AND deleted=?"
conditionParams := make([]any, 0, 16) conditionParams := make([]any, 0, 16)
conditionParams = append(conditionParams, uid) conditionParams = append(conditionParams, uid)
@@ -1896,10 +1928,10 @@ func (s *TransactionService) buildTransactionQueryCondition(uid int64, maxTransa
accountIdConditionParams = append(accountIdConditionParams, accountIds[i]) accountIdConditionParams = append(accountIdConditionParams, accountIds[i])
} }
if models.TRANSACTION_DB_TYPE_MODIFY_BALANCE <= transactionType && transactionType <= models.TRANSACTION_DB_TYPE_EXPENSE { if models.TRANSACTION_DB_TYPE_MODIFY_BALANCE <= transactionDbType && transactionDbType <= models.TRANSACTION_DB_TYPE_EXPENSE {
condition = condition + " AND type=?" condition = condition + " AND type=?"
conditionParams = append(conditionParams, transactionType) conditionParams = append(conditionParams, transactionDbType)
} else if transactionType == models.TRANSACTION_DB_TYPE_TRANSFER_OUT || transactionType == models.TRANSACTION_DB_TYPE_TRANSFER_IN { } else if transactionDbType == models.TRANSACTION_DB_TYPE_TRANSFER_OUT || transactionDbType == models.TRANSACTION_DB_TYPE_TRANSFER_IN {
if len(accountIds) == 0 { if len(accountIds) == 0 {
condition = condition + " AND type=?" condition = condition + " AND type=?"
conditionParams = append(conditionParams, models.TRANSACTION_DB_TYPE_TRANSFER_OUT) conditionParams = append(conditionParams, models.TRANSACTION_DB_TYPE_TRANSFER_OUT)