diff --git a/conf/ezbookkeeping.ini b/conf/ezbookkeeping.ini index 7408e690..fc21ba18 100644 --- a/conf/ezbookkeeping.ini +++ b/conf/ezbookkeeping.ini @@ -244,6 +244,7 @@ max_user_avatar_size = 1048576 # 10: Export Transactions # 11: Clear All Data # 12: Sync Application Settings +# 13: MCP (Model Context Protocol) Access default_feature_restrictions = [data] diff --git a/pkg/api/model_context_protocols.go b/pkg/api/model_context_protocols.go index ec6091e5..90d01a1f 100644 --- a/pkg/api/model_context_protocols.go +++ b/pkg/api/model_context_protocols.go @@ -7,6 +7,7 @@ import ( "github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/errs" + "github.com/mayswind/ezbookkeeping/pkg/log" "github.com/mayswind/ezbookkeeping/pkg/mcp" "github.com/mayswind/ezbookkeeping/pkg/services" "github.com/mayswind/ezbookkeeping/pkg/settings" @@ -52,6 +53,18 @@ func (a *ModelContextProtocolAPI) InitializeHandler(c *core.WebContext, jsonRPCR return nil, errs.ErrIncompleteOrIncorrectSubmission } + uid := c.GetCurrentUid() + user, err := a.users.GetUserById(c, uid) + + if err != nil { + log.Warnf(c, "[model_context_protocols.InitializeHandler] failed to get user \"uid:%d\" info, because %s", uid, err.Error()) + return nil, errs.ErrUserNotFound + } + + if user.FeatureRestriction.Contains(core.USER_FEATURE_RESTRICTION_TYPE_MCP_ACCESS) { + return nil, errs.ErrNotPermittedToPerformThisAction + } + protocolVersion := mcp.MCPProtocolVersion(initRequest.ProtocolVersion) _, exists := mcp.SupportedMCPVersion[protocolVersion] @@ -78,6 +91,18 @@ func (a *ModelContextProtocolAPI) InitializeHandler(c *core.WebContext, jsonRPCR // ListResourcesHandler returns the list of resources for model context protocol func (a *ModelContextProtocolAPI) ListResourcesHandler(c *core.WebContext, jsonRPCRequest *core.JSONRPCRequest) (any, *errs.Error) { + uid := c.GetCurrentUid() + user, err := a.users.GetUserById(c, uid) + + if err != nil { + log.Warnf(c, "[model_context_protocols.ListResourcesHandler] failed to get user \"uid:%d\" info, because %s", uid, err.Error()) + return nil, errs.ErrUserNotFound + } + + if user.FeatureRestriction.Contains(core.USER_FEATURE_RESTRICTION_TYPE_MCP_ACCESS) { + return nil, errs.ErrNotPermittedToPerformThisAction + } + listResourcesResp := mcp.MCPListResourcesResponse{ Resources: make([]*mcp.MCPResource, 0), } @@ -97,11 +122,35 @@ func (a *ModelContextProtocolAPI) ReadResourceHandler(c *core.WebContext, jsonRP return nil, errs.ErrIncompleteOrIncorrectSubmission } + uid := c.GetCurrentUid() + user, err := a.users.GetUserById(c, uid) + + if err != nil { + log.Warnf(c, "[model_context_protocols.ReadResourceHandler] failed to get user \"uid:%d\" info, because %s", uid, err.Error()) + return nil, errs.ErrUserNotFound + } + + if user.FeatureRestriction.Contains(core.USER_FEATURE_RESTRICTION_TYPE_MCP_ACCESS) { + return nil, errs.ErrNotPermittedToPerformThisAction + } + return nil, errs.ErrApiNotFound } // ListToolsHandler returns the list of tools for model context protocol func (a *ModelContextProtocolAPI) ListToolsHandler(c *core.WebContext, jsonRPCRequest *core.JSONRPCRequest) (any, *errs.Error) { + uid := c.GetCurrentUid() + user, err := a.users.GetUserById(c, uid) + + if err != nil { + log.Warnf(c, "[model_context_protocols.ListToolsHandler] failed to get user \"uid:%d\" info, because %s", uid, err.Error()) + return nil, errs.ErrUserNotFound + } + + if user.FeatureRestriction.Contains(core.USER_FEATURE_RESTRICTION_TYPE_MCP_ACCESS) { + return nil, errs.ErrNotPermittedToPerformThisAction + } + mcpVersion := a.getMCPVersion(c) toolsInfo := mcp.Container.GetMCPTools() finalToolsInfos := make([]*mcp.MCPTool, len(toolsInfo)) @@ -128,6 +177,18 @@ func (a *ModelContextProtocolAPI) ListToolsHandler(c *core.WebContext, jsonRPCRe // CallToolHandler returns the result of calling a specific tool for model context protocol func (a *ModelContextProtocolAPI) CallToolHandler(c *core.WebContext, jsonRPCRequest *core.JSONRPCRequest) (any, *errs.Error) { + uid := c.GetCurrentUid() + user, err := a.users.GetUserById(c, uid) + + if err != nil { + log.Warnf(c, "[model_context_protocols.CallToolHandler] failed to get user \"uid:%d\" info, because %s", uid, err.Error()) + return nil, errs.ErrUserNotFound + } + + if user.FeatureRestriction.Contains(core.USER_FEATURE_RESTRICTION_TYPE_MCP_ACCESS) { + return nil, errs.ErrNotPermittedToPerformThisAction + } + var callToolReq mcp.MCPCallToolRequest if jsonRPCRequest.Params != nil { @@ -138,10 +199,10 @@ func (a *ModelContextProtocolAPI) CallToolHandler(c *core.WebContext, jsonRPCReq return nil, errs.ErrIncompleteOrIncorrectSubmission } - result, err := mcp.Container.HandleTool(c, &callToolReq, a.CurrentConfig(), a) + result, err := mcp.Container.HandleTool(c, &callToolReq, user, a.CurrentConfig(), a) if err != nil { - return nil, err + return nil, errs.Or(err, errs.ErrOperationFailed) } return result, nil diff --git a/pkg/api/tokens.go b/pkg/api/tokens.go index e6595232..c0cbed5b 100644 --- a/pkg/api/tokens.go +++ b/pkg/api/tokens.go @@ -99,6 +99,10 @@ func (a *TokensApi) TokenGenerateMCPHandler(c *core.WebContext) (any, *errs.Erro return nil, errs.ErrUserNotFound } + if user.FeatureRestriction.Contains(core.USER_FEATURE_RESTRICTION_TYPE_MCP_ACCESS) { + return false, errs.ErrNotPermittedToPerformThisAction + } + if !a.users.IsPasswordEqualsUserPassword(generateMCPTokenReq.Password, user) { return nil, errs.ErrUserPasswordWrong } diff --git a/pkg/cli/user_data.go b/pkg/cli/user_data.go index 13c8417c..0976da7f 100644 --- a/pkg/cli/user_data.go +++ b/pkg/cli/user_data.go @@ -422,6 +422,10 @@ func (l *UserDataCli) CreateNewUserToken(c *core.CliContext, username string, to var tokenRecord *models.TokenRecord if tokenType == "mcp" { + if user.FeatureRestriction.Contains(core.USER_FEATURE_RESTRICTION_TYPE_MCP_ACCESS) { + return nil, "", errs.ErrNotPermittedToPerformThisAction + } + token, tokenRecord, err = l.tokens.CreateMCPTokenViaCli(c, user) } else if tokenType == "normal" { token, tokenRecord, err = l.tokens.CreateTokenViaCli(c, user) diff --git a/pkg/core/user_feature_restriction.go b/pkg/core/user_feature_restriction.go index 3285b0a0..cec37c2b 100644 --- a/pkg/core/user_feature_restriction.go +++ b/pkg/core/user_feature_restriction.go @@ -88,10 +88,11 @@ const ( USER_FEATURE_RESTRICTION_TYPE_EXPORT_TRANSACTION UserFeatureRestrictionType = 10 USER_FEATURE_RESTRICTION_TYPE_CLEAR_ALL_DATA UserFeatureRestrictionType = 11 USER_FEATURE_RESTRICTION_TYPE_SYNC_APPLICATION_SETTINGS UserFeatureRestrictionType = 12 + USER_FEATURE_RESTRICTION_TYPE_MCP_ACCESS UserFeatureRestrictionType = 13 ) const userFeatureRestrictionTypeMinValue UserFeatureRestrictionType = USER_FEATURE_RESTRICTION_TYPE_UPDATE_PASSWORD -const userFeatureRestrictionTypeMaxValue UserFeatureRestrictionType = USER_FEATURE_RESTRICTION_TYPE_SYNC_APPLICATION_SETTINGS +const userFeatureRestrictionTypeMaxValue UserFeatureRestrictionType = USER_FEATURE_RESTRICTION_TYPE_MCP_ACCESS // String returns a textual representation of the restriction type of user features func (t UserFeatureRestrictionType) String() string { @@ -120,6 +121,8 @@ func (t UserFeatureRestrictionType) String() string { return "Clear All Data" case USER_FEATURE_RESTRICTION_TYPE_SYNC_APPLICATION_SETTINGS: return "Sync Application Settings" + case USER_FEATURE_RESTRICTION_TYPE_MCP_ACCESS: + return "MCP (Model Context Protocol) Access" default: return fmt.Sprintf("Invalid(%d)", int(t)) } diff --git a/pkg/mcp/add_transaction.go b/pkg/mcp/add_transaction.go index 596613fd..44e3dd57 100644 --- a/pkg/mcp/add_transaction.go +++ b/pkg/mcp/add_transaction.go @@ -63,7 +63,7 @@ func (h *mcpAddTransactionToolHandler) OutputType() reflect.Type { } // Handle processes the MCP call tool request and returns the response -func (h *mcpAddTransactionToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, *errs.Error) { +func (h *mcpAddTransactionToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, user *models.User, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, error) { var addTransactionRequest MCPAddTransactionRequest if callToolReq.Arguments != nil { @@ -84,22 +84,12 @@ func (h *mcpAddTransactionToolHandler) Handle(c *core.WebContext, callToolReq *M return nil, nil, errs.ErrTransactionHasTooManyTags } - uid := c.GetCurrentUid() - user, err := services.GetUserService().GetUserById(c, uid) - - if err != nil { - if !errs.IsCustomError(err) { - log.Errorf(c, "[add_transaction.Handle] failed to get user, because %s", err.Error()) - } - - return nil, nil, errs.ErrUserNotFound - } - + uid := user.Uid allAccounts, err := services.GetAccountService().GetAllAccountsByUid(c, uid) if err != nil { log.Warnf(c, "[add_transaction.Handle] get account error, because %s", err.Error()) - return nil, nil, errs.Or(err, errs.ErrOperationFailed) + return nil, nil, err } accountsMap := services.GetAccountService().GetVisibleAccountNameMapByList(allAccounts) @@ -128,7 +118,7 @@ func (h *mcpAddTransactionToolHandler) Handle(c *core.WebContext, callToolReq *M if err != nil { log.Warnf(c, "[add_transaction.Handle] get transaction category error, because %s", err.Error()) - return nil, nil, errs.Or(err, errs.ErrOperationFailed) + return nil, nil, err } categoriesMap := services.GetTransactionCategoryService().GetVisibleCategoryNameMapByList(allCategories) @@ -146,7 +136,7 @@ func (h *mcpAddTransactionToolHandler) Handle(c *core.WebContext, callToolReq *M if err != nil { log.Warnf(c, "[add_transaction.Handle] get transaction tag ids error, because %s", err.Error()) - return nil, nil, errs.Or(err, errs.ErrOperationFailed) + return nil, nil, err } tagMaps := services.GetTransactionTagService().GetTagNameMapByList(allTags) @@ -173,7 +163,7 @@ func (h *mcpAddTransactionToolHandler) Handle(c *core.WebContext, callToolReq *M if err != nil { log.Errorf(c, "[add_transaction.Handle] failed to create transaction \"id:%d\" for user \"uid:%d\", because %s", transaction.TransactionId, uid, err.Error()) - return nil, nil, errs.Or(err, errs.ErrOperationFailed) + return nil, nil, err } log.Infof(c, "[add_transaction.Handle] user \"uid:%d\" has created a new transaction \"id:%d\" successfully", uid, transaction.TransactionId) @@ -193,7 +183,7 @@ func (h *mcpAddTransactionToolHandler) Handle(c *core.WebContext, callToolReq *M structuredResponse, response, err := h.createNewMCPAddTransactionResponse(c, transaction, newAccounts, false) if err != nil { - return nil, nil, errs.Or(err, errs.ErrOperationFailed) + return nil, nil, err } return structuredResponse, response, nil @@ -215,7 +205,7 @@ func (h *mcpAddTransactionToolHandler) Handle(c *core.WebContext, callToolReq *M structuredResponse, response, err := h.createNewMCPAddTransactionResponse(c, transaction, newAccounts, true) if err != nil { - return nil, nil, errs.Or(err, errs.ErrOperationFailed) + return nil, nil, err } return structuredResponse, response, nil diff --git a/pkg/mcp/handler.go b/pkg/mcp/handler.go index 06510000..6d955a12 100644 --- a/pkg/mcp/handler.go +++ b/pkg/mcp/handler.go @@ -4,7 +4,7 @@ import ( "reflect" "github.com/mayswind/ezbookkeeping/pkg/core" - "github.com/mayswind/ezbookkeeping/pkg/errs" + "github.com/mayswind/ezbookkeeping/pkg/models" "github.com/mayswind/ezbookkeeping/pkg/services" "github.com/mayswind/ezbookkeeping/pkg/settings" ) @@ -33,5 +33,5 @@ type MCPToolHandler[T MCPTextContent | MCPImageContent | MCPAudioContent | MCPRe OutputType() reflect.Type // Handle processes the MCP call tool request and returns the response - Handle(*core.WebContext, *MCPCallToolRequest, *settings.Config, MCPAvailableServices) (any, []*T, *errs.Error) + Handle(*core.WebContext, *MCPCallToolRequest, *models.User, *settings.Config, MCPAvailableServices) (any, []*T, error) } diff --git a/pkg/mcp/mcp_container.go b/pkg/mcp/mcp_container.go index 11389ccc..cf149c54 100644 --- a/pkg/mcp/mcp_container.go +++ b/pkg/mcp/mcp_container.go @@ -6,6 +6,7 @@ import ( "github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/errs" + "github.com/mayswind/ezbookkeeping/pkg/models" "github.com/mayswind/ezbookkeeping/pkg/settings" ) @@ -34,25 +35,25 @@ func (c *MCPContainer) GetMCPTools() []*MCPTool { } // HandleTool returns the result of the MCP tool handler based on the tool name -func (c *MCPContainer) HandleTool(ctx *core.WebContext, callToolReq *MCPCallToolRequest, currentConfig *settings.Config, services MCPAvailableServices) (any, *errs.Error) { +func (c *MCPContainer) HandleTool(ctx *core.WebContext, callToolReq *MCPCallToolRequest, user *models.User, currentConfig *settings.Config, services MCPAvailableServices) (any, error) { if handler, exists := c.mcpTextContentTools.Get(callToolReq.Name); exists { - return handleTool(ctx, handler, currentConfig, services, callToolReq) + return handleTool(ctx, handler, currentConfig, services, callToolReq, user) } if handler, exists := c.mcpImageContentTools.Get(callToolReq.Name); exists { - return handleTool(ctx, handler, currentConfig, services, callToolReq) + return handleTool(ctx, handler, currentConfig, services, callToolReq, user) } if handler, exists := c.mcpAudioContentTools.Get(callToolReq.Name); exists { - return handleTool(ctx, handler, currentConfig, services, callToolReq) + return handleTool(ctx, handler, currentConfig, services, callToolReq, user) } if handler, exists := c.mcpResourceLinkTools.Get(callToolReq.Name); exists { - return handleTool(ctx, handler, currentConfig, services, callToolReq) + return handleTool(ctx, handler, currentConfig, services, callToolReq, user) } if handler, exists := c.mcpEmbeddedResourceTools.Get(callToolReq.Name); exists { - return handleTool(ctx, handler, currentConfig, services, callToolReq) + return handleTool(ctx, handler, currentConfig, services, callToolReq, user) } return nil, errs.ErrApiNotFound @@ -109,8 +110,8 @@ func registerMCPToolHandler[T MCPTextContent | MCPImageContent | MCPAudioContent c.mcpTools = append(c.mcpTools, createNewMCPToolInfo(handler.Name(), handler)) } -func handleTool[T MCPTextContent | MCPImageContent | MCPAudioContent | MCPResourceLink | MCPEmbeddedResource](ctx *core.WebContext, handler MCPToolHandler[T], currentConfig *settings.Config, services MCPAvailableServices, callToolReq *MCPCallToolRequest) (any, *errs.Error) { - structuredResponse, result, err := handler.Handle(ctx, callToolReq, currentConfig, services) +func handleTool[T MCPTextContent | MCPImageContent | MCPAudioContent | MCPResourceLink | MCPEmbeddedResource](ctx *core.WebContext, handler MCPToolHandler[T], currentConfig *settings.Config, services MCPAvailableServices, callToolReq *MCPCallToolRequest, user *models.User) (any, error) { + structuredResponse, result, err := handler.Handle(ctx, callToolReq, user, currentConfig, services) if err != nil { return nil, errs.Or(err, errs.ErrOperationFailed) diff --git a/pkg/mcp/query_all_accounts.go b/pkg/mcp/query_all_accounts.go index 1a1639f7..a9e54db9 100644 --- a/pkg/mcp/query_all_accounts.go +++ b/pkg/mcp/query_all_accounts.go @@ -5,7 +5,6 @@ import ( "reflect" "github.com/mayswind/ezbookkeeping/pkg/core" - "github.com/mayswind/ezbookkeeping/pkg/errs" "github.com/mayswind/ezbookkeeping/pkg/log" "github.com/mayswind/ezbookkeeping/pkg/models" "github.com/mayswind/ezbookkeeping/pkg/settings" @@ -49,19 +48,19 @@ func (h *mcpQueryAllAccountsToolHandler) OutputType() reflect.Type { } // Handle processes the MCP call tool request and returns the response -func (h *mcpQueryAllAccountsToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, *errs.Error) { - uid := c.GetCurrentUid() +func (h *mcpQueryAllAccountsToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, user *models.User, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, error) { + uid := user.Uid accounts, err := services.GetAccountService().GetAllAccountsByUid(c, uid) if err != nil { log.Errorf(c, "[query_all_accounts.Handle] failed to get all accounts for user \"uid:%d\", because %s", uid, err.Error()) - return nil, nil, errs.Or(err, errs.ErrOperationFailed) + return nil, nil, err } structuredResponse, response, err := h.createNewMCPQueryAllAccountsResponse(c, accounts) if err != nil { - return nil, nil, errs.Or(err, errs.ErrOperationFailed) + return nil, nil, err } return structuredResponse, response, nil diff --git a/pkg/mcp/query_all_transaction_categories.go b/pkg/mcp/query_all_transaction_categories.go index b01f5a54..6c8febc5 100644 --- a/pkg/mcp/query_all_transaction_categories.go +++ b/pkg/mcp/query_all_transaction_categories.go @@ -5,7 +5,6 @@ import ( "reflect" "github.com/mayswind/ezbookkeeping/pkg/core" - "github.com/mayswind/ezbookkeeping/pkg/errs" "github.com/mayswind/ezbookkeeping/pkg/log" "github.com/mayswind/ezbookkeeping/pkg/models" "github.com/mayswind/ezbookkeeping/pkg/settings" @@ -43,19 +42,19 @@ func (h *mcpQueryAllTransactionCategoriesToolHandler) OutputType() reflect.Type } // Handle processes the MCP call tool request and returns the response -func (h *mcpQueryAllTransactionCategoriesToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, *errs.Error) { - uid := c.GetCurrentUid() +func (h *mcpQueryAllTransactionCategoriesToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, user *models.User, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, error) { + uid := user.Uid categories, err := services.GetTransactionCategoryService().GetAllCategoriesByUid(c, uid, 0, -1) if err != nil { log.Errorf(c, "[query_all_transaction_categories.Handle] failed to get categories for user \"uid:%d\", because %s", uid, err.Error()) - return nil, nil, errs.Or(err, errs.ErrOperationFailed) + return nil, nil, err } structuredResponse, response, err := h.createNewMCPQueryAllTransactionCategoriesResponse(c, categories) if err != nil { - return nil, nil, errs.Or(err, errs.ErrOperationFailed) + return nil, nil, err } return structuredResponse, response, nil diff --git a/pkg/mcp/query_all_transaction_tags.go b/pkg/mcp/query_all_transaction_tags.go index 08cf1e07..ee46881d 100644 --- a/pkg/mcp/query_all_transaction_tags.go +++ b/pkg/mcp/query_all_transaction_tags.go @@ -5,8 +5,8 @@ import ( "reflect" "github.com/mayswind/ezbookkeeping/pkg/core" - "github.com/mayswind/ezbookkeeping/pkg/errs" "github.com/mayswind/ezbookkeeping/pkg/log" + "github.com/mayswind/ezbookkeeping/pkg/models" "github.com/mayswind/ezbookkeeping/pkg/settings" ) @@ -40,13 +40,13 @@ func (h *mcpQueryAllTransactionTagsToolHandler) OutputType() reflect.Type { } // Handle processes the MCP call tool request and returns the response -func (h *mcpQueryAllTransactionTagsToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, *errs.Error) { - uid := c.GetCurrentUid() +func (h *mcpQueryAllTransactionTagsToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, user *models.User, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, error) { + uid := user.Uid tags, err := services.GetTransactionTagService().GetAllTagsByUid(c, uid) if err != nil { log.Errorf(c, "[query_all_transaction_tags.Handle] failed to get tags for user \"uid:%d\", because %s", uid, err.Error()) - return nil, nil, errs.Or(err, errs.ErrOperationFailed) + return nil, nil, err } tagNames := make([]string, len(tags)) @@ -62,7 +62,7 @@ func (h *mcpQueryAllTransactionTagsToolHandler) Handle(c *core.WebContext, callT content, err := json.Marshal(response) if err != nil { - return nil, nil, errs.Or(err, errs.ErrOperationFailed) + return nil, nil, err } return response, []*MCPTextContent{ diff --git a/pkg/mcp/query_latest_exchange_rates.go b/pkg/mcp/query_latest_exchange_rates.go index 8f1099dc..a90fe32a 100644 --- a/pkg/mcp/query_latest_exchange_rates.go +++ b/pkg/mcp/query_latest_exchange_rates.go @@ -57,7 +57,7 @@ func (h *mcpQueryLatestExchangeRatesToolHandler) OutputType() reflect.Type { } // Handle processes the MCP call tool request and returns the response -func (h *mcpQueryLatestExchangeRatesToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, *errs.Error) { +func (h *mcpQueryLatestExchangeRatesToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, user *models.User, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, error) { var exchangeRatesRequest MCPQueryExchangeRatesRequest if callToolReq.Arguments != nil { @@ -74,16 +74,16 @@ func (h *mcpQueryLatestExchangeRatesToolHandler) Handle(c *core.WebContext, call return nil, nil, errs.ErrInvalidExchangeRatesDataSource } - exchangeRateResponse, err := dataSource.GetLatestExchangeRates(c, c.GetCurrentUid(), currentConfig) + exchangeRateResponse, err := dataSource.GetLatestExchangeRates(c, user.Uid, currentConfig) if err != nil { - return nil, nil, errs.Or(err, errs.ErrOperationFailed) + return nil, nil, err } structuredResponse, response, err := h.createNewMCPQueryExchangeRatesResponse(exchangeRatesRequest.Currencies, exchangeRateResponse) if err != nil { - return nil, nil, errs.Or(err, errs.ErrOperationFailed) + return nil, nil, err } return structuredResponse, response, nil diff --git a/pkg/mcp/query_transactions.go b/pkg/mcp/query_transactions.go index 7959e7f3..38a36742 100644 --- a/pkg/mcp/query_transactions.go +++ b/pkg/mcp/query_transactions.go @@ -74,7 +74,7 @@ func (h *mcpQueryTransactionsToolHandler) OutputType() reflect.Type { } // Handle processes the MCP call tool request and returns the response -func (h *mcpQueryTransactionsToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, *errs.Error) { +func (h *mcpQueryTransactionsToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, user *models.User, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, error) { var queryTransactionsRequest MCPQueryTransactionsRequest if callToolReq.Arguments != nil { @@ -85,7 +85,7 @@ func (h *mcpQueryTransactionsToolHandler) Handle(c *core.WebContext, callToolReq return nil, nil, errs.ErrIncompleteOrIncorrectSubmission } - uid := c.GetCurrentUid() + uid := user.Uid maxTime, err := utils.ParseFromLongDateTimeWithTimezoneRFC3339Format(queryTransactionsRequest.EndTime) if err != nil { @@ -123,7 +123,7 @@ func (h *mcpQueryTransactionsToolHandler) Handle(c *core.WebContext, callToolReq if err != nil { log.Warnf(c, "[add_transaction.Handle] get account error, because %s", err.Error()) - return nil, nil, errs.Or(err, errs.ErrOperationFailed) + return nil, nil, err } accountsMap := services.GetAccountService().GetVisibleAccountNameMapByList(allAccounts) @@ -141,7 +141,7 @@ func (h *mcpQueryTransactionsToolHandler) Handle(c *core.WebContext, callToolReq if err != nil { log.Warnf(c, "[add_transaction.Handle] get transaction category error, because %s", err.Error()) - return nil, nil, errs.Or(err, errs.ErrOperationFailed) + return nil, nil, err } categoriesMap := services.GetTransactionCategoryService().GetVisibleCategoryNameMapByList(allCategories) @@ -159,14 +159,14 @@ func (h *mcpQueryTransactionsToolHandler) Handle(c *core.WebContext, callToolReq if err != nil { log.Errorf(c, "[transactions.TransactionListHandler] failed to get transaction count for user \"uid:%d\", because %s", uid, err.Error()) - return nil, nil, errs.Or(err, errs.ErrOperationFailed) + return nil, nil, err } transactions, err := services.GetTransactionService().GetTransactionsByMaxTime(c, uid, maxTransactionTime, minTransactionTime, transactionType, filterCategoryIds, filterAccountIds, nil, false, models.TRANSACTION_TAG_FILTER_HAS_ANY, "", queryTransactionsRequest.Keyword, queryTransactionsRequest.Page, queryTransactionsRequest.Count, false, true) structuredResponse, response, err := h.createNewMCPQueryTransactionsResponse(c, &queryTransactionsRequest, transactions, totalCount, services.GetAccountService().GetAccountMapByList(allAccounts), services.GetTransactionCategoryService().GetCategoryMapByList(allCategories)) if err != nil { - return nil, nil, errs.Or(err, errs.ErrOperationFailed) + return nil, nil, err } return structuredResponse, response, nil