feature restriction supports mcp

This commit is contained in:
MaysWind
2025-07-07 01:21:09 +08:00
parent 6215f489f2
commit 5cb129311a
13 changed files with 118 additions and 56 deletions
+1
View File
@@ -244,6 +244,7 @@ max_user_avatar_size = 1048576
# 10: Export Transactions # 10: Export Transactions
# 11: Clear All Data # 11: Clear All Data
# 12: Sync Application Settings # 12: Sync Application Settings
# 13: MCP (Model Context Protocol) Access
default_feature_restrictions = default_feature_restrictions =
[data] [data]
+63 -2
View File
@@ -7,6 +7,7 @@ import (
"github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs" "github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/log"
"github.com/mayswind/ezbookkeeping/pkg/mcp" "github.com/mayswind/ezbookkeeping/pkg/mcp"
"github.com/mayswind/ezbookkeeping/pkg/services" "github.com/mayswind/ezbookkeeping/pkg/services"
"github.com/mayswind/ezbookkeeping/pkg/settings" "github.com/mayswind/ezbookkeeping/pkg/settings"
@@ -52,6 +53,18 @@ func (a *ModelContextProtocolAPI) InitializeHandler(c *core.WebContext, jsonRPCR
return nil, errs.ErrIncompleteOrIncorrectSubmission return nil, errs.ErrIncompleteOrIncorrectSubmission
} }
uid := c.GetCurrentUid()
user, err := a.users.GetUserById(c, uid)
if err != nil {
log.Warnf(c, "[model_context_protocols.InitializeHandler] failed to get user \"uid:%d\" info, because %s", uid, err.Error())
return nil, errs.ErrUserNotFound
}
if user.FeatureRestriction.Contains(core.USER_FEATURE_RESTRICTION_TYPE_MCP_ACCESS) {
return nil, errs.ErrNotPermittedToPerformThisAction
}
protocolVersion := mcp.MCPProtocolVersion(initRequest.ProtocolVersion) protocolVersion := mcp.MCPProtocolVersion(initRequest.ProtocolVersion)
_, exists := mcp.SupportedMCPVersion[protocolVersion] _, exists := mcp.SupportedMCPVersion[protocolVersion]
@@ -78,6 +91,18 @@ func (a *ModelContextProtocolAPI) InitializeHandler(c *core.WebContext, jsonRPCR
// ListResourcesHandler returns the list of resources for model context protocol // ListResourcesHandler returns the list of resources for model context protocol
func (a *ModelContextProtocolAPI) ListResourcesHandler(c *core.WebContext, jsonRPCRequest *core.JSONRPCRequest) (any, *errs.Error) { func (a *ModelContextProtocolAPI) ListResourcesHandler(c *core.WebContext, jsonRPCRequest *core.JSONRPCRequest) (any, *errs.Error) {
uid := c.GetCurrentUid()
user, err := a.users.GetUserById(c, uid)
if err != nil {
log.Warnf(c, "[model_context_protocols.ListResourcesHandler] failed to get user \"uid:%d\" info, because %s", uid, err.Error())
return nil, errs.ErrUserNotFound
}
if user.FeatureRestriction.Contains(core.USER_FEATURE_RESTRICTION_TYPE_MCP_ACCESS) {
return nil, errs.ErrNotPermittedToPerformThisAction
}
listResourcesResp := mcp.MCPListResourcesResponse{ listResourcesResp := mcp.MCPListResourcesResponse{
Resources: make([]*mcp.MCPResource, 0), Resources: make([]*mcp.MCPResource, 0),
} }
@@ -97,11 +122,35 @@ func (a *ModelContextProtocolAPI) ReadResourceHandler(c *core.WebContext, jsonRP
return nil, errs.ErrIncompleteOrIncorrectSubmission return nil, errs.ErrIncompleteOrIncorrectSubmission
} }
uid := c.GetCurrentUid()
user, err := a.users.GetUserById(c, uid)
if err != nil {
log.Warnf(c, "[model_context_protocols.ReadResourceHandler] failed to get user \"uid:%d\" info, because %s", uid, err.Error())
return nil, errs.ErrUserNotFound
}
if user.FeatureRestriction.Contains(core.USER_FEATURE_RESTRICTION_TYPE_MCP_ACCESS) {
return nil, errs.ErrNotPermittedToPerformThisAction
}
return nil, errs.ErrApiNotFound return nil, errs.ErrApiNotFound
} }
// ListToolsHandler returns the list of tools for model context protocol // ListToolsHandler returns the list of tools for model context protocol
func (a *ModelContextProtocolAPI) ListToolsHandler(c *core.WebContext, jsonRPCRequest *core.JSONRPCRequest) (any, *errs.Error) { func (a *ModelContextProtocolAPI) ListToolsHandler(c *core.WebContext, jsonRPCRequest *core.JSONRPCRequest) (any, *errs.Error) {
uid := c.GetCurrentUid()
user, err := a.users.GetUserById(c, uid)
if err != nil {
log.Warnf(c, "[model_context_protocols.ListToolsHandler] failed to get user \"uid:%d\" info, because %s", uid, err.Error())
return nil, errs.ErrUserNotFound
}
if user.FeatureRestriction.Contains(core.USER_FEATURE_RESTRICTION_TYPE_MCP_ACCESS) {
return nil, errs.ErrNotPermittedToPerformThisAction
}
mcpVersion := a.getMCPVersion(c) mcpVersion := a.getMCPVersion(c)
toolsInfo := mcp.Container.GetMCPTools() toolsInfo := mcp.Container.GetMCPTools()
finalToolsInfos := make([]*mcp.MCPTool, len(toolsInfo)) finalToolsInfos := make([]*mcp.MCPTool, len(toolsInfo))
@@ -128,6 +177,18 @@ func (a *ModelContextProtocolAPI) ListToolsHandler(c *core.WebContext, jsonRPCRe
// CallToolHandler returns the result of calling a specific tool for model context protocol // CallToolHandler returns the result of calling a specific tool for model context protocol
func (a *ModelContextProtocolAPI) CallToolHandler(c *core.WebContext, jsonRPCRequest *core.JSONRPCRequest) (any, *errs.Error) { func (a *ModelContextProtocolAPI) CallToolHandler(c *core.WebContext, jsonRPCRequest *core.JSONRPCRequest) (any, *errs.Error) {
uid := c.GetCurrentUid()
user, err := a.users.GetUserById(c, uid)
if err != nil {
log.Warnf(c, "[model_context_protocols.CallToolHandler] failed to get user \"uid:%d\" info, because %s", uid, err.Error())
return nil, errs.ErrUserNotFound
}
if user.FeatureRestriction.Contains(core.USER_FEATURE_RESTRICTION_TYPE_MCP_ACCESS) {
return nil, errs.ErrNotPermittedToPerformThisAction
}
var callToolReq mcp.MCPCallToolRequest var callToolReq mcp.MCPCallToolRequest
if jsonRPCRequest.Params != nil { if jsonRPCRequest.Params != nil {
@@ -138,10 +199,10 @@ func (a *ModelContextProtocolAPI) CallToolHandler(c *core.WebContext, jsonRPCReq
return nil, errs.ErrIncompleteOrIncorrectSubmission return nil, errs.ErrIncompleteOrIncorrectSubmission
} }
result, err := mcp.Container.HandleTool(c, &callToolReq, a.CurrentConfig(), a) result, err := mcp.Container.HandleTool(c, &callToolReq, user, a.CurrentConfig(), a)
if err != nil { if err != nil {
return nil, err return nil, errs.Or(err, errs.ErrOperationFailed)
} }
return result, nil return result, nil
+4
View File
@@ -99,6 +99,10 @@ func (a *TokensApi) TokenGenerateMCPHandler(c *core.WebContext) (any, *errs.Erro
return nil, errs.ErrUserNotFound return nil, errs.ErrUserNotFound
} }
if user.FeatureRestriction.Contains(core.USER_FEATURE_RESTRICTION_TYPE_MCP_ACCESS) {
return false, errs.ErrNotPermittedToPerformThisAction
}
if !a.users.IsPasswordEqualsUserPassword(generateMCPTokenReq.Password, user) { if !a.users.IsPasswordEqualsUserPassword(generateMCPTokenReq.Password, user) {
return nil, errs.ErrUserPasswordWrong return nil, errs.ErrUserPasswordWrong
} }
+4
View File
@@ -422,6 +422,10 @@ func (l *UserDataCli) CreateNewUserToken(c *core.CliContext, username string, to
var tokenRecord *models.TokenRecord var tokenRecord *models.TokenRecord
if tokenType == "mcp" { if tokenType == "mcp" {
if user.FeatureRestriction.Contains(core.USER_FEATURE_RESTRICTION_TYPE_MCP_ACCESS) {
return nil, "", errs.ErrNotPermittedToPerformThisAction
}
token, tokenRecord, err = l.tokens.CreateMCPTokenViaCli(c, user) token, tokenRecord, err = l.tokens.CreateMCPTokenViaCli(c, user)
} else if tokenType == "normal" { } else if tokenType == "normal" {
token, tokenRecord, err = l.tokens.CreateTokenViaCli(c, user) token, tokenRecord, err = l.tokens.CreateTokenViaCli(c, user)
+4 -1
View File
@@ -88,10 +88,11 @@ const (
USER_FEATURE_RESTRICTION_TYPE_EXPORT_TRANSACTION UserFeatureRestrictionType = 10 USER_FEATURE_RESTRICTION_TYPE_EXPORT_TRANSACTION UserFeatureRestrictionType = 10
USER_FEATURE_RESTRICTION_TYPE_CLEAR_ALL_DATA UserFeatureRestrictionType = 11 USER_FEATURE_RESTRICTION_TYPE_CLEAR_ALL_DATA UserFeatureRestrictionType = 11
USER_FEATURE_RESTRICTION_TYPE_SYNC_APPLICATION_SETTINGS UserFeatureRestrictionType = 12 USER_FEATURE_RESTRICTION_TYPE_SYNC_APPLICATION_SETTINGS UserFeatureRestrictionType = 12
USER_FEATURE_RESTRICTION_TYPE_MCP_ACCESS UserFeatureRestrictionType = 13
) )
const userFeatureRestrictionTypeMinValue UserFeatureRestrictionType = USER_FEATURE_RESTRICTION_TYPE_UPDATE_PASSWORD const userFeatureRestrictionTypeMinValue UserFeatureRestrictionType = USER_FEATURE_RESTRICTION_TYPE_UPDATE_PASSWORD
const userFeatureRestrictionTypeMaxValue UserFeatureRestrictionType = USER_FEATURE_RESTRICTION_TYPE_SYNC_APPLICATION_SETTINGS const userFeatureRestrictionTypeMaxValue UserFeatureRestrictionType = USER_FEATURE_RESTRICTION_TYPE_MCP_ACCESS
// String returns a textual representation of the restriction type of user features // String returns a textual representation of the restriction type of user features
func (t UserFeatureRestrictionType) String() string { func (t UserFeatureRestrictionType) String() string {
@@ -120,6 +121,8 @@ func (t UserFeatureRestrictionType) String() string {
return "Clear All Data" return "Clear All Data"
case USER_FEATURE_RESTRICTION_TYPE_SYNC_APPLICATION_SETTINGS: case USER_FEATURE_RESTRICTION_TYPE_SYNC_APPLICATION_SETTINGS:
return "Sync Application Settings" return "Sync Application Settings"
case USER_FEATURE_RESTRICTION_TYPE_MCP_ACCESS:
return "MCP (Model Context Protocol) Access"
default: default:
return fmt.Sprintf("Invalid(%d)", int(t)) return fmt.Sprintf("Invalid(%d)", int(t))
} }
+8 -18
View File
@@ -63,7 +63,7 @@ func (h *mcpAddTransactionToolHandler) OutputType() reflect.Type {
} }
// Handle processes the MCP call tool request and returns the response // Handle processes the MCP call tool request and returns the response
func (h *mcpAddTransactionToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, *errs.Error) { func (h *mcpAddTransactionToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, user *models.User, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, error) {
var addTransactionRequest MCPAddTransactionRequest var addTransactionRequest MCPAddTransactionRequest
if callToolReq.Arguments != nil { if callToolReq.Arguments != nil {
@@ -84,22 +84,12 @@ func (h *mcpAddTransactionToolHandler) Handle(c *core.WebContext, callToolReq *M
return nil, nil, errs.ErrTransactionHasTooManyTags return nil, nil, errs.ErrTransactionHasTooManyTags
} }
uid := c.GetCurrentUid() uid := user.Uid
user, err := services.GetUserService().GetUserById(c, uid)
if err != nil {
if !errs.IsCustomError(err) {
log.Errorf(c, "[add_transaction.Handle] failed to get user, because %s", err.Error())
}
return nil, nil, errs.ErrUserNotFound
}
allAccounts, err := services.GetAccountService().GetAllAccountsByUid(c, uid) allAccounts, err := services.GetAccountService().GetAllAccountsByUid(c, uid)
if err != nil { if err != nil {
log.Warnf(c, "[add_transaction.Handle] get account error, because %s", err.Error()) log.Warnf(c, "[add_transaction.Handle] get account error, because %s", err.Error())
return nil, nil, errs.Or(err, errs.ErrOperationFailed) return nil, nil, err
} }
accountsMap := services.GetAccountService().GetVisibleAccountNameMapByList(allAccounts) accountsMap := services.GetAccountService().GetVisibleAccountNameMapByList(allAccounts)
@@ -128,7 +118,7 @@ func (h *mcpAddTransactionToolHandler) Handle(c *core.WebContext, callToolReq *M
if err != nil { if err != nil {
log.Warnf(c, "[add_transaction.Handle] get transaction category error, because %s", err.Error()) log.Warnf(c, "[add_transaction.Handle] get transaction category error, because %s", err.Error())
return nil, nil, errs.Or(err, errs.ErrOperationFailed) return nil, nil, err
} }
categoriesMap := services.GetTransactionCategoryService().GetVisibleCategoryNameMapByList(allCategories) categoriesMap := services.GetTransactionCategoryService().GetVisibleCategoryNameMapByList(allCategories)
@@ -146,7 +136,7 @@ func (h *mcpAddTransactionToolHandler) Handle(c *core.WebContext, callToolReq *M
if err != nil { if err != nil {
log.Warnf(c, "[add_transaction.Handle] get transaction tag ids error, because %s", err.Error()) log.Warnf(c, "[add_transaction.Handle] get transaction tag ids error, because %s", err.Error())
return nil, nil, errs.Or(err, errs.ErrOperationFailed) return nil, nil, err
} }
tagMaps := services.GetTransactionTagService().GetTagNameMapByList(allTags) tagMaps := services.GetTransactionTagService().GetTagNameMapByList(allTags)
@@ -173,7 +163,7 @@ func (h *mcpAddTransactionToolHandler) Handle(c *core.WebContext, callToolReq *M
if err != nil { if err != nil {
log.Errorf(c, "[add_transaction.Handle] failed to create transaction \"id:%d\" for user \"uid:%d\", because %s", transaction.TransactionId, uid, err.Error()) log.Errorf(c, "[add_transaction.Handle] failed to create transaction \"id:%d\" for user \"uid:%d\", because %s", transaction.TransactionId, uid, err.Error())
return nil, nil, errs.Or(err, errs.ErrOperationFailed) return nil, nil, err
} }
log.Infof(c, "[add_transaction.Handle] user \"uid:%d\" has created a new transaction \"id:%d\" successfully", uid, transaction.TransactionId) log.Infof(c, "[add_transaction.Handle] user \"uid:%d\" has created a new transaction \"id:%d\" successfully", uid, transaction.TransactionId)
@@ -193,7 +183,7 @@ func (h *mcpAddTransactionToolHandler) Handle(c *core.WebContext, callToolReq *M
structuredResponse, response, err := h.createNewMCPAddTransactionResponse(c, transaction, newAccounts, false) structuredResponse, response, err := h.createNewMCPAddTransactionResponse(c, transaction, newAccounts, false)
if err != nil { if err != nil {
return nil, nil, errs.Or(err, errs.ErrOperationFailed) return nil, nil, err
} }
return structuredResponse, response, nil return structuredResponse, response, nil
@@ -215,7 +205,7 @@ func (h *mcpAddTransactionToolHandler) Handle(c *core.WebContext, callToolReq *M
structuredResponse, response, err := h.createNewMCPAddTransactionResponse(c, transaction, newAccounts, true) structuredResponse, response, err := h.createNewMCPAddTransactionResponse(c, transaction, newAccounts, true)
if err != nil { if err != nil {
return nil, nil, errs.Or(err, errs.ErrOperationFailed) return nil, nil, err
} }
return structuredResponse, response, nil return structuredResponse, response, nil
+2 -2
View File
@@ -4,7 +4,7 @@ import (
"reflect" "reflect"
"github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs" "github.com/mayswind/ezbookkeeping/pkg/models"
"github.com/mayswind/ezbookkeeping/pkg/services" "github.com/mayswind/ezbookkeeping/pkg/services"
"github.com/mayswind/ezbookkeeping/pkg/settings" "github.com/mayswind/ezbookkeeping/pkg/settings"
) )
@@ -33,5 +33,5 @@ type MCPToolHandler[T MCPTextContent | MCPImageContent | MCPAudioContent | MCPRe
OutputType() reflect.Type OutputType() reflect.Type
// Handle processes the MCP call tool request and returns the response // Handle processes the MCP call tool request and returns the response
Handle(*core.WebContext, *MCPCallToolRequest, *settings.Config, MCPAvailableServices) (any, []*T, *errs.Error) Handle(*core.WebContext, *MCPCallToolRequest, *models.User, *settings.Config, MCPAvailableServices) (any, []*T, error)
} }
+9 -8
View File
@@ -6,6 +6,7 @@ import (
"github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs" "github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/models"
"github.com/mayswind/ezbookkeeping/pkg/settings" "github.com/mayswind/ezbookkeeping/pkg/settings"
) )
@@ -34,25 +35,25 @@ func (c *MCPContainer) GetMCPTools() []*MCPTool {
} }
// HandleTool returns the result of the MCP tool handler based on the tool name // HandleTool returns the result of the MCP tool handler based on the tool name
func (c *MCPContainer) HandleTool(ctx *core.WebContext, callToolReq *MCPCallToolRequest, currentConfig *settings.Config, services MCPAvailableServices) (any, *errs.Error) { func (c *MCPContainer) HandleTool(ctx *core.WebContext, callToolReq *MCPCallToolRequest, user *models.User, currentConfig *settings.Config, services MCPAvailableServices) (any, error) {
if handler, exists := c.mcpTextContentTools.Get(callToolReq.Name); exists { if handler, exists := c.mcpTextContentTools.Get(callToolReq.Name); exists {
return handleTool(ctx, handler, currentConfig, services, callToolReq) return handleTool(ctx, handler, currentConfig, services, callToolReq, user)
} }
if handler, exists := c.mcpImageContentTools.Get(callToolReq.Name); exists { if handler, exists := c.mcpImageContentTools.Get(callToolReq.Name); exists {
return handleTool(ctx, handler, currentConfig, services, callToolReq) return handleTool(ctx, handler, currentConfig, services, callToolReq, user)
} }
if handler, exists := c.mcpAudioContentTools.Get(callToolReq.Name); exists { if handler, exists := c.mcpAudioContentTools.Get(callToolReq.Name); exists {
return handleTool(ctx, handler, currentConfig, services, callToolReq) return handleTool(ctx, handler, currentConfig, services, callToolReq, user)
} }
if handler, exists := c.mcpResourceLinkTools.Get(callToolReq.Name); exists { if handler, exists := c.mcpResourceLinkTools.Get(callToolReq.Name); exists {
return handleTool(ctx, handler, currentConfig, services, callToolReq) return handleTool(ctx, handler, currentConfig, services, callToolReq, user)
} }
if handler, exists := c.mcpEmbeddedResourceTools.Get(callToolReq.Name); exists { if handler, exists := c.mcpEmbeddedResourceTools.Get(callToolReq.Name); exists {
return handleTool(ctx, handler, currentConfig, services, callToolReq) return handleTool(ctx, handler, currentConfig, services, callToolReq, user)
} }
return nil, errs.ErrApiNotFound return nil, errs.ErrApiNotFound
@@ -109,8 +110,8 @@ func registerMCPToolHandler[T MCPTextContent | MCPImageContent | MCPAudioContent
c.mcpTools = append(c.mcpTools, createNewMCPToolInfo(handler.Name(), handler)) c.mcpTools = append(c.mcpTools, createNewMCPToolInfo(handler.Name(), handler))
} }
func handleTool[T MCPTextContent | MCPImageContent | MCPAudioContent | MCPResourceLink | MCPEmbeddedResource](ctx *core.WebContext, handler MCPToolHandler[T], currentConfig *settings.Config, services MCPAvailableServices, callToolReq *MCPCallToolRequest) (any, *errs.Error) { func handleTool[T MCPTextContent | MCPImageContent | MCPAudioContent | MCPResourceLink | MCPEmbeddedResource](ctx *core.WebContext, handler MCPToolHandler[T], currentConfig *settings.Config, services MCPAvailableServices, callToolReq *MCPCallToolRequest, user *models.User) (any, error) {
structuredResponse, result, err := handler.Handle(ctx, callToolReq, currentConfig, services) structuredResponse, result, err := handler.Handle(ctx, callToolReq, user, currentConfig, services)
if err != nil { if err != nil {
return nil, errs.Or(err, errs.ErrOperationFailed) return nil, errs.Or(err, errs.ErrOperationFailed)
+4 -5
View File
@@ -5,7 +5,6 @@ import (
"reflect" "reflect"
"github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/log" "github.com/mayswind/ezbookkeeping/pkg/log"
"github.com/mayswind/ezbookkeeping/pkg/models" "github.com/mayswind/ezbookkeeping/pkg/models"
"github.com/mayswind/ezbookkeeping/pkg/settings" "github.com/mayswind/ezbookkeeping/pkg/settings"
@@ -49,19 +48,19 @@ func (h *mcpQueryAllAccountsToolHandler) OutputType() reflect.Type {
} }
// Handle processes the MCP call tool request and returns the response // Handle processes the MCP call tool request and returns the response
func (h *mcpQueryAllAccountsToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, *errs.Error) { func (h *mcpQueryAllAccountsToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, user *models.User, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, error) {
uid := c.GetCurrentUid() uid := user.Uid
accounts, err := services.GetAccountService().GetAllAccountsByUid(c, uid) accounts, err := services.GetAccountService().GetAllAccountsByUid(c, uid)
if err != nil { if err != nil {
log.Errorf(c, "[query_all_accounts.Handle] failed to get all accounts for user \"uid:%d\", because %s", uid, err.Error()) log.Errorf(c, "[query_all_accounts.Handle] failed to get all accounts for user \"uid:%d\", because %s", uid, err.Error())
return nil, nil, errs.Or(err, errs.ErrOperationFailed) return nil, nil, err
} }
structuredResponse, response, err := h.createNewMCPQueryAllAccountsResponse(c, accounts) structuredResponse, response, err := h.createNewMCPQueryAllAccountsResponse(c, accounts)
if err != nil { if err != nil {
return nil, nil, errs.Or(err, errs.ErrOperationFailed) return nil, nil, err
} }
return structuredResponse, response, nil return structuredResponse, response, nil
+4 -5
View File
@@ -5,7 +5,6 @@ import (
"reflect" "reflect"
"github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/log" "github.com/mayswind/ezbookkeeping/pkg/log"
"github.com/mayswind/ezbookkeeping/pkg/models" "github.com/mayswind/ezbookkeeping/pkg/models"
"github.com/mayswind/ezbookkeeping/pkg/settings" "github.com/mayswind/ezbookkeeping/pkg/settings"
@@ -43,19 +42,19 @@ func (h *mcpQueryAllTransactionCategoriesToolHandler) OutputType() reflect.Type
} }
// Handle processes the MCP call tool request and returns the response // Handle processes the MCP call tool request and returns the response
func (h *mcpQueryAllTransactionCategoriesToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, *errs.Error) { func (h *mcpQueryAllTransactionCategoriesToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, user *models.User, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, error) {
uid := c.GetCurrentUid() uid := user.Uid
categories, err := services.GetTransactionCategoryService().GetAllCategoriesByUid(c, uid, 0, -1) categories, err := services.GetTransactionCategoryService().GetAllCategoriesByUid(c, uid, 0, -1)
if err != nil { if err != nil {
log.Errorf(c, "[query_all_transaction_categories.Handle] failed to get categories for user \"uid:%d\", because %s", uid, err.Error()) log.Errorf(c, "[query_all_transaction_categories.Handle] failed to get categories for user \"uid:%d\", because %s", uid, err.Error())
return nil, nil, errs.Or(err, errs.ErrOperationFailed) return nil, nil, err
} }
structuredResponse, response, err := h.createNewMCPQueryAllTransactionCategoriesResponse(c, categories) structuredResponse, response, err := h.createNewMCPQueryAllTransactionCategoriesResponse(c, categories)
if err != nil { if err != nil {
return nil, nil, errs.Or(err, errs.ErrOperationFailed) return nil, nil, err
} }
return structuredResponse, response, nil return structuredResponse, response, nil
+5 -5
View File
@@ -5,8 +5,8 @@ import (
"reflect" "reflect"
"github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/log" "github.com/mayswind/ezbookkeeping/pkg/log"
"github.com/mayswind/ezbookkeeping/pkg/models"
"github.com/mayswind/ezbookkeeping/pkg/settings" "github.com/mayswind/ezbookkeeping/pkg/settings"
) )
@@ -40,13 +40,13 @@ func (h *mcpQueryAllTransactionTagsToolHandler) OutputType() reflect.Type {
} }
// Handle processes the MCP call tool request and returns the response // Handle processes the MCP call tool request and returns the response
func (h *mcpQueryAllTransactionTagsToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, *errs.Error) { func (h *mcpQueryAllTransactionTagsToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, user *models.User, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, error) {
uid := c.GetCurrentUid() uid := user.Uid
tags, err := services.GetTransactionTagService().GetAllTagsByUid(c, uid) tags, err := services.GetTransactionTagService().GetAllTagsByUid(c, uid)
if err != nil { if err != nil {
log.Errorf(c, "[query_all_transaction_tags.Handle] failed to get tags for user \"uid:%d\", because %s", uid, err.Error()) log.Errorf(c, "[query_all_transaction_tags.Handle] failed to get tags for user \"uid:%d\", because %s", uid, err.Error())
return nil, nil, errs.Or(err, errs.ErrOperationFailed) return nil, nil, err
} }
tagNames := make([]string, len(tags)) tagNames := make([]string, len(tags))
@@ -62,7 +62,7 @@ func (h *mcpQueryAllTransactionTagsToolHandler) Handle(c *core.WebContext, callT
content, err := json.Marshal(response) content, err := json.Marshal(response)
if err != nil { if err != nil {
return nil, nil, errs.Or(err, errs.ErrOperationFailed) return nil, nil, err
} }
return response, []*MCPTextContent{ return response, []*MCPTextContent{
+4 -4
View File
@@ -57,7 +57,7 @@ func (h *mcpQueryLatestExchangeRatesToolHandler) OutputType() reflect.Type {
} }
// Handle processes the MCP call tool request and returns the response // Handle processes the MCP call tool request and returns the response
func (h *mcpQueryLatestExchangeRatesToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, *errs.Error) { func (h *mcpQueryLatestExchangeRatesToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, user *models.User, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, error) {
var exchangeRatesRequest MCPQueryExchangeRatesRequest var exchangeRatesRequest MCPQueryExchangeRatesRequest
if callToolReq.Arguments != nil { if callToolReq.Arguments != nil {
@@ -74,16 +74,16 @@ func (h *mcpQueryLatestExchangeRatesToolHandler) Handle(c *core.WebContext, call
return nil, nil, errs.ErrInvalidExchangeRatesDataSource return nil, nil, errs.ErrInvalidExchangeRatesDataSource
} }
exchangeRateResponse, err := dataSource.GetLatestExchangeRates(c, c.GetCurrentUid(), currentConfig) exchangeRateResponse, err := dataSource.GetLatestExchangeRates(c, user.Uid, currentConfig)
if err != nil { if err != nil {
return nil, nil, errs.Or(err, errs.ErrOperationFailed) return nil, nil, err
} }
structuredResponse, response, err := h.createNewMCPQueryExchangeRatesResponse(exchangeRatesRequest.Currencies, exchangeRateResponse) structuredResponse, response, err := h.createNewMCPQueryExchangeRatesResponse(exchangeRatesRequest.Currencies, exchangeRateResponse)
if err != nil { if err != nil {
return nil, nil, errs.Or(err, errs.ErrOperationFailed) return nil, nil, err
} }
return structuredResponse, response, nil return structuredResponse, response, nil
+6 -6
View File
@@ -74,7 +74,7 @@ func (h *mcpQueryTransactionsToolHandler) OutputType() reflect.Type {
} }
// Handle processes the MCP call tool request and returns the response // Handle processes the MCP call tool request and returns the response
func (h *mcpQueryTransactionsToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, *errs.Error) { func (h *mcpQueryTransactionsToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, user *models.User, currentConfig *settings.Config, services MCPAvailableServices) (any, []*MCPTextContent, error) {
var queryTransactionsRequest MCPQueryTransactionsRequest var queryTransactionsRequest MCPQueryTransactionsRequest
if callToolReq.Arguments != nil { if callToolReq.Arguments != nil {
@@ -85,7 +85,7 @@ func (h *mcpQueryTransactionsToolHandler) Handle(c *core.WebContext, callToolReq
return nil, nil, errs.ErrIncompleteOrIncorrectSubmission return nil, nil, errs.ErrIncompleteOrIncorrectSubmission
} }
uid := c.GetCurrentUid() uid := user.Uid
maxTime, err := utils.ParseFromLongDateTimeWithTimezoneRFC3339Format(queryTransactionsRequest.EndTime) maxTime, err := utils.ParseFromLongDateTimeWithTimezoneRFC3339Format(queryTransactionsRequest.EndTime)
if err != nil { if err != nil {
@@ -123,7 +123,7 @@ func (h *mcpQueryTransactionsToolHandler) Handle(c *core.WebContext, callToolReq
if err != nil { if err != nil {
log.Warnf(c, "[add_transaction.Handle] get account error, because %s", err.Error()) log.Warnf(c, "[add_transaction.Handle] get account error, because %s", err.Error())
return nil, nil, errs.Or(err, errs.ErrOperationFailed) return nil, nil, err
} }
accountsMap := services.GetAccountService().GetVisibleAccountNameMapByList(allAccounts) accountsMap := services.GetAccountService().GetVisibleAccountNameMapByList(allAccounts)
@@ -141,7 +141,7 @@ func (h *mcpQueryTransactionsToolHandler) Handle(c *core.WebContext, callToolReq
if err != nil { if err != nil {
log.Warnf(c, "[add_transaction.Handle] get transaction category error, because %s", err.Error()) log.Warnf(c, "[add_transaction.Handle] get transaction category error, because %s", err.Error())
return nil, nil, errs.Or(err, errs.ErrOperationFailed) return nil, nil, err
} }
categoriesMap := services.GetTransactionCategoryService().GetVisibleCategoryNameMapByList(allCategories) categoriesMap := services.GetTransactionCategoryService().GetVisibleCategoryNameMapByList(allCategories)
@@ -159,14 +159,14 @@ func (h *mcpQueryTransactionsToolHandler) Handle(c *core.WebContext, callToolReq
if err != nil { if err != nil {
log.Errorf(c, "[transactions.TransactionListHandler] failed to get transaction count for user \"uid:%d\", because %s", uid, err.Error()) log.Errorf(c, "[transactions.TransactionListHandler] failed to get transaction count for user \"uid:%d\", because %s", uid, err.Error())
return nil, nil, errs.Or(err, errs.ErrOperationFailed) return nil, nil, err
} }
transactions, err := services.GetTransactionService().GetTransactionsByMaxTime(c, uid, maxTransactionTime, minTransactionTime, transactionType, filterCategoryIds, filterAccountIds, nil, false, models.TRANSACTION_TAG_FILTER_HAS_ANY, "", queryTransactionsRequest.Keyword, queryTransactionsRequest.Page, queryTransactionsRequest.Count, false, true) transactions, err := services.GetTransactionService().GetTransactionsByMaxTime(c, uid, maxTransactionTime, minTransactionTime, transactionType, filterCategoryIds, filterAccountIds, nil, false, models.TRANSACTION_TAG_FILTER_HAS_ANY, "", queryTransactionsRequest.Keyword, queryTransactionsRequest.Page, queryTransactionsRequest.Count, false, true)
structuredResponse, response, err := h.createNewMCPQueryTransactionsResponse(c, &queryTransactionsRequest, transactions, totalCount, services.GetAccountService().GetAccountMapByList(allAccounts), services.GetTransactionCategoryService().GetCategoryMapByList(allCategories)) structuredResponse, response, err := h.createNewMCPQueryTransactionsResponse(c, &queryTransactionsRequest, transactions, totalCount, services.GetAccountService().GetAccountMapByList(allAccounts), services.GetTransactionCategoryService().GetCategoryMapByList(allCategories))
if err != nil { if err != nil {
return nil, nil, errs.Or(err, errs.ErrOperationFailed) return nil, nil, err
} }
return structuredResponse, response, nil return structuredResponse, response, nil