diff --git a/pkg/api/transactions.go b/pkg/api/transactions.go index 7ae0a4f8..850a8a1d 100644 --- a/pkg/api/transactions.go +++ b/pkg/api/transactions.go @@ -1158,6 +1158,8 @@ func (a *TransactionsApi) TransactionImportHandler(c *core.WebContext) (any, *er } } + newTransactionTagIdsMap := make(map[int][]int64, len(transactionImportReq.Transactions)) + for i := 0; i < len(transactionImportReq.Transactions); i++ { transactionCreateReq := transactionImportReq.Transactions[i] tagIds, err := utils.StringArrayToInt64Array(transactionCreateReq.TagIds) @@ -1193,6 +1195,8 @@ func (a *TransactionsApi) TransactionImportHandler(c *core.WebContext) (any, *er log.Warnf(c, "[transactions.TransactionImportHandler] non-transfer transaction \"index:%d\" destination amount cannot be set", i) return nil, errs.ErrTransactionDestinationAmountCannotBeSet } + + newTransactionTagIdsMap[i] = tagIds } user, err := a.users.GetUserById(c, uid) @@ -1219,7 +1223,7 @@ func (a *TransactionsApi) TransactionImportHandler(c *core.WebContext) (any, *er newTransactions[i] = transaction } - err = a.transactions.BatchCreateTransactions(c, user.Uid, newTransactions) + err = a.transactions.BatchCreateTransactions(c, user.Uid, newTransactions, newTransactionTagIdsMap) count := len(newTransactions) if err != nil { diff --git a/pkg/cli/user_data.go b/pkg/cli/user_data.go index 2d5b2502..60703009 100644 --- a/pkg/cli/user_data.go +++ b/pkg/cli/user_data.go @@ -693,14 +693,14 @@ func (l *UserDataCli) ImportTransaction(c *core.CliContext, username string, fil return err } - newTransactions, newAccounts, newCategories, newTags, err := dataImporter.ParseImportedData(c, user, data, utils.GetTimezoneOffsetMinutes(time.Local), accountMap, categoryMap, tagMap) + parsedTransactions, newAccounts, newCategories, newTags, err := dataImporter.ParseImportedData(c, user, data, utils.GetTimezoneOffsetMinutes(time.Local), accountMap, categoryMap, tagMap) if err != nil { log.BootErrorf(c, "[user_data.ImportTransaction] failed to parse imported data for \"%s\", because %s", username, err.Error()) return err } - if len(newTransactions) < 1 { + if len(parsedTransactions) < 1 { log.BootErrorf(c, "[user_data.ImportTransaction] there are no transactions in import file") return errs.ErrOperationFailed } @@ -720,7 +720,15 @@ func (l *UserDataCli) ImportTransaction(c *core.CliContext, username string, fil return errs.ErrOperationFailed } - err = l.transactions.BatchCreateTransactions(c, user.Uid, newTransactions.ToTransactionsList()) + newTransactions := parsedTransactions.ToTransactionsList() + newTransactionTagIdsMap, err := parsedTransactions.ToTransactionTagIdsMap() + + if err != nil { + log.BootErrorf(c, "[user_data.ImportTransaction] failed to get transaction tag ids map, because %s", err.Error()) + return errs.ErrOperationFailed + } + + err = l.transactions.BatchCreateTransactions(c, user.Uid, newTransactions, newTransactionTagIdsMap) if err != nil { log.BootErrorf(c, "[user_data.ImportTransaction] failed to create transaction, because %s", err.Error()) diff --git a/pkg/models/imported_transaction.go b/pkg/models/imported_transaction.go index bb977ff0..25167cdb 100644 --- a/pkg/models/imported_transaction.go +++ b/pkg/models/imported_transaction.go @@ -115,7 +115,7 @@ func (s ImportedTransactionSlice) Less(i, j int) bool { return s[i].TransactionTime < s[j].TransactionTime } -// ToTransactionsList returns the a list of transactions +// ToTransactionsList returns a list of transaction models func (s ImportedTransactionSlice) ToTransactionsList() []*Transaction { transactions := make([]*Transaction, s.Len()) @@ -126,6 +126,23 @@ func (s ImportedTransactionSlice) ToTransactionsList() []*Transaction { return transactions } +// ToTransactionTagIdsMap returns a list of transaction tag ids +func (s ImportedTransactionSlice) ToTransactionTagIdsMap() (map[int][]int64, error) { + transactionTagIdsMap := make(map[int][]int64, s.Len()) + + for i := 0; i < s.Len(); i++ { + tagIds, err := utils.StringArrayToInt64Array(s[i].TagIds) + + if err != nil { + return nil, err + } + + transactionTagIdsMap[i] = tagIds + } + + return transactionTagIdsMap, nil +} + // ToImportTransactionResponseList returns the a list of view-objects according to imported transaction data func (s ImportedTransactionSlice) ToImportTransactionResponseList() []*ImportTransactionResponse { transactionResps := make([]*ImportTransactionResponse, 0, s.Len()) diff --git a/pkg/services/transactions.go b/pkg/services/transactions.go index 96b3643a..29be7b6b 100644 --- a/pkg/services/transactions.go +++ b/pkg/services/transactions.go @@ -272,9 +272,10 @@ func (s *TransactionService) CreateTransaction(c core.Context, transaction *mode } // BatchCreateTransactions saves new transactions to database -func (s *TransactionService) BatchCreateTransactions(c core.Context, uid int64, transactions []*models.Transaction) error { +func (s *TransactionService) BatchCreateTransactions(c core.Context, uid int64, transactions []*models.Transaction, allTagIds map[int][]int64) error { now := time.Now().Unix() - needUuidCount := uint16(0) + needTransactionUuidCount := uint16(0) + needTagIndexUuidCount := uint16(0) for i := 0; i < len(transactions); i++ { transaction := transactions[i] @@ -291,9 +292,9 @@ func (s *TransactionService) BatchCreateTransactions(c core.Context, uid int64, } if transaction.Type == models.TRANSACTION_DB_TYPE_TRANSFER_OUT || transaction.Type == models.TRANSACTION_DB_TYPE_TRANSFER_IN { - needUuidCount += 2 + needTransactionUuidCount += 2 } else { - needUuidCount++ + needTransactionUuidCount++ } transaction.TransactionTime = utils.GetMinTransactionTimeFromUnixTime(utils.GetUnixTimeFromTransactionTime(transaction.TransactionTime)) @@ -302,33 +303,78 @@ func (s *TransactionService) BatchCreateTransactions(c core.Context, uid int64, transaction.UpdatedUnixTime = now } - if needUuidCount > uint16(65535) { + for index, tagIds := range allTagIds { + if index < 0 || index >= len(transactions) { + return errs.ErrOperationFailed + } + + uniqueTagIds := utils.ToUniqueInt64Slice(tagIds) + needTagIndexUuidCount += uint16(len(uniqueTagIds)) + } + + if needTransactionUuidCount > uint16(65535) || needTagIndexUuidCount > uint16(65535) { return errs.ErrImportTooManyTransaction } - uuids := s.GenerateUuids(uuid.UUID_TYPE_TRANSACTION, needUuidCount) - uuidIndex := 0 + transactionUuids := s.GenerateUuids(uuid.UUID_TYPE_TRANSACTION, needTransactionUuidCount) + transactionUuidIndex := 0 - if len(uuids) < int(needUuidCount) { + if len(transactionUuids) < int(needTransactionUuidCount) { return errs.ErrSystemIsBusy } for i := 0; i < len(transactions); i++ { transaction := transactions[i] - transaction.TransactionId = uuids[uuidIndex] - uuidIndex++ + transaction.TransactionId = transactionUuids[transactionUuidIndex] + transactionUuidIndex++ if transaction.Type == models.TRANSACTION_DB_TYPE_TRANSFER_OUT || transaction.Type == models.TRANSACTION_DB_TYPE_TRANSFER_IN { - transaction.RelatedId = uuids[uuidIndex] - uuidIndex++ + transaction.RelatedId = transactionUuids[transactionUuidIndex] + transactionUuidIndex++ } } + tagIndexUuids := s.GenerateUuids(uuid.UUID_TYPE_TAG_INDEX, needTagIndexUuidCount) + tagIndexUuidIndex := 0 + + if len(tagIndexUuids) < int(needTagIndexUuidCount) { + return errs.ErrSystemIsBusy + } + + allTransactionTagIndexes := make(map[int64][]*models.TransactionTagIndex) + allTransactionTagIds := make(map[int64][]int64) + + for index, tagIds := range allTagIds { + transaction := transactions[index] + uniqueTagIds := utils.ToUniqueInt64Slice(tagIds) + + transactionTagIndexes := make([]*models.TransactionTagIndex, len(uniqueTagIds)) + + for i := 0; i < len(uniqueTagIds); i++ { + transactionTagIndexes[i] = &models.TransactionTagIndex{ + TagIndexId: tagIndexUuids[tagIndexUuidIndex], + Uid: transaction.Uid, + Deleted: false, + TagId: uniqueTagIds[i], + TransactionId: transaction.TransactionId, + CreatedUnixTime: now, + UpdatedUnixTime: now, + } + + tagIndexUuidIndex++ + } + + allTransactionTagIndexes[transaction.TransactionId] = transactionTagIndexes + allTransactionTagIds[transaction.TransactionId] = uniqueTagIds + } + return s.UserDataDB(uid).DoTransaction(c, func(sess *xorm.Session) error { for i := 0; i < len(transactions); i++ { transaction := transactions[i] - err := s.doCreateTransaction(sess, transaction, nil, nil, nil, nil) + transactionTagIndexes := allTransactionTagIndexes[transaction.TransactionId] + transactionTagIds := allTransactionTagIds[transaction.TransactionId] + err := s.doCreateTransaction(sess, transaction, transactionTagIndexes, transactionTagIds, nil, nil) if err != nil { transactionUnixTime := utils.GetUnixTimeFromTransactionTime(transaction.TransactionTime)