From 8dce0f2d6a9d9a6cf70d65c6ff824a92e6e42542 Mon Sep 17 00:00:00 2001 From: MaysWind Date: Sun, 6 Jul 2025 03:02:19 +0800 Subject: [PATCH] add mcp (Model Context Protocol) support --- cmd/webserver.go | 60 +++++++ conf/ezbookkeeping.ini | 7 + go.mod | 1 + go.sum | 2 + pkg/api/model_context_protocols.go | 161 ++++++++++++++++++ pkg/core/handler.go | 3 + pkg/core/ip_pattern.go | 177 ++++++++++++++++++++ pkg/core/ip_pattern_test.go | 135 +++++++++++++++ pkg/core/json_rpc.go | 95 +++++++++++ pkg/errs/global.go | 1 + pkg/errs/setting.go | 1 + pkg/mcp/all_handlers.go | 12 ++ pkg/mcp/exchange_rate.go | 124 ++++++++++++++ pkg/mcp/handler.go | 187 +++++++++++++++++++++ pkg/mcp/model_context_protocol.go | 218 +++++++++++++++++++++++++ pkg/middlewares/mcp_server_ip_limit.go | 27 +++ pkg/settings/setting.go | 39 +++++ pkg/utils/api.go | 41 +++++ pkg/utils/datetimes.go | 41 ++++- pkg/utils/datetimes_test.go | 28 ++++ src/locales/de.json | 3 +- src/locales/en.json | 3 +- src/locales/es.json | 3 +- src/locales/it.json | 3 +- src/locales/ja.json | 3 +- src/locales/pt_BR.json | 3 +- src/locales/ru.json | 3 +- src/locales/uk.json | 3 +- src/locales/vi.json | 3 +- src/locales/zh_Hans.json | 3 +- src/locales/zh_Hant.json | 3 +- third-party-dependencies.json | 6 + 32 files changed, 1379 insertions(+), 20 deletions(-) create mode 100644 pkg/api/model_context_protocols.go create mode 100644 pkg/core/ip_pattern.go create mode 100644 pkg/core/ip_pattern_test.go create mode 100644 pkg/core/json_rpc.go create mode 100644 pkg/mcp/all_handlers.go create mode 100644 pkg/mcp/exchange_rate.go create mode 100644 pkg/mcp/handler.go create mode 100644 pkg/mcp/model_context_protocol.go create mode 100644 pkg/middlewares/mcp_server_ip_limit.go diff --git a/cmd/webserver.go b/cmd/webserver.go index 97d6e2d0..a14ce483 100644 --- a/cmd/webserver.go +++ b/cmd/webserver.go @@ -2,6 +2,7 @@ package cmd import ( "fmt" + "net/http" "path/filepath" "time" @@ -212,6 +213,27 @@ func startWebServer(c *core.CliContext) error { qrCodeRoute.GET("/mobile_url.png", bindCachedImage(api.QrCodes.MobileUrlQrCodeHandler, qrCodeCacheStore)) } + if config.EnableMCPServer { + mcpRoute := router.Group("/mcp") + mcpRoute.Use(bindMiddleware(middlewares.RequestId(config))) + mcpRoute.Use(bindMiddleware(middlewares.RequestLog)) + mcpRoute.Use(bindMiddleware(middlewares.MCPServerIpLimit(config))) + mcpRoute.Use(bindMiddleware(middlewares.JWTAuthorization)) + { + mcpRoute.POST("", bindJSONRPCApi(map[string]core.JSONRPCApiHandlerFunc{ + "initialize": api.ModelContextProtocols.InitializeHandler, + "resources/list": api.ModelContextProtocols.ListResourcesHandler, + "resources/read": api.ModelContextProtocols.ReadResourceHandler, + "tools/list": api.ModelContextProtocols.ListToolsHandler, + "tools/call": api.ModelContextProtocols.CallToolHandler, + "ping": api.ModelContextProtocols.PingHandler, + }, map[string]int{ + "notifications/initialized": http.StatusAccepted, + })) + mcpRoute.GET("", bindApi(api.Default.MethodNotAllowed)) + } + } + apiRoute := router.Group("/api") apiRoute.Use(bindMiddleware(middlewares.RequestId(config))) @@ -432,6 +454,44 @@ func bindApiWithTokenUpdate(fn core.ApiHandlerFunc, config *settings.Config) gin } } +func bindJSONRPCApi(fns map[string]core.JSONRPCApiHandlerFunc, skipMethods map[string]int) gin.HandlerFunc { + return func(ginCtx *gin.Context) { + c := core.WrapWebContext(ginCtx) + + var jsonRPCRequest core.JSONRPCRequest + reqErr := c.ShouldBindBodyWithJSON(&jsonRPCRequest) + + if reqErr != nil { + utils.PrintJSONRPCErrorResult(c, nil, errs.NewIncompleteOrIncorrectSubmissionError(reqErr)) + return + } + + if skipMethods != nil { + httpStatusCode, exists := skipMethods[jsonRPCRequest.Method] + + if exists { + c.AbortWithStatus(httpStatusCode) + return + } + } + + fn, exists := fns[jsonRPCRequest.Method] + + if !exists { + utils.PrintJSONRPCErrorResult(c, &jsonRPCRequest, errs.ErrApiNotFound) + return + } + + result, err := fn(c, &jsonRPCRequest) + + if err != nil { + utils.PrintJSONRPCErrorResult(c, &jsonRPCRequest, err) + } else { + utils.PrintJSONRPCSuccessResult(c, &jsonRPCRequest, result) + } + } +} + func bindEventStreamApi(fn core.EventStreamApiHandlerFunc) gin.HandlerFunc { return func(ginCtx *gin.Context) { c := core.WrapWebContext(ginCtx) diff --git a/conf/ezbookkeeping.ini b/conf/ezbookkeeping.ini index ae734f8e..7408e690 100644 --- a/conf/ezbookkeeping.ini +++ b/conf/ezbookkeeping.ini @@ -37,6 +37,13 @@ enable_gzip = false # Set to true to log each request and execution time log_request = true +[mcp] +# Set to true to enable MCP (Model Context Protocol) server (via http / https web server) for AI/LLM access +enable_mcp = false + +# MCP server allowed remote IPs, a comma-separated list of allowed remote IPs (asterisk * for any addresses, e.g. 192.168.1.* means any IPs in the 192.168.1.x subnet), leave blank to allow all remote IPs +mcp_allowed_remote_ips = + [database] # Either "mysql", "postgres" or "sqlite3" type = sqlite3 diff --git a/go.mod b/go.mod index 51cb9c38..4bf9bb8c 100644 --- a/go.mod +++ b/go.mod @@ -57,6 +57,7 @@ require ( github.com/golang/snappy v0.0.4 // indirect github.com/gomodule/redigo v1.9.2 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect github.com/jonboulle/clockwork v0.5.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.18.0 // indirect diff --git a/go.sum b/go.sum index b6173c21..4ec754af 100644 --- a/go.sum +++ b/go.sum @@ -76,6 +76,8 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/jonboulle/clockwork v0.5.0 h1:Hyh9A8u51kptdkR+cqRpT1EebBwTn1oK9YfGYbdFz6I= github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= diff --git a/pkg/api/model_context_protocols.go b/pkg/api/model_context_protocols.go new file mode 100644 index 00000000..4925026e --- /dev/null +++ b/pkg/api/model_context_protocols.go @@ -0,0 +1,161 @@ +package api + +import ( + "encoding/json" + + "github.com/gin-gonic/gin" + + "github.com/mayswind/ezbookkeeping/pkg/core" + "github.com/mayswind/ezbookkeeping/pkg/errs" + "github.com/mayswind/ezbookkeeping/pkg/mcp" + "github.com/mayswind/ezbookkeeping/pkg/services" + "github.com/mayswind/ezbookkeeping/pkg/settings" +) + +const mcpServerName = "ezBookkeeping-mcp" + +// ModelContextProtocolAPI represents model context protocol api +type ModelContextProtocolAPI struct { + ApiUsingConfig + transactions *services.TransactionService + transactionCategories *services.TransactionCategoryService + transactionTags *services.TransactionTagService + accounts *services.AccountService + users *services.UserService + userCustomExchangeRates *services.UserCustomExchangeRatesService +} + +// Initialize a model context protocol api singleton instance +var ( + ModelContextProtocols = &ModelContextProtocolAPI{ + ApiUsingConfig: ApiUsingConfig{ + container: settings.Container, + }, + transactions: services.Transactions, + transactionCategories: services.TransactionCategories, + transactionTags: services.TransactionTags, + accounts: services.Accounts, + users: services.Users, + userCustomExchangeRates: services.UserCustomExchangeRates, + } +) + +// InitializeHandler returns the initialize response for model context protocol +func (a *ModelContextProtocolAPI) InitializeHandler(c *core.WebContext, jsonRPCRequest *core.JSONRPCRequest) (any, *errs.Error) { + var initRequest mcp.MCPInitializeRequest + + if jsonRPCRequest.Params != nil { + if err := json.Unmarshal(jsonRPCRequest.Params, &initRequest); err != nil { + return nil, errs.NewIncompleteOrIncorrectSubmissionError(err) + } + } else { + return nil, errs.ErrIncompleteOrIncorrectSubmission + } + + protocolVersion := mcp.MCPProtocolVersion(initRequest.ProtocolVersion) + _, exists := mcp.SupportedMCPVersion[protocolVersion] + + if !exists { + protocolVersion = mcp.LatestSupportedMCPVersion + } + + initResp := mcp.MCPInitializeResponse{ + ProtocolVersion: string(protocolVersion), + Capabilities: &mcp.MCPCapabilities{ + Tools: &mcp.MCPToolCapabilities{ + ListChanged: false, + }, + }, + ServerInfo: &mcp.MCPImplementation{ + Name: mcpServerName, + Title: a.CurrentConfig().AppName, + Version: settings.Version, + }, + } + + return initResp, nil +} + +// ListResourcesHandler returns the list of resources for model context protocol +func (a *ModelContextProtocolAPI) ListResourcesHandler(c *core.WebContext, jsonRPCRequest *core.JSONRPCRequest) (any, *errs.Error) { + listResourcesResp := mcp.MCPListResourcesResponse{ + Resources: make([]*mcp.MCPResource, 0), + } + + return listResourcesResp, nil +} + +// ReadResourceHandler returns the resource details for a specific resource in model context protocol +func (a *ModelContextProtocolAPI) ReadResourceHandler(c *core.WebContext, jsonRPCRequest *core.JSONRPCRequest) (any, *errs.Error) { + var readResourceReq mcp.MCPReadResourceRequest + + if jsonRPCRequest.Params != nil { + if err := json.Unmarshal(jsonRPCRequest.Params, &readResourceReq); err != nil { + return nil, errs.NewIncompleteOrIncorrectSubmissionError(err) + } + } else { + return nil, errs.ErrIncompleteOrIncorrectSubmission + } + + return nil, errs.ErrApiNotFound +} + +// ListToolsHandler returns the list of tools for model context protocol +func (a *ModelContextProtocolAPI) ListToolsHandler(c *core.WebContext, jsonRPCRequest *core.JSONRPCRequest) (any, *errs.Error) { + listToolsResp := mcp.MCPListToolsResponse{ + Tools: mcp.AllMCPToolInfos, + } + + return listToolsResp, nil +} + +// CallToolHandler returns the result of calling a specific tool for model context protocol +func (a *ModelContextProtocolAPI) CallToolHandler(c *core.WebContext, jsonRPCRequest *core.JSONRPCRequest) (any, *errs.Error) { + var callToolReq mcp.MCPCallToolRequest + + if jsonRPCRequest.Params != nil { + if err := json.Unmarshal(jsonRPCRequest.Params, &callToolReq); err != nil { + return nil, errs.NewIncompleteOrIncorrectSubmissionError(err) + } + } else { + return nil, errs.ErrIncompleteOrIncorrectSubmission + } + + result, err := mcp.MCPToolHandle(c, &callToolReq, a.CurrentConfig(), a) + + if err != nil { + return nil, err + } + + return result, nil +} + +// PingHandler return the ping response for model context protocol +func (a *ModelContextProtocolAPI) PingHandler(c *core.WebContext, jsonRPCRequest *core.JSONRPCRequest) (any, *errs.Error) { + return gin.H{}, nil +} + +// GetTransactionService implements the MCPAvailableServices interface +func (a *ModelContextProtocolAPI) GetTransactionService() *services.TransactionService { + return a.transactions +} + +// GetUserCustomExchangeRatesService implements the MCPAvailableServices interface +func (a *ModelContextProtocolAPI) GetTransactionCategoryService() *services.TransactionCategoryService { + return a.transactionCategories +} + +// GetTransactionTagService implements the MCPAvailableServices interface +func (a *ModelContextProtocolAPI) GetTransactionTagService() *services.TransactionTagService { + return a.transactionTags +} + +// GetAccountService implements the MCPAvailableServices interface +func (a *ModelContextProtocolAPI) GetAccountService() *services.AccountService { + return a.accounts +} + +// GetUserCustomExchangeRatesService implements the MCPAvailableServices interface +func (a *ModelContextProtocolAPI) GetUserService() *services.UserService { + return a.users +} diff --git a/pkg/core/handler.go b/pkg/core/handler.go index 2f6a1358..4e363ef2 100644 --- a/pkg/core/handler.go +++ b/pkg/core/handler.go @@ -15,6 +15,9 @@ type MiddlewareHandlerFunc func(*WebContext) // ApiHandlerFunc represents the api handler function type ApiHandlerFunc func(*WebContext) (any, *errs.Error) +// JSONRPCApiHandlerFunc represents the api handler function +type JSONRPCApiHandlerFunc func(*WebContext, *JSONRPCRequest) (any, *errs.Error) + // EventStreamApiHandlerFunc represents the event stream api handler function type EventStreamApiHandlerFunc func(*WebContext) *errs.Error diff --git a/pkg/core/ip_pattern.go b/pkg/core/ip_pattern.go new file mode 100644 index 00000000..6c85b83f --- /dev/null +++ b/pkg/core/ip_pattern.go @@ -0,0 +1,177 @@ +package core + +import ( + "regexp" + "strconv" + "strings" + + "github.com/mayswind/ezbookkeeping/pkg/errs" +) + +// IPPattern represents a pattern for matching IP addresses, either IPv4 or IPv6 +type IPPattern struct { + Pattern string + regex *regexp.Regexp +} + +// Match returns if the given IP address matches the pattern +func (p *IPPattern) Match(ip string) bool { + if p.regex == nil { + return false + } + + return p.regex.MatchString(ip) +} + +// GobEncode returns the encoded data for this IP pattern +func (p *IPPattern) GobEncode() ([]byte, error) { + return []byte(p.Pattern), nil +} + +// GobDecode decodes the data into the IP pattern +func (p *IPPattern) GobDecode(data []byte) error { + pattern := string(data) + + if pattern == "" { + p.Pattern = "" + p.regex = nil + return nil + } + + newPattern, err := ParseIPPattern(pattern) + + if err != nil { + return err + } + + p.Pattern = newPattern.Pattern + p.regex = newPattern.regex + return nil +} + +// ParseIPPattern parses the given IP address pattern and returns an IPPattern object +func ParseIPPattern(ipPattern string) (*IPPattern, error) { + if ipPattern == "" { + return nil, nil + } + + hasDot := false + hasSemicolon := false + + for i := 0; i < len(ipPattern); i++ { + ch := rune(ipPattern[i]) + + if ch == '.' { // may be IPv4 + if hasSemicolon { + return nil, errs.ErrInvalidIpAddressPattern + } + hasDot = true + } else if ch == ':' { // may be IPv6 + if hasDot { + return nil, errs.ErrInvalidIpAddressPattern + } + hasSemicolon = true + } + } + + if hasDot { + return ParseIPv4Pattern(ipPattern) + } else if hasSemicolon { + return ParseIPv6Pattern(ipPattern) + } else { + return nil, errs.ErrInvalidIpAddressPattern + } +} + +// ParseIPv4Pattern parses the given IPv4 address pattern and returns an IPPattern object +func ParseIPv4Pattern(ipPattern string) (*IPPattern, error) { + items := strings.Split(ipPattern, ".") + + if len(items) != 4 { + return nil, errs.ErrInvalidIpAddressPattern + } + + regexBuilder := strings.Builder{} + regexBuilder.WriteRune('^') + + for i := 0; i < len(items); i++ { + item := strings.TrimSpace(items[i]) + + if item == "*" { + regexBuilder.WriteString("[0-9]{1,3}") + } else if item == "" { + return nil, errs.ErrInvalidIpAddressPattern + } else { + num, err := strconv.Atoi(item) + + if err != nil || num < 0 || num > 255 { + return nil, errs.ErrInvalidIpAddressPattern + } + + regexBuilder.WriteString(item) + } + + if i < len(items)-1 { + regexBuilder.WriteRune('\\') + regexBuilder.WriteRune('.') + } + } + + regexBuilder.WriteRune('$') + regex, err := regexp.Compile(regexBuilder.String()) + + if err != nil { + return nil, errs.ErrInvalidIpAddressPattern + } + + return &IPPattern{ + Pattern: ipPattern, + regex: regex, + }, nil +} + +// ParseIPv6Pattern parses the given IPv6 address pattern and returns an IPPattern object +func ParseIPv6Pattern(ipPattern string) (*IPPattern, error) { + items := strings.Split(ipPattern, ":") + + if len(items) < 2 || len(items) > 8 { + return nil, errs.ErrInvalidIpAddressPattern + } + + regexBuilder := strings.Builder{} + regexBuilder.WriteRune('^') + + for i := 0; i < len(items); i++ { + item := strings.TrimSpace(items[i]) + + if item == "*" { + regexBuilder.WriteString("[0-9a-fA-F]{1,4}") + } else if i < len(items)-1 && item == "" { + // Do Nothing + } else { + num, err := strconv.ParseInt(item, 16, 32) + + if err != nil || num < 0 || num > 0xFFFF { + return nil, errs.ErrInvalidIpAddressPattern + } + + regexBuilder.WriteString(item) + } + + if i < len(items)-1 { + regexBuilder.WriteRune(':') + } + } + + regexBuilder.WriteRune('$') + regex, err := regexp.Compile(regexBuilder.String()) + + if err != nil { + return nil, errs.ErrInvalidIpAddressPattern + } + + return &IPPattern{ + Pattern: ipPattern, + regex: regex, + }, nil +} diff --git a/pkg/core/ip_pattern_test.go b/pkg/core/ip_pattern_test.go new file mode 100644 index 00000000..36bb83a3 --- /dev/null +++ b/pkg/core/ip_pattern_test.go @@ -0,0 +1,135 @@ +package core + +import ( + "bytes" + "encoding/gob" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/mayswind/ezbookkeeping/pkg/errs" +) + +func TestIPPattern_GobEncode(t *testing.T) { + pattern, err := ParseIPPattern("192.168.1.*") + assert.Nil(t, err) + assert.NotNil(t, pattern) + + var buf bytes.Buffer + err = gob.NewEncoder(&buf).Encode(pattern) + assert.Nil(t, err) + + newPattern := &IPPattern{} + err = gob.NewDecoder(bytes.NewBuffer(buf.Bytes())).Decode(newPattern) + assert.Nil(t, err) + assert.NotNil(t, newPattern) + + assert.Equal(t, pattern.Pattern, newPattern.Pattern) + assert.Equal(t, pattern.regex.String(), newPattern.regex.String()) + + assert.True(t, newPattern.Match("192.168.1.1")) + assert.True(t, newPattern.Match("192.168.1.255")) +} + +func TestParseIPPattern(t *testing.T) { + pattern, err := ParseIPPattern("") + assert.Nil(t, err) + assert.Nil(t, pattern) + + pattern, err = ParseIPPattern("invalid") + assert.Equal(t, errs.ErrInvalidIpAddressPattern, err) + assert.Nil(t, pattern) + + pattern, err = ParseIPPattern("192.1:2:3.4") + assert.Equal(t, errs.ErrInvalidIpAddressPattern, err) + assert.Nil(t, pattern) + + pattern, err = ParseIPPattern("0:0:0:0:0:0:1.2.3.4") // not support IPv6 with embedded IPv4 + assert.Equal(t, errs.ErrInvalidIpAddressPattern, err) + assert.Nil(t, pattern) + + pattern, err = ParseIPPattern("192.168.1.*") + assert.Nil(t, err) + assert.NotNil(t, pattern) + assert.True(t, pattern.Match("192.168.1.1")) + assert.True(t, pattern.Match("192.168.1.255")) + assert.False(t, pattern.Match("192.168.2.1")) + + pattern, err = ParseIPPattern("2001:db8::*") + assert.Nil(t, err) + assert.NotNil(t, pattern) + assert.True(t, pattern.Match("2001:db8::1")) + assert.True(t, pattern.Match("2001:db8::ffff")) + assert.False(t, pattern.Match("2001:db9::1")) +} + +func TestParseIPv4Pattern(t *testing.T) { + pattern, err := ParseIPv4Pattern("192.168.1.1") + assert.Nil(t, err) + assert.NotNil(t, pattern) + assert.True(t, pattern.Match("192.168.1.1")) + assert.False(t, pattern.Match("192.168.1.2")) + + pattern, err = ParseIPv4Pattern("192.168.*.1") + assert.Nil(t, err) + assert.NotNil(t, pattern) + assert.True(t, pattern.Match("192.168.1.1")) + assert.True(t, pattern.Match("192.168.255.1")) + assert.False(t, pattern.Match("192.168.1.2")) + + pattern, err = ParseIPv4Pattern("*.*.*.*") + assert.Nil(t, err) + assert.NotNil(t, pattern) + assert.True(t, pattern.Match("0.0.0.0")) + assert.True(t, pattern.Match("255.255.255.255")) + + pattern, err = ParseIPv4Pattern("256.256.256.256") + assert.Equal(t, errs.ErrInvalidIpAddressPattern, err) + assert.Nil(t, pattern) + + pattern, err = ParseIPv4Pattern("1.2.3") + assert.Equal(t, errs.ErrInvalidIpAddressPattern, err) + assert.Nil(t, pattern) + + pattern, err = ParseIPv4Pattern("1.2.3.4.5") + assert.Equal(t, errs.ErrInvalidIpAddressPattern, err) + assert.Nil(t, pattern) + + pattern, err = ParseIPv4Pattern("a.b.c.d") + assert.Equal(t, errs.ErrInvalidIpAddressPattern, err) + assert.Nil(t, pattern) +} + +func TestParseIPv6Pattern(t *testing.T) { + pattern, err := ParseIPv6Pattern("2001:db8:85a3:8d3:1319:8a2e:370:7348") + assert.Nil(t, err) + assert.NotNil(t, pattern) + assert.True(t, pattern.Match("2001:db8:85a3:8d3:1319:8a2e:370:7348")) + assert.False(t, pattern.Match("2001:db8:85a3:8d3:1319:8a2e:370:7349")) + + pattern, err = ParseIPv6Pattern("2001:db8::*") + assert.Nil(t, err) + assert.NotNil(t, pattern) + assert.True(t, pattern.Match("2001:db8::0")) + assert.True(t, pattern.Match("2001:db8::ffff")) + assert.False(t, pattern.Match("2001:db9::0")) + + pattern, err = ParseIPv6Pattern("::*") + assert.Nil(t, err) + assert.NotNil(t, pattern) + assert.True(t, pattern.Match("::1")) + assert.True(t, pattern.Match("::2")) + assert.False(t, pattern.Match(":1:1")) + + pattern, err = ParseIPv6Pattern("2001:db8:85a3:8d3:1319:8a2e:370:7348:extra") + assert.Equal(t, errs.ErrInvalidIpAddressPattern, err) + assert.Nil(t, pattern) + + pattern, err = ParseIPv6Pattern("g001:db8:85a3:8d3") + assert.Equal(t, errs.ErrInvalidIpAddressPattern, err) + assert.Nil(t, pattern) + + pattern, err = ParseIPv6Pattern("2001:db8:") + assert.Equal(t, errs.ErrInvalidIpAddressPattern, err) + assert.Nil(t, pattern) +} diff --git a/pkg/core/json_rpc.go b/pkg/core/json_rpc.go new file mode 100644 index 00000000..ce554866 --- /dev/null +++ b/pkg/core/json_rpc.go @@ -0,0 +1,95 @@ +package core + +import "encoding/json" + +// JSONRPCVersion defines the version of JSON-RPC protocol +const JSONRPCVersion = "2.0" + +// JSONRPCRequest represents the JSON-RPC 2.0 request +type JSONRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params json.RawMessage `json:"params,omitempty"` + ID any `json:"id,omitempty"` +} + +// JSONRPCResponse represents the JSON-RPC 2.0 response +type JSONRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + Result any `json:"result,omitempty"` + Error *JSONRPCError `json:"error,omitempty"` + ID any `json:"id,omitempty"` +} + +// JSONRPCError represents the JSON-RPC 2.0 error object +type JSONRPCError struct { + Code int `json:"code"` + Message string `json:"message"` + Data any `json:"data,omitempty"` +} + +// JSONRPCParseError represents the "Parse error" in JSON-RPC 2.0 +var JSONRPCParseError = &JSONRPCError{ + Code: -32700, + Message: "Parse error", + Data: nil, +} + +// JSONRPCMethodNotFoundError represents the "Method not found" error in JSON-RPC 2.0 +var JSONRPCMethodNotFoundError = &JSONRPCError{ + Code: -32601, + Message: "Method not found", + Data: nil, +} + +// JSONRPCInvalidParamsError represents the "Invalid params" error in JSON-RPC 2.0 +var JSONRPCInvalidParamsError = &JSONRPCError{ + Code: -32602, + Message: "Invalid params", + Data: nil, +} + +// JSONRPCInternalError represents the "Internal error" in JSON-RPC 2.0 +var JSONRPCInternalError = &JSONRPCError{ + Code: -32603, + Message: "Internal error", + Data: nil, +} + +// NewJSONRPCResponse creates a new JSON-RPC response with the result +func NewJSONRPCResponse(id any, result any) *JSONRPCResponse { + return &JSONRPCResponse{ + JSONRPC: JSONRPCVersion, + Result: result, + Error: nil, + ID: id, + } +} + +// NewJSONRPCErrorResponse creates a new JSON-RPC error response +func NewJSONRPCErrorResponse(id any, err *JSONRPCError) *JSONRPCResponse { + return &JSONRPCResponse{ + JSONRPC: JSONRPCVersion, + Result: nil, + Error: &JSONRPCError{ + Code: err.Code, + Message: err.Message, + Data: nil, + }, + ID: id, + } +} + +// NewJSONRPCErrorResponseWithCause creates a new JSON-RPC error response +func NewJSONRPCErrorResponseWithCause(id any, err *JSONRPCError, cause string) *JSONRPCResponse { + return &JSONRPCResponse{ + JSONRPC: JSONRPCVersion, + Result: nil, + Error: &JSONRPCError{ + Code: err.Code, + Message: err.Message, + Data: cause, + }, + ID: id, + } +} diff --git a/pkg/errs/global.go b/pkg/errs/global.go index 4b0cd616..9d2690a4 100644 --- a/pkg/errs/global.go +++ b/pkg/errs/global.go @@ -27,6 +27,7 @@ var ( ErrExceedMaxUploadFileSize = NewNormalError(NormalSubcategoryGlobal, 17, http.StatusBadRequest, "uploaded file size exceeds the maximum allowed size") ErrFailureCountLimitReached = NewNormalError(NormalSubcategoryGlobal, 18, http.StatusBadRequest, "failure count exceeded maximum limit") ErrRepeatedRequest = NewNormalError(NormalSubcategoryGlobal, 19, http.StatusBadRequest, "repeated request") + ErrIPForbidden = NewNormalError(NormalSubcategoryGlobal, 20, http.StatusBadRequest, "ip address is forbidden to access this resource") ) // GetParameterInvalidMessage returns specific error message for invalid parameter error diff --git a/pkg/errs/setting.go b/pkg/errs/setting.go index 7dcc0c9d..cfda4b12 100644 --- a/pkg/errs/setting.go +++ b/pkg/errs/setting.go @@ -23,4 +23,5 @@ var ( ErrInvalidAmapSecurityVerificationMethod = NewSystemError(SystemSubcategorySetting, 16, http.StatusInternalServerError, "invalid amap security verification method") ErrInvalidPasswordResetTokenExpiredTime = NewSystemError(SystemSubcategorySetting, 17, http.StatusInternalServerError, "invalid password reset token expired time") ErrInvalidExchangeRatesDataSource = NewSystemError(SystemSubcategorySetting, 18, http.StatusInternalServerError, "invalid exchange rates data source") + ErrInvalidIpAddressPattern = NewSystemError(SystemSubcategorySetting, 19, http.StatusInternalServerError, "invalid ip address pattern") ) diff --git a/pkg/mcp/all_handlers.go b/pkg/mcp/all_handlers.go new file mode 100644 index 00000000..7ce653cc --- /dev/null +++ b/pkg/mcp/all_handlers.go @@ -0,0 +1,12 @@ +package mcp + +var mcpTextContentTools = map[string]MCPToolHandler[MCPTextContent]{ + "query_latest_exchange_rates": MCPQueryLatestExchangeRatesRequestToolHandler, +} + +var mcpImageContentTools = map[string]MCPToolHandler[MCPImageContent]{} +var mcpAudioContentTools = map[string]MCPToolHandler[MCPAudioContent]{} +var mcpResourceLinkTools = map[string]MCPToolHandler[MCPResourceLink]{} +var mcpEmbeddedResourceTools = map[string]MCPToolHandler[MCPEmbeddedResource]{} + +var AllMCPToolInfos = GetAllMCPToolInfos() diff --git a/pkg/mcp/exchange_rate.go b/pkg/mcp/exchange_rate.go new file mode 100644 index 00000000..5a506dd0 --- /dev/null +++ b/pkg/mcp/exchange_rate.go @@ -0,0 +1,124 @@ +package mcp + +import ( + "encoding/json" + "reflect" + "strings" + "time" + + "github.com/mayswind/ezbookkeeping/pkg/core" + "github.com/mayswind/ezbookkeeping/pkg/errs" + "github.com/mayswind/ezbookkeeping/pkg/exchangerates" + "github.com/mayswind/ezbookkeeping/pkg/models" + "github.com/mayswind/ezbookkeeping/pkg/settings" + "github.com/mayswind/ezbookkeeping/pkg/utils" +) + +// MCPQueryExchangeRatesRequest represents all parameters of the query exchange rates request +type MCPQueryExchangeRatesRequest struct { + Currencies string `json:"currencies" jsonschema:"required,description=Comma-separated list of currencies to query exchange rates for (e.g. USD,CNY,EUR)"` +} + +// MCPQueryExchangeRatesResponse represents the response structure for querying exchange rates +type MCPQueryExchangeRatesResponse struct { + BaseCurrency string `json:"base_currency" jsonschema_description:"Base currency code (e.g. USD)"` + UpdateTime string `json:"update_time" jsonschema_description:"Last update time of the exchange rates in RFC 3339 format (e.g. '2023-01-01T12:00:00Z')"` + Rates []*MCPQueryExchangeRateInfo `json:"rates" jsonschema_description:"Exchange rates for the specified currencies"` +} + +// MCPQueryExchangeRateInfo defines the structure of exchange rate information for a specific currency +type MCPQueryExchangeRateInfo struct { + Currency string `json:"currency" jsonschema_description:"Currency code (e.g. USD)"` + Rate string `json:"rate" jsonschema_description:"The amount of the base currency that can be exchanged for 1 of this currency"` +} + +type mcpQueryLatestExchangeRatesRequestToolHandler struct{} + +var MCPQueryLatestExchangeRatesRequestToolHandler = &mcpQueryLatestExchangeRatesRequestToolHandler{} + +// Description returns the description of the MCP tool +func (h *mcpQueryLatestExchangeRatesRequestToolHandler) Description() string { + return "Query latest exchange rates with specified currencies." +} + +// InputType returns the input type for the MCP tool request +func (h *mcpQueryLatestExchangeRatesRequestToolHandler) InputType() reflect.Type { + return reflect.TypeOf(&MCPQueryExchangeRatesRequest{}) +} + +// OutputType returns the output type for the MCP tool response +func (h *mcpQueryLatestExchangeRatesRequestToolHandler) OutputType() reflect.Type { + return reflect.TypeOf(&MCPQueryExchangeRatesResponse{}) +} + +// Handle processes the MCP call tool request and returns the response +func (h *mcpQueryLatestExchangeRatesRequestToolHandler) Handle(c *core.WebContext, callToolReq *MCPCallToolRequest, currentConfig *settings.Config, services MCPAvailableServices) ([]*MCPTextContent, *errs.Error) { + var exchangeRatesRequest MCPQueryExchangeRatesRequest + + if callToolReq.Arguments != nil { + if err := json.Unmarshal(callToolReq.Arguments, &exchangeRatesRequest); err != nil { + return nil, errs.NewIncompleteOrIncorrectSubmissionError(err) + } + } else { + return nil, errs.ErrIncompleteOrIncorrectSubmission + } + + dataSource := exchangerates.Container.Current + + if dataSource == nil { + return nil, errs.ErrInvalidExchangeRatesDataSource + } + + exchangeRateResponse, err := dataSource.GetLatestExchangeRates(c, c.GetCurrentUid(), currentConfig) + + if err != nil { + return nil, errs.Or(err, errs.ErrOperationFailed) + } + + response, err := h.createNewMCPQueryExchangeRatesResponse(exchangeRatesRequest.Currencies, exchangeRateResponse) + + if err != nil { + return nil, errs.Or(err, errs.ErrOperationFailed) + } + + return response, nil +} + +func (h *mcpQueryLatestExchangeRatesRequestToolHandler) createNewMCPQueryExchangeRatesResponse(currencies string, exchangeRatesResp *models.LatestExchangeRateResponse) ([]*MCPTextContent, error) { + queryCurrencies := make(map[string]bool) + + for _, currency := range strings.Split(currencies, ",") { + currency = strings.TrimSpace(currency) + + if currency != "" { + queryCurrencies[currency] = true + } + } + + response := &MCPQueryExchangeRatesResponse{ + BaseCurrency: exchangeRatesResp.BaseCurrency, + UpdateTime: utils.FormatUnixTimeToLongDateTimeWithTimezoneRFC3389Format(exchangeRatesResp.UpdateTime, time.UTC), + Rates: make([]*MCPQueryExchangeRateInfo, 0, len(exchangeRatesResp.ExchangeRates)), + } + + for _, rate := range exchangeRatesResp.ExchangeRates { + if _, exists := queryCurrencies[rate.Currency]; rate.Currency != exchangeRatesResp.BaseCurrency && !exists { + continue + } + + response.Rates = append(response.Rates, &MCPQueryExchangeRateInfo{ + Currency: rate.Currency, + Rate: rate.Rate, + }) + } + + content, err := json.Marshal(response) + + if err != nil { + return nil, err + } + + return []*MCPTextContent{ + NewMCPTextContent(string(content)), + }, nil +} diff --git a/pkg/mcp/handler.go b/pkg/mcp/handler.go new file mode 100644 index 00000000..c849d9d2 --- /dev/null +++ b/pkg/mcp/handler.go @@ -0,0 +1,187 @@ +package mcp + +import ( + "reflect" + + "github.com/invopop/jsonschema" + "github.com/mayswind/ezbookkeeping/pkg/core" + "github.com/mayswind/ezbookkeeping/pkg/errs" + "github.com/mayswind/ezbookkeeping/pkg/services" + "github.com/mayswind/ezbookkeeping/pkg/settings" +) + +// MCPAvailableServices holds the services available for MCP tools +type MCPAvailableServices interface { + GetTransactionService() *services.TransactionService + GetTransactionCategoryService() *services.TransactionCategoryService + GetTransactionTagService() *services.TransactionTagService + GetAccountService() *services.AccountService + GetUserService() *services.UserService +} + +// MCPToolHandler defines the MCP tool handler +type MCPToolHandler[T MCPTextContent | MCPImageContent | MCPAudioContent | MCPResourceLink | MCPEmbeddedResource] interface { + // Description returns the description of the MCP tool + Description() string + + // InputType returns the input type for the MCP tool request + InputType() reflect.Type + + // OutputType returns the output type for the MCP tool response + OutputType() reflect.Type + + // Handle processes the MCP call tool request and returns the response + Handle(*core.WebContext, *MCPCallToolRequest, *settings.Config, MCPAvailableServices) ([]*T, *errs.Error) +} + +// GetAllMCPToolInfos returns all available MCP tool information +func GetAllMCPToolInfos() []*MCPTool { + toolInfos := make([]*MCPTool, 0) + + for name, handler := range mcpTextContentTools { + toolInfos = append(toolInfos, getMCPToolInfo(name, handler)) + } + + for name, handler := range mcpImageContentTools { + toolInfos = append(toolInfos, getMCPToolInfo(name, handler)) + } + + for name, handler := range mcpAudioContentTools { + toolInfos = append(toolInfos, getMCPToolInfo(name, handler)) + } + + for name, handler := range mcpResourceLinkTools { + toolInfos = append(toolInfos, getMCPToolInfo(name, handler)) + } + + for name, handler := range mcpEmbeddedResourceTools { + toolInfos = append(toolInfos, getMCPToolInfo(name, handler)) + } + + return toolInfos +} + +// MCPToolHandle handles the MCP tool request based on the tool name +func MCPToolHandle(c *core.WebContext, callToolReq *MCPCallToolRequest, currentConfig *settings.Config, services MCPAvailableServices) (any, *errs.Error) { + if handler, exists := mcpTextContentTools[callToolReq.Name]; exists { + return mcpTextContentToolHandle(c, handler, currentConfig, services, callToolReq) + } + + if handler, exists := mcpImageContentTools[callToolReq.Name]; exists { + return mcpImageContentToolHandle(c, handler, currentConfig, services, callToolReq) + } + + if handler, exists := mcpAudioContentTools[callToolReq.Name]; exists { + return mcpAudioContentToolHandle(c, handler, currentConfig, services, callToolReq) + } + + if handler, exists := mcpResourceLinkTools[callToolReq.Name]; exists { + return mcpResourceLinkToolHandle(c, handler, currentConfig, services, callToolReq) + } + + if handler, exists := mcpEmbeddedResourceTools[callToolReq.Name]; exists { + return mcpEmbeddedResourceToolHandle(c, handler, currentConfig, services, callToolReq) + } + + return nil, errs.ErrApiNotFound +} + +func mcpTextContentToolHandle(c *core.WebContext, handler MCPToolHandler[MCPTextContent], currentConfig *settings.Config, services MCPAvailableServices, callToolReq *MCPCallToolRequest) (any, *errs.Error) { + result, err := handler.Handle(c, callToolReq, currentConfig, services) + + if err != nil { + return nil, errs.Or(err, errs.ErrOperationFailed) + } + + callToolResp := MCPCallToolResponse[MCPTextContent]{ + Content: result, + IsError: false, + } + + return callToolResp, nil +} + +func mcpImageContentToolHandle(c *core.WebContext, handler MCPToolHandler[MCPImageContent], currentConfig *settings.Config, services MCPAvailableServices, callToolReq *MCPCallToolRequest) (any, *errs.Error) { + result, err := handler.Handle(c, callToolReq, currentConfig, services) + + if err != nil { + return nil, errs.Or(err, errs.ErrOperationFailed) + } + + callToolResp := MCPCallToolResponse[MCPImageContent]{ + Content: result, + IsError: false, + } + + return callToolResp, nil +} + +func mcpAudioContentToolHandle(c *core.WebContext, handler MCPToolHandler[MCPAudioContent], currentConfig *settings.Config, services MCPAvailableServices, callToolReq *MCPCallToolRequest) (any, *errs.Error) { + result, err := handler.Handle(c, callToolReq, currentConfig, services) + + if err != nil { + return nil, errs.Or(err, errs.ErrOperationFailed) + } + + callToolResp := MCPCallToolResponse[MCPAudioContent]{ + Content: result, + IsError: false, + } + + return callToolResp, nil +} + +func mcpResourceLinkToolHandle(c *core.WebContext, handler MCPToolHandler[MCPResourceLink], currentConfig *settings.Config, services MCPAvailableServices, callToolReq *MCPCallToolRequest) (any, *errs.Error) { + result, err := handler.Handle(c, callToolReq, currentConfig, services) + + if err != nil { + return nil, errs.Or(err, errs.ErrOperationFailed) + } + + callToolResp := MCPCallToolResponse[MCPResourceLink]{ + Content: result, + IsError: false, + } + + return callToolResp, nil +} + +func mcpEmbeddedResourceToolHandle(c *core.WebContext, handler MCPToolHandler[MCPEmbeddedResource], currentConfig *settings.Config, services MCPAvailableServices, callToolReq *MCPCallToolRequest) (any, *errs.Error) { + result, err := handler.Handle(c, callToolReq, currentConfig, services) + + if err != nil { + return nil, errs.Or(err, errs.ErrOperationFailed) + } + + callToolResp := MCPCallToolResponse[MCPEmbeddedResource]{ + Content: result, + IsError: false, + } + + return callToolResp, nil +} + +func getMCPToolInfo[T MCPTextContent | MCPImageContent | MCPAudioContent | MCPResourceLink | MCPEmbeddedResource](name string, handler MCPToolHandler[T]) *MCPTool { + mcpTool := &MCPTool{ + Name: name, + Description: handler.Description(), + } + + schemeGenerator := jsonschema.Reflector{ + Anonymous: true, + DoNotReference: true, + ExpandedStruct: true, + } + + if handler.InputType() != nil { + schema := schemeGenerator.ReflectFromType(handler.InputType()) + mcpTool.InputSchema = schema + } + + if handler.OutputType() != nil { + schema := schemeGenerator.ReflectFromType(handler.OutputType()) + mcpTool.OutputSchema = schema + } + + return mcpTool +} diff --git a/pkg/mcp/model_context_protocol.go b/pkg/mcp/model_context_protocol.go new file mode 100644 index 00000000..7996b862 --- /dev/null +++ b/pkg/mcp/model_context_protocol.go @@ -0,0 +1,218 @@ +package mcp + +import ( + "encoding/base64" + "encoding/json" + + "github.com/invopop/jsonschema" +) + +// MCPProtocolVersion defines the type for Model Context Protocol (MCP) version +type MCPProtocolVersion string + +// MCP Protocol Versions +const ( + MCPProtocolVersion20250618 MCPProtocolVersion = "2025-06-18" + MCPProtocolVersion20250326 MCPProtocolVersion = "2025-03-26" + MCPProtocolVersion20241105 MCPProtocolVersion = "2024-11-05" +) + +// LatestSupportedMCPVersion defines the latest supported version of Model Context Protocol (MCP) +const LatestSupportedMCPVersion = MCPProtocolVersion20250618 + +// SupportedMCPVersion defines a map of supported MCP versions +var SupportedMCPVersion = map[MCPProtocolVersion]bool{ + MCPProtocolVersion20250618: true, + MCPProtocolVersion20250326: true, + MCPProtocolVersion20241105: true, +} + +// MCPInitializeRequest defines the request structure for initializing the MCP connection +type MCPInitializeRequest struct { + ProtocolVersion string `json:"protocolVersion"` + ClientInfo *MCPImplementation `json:"clientInfo"` +} + +// MCPInitializeResponse defines the response structure for the MCP initialization request +type MCPInitializeResponse struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities *MCPCapabilities `json:"capabilities"` + ServerInfo *MCPImplementation `json:"serverInfo"` +} + +// MCPCapabilities defines the capabilities of the MCP server +type MCPCapabilities struct { + Resources *MCPResourceCapabilities `json:"resources,omitempty"` + Tools *MCPToolCapabilities `json:"tools,omitempty"` + Prompts *MCPPromptCapabilities `json:"prompts,omitempty"` +} + +// MCPImplementation defines the client/server information structure sent in the MCP initialization request/response +type MCPImplementation struct { + Name string `json:"name"` + Title string `json:"title,omitempty"` + Version string `json:"version"` +} + +// MCPResourceCapabilities defines the capabilities related to resources in the MCP +type MCPResourceCapabilities struct { + Subscribe bool `json:"subscribe"` + ListChanged bool `json:"listChanged"` +} + +// MCPToolCapabilities defines the capabilities related to tools in the MCP +type MCPToolCapabilities struct { + ListChanged bool `json:"listChanged"` +} + +// MCPPromptCapabilities defines the capabilities related to prompts in the MCP +type MCPPromptCapabilities struct { + ListChanged bool `json:"listChanged"` +} + +// MCPListResourcesResponse defines the response structure for listing resources in the MCP +type MCPListResourcesResponse struct { + Resources []*MCPResource `json:"resources"` + NextCursor string `json:"nextCursor,omitempty"` +} + +// MCPResource defines the structure of a resource in the MCP +type MCPResource struct { + URI string `json:"uri"` + Name string `json:"name"` + Size int `json:"size,omitempty"` + MimeType string `json:"mimeType,omitempty"` + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` +} + +// MCPReadResourceRequest defines the request structure for reading a resource in the MCP +type MCPReadResourceRequest struct { + URI string `json:"uri"` +} + +// MCPReadResourceResponse defines the response structure for reading a resource in the MCP +type MCPReadResourceResponse[T MCPTextResourceContents | MCPBlobResourceContents] struct { + Contents []*T `json:"contents"` +} + +// MCPTextResourceContents defines the text contents structure of a resource in the MCP +type MCPTextResourceContents struct { + URI string `json:"uri"` + Text string `json:"text"` + MimeType string `json:"mimeType,omitempty"` +} + +// MCPBlobResourceContents defines the blob contents structure of a resource in the MCP +type MCPBlobResourceContents struct { + URI string `json:"uri"` + Blob string `json:"blob"` // Base64 encoded content of the resource + MimeType string `json:"mimeType,omitempty"` +} + +// MCPListToolsResponse defines the response structure for listing tools in the MCP +type MCPListToolsResponse struct { + Tools []*MCPTool `json:"tools"` + NextCursor string `json:"nextCursor,omitempty"` +} + +// MCPTool defines the structure of a tool in the MCP +type MCPTool struct { + Name string `json:"name"` + InputSchema *jsonschema.Schema `json:"inputSchema"` + OutputSchema *jsonschema.Schema `json:"outputSchema"` + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` +} + +// MCPCallToolRequest defines the request structure for listing tools in the MCP +type MCPCallToolRequest struct { + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments,omitempty"` +} + +// MCPCallToolResponse defines the response structure for calling a tool in the MCP +type MCPCallToolResponse[T MCPTextContent | MCPImageContent | MCPAudioContent | MCPResourceLink | MCPEmbeddedResource] struct { + Content []*T `json:"content"` + IsError bool `json:"isError,omitempty"` +} + +// MCPTextContent defines the text content structure used in MCP +type MCPTextContent struct { + Type string `json:"type"` + Text string `json:"text"` +} + +// MCPImageContent defines the image content structure used in MCP +type MCPImageContent struct { + Type string `json:"type"` + MimeType string `json:"mimeType"` + Data string `json:"data"` // Base64 encoded content for binary data +} + +// MCPAudioContent defines the audio content structure used in MCP +type MCPAudioContent struct { + Type string `json:"type"` + MimeType string `json:"mimeType"` + Data string `json:"data"` // Base64 encoded content for binary data +} + +// MCPResourceLink defines the resource link content structure used in MCP +type MCPResourceLink struct { + URI string `json:"uri"` + Type string `json:"type"` + Name string `json:"name"` + Size int `json:"size,omitempty"` + MimeType string `json:"mimeType,omitempty"` + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` +} + +// MCPEmbeddedResource defines the embedded resource content structure used in MCP +type MCPEmbeddedResource struct { + Type string `json:"type"` + Resource any `json:"resource"` +} + +// NewMCPTextContent creates a new instance of MCPTextContent with the given text +func NewMCPTextContent(text string) *MCPTextContent { + return &MCPTextContent{ + Type: "text", + Text: text, + } +} + +// NewMCPImageContent creates a new instance of MCPImageContent with the given data and MIME type +func NewMCPImageContent(data []byte, mimeType string) *MCPImageContent { + return &MCPImageContent{ + Type: "image", + MimeType: mimeType, + Data: base64.StdEncoding.EncodeToString(data), + } +} + +// NewMCPAudioContent creates a new instance of MCPAudioContent with the given data and MIME type +func NewMCPAudioContent(data []byte, mimeType string) *MCPAudioContent { + return &MCPAudioContent{ + Type: "audio", + MimeType: mimeType, + Data: base64.StdEncoding.EncodeToString(data), + } +} + +// NewMCPResourceLink creates a new instance of MCPResourceLink with the given parameters +func NewMCPResourceLink(uri string, name string) *MCPResourceLink { + return &MCPResourceLink{ + URI: uri, + Type: "resource_link", + Name: name, + } +} + +// NewMCPEmbeddedResource creates a new instance of MCPEmbeddedResource with the given resource +func NewMCPEmbeddedResource[T MCPTextResourceContents | MCPBlobResourceContents](resource *T) *MCPEmbeddedResource { + return &MCPEmbeddedResource{ + Type: "resource", + Resource: resource, + } +} diff --git a/pkg/middlewares/mcp_server_ip_limit.go b/pkg/middlewares/mcp_server_ip_limit.go new file mode 100644 index 00000000..f17ca2ae --- /dev/null +++ b/pkg/middlewares/mcp_server_ip_limit.go @@ -0,0 +1,27 @@ +package middlewares + +import ( + "github.com/mayswind/ezbookkeeping/pkg/core" + "github.com/mayswind/ezbookkeeping/pkg/errs" + "github.com/mayswind/ezbookkeeping/pkg/settings" + "github.com/mayswind/ezbookkeeping/pkg/utils" +) + +// MCPServerIpLimit limits access to the MCP server based on IP address. +func MCPServerIpLimit(config *settings.Config) core.MiddlewareHandlerFunc { + return func(c *core.WebContext) { + if len(config.MCPAllowedRemoteIPs) < 1 { + c.Next() + return + } + + for i := 0; i < len(config.MCPAllowedRemoteIPs); i++ { + if config.MCPAllowedRemoteIPs[i].Match(c.ClientIP()) { + c.Next() + return + } + } + + utils.PrintJsonErrorResult(c, errs.ErrIPForbidden) + } +} diff --git a/pkg/settings/setting.go b/pkg/settings/setting.go index 152d3796..b8abbcb3 100644 --- a/pkg/settings/setting.go +++ b/pkg/settings/setting.go @@ -234,6 +234,10 @@ type Config struct { EnableGZip bool EnableRequestLog bool + // MCP + EnableMCPServer bool + MCPAllowedRemoteIPs []*core.IPPattern + // Database DatabaseConfig *DatabaseConfig EnableQueryLog bool @@ -377,6 +381,12 @@ func LoadConfiguration(configFilePath string) (*Config, error) { return nil, err } + err = loadMCPServerConfiguration(config, cfgFile, "mcp") + + if err != nil { + return nil, err + } + err = loadDatabaseConfiguration(config, cfgFile, "database") if err != nil { @@ -540,6 +550,35 @@ func loadServerConfiguration(config *Config, configFile *ini.File, sectionName s return nil } +func loadMCPServerConfiguration(config *Config, configFile *ini.File, sectionName string) error { + config.EnableMCPServer = getConfigItemBoolValue(configFile, sectionName, "enable_mcp", false) + mcpAllowedRemoteIps := getConfigItemStringValue(configFile, sectionName, "mcp_allowed_remote_ips", "") + + if mcpAllowedRemoteIps != "" { + remoteIPs := strings.Split(mcpAllowedRemoteIps, ",") + config.MCPAllowedRemoteIPs = make([]*core.IPPattern, 0, len(remoteIPs)) + + for i := 0; i < len(remoteIPs); i++ { + ip := strings.TrimSpace(remoteIPs[i]) + pattern, err := core.ParseIPPattern(ip) + + if err != nil { + return err + } + + if pattern == nil { + continue + } + + config.MCPAllowedRemoteIPs = append(config.MCPAllowedRemoteIPs, pattern) + } + } else { + config.MCPAllowedRemoteIPs = nil + } + + return nil +} + func loadDatabaseConfiguration(config *Config, configFile *ini.File, sectionName string) error { dbConfig := &DatabaseConfig{} diff --git a/pkg/utils/api.go b/pkg/utils/api.go index be82a23b..1edf41c1 100644 --- a/pkg/utils/api.go +++ b/pkg/utils/api.go @@ -60,6 +60,47 @@ func PrintJsonErrorResult(c *core.WebContext, err *errs.Error) { c.AbortWithStatusJSON(err.HttpStatusCode, result) } +// PrintJSONRPCSuccessResult writes success response in JSON-RPC format to current http context +func PrintJSONRPCSuccessResult(c *core.WebContext, jsonRPCRequest *core.JSONRPCRequest, result any) { + c.JSON(http.StatusOK, core.NewJSONRPCResponse(jsonRPCRequest.ID, result)) +} + +// PrintJSONRPCErrorResult writes error response in JSON-RPC format to current http context +func PrintJSONRPCErrorResult(c *core.WebContext, jsonRPCRequest *core.JSONRPCRequest, err *errs.Error) { + c.SetResponseError(err) + + errorMessage := err.Error() + + if err.Code() == errs.ErrIncompleteOrIncorrectSubmission.Code() && len(err.BaseError) > 0 { + validationErrors, ok := err.BaseError[0].(validator.ValidationErrors) + + if ok { + for _, err := range validationErrors { + errorMessage = getValidationErrorText(err) + break + } + } + } + + var id any + + if jsonRPCRequest != nil { + id = jsonRPCRequest.ID + } + + jsonRPCError := core.JSONRPCInternalError + + if err.Code() == errs.ErrIncompleteOrIncorrectSubmission.Code() { + jsonRPCError = core.JSONRPCParseError + } else if err.Code() == errs.ErrApiNotFound.Code() { + jsonRPCError = core.JSONRPCMethodNotFoundError + } else if err.Code() == errs.ErrParameterInvalid.Code() { + jsonRPCError = core.JSONRPCInvalidParamsError + } + + c.AbortWithStatusJSON(err.HttpStatusCode, core.NewJSONRPCErrorResponseWithCause(id, jsonRPCError, errorMessage)) +} + // PrintDataErrorResult writes error response in custom content type to current http context func PrintDataErrorResult(c *core.WebContext, contentType string, err *errs.Error) { c.SetResponseError(err) diff --git a/pkg/utils/datetimes.go b/pkg/utils/datetimes.go index 08ca9bcc..7a4d5304 100644 --- a/pkg/utils/datetimes.go +++ b/pkg/utils/datetimes.go @@ -9,15 +9,16 @@ import ( ) const ( - longDateFormat = "2006-01-02" - longDateTimeFormat = "2006-01-02 15:04:05" - longDateTimeWithTimezoneFormat = "2006-01-02 15:04:05Z07:00" - longDateTimeWithTimezoneFormat2 = "2006-01-02 15:04:05 Z0700" - longDateTimeWithoutSecondFormat = "2006-01-02 15:04" - shortDateTimeFormat = "2006-1-2 15:4:5" - yearMonthDateTimeFormat = "2006-01" - westernmostTimezoneUtcOffset = -720 // Etc/GMT+12 (UTC-12:00) - easternmostTimezoneUtcOffset = 840 // Pacific/Kiritimati (UTC+14:00) + longDateFormat = "2006-01-02" + longDateTimeFormat = "2006-01-02 15:04:05" + longDateTimeWithTimezoneFormat = "2006-01-02 15:04:05Z07:00" + longDateTimeWithTimezoneFormat2 = "2006-01-02 15:04:05 Z0700" + longDateTimeWithTimezoneRFC3389Format = "2006-01-02T15:04:05Z07:00" + longDateTimeWithoutSecondFormat = "2006-01-02 15:04" + shortDateTimeFormat = "2006-1-2 15:4:5" + yearMonthDateTimeFormat = "2006-01" + westernmostTimezoneUtcOffset = -720 // Etc/GMT+12 (UTC-12:00) + easternmostTimezoneUtcOffset = 840 // Pacific/Kiritimati (UTC+14:00) ) // ParseNumericYearMonth returns numeric year and month from textual content @@ -65,6 +66,28 @@ func FormatUnixTimeToLongDateTime(unixTime int64, timezone *time.Location) strin return t.Format(longDateTimeFormat) } +// FormatUnixTimeToLongDateTimeWithTimezone returns a textual representation of the unix time formatted by long date time with timezone format +func FormatUnixTimeToLongDateTimeWithTimezone(unixTime int64, timezone *time.Location) string { + t := parseFromUnixTime(unixTime) + + if timezone != nil { + t = t.In(timezone) + } + + return t.Format(longDateTimeWithTimezoneFormat) +} + +// FormatUnixTimeToLongDateTimeWithTimezoneRFC3389Format returns a textual representation of the unix time formatted by long date time with timezone RFC 3389 format +func FormatUnixTimeToLongDateTimeWithTimezoneRFC3389Format(unixTime int64, timezone *time.Location) string { + t := parseFromUnixTime(unixTime) + + if timezone != nil { + t = t.In(timezone) + } + + return t.Format(longDateTimeWithTimezoneRFC3389Format) +} + func FormatYearMonthDayToLongDateTime(year string, month string, day string) (string, error) { if len(year) == 2 { yearLast2Digits, err := StringToInt(year) diff --git a/pkg/utils/datetimes_test.go b/pkg/utils/datetimes_test.go index d695b3d0..f2bd25c3 100644 --- a/pkg/utils/datetimes_test.go +++ b/pkg/utils/datetimes_test.go @@ -32,6 +32,34 @@ func TestFormatUnixTimeToLongDate(t *testing.T) { assert.Equal(t, expectedValue, actualValue) } +func TestFormatUnixTimeToLongDateTimeWithTimezone(t *testing.T) { + unixTime := int64(1617228083) + utcTimezone := time.FixedZone("Test Timezone", 0) // UTC + utc8Timezone := time.FixedZone("Test Timezone", 28800) // UTC+8 + + expectedValue := "2021-03-31 22:01:23Z" + actualValue := FormatUnixTimeToLongDateTimeWithTimezone(unixTime, utcTimezone) + assert.Equal(t, expectedValue, actualValue) + + expectedValue = "2021-04-01 06:01:23+08:00" + actualValue = FormatUnixTimeToLongDateTimeWithTimezone(unixTime, utc8Timezone) + assert.Equal(t, expectedValue, actualValue) +} + +func TestFormatUnixTimeToLongDateTimeWithTimezoneRFC3389Format(t *testing.T) { + unixTime := int64(1617228083) + utcTimezone := time.FixedZone("Test Timezone", 0) // UTC + utc8Timezone := time.FixedZone("Test Timezone", 28800) // UTC+8 + + expectedValue := "2021-03-31T22:01:23Z" + actualValue := FormatUnixTimeToLongDateTimeWithTimezoneRFC3389Format(unixTime, utcTimezone) + assert.Equal(t, expectedValue, actualValue) + + expectedValue = "2021-04-01T06:01:23+08:00" + actualValue = FormatUnixTimeToLongDateTimeWithTimezoneRFC3389Format(unixTime, utc8Timezone) + assert.Equal(t, expectedValue, actualValue) +} + func TestFormatUnixTimeToLongDateTime(t *testing.T) { unixTime := int64(1617228083) utcTimezone := time.FixedZone("Test Timezone", 0) // UTC diff --git a/src/locales/de.json b/src/locales/de.json index a8330531..8802d49d 100644 --- a/src/locales/de.json +++ b/src/locales/de.json @@ -1198,7 +1198,8 @@ "uploaded file is empty": "Hochgeladene Datei ist leer", "uploaded file size exceeds the maximum allowed size": "Hochgeladene Datei überschreitet die maximal zulässige Größe", "failure count exceeded maximum limit": "Failure count exceeded maximum limit, please try again after some time", - "repeated request": "Repeated Request" + "repeated request": "Repeated Request", + "ip address is forbidden to access this resource": "IP address is forbidden to access this resource" }, "parameter": { "id": "ID", diff --git a/src/locales/en.json b/src/locales/en.json index 7853d4dd..2c7e4567 100644 --- a/src/locales/en.json +++ b/src/locales/en.json @@ -1198,7 +1198,8 @@ "uploaded file is empty": "Uploaded file is empty", "uploaded file size exceeds the maximum allowed size": "Uploaded file size exceeds the maximum allowed size", "failure count exceeded maximum limit": "Failure count exceeded maximum limit, please try again after some time", - "repeated request": "Repeated Request" + "repeated request": "Repeated Request", + "ip address is forbidden to access this resource": "IP address is forbidden to access this resource" }, "parameter": { "id": "ID", diff --git a/src/locales/es.json b/src/locales/es.json index 55f831e9..82bd2be8 100644 --- a/src/locales/es.json +++ b/src/locales/es.json @@ -1198,7 +1198,8 @@ "uploaded file is empty": "El archivo subido está vacío", "uploaded file size exceeds the maximum allowed size": "El tamaño del archivo cargado excede el tamaño máximo permitido", "failure count exceeded maximum limit": "Failure count exceeded maximum limit, please try again after some time", - "repeated request": "Repeated Request" + "repeated request": "Repeated Request", + "ip address is forbidden to access this resource": "IP address is forbidden to access this resource" }, "parameter": { "id": "IDENTIFICACIÓN", diff --git a/src/locales/it.json b/src/locales/it.json index 80029c36..bf64ab3f 100644 --- a/src/locales/it.json +++ b/src/locales/it.json @@ -1198,7 +1198,8 @@ "uploaded file is empty": "Il file caricato è vuoto", "uploaded file size exceeds the maximum allowed size": "La dimensione del file caricato supera la dimensione massima consentita", "failure count exceeded maximum limit": "Il conteggio dei fallimenti ha superato il limite massimo, riprova più tardi", - "repeated request": "Repeated Request" + "repeated request": "Repeated Request", + "ip address is forbidden to access this resource": "IP address is forbidden to access this resource" }, "parameter": { "id": "ID", diff --git a/src/locales/ja.json b/src/locales/ja.json index f0343476..5b66a858 100644 --- a/src/locales/ja.json +++ b/src/locales/ja.json @@ -1198,7 +1198,8 @@ "uploaded file is empty": "アップロードされたファイルは空です", "uploaded file size exceeds the maximum allowed size": "アップロードされたファイルが最大許容サイズを超えています", "failure count exceeded maximum limit": "Failure count exceeded maximum limit, please try again after some time", - "repeated request": "Repeated Request" + "repeated request": "Repeated Request", + "ip address is forbidden to access this resource": "IP address is forbidden to access this resource" }, "parameter": { "id": "ID", diff --git a/src/locales/pt_BR.json b/src/locales/pt_BR.json index fc405268..1467f35d 100644 --- a/src/locales/pt_BR.json +++ b/src/locales/pt_BR.json @@ -1198,7 +1198,8 @@ "uploaded file is empty": "Arquivo enviado está vazio", "uploaded file size exceeds the maximum allowed size": "O tamanho do arquivo enviado excede o tamanho máximo permitido", "failure count exceeded maximum limit": "Contagem de falhas excedeu o limite máximo, por favor tente novamente mais tarde", - "repeated request": "Pedido Repetido" + "repeated request": "Pedido Repetido", + "ip address is forbidden to access this resource": "IP address is forbidden to access this resource" }, "parameter": { "id": "ID", diff --git a/src/locales/ru.json b/src/locales/ru.json index efcbe126..02252dd9 100644 --- a/src/locales/ru.json +++ b/src/locales/ru.json @@ -1198,7 +1198,8 @@ "uploaded file is empty": "Загруженный файл пуст", "uploaded file size exceeds the maximum allowed size": "Размер загруженного файла превышает максимально допустимый размер", "failure count exceeded maximum limit": "Failure count exceeded maximum limit, please try again after some time", - "repeated request": "Repeated Request" + "repeated request": "Repeated Request", + "ip address is forbidden to access this resource": "IP address is forbidden to access this resource" }, "parameter": { "id": "ID", diff --git a/src/locales/uk.json b/src/locales/uk.json index 7cf38df2..8a7b2430 100644 --- a/src/locales/uk.json +++ b/src/locales/uk.json @@ -1198,7 +1198,8 @@ "uploaded file is empty": "Завантажений файл порожній", "uploaded file size exceeds the maximum allowed size": "Розмір завантаженого файлу перевищує максимально допустимий", "failure count exceeded maximum limit": "Кількість невдали спроб перевищила допустимий ліміт, спробуйте пізніше", - "repeated request": "Repeated Request" + "repeated request": "Repeated Request", + "ip address is forbidden to access this resource": "IP address is forbidden to access this resource" }, "parameter": { "id": "ID", diff --git a/src/locales/vi.json b/src/locales/vi.json index 874c61d6..69aa93f7 100644 --- a/src/locales/vi.json +++ b/src/locales/vi.json @@ -1198,7 +1198,8 @@ "uploaded file is empty": "Tệp đã tải lên trống", "uploaded file size exceeds the maximum allowed size": "Kích thước tệp đã tải lên vượt quá kích thước tối đa cho phép", "failure count exceeded maximum limit": "Failure count exceeded maximum limit, please try again after some time", - "repeated request": "Repeated Request" + "repeated request": "Repeated Request", + "ip address is forbidden to access this resource": "IP address is forbidden to access this resource" }, "parameter": { "id": "ID", diff --git a/src/locales/zh_Hans.json b/src/locales/zh_Hans.json index 6f76df03..dc365f11 100644 --- a/src/locales/zh_Hans.json +++ b/src/locales/zh_Hans.json @@ -1198,7 +1198,8 @@ "uploaded file is empty": "上传的文件为空", "uploaded file size exceeds the maximum allowed size": "上传的文件大小超出了允许的最大大小", "failure count exceeded maximum limit": "失败次数超出最大限制,请稍后重试", - "repeated request": "重复的请求" + "repeated request": "重复的请求", + "ip address is forbidden to access this resource": "IP 地址被禁止访问该资源" }, "parameter": { "id": "ID", diff --git a/src/locales/zh_Hant.json b/src/locales/zh_Hant.json index 1fa4f931..159a5e45 100644 --- a/src/locales/zh_Hant.json +++ b/src/locales/zh_Hant.json @@ -1198,7 +1198,8 @@ "uploaded file is empty": "上傳的檔案為空", "uploaded file size exceeds the maximum allowed size": "上傳的檔案大小超出了允許的最大大小", "failure count exceeded maximum limit": "失敗次數超出最大限制,請稍後重試", - "repeated request": "重複的請求" + "repeated request": "重複的請求", + "ip address is forbidden to access this resource": "IP 地址被禁止訪問此資源" }, "parameter": { "id": "ID", diff --git a/third-party-dependencies.json b/third-party-dependencies.json index 7638b896..3a9dc992 100644 --- a/third-party-dependencies.json +++ b/third-party-dependencies.json @@ -134,6 +134,12 @@ "url": "https://github.com/extrame/xls", "licenseUrl": "https://github.com/extrame/xls/blob/4a6cf263071b975a90abf74ca3e804b48243be28/LICENSE" }, + { + "name": "jsonschema", + "copyright": "Copyright (C) 2014 Alec Thomas", + "url": "https://github.com/invopop/jsonschema", + "licenseUrl": "https://github.com/invopop/jsonschema/blob/v0.13.0/COPYING" + }, { "name": "go-ordered-map", "url": "https://github.com/wk8/go-ordered-map",