Files
claude-skills-reference/engineering/statistical-analyst/scripts/sample_size_calculator.py
Reza Rezvani 7c2564845a refactor(engineering): move statistical-analyst to engineering/, fix cross-refs
- Move from data-analysis/ to engineering/
- Fix 5 cross-references to use correct domain paths
- Fix Python 3.9 compat in sample_size_calculator.py

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-07 11:18:12 +02:00

263 lines
9.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
from __future__ import annotations
"""
sample_size_calculator.py — Required sample size per variant for A/B experiments.
Supports proportion tests (conversion rates) and mean tests (continuous metrics).
All math uses Python stdlib only.
Usage:
python3 sample_size_calculator.py --test proportion \
--baseline 0.05 --mde 0.20 --alpha 0.05 --power 0.80
python3 sample_size_calculator.py --test mean \
--baseline-mean 42.3 --baseline-std 18.1 --mde 0.10 \
--alpha 0.05 --power 0.80
python3 sample_size_calculator.py --test proportion \
--baseline 0.05 --mde 0.20 --table
python3 sample_size_calculator.py --test proportion \
--baseline 0.05 --mde 0.20 --format json
"""
import argparse
import json
import math
import sys
def normal_cdf(z: float) -> float:
return 0.5 * math.erfc(-z / math.sqrt(2))
def normal_ppf(p: float) -> float:
"""Inverse normal CDF via bisection."""
lo, hi = -10.0, 10.0
for _ in range(100):
mid = (lo + hi) / 2
if normal_cdf(mid) < p:
lo = mid
else:
hi = mid
return (lo + hi) / 2
def sample_size_proportion(baseline: float, mde: float, alpha: float, power: float) -> int:
"""
Required n per variant for a two-proportion Z-test.
Uses the standard formula:
n = (z_α/2 + z_β)² × (p1(1p1) + p2(1p2)) / (p1 p2)²
Args:
baseline: Control conversion rate (e.g. 0.05 for 5%)
mde: Minimum detectable effect as relative change (e.g. 0.20 for +20% relative)
alpha: Significance level (e.g. 0.05)
power: Statistical power (e.g. 0.80)
"""
p1 = baseline
p2 = baseline * (1 + mde)
if not (0 < p1 < 1) or not (0 < p2 < 1):
raise ValueError(f"Rates must be between 0 and 1. Got baseline={p1}, treatment={p2:.4f}")
z_alpha = normal_ppf(1 - alpha / 2)
z_beta = normal_ppf(power)
numerator = (z_alpha + z_beta) ** 2 * (p1 * (1 - p1) + p2 * (1 - p2))
denominator = (p2 - p1) ** 2
return math.ceil(numerator / denominator)
def sample_size_mean(baseline_mean: float, baseline_std: float, mde: float, alpha: float, power: float) -> int:
"""
Required n per variant for a two-sample t-test.
Uses:
n = 2 × σ² × (z_α/2 + z_β)² / δ²
where δ = mde × baseline_mean (absolute effect).
Args:
baseline_mean: Control group mean
baseline_std: Control group standard deviation
mde: Minimum detectable effect as relative change (e.g. 0.10 for +10%)
alpha: Significance level
power: Statistical power
"""
delta = abs(mde * baseline_mean)
if delta == 0:
raise ValueError("MDE × baseline_mean = 0. Cannot size experiment with zero effect.")
z_alpha = normal_ppf(1 - alpha / 2)
z_beta = normal_ppf(power)
n = 2 * baseline_std ** 2 * (z_alpha + z_beta) ** 2 / delta ** 2
return math.ceil(n)
def duration_estimate(n_per_variant: int, daily_traffic: int | None, variants: int = 2) -> str:
if daily_traffic and daily_traffic > 0:
traffic_per_variant = daily_traffic / variants
days = math.ceil(n_per_variant / traffic_per_variant)
weeks = days / 7
return f"{days} days ({weeks:.1f} weeks) at {daily_traffic:,} daily users split {variants} ways"
return "Provide --daily-traffic to estimate duration"
def print_report(
test: str, n: int, baseline: float, mde: float, alpha: float, power: float,
daily_traffic: int | None, variants: int,
baseline_mean: float | None = None, baseline_std: float | None = None
):
total = n * variants
treatment_rate = baseline * (1 + mde) if test == "proportion" else None
absolute_mde = baseline * mde if test == "proportion" else (baseline_mean or 0) * mde
print("=" * 60)
print(" SAMPLE SIZE REPORT")
print("=" * 60)
if test == "proportion":
print(f" Baseline conversion rate: {baseline:.2%}")
print(f" Target conversion rate: {treatment_rate:.2%}")
print(f" MDE: {mde:+.1%} relative ({absolute_mde:+.4f} absolute)")
else:
print(f" Baseline mean: {baseline_mean} (std: {baseline_std})")
print(f" MDE: {mde:+.1%} relative (absolute: {absolute_mde:+.4f})")
print(f" Significance level (α): {alpha}")
print(f" Statistical power (1β): {power:.0%}")
print(f" Variants: {variants}")
print()
print(f" Required per variant: {n:>10,}")
print(f" Required total: {total:>10,}")
print()
print(f" Duration: {duration_estimate(n, daily_traffic, variants)}")
print()
# Risk interpretation
if n < 100:
print(" ⚠️ Very small sample — results may be sensitive to outliers.")
elif n > 1_000_000:
print(" ⚠️ Very large sample required — consider increasing MDE or accepting lower power.")
else:
print(" ✅ Sample size is achievable for most web/app products.")
print("=" * 60)
def print_table(test: str, baseline: float, mde: float, alpha: float,
baseline_mean: float | None, baseline_std: float | None):
"""Print tradeoff table across power levels and MDE values."""
powers = [0.70, 0.75, 0.80, 0.85, 0.90, 0.95]
mdes = [mde * 0.5, mde * 0.75, mde, mde * 1.5, mde * 2.0]
print("=" * 70)
print(f" SAMPLE SIZE TRADEOFF TABLE (α={alpha}, baseline={'proportion' if test == 'proportion' else 'mean'})")
print("=" * 70)
header = f" {'MDE':>8} | " + " | ".join(f"power={p:.0%}" for p in powers)
print(header)
print(" " + "-" * (len(header) - 2))
for m in mdes:
row = f" {m:>+7.1%} | "
cells = []
for p in powers:
try:
if test == "proportion":
n = sample_size_proportion(baseline, m, alpha, p)
else:
n = sample_size_mean(baseline_mean, baseline_std, m, alpha, p)
cells.append(f"{n:>9,}")
except ValueError:
cells.append(f"{'N/A':>9}")
row += " | ".join(cells)
print(row)
print("=" * 70)
print(" (Values = required n per variant)")
print()
def main():
parser = argparse.ArgumentParser(description="Calculate required sample size for A/B experiments.")
parser.add_argument("--test", choices=["proportion", "mean"], required=True,
help="Type of metric: proportion (conversion rate) or mean (continuous)")
parser.add_argument("--alpha", type=float, default=0.05, help="Significance level (default: 0.05)")
parser.add_argument("--power", type=float, default=0.80, help="Statistical power (default: 0.80)")
parser.add_argument("--mde", type=float, required=True,
help="Minimum detectable effect as relative change (e.g. 0.20 = +20%%)")
parser.add_argument("--variants", type=int, default=2, help="Number of variants including control (default: 2)")
parser.add_argument("--daily-traffic", type=int, help="Daily unique users (for duration estimate)")
parser.add_argument("--table", action="store_true", help="Print tradeoff table across power and MDE")
parser.add_argument("--format", choices=["text", "json"], default="text")
# Proportion-specific
parser.add_argument("--baseline", type=float, help="Baseline conversion rate (e.g. 0.05 for 5%%)")
# Mean-specific
parser.add_argument("--baseline-mean", type=float, help="Control group mean")
parser.add_argument("--baseline-std", type=float, help="Control group standard deviation")
args = parser.parse_args()
try:
if args.test == "proportion":
if args.baseline is None:
print("Error: --baseline is required for proportion test", file=sys.stderr)
sys.exit(1)
n = sample_size_proportion(args.baseline, args.mde, args.alpha, args.power)
else:
if args.baseline_mean is None or args.baseline_std is None:
print("Error: --baseline-mean and --baseline-std are required for mean test", file=sys.stderr)
sys.exit(1)
n = sample_size_mean(args.baseline_mean, args.baseline_std, args.mde, args.alpha, args.power)
except ValueError as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
if args.format == "json":
output = {
"test": args.test,
"n_per_variant": n,
"n_total": n * args.variants,
"alpha": args.alpha,
"power": args.power,
"mde": args.mde,
"variants": args.variants,
}
if args.test == "proportion":
output["baseline_rate"] = args.baseline
output["treatment_rate"] = round(args.baseline * (1 + args.mde), 6)
else:
output["baseline_mean"] = args.baseline_mean
output["baseline_std"] = args.baseline_std
if args.daily_traffic:
days = math.ceil(n / (args.daily_traffic / args.variants))
output["estimated_days"] = days
print(json.dumps(output, indent=2))
return
if args.table:
print_table(args.test, args.baseline if args.test == "proportion" else None,
args.mde, args.alpha, args.baseline_mean, args.baseline_std)
print_report(
args.test, n,
baseline=args.baseline or 0,
mde=args.mde,
alpha=args.alpha,
power=args.power,
daily_traffic=args.daily_traffic,
variants=args.variants,
baseline_mean=args.baseline_mean,
baseline_std=args.baseline_std,
)
if __name__ == "__main__":
main()