From a9a37b0c97706db74d48117bcd1ed3664c316eed Mon Sep 17 00:00:00 2001 From: MaysWind Date: Mon, 17 Feb 2025 00:32:53 +0800 Subject: [PATCH] fix the Postgres database transaction cannot continue to execute after failure (#50) --- pkg/datastore/database.go | 24 ++++++++++++++++++++++- pkg/datastore/datastore_container.go | 3 ++- pkg/services/accounts.go | 21 +++++++++++++++++++- pkg/services/transactions.go | 29 +++++++++++++++++++++++----- 4 files changed, 69 insertions(+), 8 deletions(-) diff --git a/pkg/datastore/database.go b/pkg/datastore/database.go index dd1df009..7336bb8f 100644 --- a/pkg/datastore/database.go +++ b/pkg/datastore/database.go @@ -4,11 +4,13 @@ import ( "xorm.io/xorm" "github.com/mayswind/ezbookkeeping/pkg/core" + "github.com/mayswind/ezbookkeeping/pkg/settings" ) // Database represents a database instance type Database struct { - engineGroup *xorm.EngineGroup + databaseType string + engineGroup *xorm.EngineGroup } // NewSession starts a new session with the specified context @@ -41,3 +43,23 @@ func (db *Database) DoTransaction(c core.Context, fn func(sess *xorm.Session) er return nil } + +// SetSavePoint sets a save point in the current transaction for Postgres +func (db *Database) SetSavePoint(sess *xorm.Session, savePointName string) error { + if db.databaseType == settings.PostgresDbType { + _, err := sess.Exec("SAVEPOINT " + savePointName) + return err + } + + return nil +} + +// RollbackToSavePoint rolls back to the specified save point in the current transaction for Postgres +func (db *Database) RollbackToSavePoint(sess *xorm.Session, savePointName string) error { + if db.databaseType == settings.PostgresDbType { + _, err := sess.Exec("ROLLBACK TO SAVEPOINT " + savePointName) + return err + } + + return nil +} diff --git a/pkg/datastore/datastore_container.go b/pkg/datastore/datastore_container.go index 83904823..e025ac86 100644 --- a/pkg/datastore/datastore_container.go +++ b/pkg/datastore/datastore_container.go @@ -104,7 +104,8 @@ func initializeDatabase(dbConfig *settings.DatabaseConfig) (*Database, error) { engineGroup.SetConnMaxLifetime(time.Duration(dbConfig.ConnectionMaxLifeTime) * time.Second) return &Database{ - engineGroup: engineGroup, + databaseType: dbConfig.DatabaseType, + engineGroup: engineGroup, }, nil } diff --git a/pkg/services/accounts.go b/pkg/services/accounts.go index 589f24a4..110b7048 100644 --- a/pkg/services/accounts.go +++ b/pkg/services/accounts.go @@ -9,6 +9,7 @@ import ( "github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/datastore" "github.com/mayswind/ezbookkeeping/pkg/errs" + "github.com/mayswind/ezbookkeeping/pkg/log" "github.com/mayswind/ezbookkeeping/pkg/models" "github.com/mayswind/ezbookkeeping/pkg/utils" "github.com/mayswind/ezbookkeeping/pkg/uuid" @@ -270,7 +271,9 @@ func (s *AccountService) CreateAccounts(c core.Context, mainAccount *models.Acco } } - return s.UserDataDB(mainAccount.Uid).DoTransaction(c, func(sess *xorm.Session) error { + userDataDb := s.UserDataDB(mainAccount.Uid) + + return userDataDb.DoTransaction(c, func(sess *xorm.Session) error { for i := 0; i < len(allAccounts); i++ { account := allAccounts[i] _, err := sess.Insert(account) @@ -282,9 +285,25 @@ func (s *AccountService) CreateAccounts(c core.Context, mainAccount *models.Acco for i := 0; i < len(allInitTransactions); i++ { transaction := allInitTransactions[i] + + insertTransactionSavePointName := "insert_transaction" + err := userDataDb.SetSavePoint(sess, insertTransactionSavePointName) + + if err != nil { + log.Errorf(c, "[accounts.CreateAccounts] failed to set save point \"%s\", because %s", insertTransactionSavePointName, err.Error()) + return err + } + createdRows, err := sess.Insert(transaction) if err != nil || createdRows < 1 { // maybe another transaction has same time + err = userDataDb.RollbackToSavePoint(sess, insertTransactionSavePointName) + + if err != nil { + log.Errorf(c, "[accounts.CreateAccounts] failed to rollback to save point \"%s\", because %s", insertTransactionSavePointName, err.Error()) + return err + } + sameSecondLatestTransaction := &models.Transaction{} minTransactionTime := utils.GetMinTransactionTimeFromUnixTime(utils.GetUnixTimeFromTransactionTime(transaction.TransactionTime)) maxTransactionTime := utils.GetMaxTransactionTimeFromUnixTime(utils.GetUnixTimeFromTransactionTime(transaction.TransactionTime)) diff --git a/pkg/services/transactions.go b/pkg/services/transactions.go index abe34061..9ec00869 100644 --- a/pkg/services/transactions.go +++ b/pkg/services/transactions.go @@ -252,8 +252,10 @@ func (s *TransactionService) CreateTransaction(c core.Context, transaction *mode UpdatedUnixTime: now, } - return s.UserDataDB(transaction.Uid).DoTransaction(c, func(sess *xorm.Session) error { - return s.doCreateTransaction(c, sess, transaction, transactionTagIndexes, tagIds, pictureIds, pictureUpdateModel) + userDataDb := s.UserDataDB(transaction.Uid) + + return userDataDb.DoTransaction(c, func(sess *xorm.Session) error { + return s.doCreateTransaction(c, userDataDb, sess, transaction, transactionTagIndexes, tagIds, pictureIds, pictureUpdateModel) }) } @@ -355,12 +357,14 @@ func (s *TransactionService) BatchCreateTransactions(c core.Context, uid int64, allTransactionTagIds[transaction.TransactionId] = uniqueTagIds } - return s.UserDataDB(uid).DoTransaction(c, func(sess *xorm.Session) error { + userDataDb := s.UserDataDB(uid) + + return userDataDb.DoTransaction(c, func(sess *xorm.Session) error { for i := 0; i < len(transactions); i++ { transaction := transactions[i] transactionTagIndexes := allTransactionTagIndexes[transaction.TransactionId] transactionTagIds := allTransactionTagIds[transaction.TransactionId] - err := s.doCreateTransaction(c, sess, transaction, transactionTagIndexes, transactionTagIds, nil, nil) + err := s.doCreateTransaction(c, userDataDb, sess, transaction, transactionTagIndexes, transactionTagIds, nil, nil) if err != nil { transactionUnixTime := utils.GetUnixTimeFromTransactionTime(transaction.TransactionTime) @@ -1562,7 +1566,7 @@ func (s *TransactionService) GetTransactionIds(transactions []*models.Transactio return transactionIds } -func (s *TransactionService) doCreateTransaction(c core.Context, sess *xorm.Session, transaction *models.Transaction, transactionTagIndexes []*models.TransactionTagIndex, tagIds []int64, pictureIds []int64, pictureUpdateModel *models.TransactionPictureInfo) error { +func (s *TransactionService) doCreateTransaction(c core.Context, database *datastore.Database, sess *xorm.Session, transaction *models.Transaction, transactionTagIndexes []*models.TransactionTagIndex, tagIds []int64, pictureIds []int64, pictureUpdateModel *models.TransactionPictureInfo) error { // Get and verify source and destination account sourceAccount, destinationAccount, err := s.getAccountModels(sess, transaction) @@ -1646,6 +1650,14 @@ func (s *TransactionService) doCreateTransaction(c core.Context, sess *xorm.Sess relatedTransaction = s.GetRelatedTransferTransaction(transaction) } + insertTransactionSavePointName := "insert_transaction" + err = database.SetSavePoint(sess, insertTransactionSavePointName) + + if err != nil { + log.Errorf(c, "[transactions.doCreateTransaction] failed to set save point \"%s\", because %s", insertTransactionSavePointName, err.Error()) + return err + } + createdRows, err := sess.Insert(transaction) if err != nil || createdRows < 1 { // maybe another transaction has same time @@ -1655,6 +1667,13 @@ func (s *TransactionService) doCreateTransaction(c core.Context, sess *xorm.Sess log.Warnf(c, "[transactions.doCreateTransaction] cannot create trasaction, regenerate transaction time value") } + err = database.RollbackToSavePoint(sess, insertTransactionSavePointName) + + if err != nil { + log.Errorf(c, "[transactions.doCreateTransaction] failed to rollback to save point \"%s\", because %s", insertTransactionSavePointName, err.Error()) + return err + } + sameSecondLatestTransaction := &models.Transaction{} minTransactionTime := utils.GetMinTransactionTimeFromUnixTime(utils.GetUnixTimeFromTransactionTime(transaction.TransactionTime)) maxTransactionTime := utils.GetMaxTransactionTimeFromUnixTime(utils.GetUnixTimeFromTransactionTime(transaction.TransactionTime))