using ObsidianMcp.Config;
using Microsoft.Extensions.Options;
namespace ObsidianMcp.Services;
///
/// 写入门禁——在路径安全(VaultPathResolver)之上再加写入白名单控制。
///
/// 规则优先级(从高到低):
/// 1. 永禁写入:AGENTS.md / PROFILE.md / README.md / CLAUDE.md(任何路径下的同名文件)
/// 2. 永禁前缀:01-Secret/
/// 3. 必须命中写入白名单之一才允许
///
/// 白名单(hardcode):
/// - 前缀 02-ShengquGames/logs/
/// - 前缀 Coding/
/// - 精确匹配 NAS/NAS 待办清单.md
///
/// 白名单可通过 env Vault__WriteWhitelist__N 扩展。
///
public class VaultWriteGuard
{
// 永禁写入的文件名(不含路径,任何目录下的同名文件都禁写)
private static readonly HashSet ForbiddenFileNames =
new(StringComparer.OrdinalIgnoreCase)
{
"AGENTS.md",
"PROFILE.md",
"README.md",
"CLAUDE.md",
};
// 永禁写入的路径前缀(相对路径)
private static readonly string[] ForbiddenPrefixes =
[
"01-Secret/",
"01-Secret\\",
];
// hardcode 写入白名单
// 前缀匹配:以 / 或 \ 结尾表示前缀;精确匹配:其他
private static readonly string[] HardcodeWhitelist =
[
"02-ShengquGames/logs/",
"02-ShengquGames\\logs\\",
"Coding/",
"Coding\\",
"NAS/NAS 待办清单.md",
"NAS\\NAS 待办清单.md",
];
private readonly VaultPathResolver _resolver;
private readonly string[] _extraWhitelist;
public VaultWriteGuard(VaultPathResolver resolver, IOptions opts)
{
_resolver = resolver;
_extraWhitelist = opts.Value.WriteWhitelist ?? [];
}
///
/// 校验相对路径是否允许写入。
/// 通过则返回规范化后的绝对路径;不通过则抛 UnauthorizedAccessException。
///
public string EnsureWritable(string relativePath)
{
// 先过路径安全守卫(防穿越 + 黑名单)
var absPath = _resolver.Resolve(relativePath);
// 规范化相对路径(用于白名单匹配),统一用 /
var normalized = NormalizeRelative(relativePath);
// 1. 永禁文件名
var fileName = Path.GetFileName(absPath);
if (ForbiddenFileNames.Contains(fileName))
throw new UnauthorizedAccessException(
$"禁止写入保护文件:{relativePath}");
// 2. 永禁前缀
foreach (var prefix in ForbiddenPrefixes)
{
if (normalized.StartsWith(NormalizeRelative(prefix), StringComparison.OrdinalIgnoreCase))
throw new UnauthorizedAccessException(
$"禁止写入 01-Secret/ 目录:{relativePath}");
}
// 3. 白名单(hardcode + env 扩展)
if (!IsInWhitelist(normalized))
throw new UnauthorizedAccessException(
$"路径不在写入白名单内:{relativePath}");
return absPath;
}
private bool IsInWhitelist(string normalized)
{
var allWhitelist = HardcodeWhitelist.Concat(_extraWhitelist);
foreach (var entry in allWhitelist)
{
var normalizedEntry = NormalizeRelative(entry);
if (normalizedEntry.EndsWith('/'))
{
// 前缀匹配
if (normalized.StartsWith(normalizedEntry, StringComparison.OrdinalIgnoreCase))
return true;
}
else
{
// 精确匹配
if (string.Equals(normalized, normalizedEntry, StringComparison.OrdinalIgnoreCase))
return true;
}
}
return false;
}
/// 统一用 / 作分隔符,用于白名单匹配。
private static string NormalizeRelative(string path) =>
path.Replace('\\', '/');
}