* fix(skill): enhance git-worktree-manager with scripts, references, and Anthropic best practices * fix(skill): enhance mcp-server-builder with scripts, references, and Anthropic best practices * fix(skill): enhance changelog-generator with scripts, references, and Anthropic best practices * fix(skill): enhance ci-cd-pipeline-builder with scripts, references, and Anthropic best practices * fix(skill): enhance prompt-engineer-toolkit with scripts, references, and Anthropic best practices * docs: update README, CHANGELOG, and plugin metadata * fix: correct marketing plugin count, expand thin references --------- Co-authored-by: Leo <leo@openclaw.ai>
240 lines
7.6 KiB
Python
Executable File
240 lines
7.6 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""A/B test prompts against structured test cases.
|
|
|
|
Supports:
|
|
- --input JSON payload or stdin JSON payload
|
|
- --prompt-a/--prompt-b or file variants
|
|
- --cases-file for test suite JSON
|
|
- optional --runner-cmd with {prompt} and {input} placeholders
|
|
|
|
If runner command is omitted, script performs static prompt quality scoring only.
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import re
|
|
import shlex
|
|
import subprocess
|
|
import sys
|
|
from dataclasses import dataclass, asdict
|
|
from pathlib import Path
|
|
from statistics import mean
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
|
|
class CLIError(Exception):
|
|
"""Raised for expected CLI errors."""
|
|
|
|
|
|
@dataclass
|
|
class CaseScore:
|
|
case_id: str
|
|
prompt_variant: str
|
|
score: float
|
|
matched_expected: int
|
|
missed_expected: int
|
|
forbidden_hits: int
|
|
regex_matches: int
|
|
output_length: int
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="A/B test prompts against test cases.")
|
|
parser.add_argument("--input", help="JSON input file for full payload.")
|
|
parser.add_argument("--prompt-a", help="Prompt A text.")
|
|
parser.add_argument("--prompt-b", help="Prompt B text.")
|
|
parser.add_argument("--prompt-a-file", help="Path to prompt A file.")
|
|
parser.add_argument("--prompt-b-file", help="Path to prompt B file.")
|
|
parser.add_argument("--cases-file", help="Path to JSON test cases array.")
|
|
parser.add_argument(
|
|
"--runner-cmd",
|
|
help="External command template, e.g. 'llm --prompt {prompt} --input {input}'.",
|
|
)
|
|
parser.add_argument("--format", choices=["text", "json"], default="text", help="Output format.")
|
|
return parser.parse_args()
|
|
|
|
|
|
def read_text_file(path: Optional[str]) -> Optional[str]:
|
|
if not path:
|
|
return None
|
|
try:
|
|
return Path(path).read_text(encoding="utf-8")
|
|
except Exception as exc:
|
|
raise CLIError(f"Failed reading file {path}: {exc}") from exc
|
|
|
|
|
|
def load_payload(args: argparse.Namespace) -> Dict[str, Any]:
|
|
if args.input:
|
|
try:
|
|
return json.loads(Path(args.input).read_text(encoding="utf-8"))
|
|
except Exception as exc:
|
|
raise CLIError(f"Failed reading --input payload: {exc}") from exc
|
|
|
|
if not sys.stdin.isatty():
|
|
raw = sys.stdin.read().strip()
|
|
if raw:
|
|
try:
|
|
return json.loads(raw)
|
|
except json.JSONDecodeError as exc:
|
|
raise CLIError(f"Invalid JSON from stdin: {exc}") from exc
|
|
|
|
payload: Dict[str, Any] = {}
|
|
|
|
prompt_a = args.prompt_a or read_text_file(args.prompt_a_file)
|
|
prompt_b = args.prompt_b or read_text_file(args.prompt_b_file)
|
|
if prompt_a:
|
|
payload["prompt_a"] = prompt_a
|
|
if prompt_b:
|
|
payload["prompt_b"] = prompt_b
|
|
|
|
if args.cases_file:
|
|
try:
|
|
payload["cases"] = json.loads(Path(args.cases_file).read_text(encoding="utf-8"))
|
|
except Exception as exc:
|
|
raise CLIError(f"Failed reading --cases-file: {exc}") from exc
|
|
|
|
if args.runner_cmd:
|
|
payload["runner_cmd"] = args.runner_cmd
|
|
|
|
return payload
|
|
|
|
|
|
def run_runner(runner_cmd: str, prompt: str, case_input: str) -> str:
|
|
cmd = runner_cmd.format(prompt=prompt, input=case_input)
|
|
parts = shlex.split(cmd)
|
|
try:
|
|
proc = subprocess.run(parts, text=True, capture_output=True, check=True)
|
|
except subprocess.CalledProcessError as exc:
|
|
raise CLIError(f"Runner command failed: {exc.stderr.strip()}") from exc
|
|
return proc.stdout.strip()
|
|
|
|
|
|
def static_output(prompt: str, case_input: str) -> str:
|
|
rendered = prompt.replace("{{input}}", case_input)
|
|
return rendered
|
|
|
|
|
|
def score_output(case: Dict[str, Any], output: str, prompt_variant: str) -> CaseScore:
|
|
case_id = str(case.get("id", "case"))
|
|
expected = [str(x) for x in case.get("expected_contains", []) if str(x)]
|
|
forbidden = [str(x) for x in case.get("forbidden_contains", []) if str(x)]
|
|
regexes = [str(x) for x in case.get("expected_regex", []) if str(x)]
|
|
|
|
matched_expected = sum(1 for item in expected if item.lower() in output.lower())
|
|
missed_expected = len(expected) - matched_expected
|
|
forbidden_hits = sum(1 for item in forbidden if item.lower() in output.lower())
|
|
regex_matches = 0
|
|
for pattern in regexes:
|
|
try:
|
|
if re.search(pattern, output, flags=re.MULTILINE):
|
|
regex_matches += 1
|
|
except re.error:
|
|
pass
|
|
|
|
score = 100.0
|
|
score -= missed_expected * 15
|
|
score -= forbidden_hits * 25
|
|
score += regex_matches * 8
|
|
|
|
# Heuristic penalty for unbounded verbosity
|
|
if len(output) > 4000:
|
|
score -= 10
|
|
if len(output.strip()) < 10:
|
|
score -= 10
|
|
|
|
score = max(0.0, min(100.0, score))
|
|
|
|
return CaseScore(
|
|
case_id=case_id,
|
|
prompt_variant=prompt_variant,
|
|
score=score,
|
|
matched_expected=matched_expected,
|
|
missed_expected=missed_expected,
|
|
forbidden_hits=forbidden_hits,
|
|
regex_matches=regex_matches,
|
|
output_length=len(output),
|
|
)
|
|
|
|
|
|
def aggregate(scores: List[CaseScore]) -> Dict[str, Any]:
|
|
if not scores:
|
|
return {"average": 0.0, "min": 0.0, "max": 0.0, "cases": 0}
|
|
vals = [s.score for s in scores]
|
|
return {
|
|
"average": round(mean(vals), 2),
|
|
"min": round(min(vals), 2),
|
|
"max": round(max(vals), 2),
|
|
"cases": len(vals),
|
|
}
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
payload = load_payload(args)
|
|
|
|
prompt_a = str(payload.get("prompt_a", "")).strip()
|
|
prompt_b = str(payload.get("prompt_b", "")).strip()
|
|
cases = payload.get("cases", [])
|
|
runner_cmd = payload.get("runner_cmd")
|
|
|
|
if not prompt_a or not prompt_b:
|
|
raise CLIError("Both prompt_a and prompt_b are required (flags or JSON payload).")
|
|
if not isinstance(cases, list) or not cases:
|
|
raise CLIError("cases must be a non-empty array.")
|
|
|
|
scores_a: List[CaseScore] = []
|
|
scores_b: List[CaseScore] = []
|
|
|
|
for case in cases:
|
|
if not isinstance(case, dict):
|
|
continue
|
|
case_input = str(case.get("input", "")).strip()
|
|
|
|
output_a = run_runner(runner_cmd, prompt_a, case_input) if runner_cmd else static_output(prompt_a, case_input)
|
|
output_b = run_runner(runner_cmd, prompt_b, case_input) if runner_cmd else static_output(prompt_b, case_input)
|
|
|
|
scores_a.append(score_output(case, output_a, "A"))
|
|
scores_b.append(score_output(case, output_b, "B"))
|
|
|
|
agg_a = aggregate(scores_a)
|
|
agg_b = aggregate(scores_b)
|
|
winner = "A" if agg_a["average"] >= agg_b["average"] else "B"
|
|
|
|
result = {
|
|
"summary": {
|
|
"winner": winner,
|
|
"prompt_a": agg_a,
|
|
"prompt_b": agg_b,
|
|
"mode": "runner" if runner_cmd else "static",
|
|
},
|
|
"case_scores": {
|
|
"prompt_a": [asdict(item) for item in scores_a],
|
|
"prompt_b": [asdict(item) for item in scores_b],
|
|
},
|
|
}
|
|
|
|
if args.format == "json":
|
|
print(json.dumps(result, indent=2))
|
|
else:
|
|
print("Prompt A/B test result")
|
|
print(f"- mode: {result['summary']['mode']}")
|
|
print(f"- winner: {winner}")
|
|
print(f"- prompt A avg: {agg_a['average']}")
|
|
print(f"- prompt B avg: {agg_b['average']}")
|
|
print("Case details:")
|
|
for item in scores_a + scores_b:
|
|
print(
|
|
f"- case={item.case_id} variant={item.prompt_variant} score={item.score} "
|
|
f"expected+={item.matched_expected} forbidden={item.forbidden_hits} regex={item.regex_matches}"
|
|
)
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
raise SystemExit(main())
|
|
except CLIError as exc:
|
|
print(f"ERROR: {exc}", file=sys.stderr)
|
|
raise SystemExit(2)
|