diff --git a/cmd/webserver.go b/cmd/webserver.go index 90adfe2f..3c6eb9c8 100644 --- a/cmd/webserver.go +++ b/cmd/webserver.go @@ -316,6 +316,7 @@ func startWebServer(c *core.CliContext) error { apiV1Route := apiRoute.Group("/v1") apiV1Route.Use(bindMiddleware(middlewares.JWTAuthorization(config))) + apiV1Route.Use(bindMiddleware(middlewares.APITokenIpLimit(config))) { // Tokens apiV1Route.GET("/tokens/list.json", bindApi(api.Tokens.TokenListHandler)) diff --git a/conf/ezbookkeeping.ini b/conf/ezbookkeeping.ini index de2fb8fe..74e1dc5c 100644 --- a/conf/ezbookkeeping.ini +++ b/conf/ezbookkeeping.ini @@ -296,6 +296,9 @@ password_reset_token_expired_time = 3600 # Set to true to enable API token generation enable_api_token = false +# Allowed remote IPs for using the API token, a comma-separated list of allowed remote IPs (asterisk * for any addresses, e.g. 192.168.1.* means any IPs in the 192.168.1.x subnet), leave blank to allow all remote IPs +api_token_allowed_remote_ips = + # Maximum count of password / token check failures (0 - 4294967295) per IP per minute (use the above duplicate checker), default is 5, set to 0 to disable max_failures_per_ip_per_minute = 5 diff --git a/pkg/middlewares/api_token_ip_limit.go b/pkg/middlewares/api_token_ip_limit.go new file mode 100644 index 00000000..0c81aa56 --- /dev/null +++ b/pkg/middlewares/api_token_ip_limit.go @@ -0,0 +1,39 @@ +package middlewares + +import ( + "github.com/mayswind/ezbookkeeping/pkg/core" + "github.com/mayswind/ezbookkeeping/pkg/errs" + "github.com/mayswind/ezbookkeeping/pkg/settings" + "github.com/mayswind/ezbookkeeping/pkg/utils" +) + +// APITokenIpLimit limits API token access based on IP address +func APITokenIpLimit(config *settings.Config) core.MiddlewareHandlerFunc { + return func(c *core.WebContext) { + claims := c.GetTokenClaims() + + if claims == nil { + c.Next() + return + } + + if claims.Type != core.USER_TOKEN_TYPE_API { + c.Next() + return + } + + if len(config.APITokenAllowedRemoteIPs) < 1 { + c.Next() + return + } + + for i := 0; i < len(config.APITokenAllowedRemoteIPs); i++ { + if config.APITokenAllowedRemoteIPs[i].Match(c.ClientIP()) { + c.Next() + return + } + } + + utils.PrintJsonErrorResult(c, errs.ErrIPForbidden) + } +} diff --git a/pkg/settings/setting.go b/pkg/settings/setting.go index 8d9ed33e..ebce2265 100644 --- a/pkg/settings/setting.go +++ b/pkg/settings/setting.go @@ -370,6 +370,7 @@ type Config struct { PasswordResetTokenExpiredTime uint32 PasswordResetTokenExpiredTimeDuration time.Duration EnableAPIToken bool + APITokenAllowedRemoteIPs []*core.IPPattern MaxFailuresPerIpPerMinute uint32 MaxFailuresPerUserPerMinute uint32 @@ -667,29 +668,13 @@ func loadServerConfiguration(config *Config, configFile *ini.File, sectionName s } func loadMCPServerConfiguration(config *Config, configFile *ini.File, sectionName string) error { + var err error + config.EnableMCPServer = getConfigItemBoolValue(configFile, sectionName, "enable_mcp", false) - mcpAllowedRemoteIps := getConfigItemStringValue(configFile, sectionName, "mcp_allowed_remote_ips", "") + config.MCPAllowedRemoteIPs, err = getIPPatterns(configFile, sectionName, "mcp_allowed_remote_ips", "") - if mcpAllowedRemoteIps != "" { - remoteIPs := strings.Split(mcpAllowedRemoteIps, ",") - config.MCPAllowedRemoteIPs = make([]*core.IPPattern, 0, len(remoteIPs)) - - for i := 0; i < len(remoteIPs); i++ { - ip := strings.TrimSpace(remoteIPs[i]) - pattern, err := core.ParseIPPattern(ip) - - if err != nil { - return err - } - - if pattern == nil { - continue - } - - config.MCPAllowedRemoteIPs = append(config.MCPAllowedRemoteIPs, pattern) - } - } else { - config.MCPAllowedRemoteIPs = nil + if err != nil { + return err } return nil @@ -976,6 +961,8 @@ func loadCronConfiguration(config *Config, configFile *ini.File, sectionName str } func loadSecurityConfiguration(config *Config, configFile *ini.File, sectionName string) error { + var err error + config.SecretKeyNoSet = !getConfigItemIsSet(configFile, sectionName, "secret_key") config.SecretKey = getConfigItemStringValue(configFile, sectionName, "secret_key", defaultSecretKey) @@ -1018,6 +1005,11 @@ func loadSecurityConfiguration(config *Config, configFile *ini.File, sectionName config.PasswordResetTokenExpiredTimeDuration = time.Duration(config.PasswordResetTokenExpiredTime) * time.Second config.EnableAPIToken = getConfigItemBoolValue(configFile, sectionName, "enable_api_token", false) + config.APITokenAllowedRemoteIPs, err = getIPPatterns(configFile, sectionName, "api_token_allowed_remote_ips", "") + + if err != nil { + return err + } config.MaxFailuresPerIpPerMinute = getConfigItemUint32Value(configFile, sectionName, "max_failures_per_ip_per_minute", defaultMaxFailuresPerIpPerMinute) config.MaxFailuresPerUserPerMinute = getConfigItemUint32Value(configFile, sectionName, "max_failures_per_user_per_minute", defaultMaxFailuresPerUserPerMinute) @@ -1260,6 +1252,34 @@ func getFinalPath(workingPath, p string) (string, error) { return p, err } +func getIPPatterns(configFile *ini.File, sectionName string, itemName string, defaultValue string) ([]*core.IPPattern, error) { + configValue := getConfigItemStringValue(configFile, sectionName, itemName, defaultValue) + + if configValue == "" { + return nil, nil + } + + remoteIPs := strings.Split(configValue, ",") + ipPatterns := make([]*core.IPPattern, 0, len(remoteIPs)) + + for i := 0; i < len(remoteIPs); i++ { + ip := strings.TrimSpace(remoteIPs[i]) + pattern, err := core.ParseIPPattern(ip) + + if err != nil { + return nil, err + } + + if pattern == nil { + continue + } + + ipPatterns = append(ipPatterns, pattern) + } + + return ipPatterns, nil +} + func getMultiLanguageContentConfig(configFile *ini.File, sectionName string, enableKey string, contentKey string) MultiLanguageContentConfig { config := MultiLanguageContentConfig{ Enabled: getConfigItemBoolValue(configFile, sectionName, enableKey, false),