- Move from data-analysis/ to engineering/ - Fix 5 cross-references to use correct domain paths - Fix Python 3.9 compat in sample_size_calculator.py Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
263 lines
9.5 KiB
Python
263 lines
9.5 KiB
Python
#!/usr/bin/env python3
|
||
from __future__ import annotations
|
||
"""
|
||
sample_size_calculator.py — Required sample size per variant for A/B experiments.
|
||
|
||
Supports proportion tests (conversion rates) and mean tests (continuous metrics).
|
||
All math uses Python stdlib only.
|
||
|
||
Usage:
|
||
python3 sample_size_calculator.py --test proportion \
|
||
--baseline 0.05 --mde 0.20 --alpha 0.05 --power 0.80
|
||
|
||
python3 sample_size_calculator.py --test mean \
|
||
--baseline-mean 42.3 --baseline-std 18.1 --mde 0.10 \
|
||
--alpha 0.05 --power 0.80
|
||
|
||
python3 sample_size_calculator.py --test proportion \
|
||
--baseline 0.05 --mde 0.20 --table
|
||
|
||
python3 sample_size_calculator.py --test proportion \
|
||
--baseline 0.05 --mde 0.20 --format json
|
||
"""
|
||
|
||
import argparse
|
||
import json
|
||
import math
|
||
import sys
|
||
|
||
|
||
def normal_cdf(z: float) -> float:
|
||
return 0.5 * math.erfc(-z / math.sqrt(2))
|
||
|
||
|
||
def normal_ppf(p: float) -> float:
|
||
"""Inverse normal CDF via bisection."""
|
||
lo, hi = -10.0, 10.0
|
||
for _ in range(100):
|
||
mid = (lo + hi) / 2
|
||
if normal_cdf(mid) < p:
|
||
lo = mid
|
||
else:
|
||
hi = mid
|
||
return (lo + hi) / 2
|
||
|
||
|
||
def sample_size_proportion(baseline: float, mde: float, alpha: float, power: float) -> int:
|
||
"""
|
||
Required n per variant for a two-proportion Z-test.
|
||
|
||
Uses the standard formula:
|
||
n = (z_α/2 + z_β)² × (p1(1−p1) + p2(1−p2)) / (p1 − p2)²
|
||
|
||
Args:
|
||
baseline: Control conversion rate (e.g. 0.05 for 5%)
|
||
mde: Minimum detectable effect as relative change (e.g. 0.20 for +20% relative)
|
||
alpha: Significance level (e.g. 0.05)
|
||
power: Statistical power (e.g. 0.80)
|
||
"""
|
||
p1 = baseline
|
||
p2 = baseline * (1 + mde)
|
||
|
||
if not (0 < p1 < 1) or not (0 < p2 < 1):
|
||
raise ValueError(f"Rates must be between 0 and 1. Got baseline={p1}, treatment={p2:.4f}")
|
||
|
||
z_alpha = normal_ppf(1 - alpha / 2)
|
||
z_beta = normal_ppf(power)
|
||
|
||
numerator = (z_alpha + z_beta) ** 2 * (p1 * (1 - p1) + p2 * (1 - p2))
|
||
denominator = (p2 - p1) ** 2
|
||
|
||
return math.ceil(numerator / denominator)
|
||
|
||
|
||
def sample_size_mean(baseline_mean: float, baseline_std: float, mde: float, alpha: float, power: float) -> int:
|
||
"""
|
||
Required n per variant for a two-sample t-test.
|
||
|
||
Uses:
|
||
n = 2 × σ² × (z_α/2 + z_β)² / δ²
|
||
|
||
where δ = mde × baseline_mean (absolute effect).
|
||
|
||
Args:
|
||
baseline_mean: Control group mean
|
||
baseline_std: Control group standard deviation
|
||
mde: Minimum detectable effect as relative change (e.g. 0.10 for +10%)
|
||
alpha: Significance level
|
||
power: Statistical power
|
||
"""
|
||
delta = abs(mde * baseline_mean)
|
||
if delta == 0:
|
||
raise ValueError("MDE × baseline_mean = 0. Cannot size experiment with zero effect.")
|
||
|
||
z_alpha = normal_ppf(1 - alpha / 2)
|
||
z_beta = normal_ppf(power)
|
||
|
||
n = 2 * baseline_std ** 2 * (z_alpha + z_beta) ** 2 / delta ** 2
|
||
return math.ceil(n)
|
||
|
||
|
||
def duration_estimate(n_per_variant: int, daily_traffic: int | None, variants: int = 2) -> str:
|
||
if daily_traffic and daily_traffic > 0:
|
||
traffic_per_variant = daily_traffic / variants
|
||
days = math.ceil(n_per_variant / traffic_per_variant)
|
||
weeks = days / 7
|
||
return f"{days} days ({weeks:.1f} weeks) at {daily_traffic:,} daily users split {variants} ways"
|
||
return "Provide --daily-traffic to estimate duration"
|
||
|
||
|
||
def print_report(
|
||
test: str, n: int, baseline: float, mde: float, alpha: float, power: float,
|
||
daily_traffic: int | None, variants: int,
|
||
baseline_mean: float | None = None, baseline_std: float | None = None
|
||
):
|
||
total = n * variants
|
||
treatment_rate = baseline * (1 + mde) if test == "proportion" else None
|
||
absolute_mde = baseline * mde if test == "proportion" else (baseline_mean or 0) * mde
|
||
|
||
print("=" * 60)
|
||
print(" SAMPLE SIZE REPORT")
|
||
print("=" * 60)
|
||
|
||
if test == "proportion":
|
||
print(f" Baseline conversion rate: {baseline:.2%}")
|
||
print(f" Target conversion rate: {treatment_rate:.2%}")
|
||
print(f" MDE: {mde:+.1%} relative ({absolute_mde:+.4f} absolute)")
|
||
else:
|
||
print(f" Baseline mean: {baseline_mean} (std: {baseline_std})")
|
||
print(f" MDE: {mde:+.1%} relative (absolute: {absolute_mde:+.4f})")
|
||
|
||
print(f" Significance level (α): {alpha}")
|
||
print(f" Statistical power (1−β): {power:.0%}")
|
||
print(f" Variants: {variants}")
|
||
print()
|
||
print(f" Required per variant: {n:>10,}")
|
||
print(f" Required total: {total:>10,}")
|
||
print()
|
||
print(f" Duration: {duration_estimate(n, daily_traffic, variants)}")
|
||
print()
|
||
|
||
# Risk interpretation
|
||
if n < 100:
|
||
print(" ⚠️ Very small sample — results may be sensitive to outliers.")
|
||
elif n > 1_000_000:
|
||
print(" ⚠️ Very large sample required — consider increasing MDE or accepting lower power.")
|
||
else:
|
||
print(" ✅ Sample size is achievable for most web/app products.")
|
||
|
||
print("=" * 60)
|
||
|
||
|
||
def print_table(test: str, baseline: float, mde: float, alpha: float,
|
||
baseline_mean: float | None, baseline_std: float | None):
|
||
"""Print tradeoff table across power levels and MDE values."""
|
||
powers = [0.70, 0.75, 0.80, 0.85, 0.90, 0.95]
|
||
mdes = [mde * 0.5, mde * 0.75, mde, mde * 1.5, mde * 2.0]
|
||
|
||
print("=" * 70)
|
||
print(f" SAMPLE SIZE TRADEOFF TABLE (α={alpha}, baseline={'proportion' if test == 'proportion' else 'mean'})")
|
||
print("=" * 70)
|
||
header = f" {'MDE':>8} | " + " | ".join(f"power={p:.0%}" for p in powers)
|
||
print(header)
|
||
print(" " + "-" * (len(header) - 2))
|
||
|
||
for m in mdes:
|
||
row = f" {m:>+7.1%} | "
|
||
cells = []
|
||
for p in powers:
|
||
try:
|
||
if test == "proportion":
|
||
n = sample_size_proportion(baseline, m, alpha, p)
|
||
else:
|
||
n = sample_size_mean(baseline_mean, baseline_std, m, alpha, p)
|
||
cells.append(f"{n:>9,}")
|
||
except ValueError:
|
||
cells.append(f"{'N/A':>9}")
|
||
row += " | ".join(cells)
|
||
print(row)
|
||
|
||
print("=" * 70)
|
||
print(" (Values = required n per variant)")
|
||
print()
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="Calculate required sample size for A/B experiments.")
|
||
parser.add_argument("--test", choices=["proportion", "mean"], required=True,
|
||
help="Type of metric: proportion (conversion rate) or mean (continuous)")
|
||
parser.add_argument("--alpha", type=float, default=0.05, help="Significance level (default: 0.05)")
|
||
parser.add_argument("--power", type=float, default=0.80, help="Statistical power (default: 0.80)")
|
||
parser.add_argument("--mde", type=float, required=True,
|
||
help="Minimum detectable effect as relative change (e.g. 0.20 = +20%%)")
|
||
parser.add_argument("--variants", type=int, default=2, help="Number of variants including control (default: 2)")
|
||
parser.add_argument("--daily-traffic", type=int, help="Daily unique users (for duration estimate)")
|
||
parser.add_argument("--table", action="store_true", help="Print tradeoff table across power and MDE")
|
||
parser.add_argument("--format", choices=["text", "json"], default="text")
|
||
|
||
# Proportion-specific
|
||
parser.add_argument("--baseline", type=float, help="Baseline conversion rate (e.g. 0.05 for 5%%)")
|
||
|
||
# Mean-specific
|
||
parser.add_argument("--baseline-mean", type=float, help="Control group mean")
|
||
parser.add_argument("--baseline-std", type=float, help="Control group standard deviation")
|
||
|
||
args = parser.parse_args()
|
||
|
||
try:
|
||
if args.test == "proportion":
|
||
if args.baseline is None:
|
||
print("Error: --baseline is required for proportion test", file=sys.stderr)
|
||
sys.exit(1)
|
||
n = sample_size_proportion(args.baseline, args.mde, args.alpha, args.power)
|
||
else:
|
||
if args.baseline_mean is None or args.baseline_std is None:
|
||
print("Error: --baseline-mean and --baseline-std are required for mean test", file=sys.stderr)
|
||
sys.exit(1)
|
||
n = sample_size_mean(args.baseline_mean, args.baseline_std, args.mde, args.alpha, args.power)
|
||
except ValueError as e:
|
||
print(f"Error: {e}", file=sys.stderr)
|
||
sys.exit(1)
|
||
|
||
if args.format == "json":
|
||
output = {
|
||
"test": args.test,
|
||
"n_per_variant": n,
|
||
"n_total": n * args.variants,
|
||
"alpha": args.alpha,
|
||
"power": args.power,
|
||
"mde": args.mde,
|
||
"variants": args.variants,
|
||
}
|
||
if args.test == "proportion":
|
||
output["baseline_rate"] = args.baseline
|
||
output["treatment_rate"] = round(args.baseline * (1 + args.mde), 6)
|
||
else:
|
||
output["baseline_mean"] = args.baseline_mean
|
||
output["baseline_std"] = args.baseline_std
|
||
if args.daily_traffic:
|
||
days = math.ceil(n / (args.daily_traffic / args.variants))
|
||
output["estimated_days"] = days
|
||
print(json.dumps(output, indent=2))
|
||
return
|
||
|
||
if args.table:
|
||
print_table(args.test, args.baseline if args.test == "proportion" else None,
|
||
args.mde, args.alpha, args.baseline_mean, args.baseline_std)
|
||
|
||
print_report(
|
||
args.test, n,
|
||
baseline=args.baseline or 0,
|
||
mde=args.mde,
|
||
alpha=args.alpha,
|
||
power=args.power,
|
||
daily_traffic=args.daily_traffic,
|
||
variants=args.variants,
|
||
baseline_mean=args.baseline_mean,
|
||
baseline_std=args.baseline_std,
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|