diff --git a/pkg/services/transactions.go b/pkg/services/transactions.go index f60c0e4c..57161d4d 100644 --- a/pkg/services/transactions.go +++ b/pkg/services/transactions.go @@ -135,20 +135,11 @@ func (s *TransactionService) CreateTransaction(transaction *models.Transaction, return errs.ErrUserIdInvalid } - if transaction.Type == models.TRANSACTION_TYPE_MODIFY_BALANCE || - transaction.Type == models.TRANSACTION_TYPE_INCOME || - transaction.Type == models.TRANSACTION_TYPE_EXPENSE { - if transaction.SourceAccountId != transaction.DestinationAccountId { - return errs.ErrTransactionSourceAndDestinationIdNotEqual - } else if transaction.SourceAmount != transaction.DestinationAmount { - return errs.ErrTransactionSourceAndDestinationAmountNotEqual - } - } else if transaction.Type == models.TRANSACTION_TYPE_TRANSFER { - if transaction.SourceAccountId == transaction.DestinationAccountId { - return errs.ErrTransactionSourceAndDestinationIdCannotBeEqual - } - } else { - return errs.ErrTransactionTypeInvalid + // Check whether account id is valid + err := s.isAccountIdValid(transaction) + + if err != nil { + return err } now := time.Now().Unix() @@ -174,30 +165,14 @@ func (s *TransactionService) CreateTransaction(transaction *models.Transaction, return s.UserDataDB(transaction.Uid).DoTransaction(func(sess *xorm.Session) error { // Get and verify source and destination account - sourceAccount := &models.Account{} - destinationAccount := &models.Account{} - has, err := sess.ID(transaction.SourceAccountId).Where("uid=? AND deleted=?", transaction.Uid, false).Get(sourceAccount) + sourceAccount, destinationAccount, err := s.getAccountModels(sess, transaction) if err != nil { return err - } else if !has { - return errs.ErrSourceAccountNotFound - } else if sourceAccount.Hidden { - return errs.ErrCannotAddTransactionToHiddenAccount } - if transaction.DestinationAccountId == transaction.SourceAccountId { - destinationAccount = sourceAccount - } else { - has, err = sess.ID(transaction.DestinationAccountId).Where("uid=? AND deleted=?", transaction.Uid, false).Get(destinationAccount) - - if err != nil { - return err - } else if !has { - return errs.ErrDestinationAccountNotFound - } else if destinationAccount.Hidden { - return errs.ErrCannotAddTransactionToHiddenAccount - } + if sourceAccount.Hidden || destinationAccount.Hidden { + return errs.ErrCannotAddTransactionToHiddenAccount } if sourceAccount.Currency == destinationAccount.Currency && transaction.SourceAmount != transaction.DestinationAmount { @@ -205,48 +180,17 @@ func (s *TransactionService) CreateTransaction(transaction *models.Transaction, } // Get and verify category - category := &models.TransactionCategory{} + err = s.isCategoryValid(sess, transaction) - if transaction.Type != models.TRANSACTION_TYPE_MODIFY_BALANCE { - has, err = sess.ID(transaction.CategoryId).Where("uid=? AND deleted=?", transaction.Uid, false).Get(category) - - if err != nil { - return err - } else if !has { - return errs.ErrTransactionCategoryNotFound - } - - if category.ParentCategoryId < 1 { - return errs.ErrCannotUsePrimaryCategoryForTransaction - } - } - - if (transaction.Type == models.TRANSACTION_TYPE_INCOME && category.Type != models.CATEGORY_TYPE_INCOME) || - (transaction.Type == models.TRANSACTION_TYPE_EXPENSE && category.Type != models.CATEGORY_TYPE_EXPENSE) || - (transaction.Type == models.TRANSACTION_TYPE_TRANSFER && category.Type != models.CATEGORY_TYPE_TRANSFER) { - return errs.ErrTransactionCategoryTypeInvalid + if err != nil { + return err } // Get and verify tags - if len(transactionTagIndexs) > 0 { - var tags []*models.TransactionTag - err := sess.Where("uid=?", transaction.Uid).In("tag_id", tagIds).Find(&tags) + err = s.isTagsValid(sess, transaction, transactionTagIndexs, tagIds) - if err != nil { - return err - } - - tagMap := make(map[int64]*models.TransactionTag) - - for i := 0; i < len(tags); i++ { - tagMap[tags[i].TagId] = tags[i] - } - - for i := 0; i < len(transactionTagIndexs); i++ { - if _, exists := tagMap[transactionTagIndexs[i].TagId]; !exists { - return errs.ErrTransactionTagNotFound - } - } + if err != nil { + return err } // Verify balance modification transaction and calculate real amount @@ -270,7 +214,7 @@ func (s *TransactionService) CreateTransaction(transaction *models.Transaction, minTransactionTime := utils.GetMinTransactionTimeFromUnixTime(utils.GetUnixTimeFromTransactionTime(transaction.TransactionTime)) maxTransactionTime := utils.GetMaxTransactionTimeFromUnixTime(utils.GetUnixTimeFromTransactionTime(transaction.TransactionTime)) - has, err = sess.Where("uid=? AND deleted=? AND transaction_time>=? AND transaction_time<=?", transaction.Uid, false, minTransactionTime, maxTransactionTime).OrderBy("transaction_time desc").Limit(1).Get(sameSecondLatestTransaction) + has, err := sess.Where("uid=? AND deleted=? AND transaction_time>=? AND transaction_time<=?", transaction.Uid, false, minTransactionTime, maxTransactionTime).OrderBy("transaction_time desc").Limit(1).Get(sameSecondLatestTransaction) if err != nil { return err @@ -394,108 +338,50 @@ func (s *TransactionService) ModifyTransaction(transaction *models.Transaction, return errs.ErrTransactionNotFound } + // Cannot change transaction type if transaction.Type != oldTransaction.Type { return errs.ErrCannotModifyTransactionType } - if oldTransaction.Type == models.TRANSACTION_TYPE_MODIFY_BALANCE || - oldTransaction.Type == models.TRANSACTION_TYPE_INCOME || - oldTransaction.Type == models.TRANSACTION_TYPE_EXPENSE { - if transaction.SourceAccountId != transaction.DestinationAccountId { - return errs.ErrTransactionSourceAndDestinationIdNotEqual - } else if transaction.SourceAmount != transaction.DestinationAmount { - return errs.ErrTransactionSourceAndDestinationAmountNotEqual - } - } else if oldTransaction.Type == models.TRANSACTION_TYPE_TRANSFER { - if transaction.SourceAccountId == transaction.DestinationAccountId { - return errs.ErrTransactionSourceAndDestinationIdCannotBeEqual - } - } else { - return errs.ErrTransactionTypeInvalid - } - - // Get and verify source and destination account (if necessary) - sourceAccount := &models.Account{} - destinationAccount := &models.Account{} - oldSourceAccount := &models.Account{} - oldDestinationAccount := &models.Account{} - has, err = sess.ID(transaction.SourceAccountId).Where("uid=? AND deleted=?", transaction.Uid, false).Get(sourceAccount) + // Check whether account id is valid + err = s.isAccountIdValid(transaction) if err != nil { return err - } else if !has { - return errs.ErrSourceAccountNotFound - } else if sourceAccount.Hidden { + } + + // Get and verify source and destination account (if necessary) + sourceAccount, destinationAccount, err := s.getAccountModels(sess, transaction) + + if err != nil { + return err + } + + if sourceAccount.Hidden || destinationAccount.Hidden { return errs.ErrCannotModifyTransactionInHiddenAccount } - if transaction.DestinationAccountId == transaction.SourceAccountId { - destinationAccount = sourceAccount - } else { - has, err = sess.ID(transaction.DestinationAccountId).Where("uid=? AND deleted=?", transaction.Uid, false).Get(destinationAccount) - - if err != nil { - return err - } else if !has { - return errs.ErrDestinationAccountNotFound - } else if destinationAccount.Hidden { - return errs.ErrCannotModifyTransactionInHiddenAccount - } + if sourceAccount.Currency == destinationAccount.Currency && transaction.SourceAmount != transaction.DestinationAmount { + return errs.ErrTransactionSourceAndDestinationAmountNotEqual } - if transaction.SourceAccountId == oldTransaction.SourceAccountId { - oldSourceAccount = sourceAccount - } else { - has, err = sess.ID(oldTransaction.SourceAccountId).Where("uid=? AND deleted=?", transaction.Uid, false).Get(oldSourceAccount) + oldSourceAccount, oldDestinationAccount, err := s.getOldAccountModels(sess, transaction, oldTransaction, sourceAccount, destinationAccount) - if err != nil { - return err - } else if !has { - return errs.ErrSourceAccountNotFound - } else if oldSourceAccount.Hidden { - return errs.ErrCannotModifyTransactionInHiddenAccount - } + if err != nil { + return err } - if transaction.DestinationAccountId == oldTransaction.DestinationAccountId { - oldDestinationAccount = destinationAccount - } else { - has, err = sess.ID(oldTransaction.DestinationAccountId).Where("uid=? AND deleted=?", transaction.Uid, false).Get(oldDestinationAccount) - - if err != nil { - return err - } else if !has { - return errs.ErrDestinationAccountNotFound - } else if oldDestinationAccount.Hidden { - return errs.ErrCannotModifyTransactionInHiddenAccount - } + if oldSourceAccount.Hidden || oldDestinationAccount.Hidden { + return errs.ErrCannotAddTransactionToHiddenAccount } // Append modified columns and verify if transaction.CategoryId != oldTransaction.CategoryId { - if oldTransaction.Type == models.TRANSACTION_TYPE_MODIFY_BALANCE { - if transaction.CategoryId > 0 { - return errs.ErrBalanceModificationTransactionCannotSetCategory - } - } else { - category := &models.TransactionCategory{} - has, err = sess.ID(transaction.CategoryId).Where("uid=? AND deleted=?", transaction.Uid, false).Get(category) + // Get and verify category + err = s.isCategoryValid(sess, transaction) - if err != nil { - return err - } else if !has { - return errs.ErrTransactionCategoryNotFound - } - - if category.ParentCategoryId < 1 { - return errs.ErrCannotUsePrimaryCategoryForTransaction - } - - if (oldTransaction.Type == models.TRANSACTION_TYPE_INCOME && category.Type != models.CATEGORY_TYPE_INCOME) || - (oldTransaction.Type == models.TRANSACTION_TYPE_EXPENSE && category.Type != models.CATEGORY_TYPE_EXPENSE) || - (oldTransaction.Type == models.TRANSACTION_TYPE_TRANSFER && category.Type != models.CATEGORY_TYPE_TRANSFER) { - return errs.ErrTransactionCategoryTypeInvalid - } + if err != nil { + return err } updateCols = append(updateCols, "category_id") @@ -547,25 +433,10 @@ func (s *TransactionService) ModifyTransaction(transaction *models.Transaction, } // Get and verify tags - if len(transactionTagIndexs) > 0 { - var tags []*models.TransactionTag - err := sess.Where("uid=?", transaction.Uid).In("tag_id", addTagIds).Find(&tags) + err = s.isTagsValid(sess, transaction, transactionTagIndexs, addTagIds) - if err != nil { - return err - } - - tagMap := make(map[int64]*models.TransactionTag) - - for i := 0; i < len(tags); i++ { - tagMap[tags[i].TagId] = tags[i] - } - - for i := 0; i < len(transactionTagIndexs); i++ { - if _, exists := tagMap[transactionTagIndexs[i].TagId]; !exists { - return errs.ErrTransactionTagNotFound - } - } + if err != nil { + return err } // Update transaction row @@ -696,30 +567,14 @@ func (s *TransactionService) DeleteTransaction(uid int64, transactionId int64) e } // Get and verify source and destination account - sourceAccount := &models.Account{} - destinationAccount := &models.Account{} - has, err = sess.ID(oldTransaction.SourceAccountId).Where("uid=? AND deleted=?", oldTransaction.Uid, false).Get(sourceAccount) + sourceAccount, destinationAccount, err := s.getAccountModels(sess, oldTransaction) if err != nil { return err - } else if !has { - return errs.ErrSourceAccountNotFound - } else if sourceAccount.Hidden { - return errs.ErrCannotDeleteTransactionInHiddenAccount } - if oldTransaction.DestinationAccountId == oldTransaction.SourceAccountId { - destinationAccount = sourceAccount - } else { - has, err = sess.ID(oldTransaction.DestinationAccountId).Where("uid=? AND deleted=?", oldTransaction.Uid, false).Get(destinationAccount) - - if err != nil { - return err - } else if !has { - return errs.ErrDestinationAccountNotFound - } else if destinationAccount.Hidden { - return errs.ErrCannotDeleteTransactionInHiddenAccount - } + if sourceAccount.Hidden || destinationAccount.Hidden { + return errs.ErrCannotDeleteTransactionInHiddenAccount } // Update transaction row to deleted @@ -782,3 +637,133 @@ func (s *TransactionService) DeleteTransaction(uid int64, transactionId int64) e return err }) } + +func (s *TransactionService) isAccountIdValid(transaction *models.Transaction) error { + if transaction.Type == models.TRANSACTION_TYPE_MODIFY_BALANCE || + transaction.Type == models.TRANSACTION_TYPE_INCOME || + transaction.Type == models.TRANSACTION_TYPE_EXPENSE { + if transaction.SourceAccountId != transaction.DestinationAccountId { + return errs.ErrTransactionSourceAndDestinationIdNotEqual + } else if transaction.SourceAmount != transaction.DestinationAmount { + return errs.ErrTransactionSourceAndDestinationAmountNotEqual + } + } else if transaction.Type == models.TRANSACTION_TYPE_TRANSFER { + if transaction.SourceAccountId == transaction.DestinationAccountId { + return errs.ErrTransactionSourceAndDestinationIdCannotBeEqual + } + } else { + return errs.ErrTransactionTypeInvalid + } + + return nil +} + +func (s *TransactionService) getAccountModels(sess *xorm.Session, transaction *models.Transaction) (sourceAccount *models.Account, destinationAccount *models.Account, err error) { + sourceAccount = &models.Account{} + destinationAccount = &models.Account{} + + has, err := sess.ID(transaction.SourceAccountId).Where("uid=? AND deleted=?", transaction.Uid, false).Get(sourceAccount) + + if err != nil { + return nil, nil, err + } else if !has { + return nil, nil, errs.ErrSourceAccountNotFound + } + + if transaction.DestinationAccountId == transaction.SourceAccountId { + destinationAccount = sourceAccount + } else { + has, err = sess.ID(transaction.DestinationAccountId).Where("uid=? AND deleted=?", transaction.Uid, false).Get(destinationAccount) + + if err != nil { + return nil, nil, err + } else if !has { + return nil, nil, errs.ErrDestinationAccountNotFound + } + } + return sourceAccount, destinationAccount, nil +} + +func (s *TransactionService) getOldAccountModels(sess *xorm.Session, transaction *models.Transaction, oldTransaction *models.Transaction, sourceAccount *models.Account, destinationAccount *models.Account) (oldSourceAccount *models.Account, oldDestinationAccount *models.Account, err error) { + oldSourceAccount = &models.Account{} + oldDestinationAccount = &models.Account{} + + if transaction.SourceAccountId == oldTransaction.SourceAccountId { + oldSourceAccount = sourceAccount + } else { + has, err := sess.ID(oldTransaction.SourceAccountId).Where("uid=? AND deleted=?", transaction.Uid, false).Get(oldSourceAccount) + + if err != nil { + return nil, nil, err + } else if !has { + return nil, nil, errs.ErrSourceAccountNotFound + } + } + + if transaction.DestinationAccountId == oldTransaction.DestinationAccountId { + oldDestinationAccount = destinationAccount + } else { + has, err := sess.ID(oldTransaction.DestinationAccountId).Where("uid=? AND deleted=?", transaction.Uid, false).Get(oldDestinationAccount) + + if err != nil { + return nil, nil, err + } else if !has { + return nil, nil, errs.ErrDestinationAccountNotFound + } + } + return oldSourceAccount, oldDestinationAccount, nil +} + +func (s *TransactionService) isCategoryValid(sess *xorm.Session, transaction *models.Transaction) error { + if transaction.Type == models.TRANSACTION_TYPE_MODIFY_BALANCE { + if transaction.CategoryId > 0 { + return errs.ErrBalanceModificationTransactionCannotSetCategory + } + } else { + category := &models.TransactionCategory{} + has, err := sess.ID(transaction.CategoryId).Where("uid=? AND deleted=?", transaction.Uid, false).Get(category) + + if err != nil { + return err + } else if !has { + return errs.ErrTransactionCategoryNotFound + } + + if category.ParentCategoryId < 1 { + return errs.ErrCannotUsePrimaryCategoryForTransaction + } + + if (transaction.Type == models.TRANSACTION_TYPE_INCOME && category.Type != models.CATEGORY_TYPE_INCOME) || + (transaction.Type == models.TRANSACTION_TYPE_EXPENSE && category.Type != models.CATEGORY_TYPE_EXPENSE) || + (transaction.Type == models.TRANSACTION_TYPE_TRANSFER && category.Type != models.CATEGORY_TYPE_TRANSFER) { + return errs.ErrTransactionCategoryTypeInvalid + } + } + + return nil +} + +func (s *TransactionService) isTagsValid(sess *xorm.Session, transaction *models.Transaction, transactionTagIndexs []*models.TransactionTagIndex, tagIds []int64) error { + if len(transactionTagIndexs) > 0 { + var tags []*models.TransactionTag + err := sess.Where("uid=?", transaction.Uid).In("tag_id", tagIds).Find(&tags) + + if err != nil { + return err + } + + tagMap := make(map[int64]*models.TransactionTag) + + for i := 0; i < len(tags); i++ { + tagMap[tags[i].TagId] = tags[i] + } + + for i := 0; i < len(transactionTagIndexs); i++ { + if _, exists := tagMap[transactionTagIndexs[i].TagId]; !exists { + return errs.ErrTransactionTagNotFound + } + } + } + + return nil +}