diff --git a/pkg/services/transactions.go b/pkg/services/transactions.go index 580fd57b..621a2386 100644 --- a/pkg/services/transactions.go +++ b/pkg/services/transactions.go @@ -369,7 +369,7 @@ func (s *TransactionService) ModifyTransaction(transaction *models.Transaction) sourceAccount := &models.Account{} destinationAccount := &models.Account{} - if transaction.SourceAccountId != oldTransaction.SourceAccountId { + if transaction.SourceAccountId != oldTransaction.SourceAccountId || transaction.SourceAmount != oldTransaction.SourceAmount { has, err := sess.ID(transaction.SourceAccountId).Where("uid=? AND deleted=?", transaction.Uid, false).Get(sourceAccount) if err != nil { @@ -377,23 +377,36 @@ func (s *TransactionService) ModifyTransaction(transaction *models.Transaction) } else if !has { return errs.ErrSourceAccountNotFound } + } + if transaction.DestinationAccountId != oldTransaction.DestinationAccountId || transaction.DestinationAmount != oldTransaction.DestinationAmount { + if transaction.DestinationAccountId == transaction.SourceAccountId { + destinationAccount = sourceAccount + } else { + has, err := sess.ID(transaction.DestinationAccountId).Where("uid=? AND deleted=?", transaction.Uid, false).Get(destinationAccount) + + if err != nil { + return err + } else if !has { + return errs.ErrDestinationAccountNotFound + } + } + } + + if transaction.SourceAccountId != oldTransaction.SourceAccountId { updateCols = append(updateCols, "source_account_id") } if transaction.DestinationAccountId != oldTransaction.DestinationAccountId { - has, err := sess.ID(transaction.DestinationAccountId).Where("uid=? AND deleted=?", transaction.Uid, false).Get(destinationAccount) - - if err != nil { - return err - } else if !has { - return errs.ErrDestinationAccountNotFound - } - updateCols = append(updateCols, "destination_account_id") } if transaction.SourceAmount != oldTransaction.SourceAmount { + if oldTransaction.Type == models.TRANSACTION_TYPE_MODIFY_BALANCE { + originalBalance := sourceAccount.Balance - oldTransaction.DestinationAmount + transaction.DestinationAmount = transaction.SourceAmount - originalBalance + } + updateCols = append(updateCols, "source_amount") } @@ -419,10 +432,14 @@ func (s *TransactionService) ModifyTransaction(transaction *models.Transaction) } if transaction.SourceAmount != oldTransaction.SourceAmount { - transaction.DestinationAmount = transaction.SourceAmount - sourceAccount.Balance + sourceAccount.UpdatedUnixTime = time.Now().Unix() + updatedRows, err := sess.ID(sourceAccount.AccountId).SetExpr("balance", fmt.Sprintf("balance-(%d)+(%d)", oldTransaction.DestinationAmount, transaction.DestinationAmount)).Cols("updated_unix_time").Where("uid=? AND deleted=?", sourceAccount.Uid, false).Update(sourceAccount) - // TODO: implement - return errs.ErrNotImplemented + if err != nil { + return err + } else if updatedRows < 1 { + return errs.ErrDatabaseOperationFailed + } } } else if oldTransaction.Type == models.TRANSACTION_TYPE_INCOME { if transaction.SourceAccountId != oldTransaction.SourceAccountId && transaction.DestinationAmount != oldTransaction.DestinationAmount {