diff --git a/pkg/mcp/query_transactions_tool_handler.go b/pkg/mcp/query_transactions_tool_handler.go index 25094694..7ff031ab 100644 --- a/pkg/mcp/query_transactions_tool_handler.go +++ b/pkg/mcp/query_transactions_tool_handler.go @@ -14,6 +14,8 @@ import ( "github.com/mayswind/ezbookkeeping/pkg/utils" ) +const pageCountForLoadTransactions = 1000 + // MCPQueryTransactionsRequest represents all parameters of the query transactions request type MCPQueryTransactionsRequest struct { StartTime string `json:"start_time" jsonschema:"format=date-time" jsonschema_description:"Start time for the query in RFC 3339 format (e.g. 2023-01-01T12:00:00Z)"` @@ -160,7 +162,7 @@ func (h *mcpQueryTransactionsToolHandler) Handle(c *core.WebContext, callToolReq return nil, nil, err } - transactions, err := services.GetTransactionService().GetTransactionsByMaxTime(c, uid, maxTransactionTime, minTransactionTime, transactionType, filterCategoryIds, filterAccountIds, nil, false, "", queryTransactionsRequest.Keyword, false, queryTransactionsRequest.Page, queryTransactionsRequest.Count, false, true) + transactions, err := services.GetTransactionService().GetTransactionsByMaxTimeUpToCount(c, uid, maxTransactionTime, minTransactionTime, transactionType, filterCategoryIds, filterAccountIds, nil, false, "", queryTransactionsRequest.Keyword, false, queryTransactionsRequest.Page, queryTransactionsRequest.Count, pageCountForLoadTransactions, false, true) structuredResponse, response, err := h.createNewMCPQueryTransactionsResponse(c, &queryTransactionsRequest, transactions, totalCount, services.GetAccountService().GetAccountMapByList(allAccounts), services.GetTransactionCategoryService().GetCategoryMapByList(allCategories)) if err != nil { diff --git a/pkg/services/transactions.go b/pkg/services/transactions.go index c602f4b6..f370b96c 100644 --- a/pkg/services/transactions.go +++ b/pkg/services/transactions.go @@ -317,12 +317,107 @@ func (s *TransactionService) GetAllAccountsDailyOpeningAndClosingBalance(c core. return accountDailyBalances, nil } +// GetTransactionsByMaxTimeUpToCount returns transactions before given time and up to given count +func (s *TransactionService) GetTransactionsByMaxTimeUpToCount(c core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionType, categoryIds []int64, accountIds []int64, tagFilters []*models.TransactionTagFilter, noTags bool, amountFilter string, keyword string, mustHavePictures bool, page int32, count int32, pageCount int32, needOneMoreItem bool, noDuplicated bool) ([]*models.Transaction, error) { + if maxTransactionTime <= 0 { + maxTransactionTime = utils.GetMaxTransactionTimeFromUnixTime(time.Now().Unix()) + } + + if page < 0 { + return nil, errs.ErrPageIndexInvalid + } else if page == 0 { + page = 1 + } + + if count < 1 { + return nil, errs.ErrPageCountInvalid + } + + finalExpectedCount := int(count) + + if needOneMoreItem { + finalExpectedCount++ + } + + var allTransactions []*models.Transaction + startOffset := int((page - 1) * count) + firstFetchCount := int(pageCount) + + if finalExpectedCount < firstFetchCount { + firstFetchCount = finalExpectedCount + } + + transactions, err := s.getTransactionsByMaxTimeWithOffset(c, uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, tagFilters, noTags, amountFilter, keyword, mustHavePictures, startOffset, firstFetchCount, noDuplicated) + + if err != nil { + return nil, err + } + + allTransactions = append(allTransactions, transactions...) + + if len(transactions) < firstFetchCount { + return allTransactions, nil + } + + maxTransactionTime = transactions[len(transactions)-1].TransactionTime - 1 + + for len(allTransactions) < finalExpectedCount && maxTransactionTime > 0 { + remainingCount := finalExpectedCount - len(allTransactions) + fetchCount := int(pageCount) + + if remainingCount < fetchCount { + fetchCount = remainingCount + } + + transactions, err := s.GetTransactionsByMaxTime(c, uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, tagFilters, noTags, amountFilter, keyword, mustHavePictures, 1, int32(fetchCount), false, noDuplicated) + + if err != nil { + return nil, err + } + + allTransactions = append(allTransactions, transactions...) + + if len(transactions) < fetchCount { + break + } + + maxTransactionTime = transactions[len(transactions)-1].TransactionTime - 1 + } + + return allTransactions, nil +} + // GetTransactionsByMaxTime returns transactions before given time func (s *TransactionService) GetTransactionsByMaxTime(c core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionType, categoryIds []int64, accountIds []int64, tagFilters []*models.TransactionTagFilter, noTags bool, amountFilter string, keyword string, mustHavePictures bool, page int32, count int32, needOneMoreItem bool, noDuplicated bool) ([]*models.Transaction, error) { if uid <= 0 { return nil, errs.ErrUserIdInvalid } + if page < 0 { + return nil, errs.ErrPageIndexInvalid + } else if page == 0 { + page = 1 + } + + if count < 1 { + return nil, errs.ErrPageCountInvalid + } + + finalCount := int(count) + + if needOneMoreItem { + finalCount++ + } + + return s.getTransactionsByMaxTimeWithOffset(c, uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, tagFilters, noTags, amountFilter, keyword, mustHavePictures, int(count*(page-1)), finalCount, noDuplicated) +} + +// getTransactionsByMaxTimeWithOffset returns transactions before given time with explicit offset and limit +func (s *TransactionService) getTransactionsByMaxTimeWithOffset(c core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionType, categoryIds []int64, accountIds []int64, tagFilters []*models.TransactionTagFilter, noTags bool, amountFilter string, keyword string, mustHavePictures bool, offset int, limit int, noDuplicated bool) ([]*models.Transaction, error) { + if uid <= 0 { + return nil, errs.ErrUserIdInvalid + } + var err error var transactionDbType models.TransactionDbType = 0 @@ -334,30 +429,14 @@ func (s *TransactionService) GetTransactionsByMaxTime(c core.Context, uid int64, } } - if page < 0 { - return nil, errs.ErrPageIndexInvalid - } else if page == 0 { - page = 1 - } - - if count < 1 { - return nil, errs.ErrPageCountInvalid - } - var transactions []*models.Transaction - actualCount := count - - if needOneMoreItem { - actualCount++ - } - condition, conditionParams := s.buildTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionDbType, categoryIds, accountIds, tagFilters, amountFilter, keyword, noDuplicated) sess := s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...) sess = s.appendFilterTagIdsConditionToQuery(sess, uid, maxTransactionTime, minTransactionTime, tagFilters, noTags) sess = s.appendFilterPicturesConditionToQuery(sess, uid, mustHavePictures) - err = sess.Limit(int(actualCount), int(count*(page-1))).OrderBy("transaction_time desc").Find(&transactions) + err = sess.Limit(limit, offset).OrderBy("transaction_time desc").Find(&transactions) return transactions, err }