diff --git a/pkg/services/accounts.go b/pkg/services/accounts.go index a77f83ff..573d1459 100644 --- a/pkg/services/accounts.go +++ b/pkg/services/accounts.go @@ -267,6 +267,7 @@ func (s *AccountService) DeleteAccount(uid int64, accountId int64) error { now := time.Now().Unix() updateModel := &models.Account{ + Balance: 0, Deleted: true, DeletedUnixTime: now, } @@ -288,7 +289,7 @@ func (s *AccountService) DeleteAccount(uid int64, accountId int64) error { } var relatedTransactionsByAccount []*models.Transaction - err = sess.Cols("uid", "deleted", "account_id", "type").Where("uid=? AND deleted=?", uid, false).In("account_id", accountAndSubAccountIds).Limit(len(accountAndSubAccounts) + 1).Find(&relatedTransactionsByAccount) + err = sess.Cols("transaction_id", "uid", "deleted", "account_id", "type").Where("uid=? AND deleted=?", uid, false).In("account_id", accountAndSubAccountIds).Limit(len(accountAndSubAccounts) + 1).Find(&relatedTransactionsByAccount) if err != nil { return err @@ -310,7 +311,7 @@ func (s *AccountService) DeleteAccount(uid int64, accountId int64) error { } } - deletedRows, err := sess.Cols("deleted", "deleted_unix_time").Where("uid=? AND deleted=?", uid, false).In("account_id", accountAndSubAccountIds).Update(updateModel) + deletedRows, err := sess.Cols("balance", "deleted", "deleted_unix_time").Where("uid=? AND deleted=?", uid, false).In("account_id", accountAndSubAccountIds).Update(updateModel) if err != nil { return err @@ -318,6 +319,27 @@ func (s *AccountService) DeleteAccount(uid int64, accountId int64) error { return errs.ErrAccountNotFound } + if len(relatedTransactionsByAccount) > 0 { + updateTransaction := &models.Transaction{ + Deleted: true, + DeletedUnixTime: now, + } + + transactionIds := make([]int64, len(relatedTransactionsByAccount)) + + for i := 0; i < len(relatedTransactionsByAccount); i++ { + transactionIds[i] = relatedTransactionsByAccount[i].TransactionId + } + + deletedTransactionRows, err := sess.Cols("deleted", "deleted_unix_time").Where("uid=? AND deleted=?", uid, false).In("transaction_id", transactionIds).Update(updateTransaction) + + if err != nil { + return err + } else if deletedTransactionRows < int64(len(transactionIds)) { + return errs.ErrDatabaseOperationFailed + } + } + return err }) }