fix(security): harden markdown rendering and sync safety

This commit is contained in:
sck_0
2026-03-15 09:21:51 +01:00
parent 078847f681
commit c0c118e223
8 changed files with 246 additions and 3 deletions

View File

@@ -49,6 +49,21 @@ function git(cmd) {
return execSync(`git ${cmd}`, { cwd: ROOT_DIR, encoding: 'utf-8', stdio: ['pipe', 'pipe', 'pipe'] }).trim();
}
function isAllowedDevOrigin(req) {
const host = req.headers?.host;
const origin = req.headers?.origin;
if (!host || !origin) {
return false;
}
try {
return new URL(origin).host === host;
} catch {
return false;
}
}
/** Ensure the upstream remote exists. */
function ensureUpstream() {
const remotes = git('remote');
@@ -250,6 +265,19 @@ export default function refreshSkillsPlugin() {
server.middlewares.use('/api/refresh-skills', async (req, res) => {
res.setHeader('Content-Type', 'application/json');
if (req.method !== 'POST') {
res.statusCode = 405;
res.setHeader('Allow', 'POST');
res.end(JSON.stringify({ success: false, error: 'Method not allowed' }));
return;
}
if (!isAllowedDevOrigin(req)) {
res.statusCode = 403;
res.end(JSON.stringify({ success: false, error: 'Forbidden origin' }));
return;
}
try {
let result;
@@ -287,3 +315,5 @@ export default function refreshSkillsPlugin() {
}
};
}
export { isAllowedDevOrigin };

View File

@@ -0,0 +1,98 @@
import { beforeEach, describe, expect, it, vi } from 'vitest';
const execSync = vi.fn((command) => {
if (command === 'git --version') return '';
if (command === 'git rev-parse --git-dir') return '.git';
if (command === 'git remote') return 'origin\nupstream\n';
if (command === 'git rev-parse HEAD') return 'abc123';
if (command === 'git fetch upstream main') return '';
if (command === 'git rev-parse upstream/main') return 'abc123';
return '';
});
vi.mock('child_process', async (importOriginal) => {
const actual = await importOriginal();
return {
...actual,
execSync,
default: {
...actual,
execSync,
},
};
});
function createResponse() {
return {
statusCode: 200,
headers: {},
body: '',
setHeader(name, value) {
this.headers[name] = value;
},
end(payload) {
this.body = payload;
},
};
}
async function loadRefreshHandler() {
const { default: refreshSkillsPlugin } = await import('../../refresh-skills-plugin.js');
const registrations = [];
const server = {
middlewares: {
use(pathOrHandler, maybeHandler) {
if (typeof pathOrHandler === 'string') {
registrations.push({ path: pathOrHandler, handler: maybeHandler });
return;
}
registrations.push({ path: null, handler: pathOrHandler });
},
},
};
refreshSkillsPlugin().configureServer(server);
const registration = registrations.find((item) => item.path === '/api/refresh-skills');
if (!registration) {
throw new Error('refresh-skills handler not registered');
}
return registration.handler;
}
describe('refresh-skills plugin security', () => {
beforeEach(() => {
execSync.mockClear();
});
it('rejects GET requests for the sync endpoint', async () => {
const handler = await loadRefreshHandler();
const req = {
method: 'GET',
headers: {
host: 'localhost:5173',
origin: 'http://localhost:5173',
},
};
const res = createResponse();
await handler(req, res);
expect(res.statusCode).toBe(405);
});
it('rejects cross-origin POST requests for the sync endpoint', async () => {
const handler = await loadRefreshHandler();
const req = {
method: 'POST',
headers: {
host: 'localhost:5173',
origin: 'http://evil.test',
},
};
const res = createResponse();
await handler(req, res);
expect(res.statusCode).toBe(403);
});
});

View File

@@ -74,7 +74,7 @@ export function Home(): React.ReactElement {
setSyncing(true);
setSyncMsg(null);
try {
const res = await fetch('/api/refresh-skills');
const res = await fetch('/api/refresh-skills', { method: 'POST' });
const data = await res.json();
if (data.success) {
if (data.upToDate) {

View File

@@ -5,7 +5,6 @@ import { SkillStarButton } from '../components/SkillStarButton';
import { useSkills } from '../context/SkillContext';
import remarkGfm from 'remark-gfm';
import rehypeHighlight from 'rehype-highlight';
import rehypeRaw from 'rehype-raw';
// Lazy load heavy markdown component
const Markdown = lazy(() => import('react-markdown'));
@@ -261,7 +260,7 @@ export function SkillDetail(): React.ReactElement {
<Suspense fallback={<div className="h-24 animate-pulse bg-slate-100 dark:bg-slate-800 rounded-lg"></div>}>
<Markdown
remarkPlugins={[remarkGfm]}
rehypePlugins={[rehypeHighlight, rehypeRaw]}
rehypePlugins={[rehypeHighlight]}
>
{markdownBody}
</Markdown>

View File

@@ -0,0 +1,69 @@
import { describe, it, expect, vi, beforeEach, Mock } from 'vitest';
import { waitFor } from '@testing-library/react';
import { SkillDetail } from '../SkillDetail';
import { renderWithRouter } from '../../utils/testUtils';
import { createMockSkill } from '../../factories/skill';
import { useSkills } from '../../context/SkillContext';
let capturedRehypePlugins: unknown[] | undefined;
vi.mock('../../components/SkillStarButton', () => ({
SkillStarButton: () => <button data-testid="star-button">0 Upvotes</button>,
}));
vi.mock('../../context/SkillContext', async (importOriginal) => {
const actual = await importOriginal<any>();
return {
...actual,
useSkills: vi.fn(),
};
});
vi.mock('react-markdown', () => ({
default: ({ children, rehypePlugins }: { children: string; rehypePlugins?: unknown[] }) => {
capturedRehypePlugins = rehypePlugins;
return <div data-testid="markdown-content">{children}</div>;
},
}));
describe('SkillDetail security', () => {
beforeEach(() => {
vi.clearAllMocks();
capturedRehypePlugins = undefined;
});
it('does not enable raw HTML rendering for skill markdown', async () => {
const mockSkill = createMockSkill({
id: 'unsafe-skill',
name: 'unsafe-skill',
description: 'Skill with embedded html',
});
(useSkills as Mock).mockReturnValue({
skills: [mockSkill],
stars: {},
loading: false,
});
global.fetch = vi.fn().mockResolvedValue({
ok: true,
text: async () => '# Demo\n\n<img src=x onerror=alert(1) />',
});
renderWithRouter(<SkillDetail />, {
route: '/skill/unsafe-skill',
path: '/skill/:id',
useProvider: false,
});
await waitFor(() => {
expect(capturedRehypePlugins).toBeDefined();
});
const pluginNames = (capturedRehypePlugins ?? []).map((plugin) =>
typeof plugin === 'function' ? plugin.name : String(plugin),
);
expect(pluginNames).not.toContain('rehypeRaw');
});
});

View File

@@ -3,6 +3,7 @@ import json
import re
import sys
from collections.abc import Mapping
from datetime import date, datetime
import yaml
from _project_paths import find_repo_root
@@ -13,6 +14,15 @@ if sys.platform == 'win32':
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
def normalize_yaml_value(value):
if isinstance(value, Mapping):
return {key: normalize_yaml_value(val) for key, val in value.items()}
if isinstance(value, list):
return [normalize_yaml_value(item) for item in value]
if isinstance(value, (date, datetime)):
return value.isoformat()
return value
def parse_frontmatter(content):
"""
Parses YAML frontmatter, sanitizing unquoted values containing @.
@@ -43,6 +53,7 @@ def parse_frontmatter(content):
try:
parsed = yaml.safe_load(sanitized_yaml) or {}
parsed = normalize_yaml_value(parsed)
if not isinstance(parsed, Mapping):
print("⚠️ YAML frontmatter must be a mapping/object")
return {}

View File

@@ -38,6 +38,31 @@ class FrontmatterParsingSecurityTests(unittest.TestCase):
self.assertIsNone(metadata)
self.assertTrue(any("mapping" in error.lower() for error in errors))
def test_validate_skills_normalizes_unquoted_yaml_dates(self):
content = "---\nname: demo\ndescription: ok\ndate_added: 2026-03-15\n---\nbody\n"
metadata, errors = validate_skills.parse_frontmatter(content)
self.assertEqual(errors, [])
self.assertEqual(metadata["date_added"], "2026-03-15")
def test_generate_index_serializes_unquoted_yaml_dates(self):
with tempfile.TemporaryDirectory() as temp_dir:
root = Path(temp_dir)
skills_dir = root / "skills"
skill_dir = skills_dir / "demo"
output_file = root / "skills_index.json"
skill_dir.mkdir(parents=True)
(skill_dir / "SKILL.md").write_text(
"---\nname: demo\ndescription: ok\ndate_added: 2026-03-15\n---\nBody\n",
encoding="utf-8",
)
skills = generate_index.generate_index(str(skills_dir), str(output_file))
self.assertEqual(skills[0]["date_added"], "2026-03-15")
self.assertIn('"date_added": "2026-03-15"', output_file.read_text(encoding="utf-8"))
def test_generate_index_ignores_symlinked_skill_markdown(self):
with tempfile.TemporaryDirectory() as temp_dir:
root = Path(temp_dir)

View File

@@ -5,6 +5,7 @@ import sys
import io
import yaml
from collections.abc import Mapping
from datetime import date, datetime
from _project_paths import find_repo_root
@@ -38,6 +39,15 @@ WHEN_TO_USE_PATTERNS = [
def has_when_to_use_section(content):
return any(pattern.search(content) for pattern in WHEN_TO_USE_PATTERNS)
def normalize_yaml_value(value):
if isinstance(value, Mapping):
return {key: normalize_yaml_value(val) for key, val in value.items()}
if isinstance(value, list):
return [normalize_yaml_value(item) for item in value]
if isinstance(value, (date, datetime)):
return value.isoformat()
return value
def parse_frontmatter(content, rel_path=None):
"""
Parse frontmatter using PyYAML for robustness.
@@ -51,6 +61,7 @@ def parse_frontmatter(content, rel_path=None):
fm_errors = []
try:
metadata = yaml.safe_load(fm_text) or {}
metadata = normalize_yaml_value(metadata)
if not isinstance(metadata, Mapping):
return None, ["Frontmatter must be a YAML mapping/object."]