feat: evaluate stratified by syndrome

This commit is contained in:
2026-05-05 13:55:32 +02:00
parent 2d03ff0a42
commit 1e904713bd

View File

@@ -97,6 +97,25 @@ def auc_fast(preds):
ties+=(k-j)
return (conc+0.5*ties)/(len(pos)*len(neg))
def auc_fast_gal(gal_preds):
"""Pooled within-stratum concordance (Σ_g conc_g) / (Σ_g n_pos_g·n_neg_g).
Equivalent to a pair-weighted average of per-galaxy AUCs."""
if not gal_preds: return 0.5
conc = 0; ties = 0; pairs = 0
for _, preds in gal_preds.items():
pos = sorted(p["p"] for p in preds if p["a"] == 1)
neg = sorted(p["p"] for p in preds if p["a"] == 0)
if not pos or not neg: continue
pairs += len(pos) * len(neg)
j = 0
for pv in pos:
while j < len(neg) and neg[j] < pv: j += 1
conc += j; k = j
while k < len(neg) and neg[k] == pv: k += 1
ties += (k - j)
if pairs == 0: return 0.5
return (conc + 0.5 * ties) / pairs
def compute_centroid(pts):
s=defaultdict(float);c=defaultdict(int)
for p in pts:
@@ -278,7 +297,9 @@ def main():
print(f" 1: Calibration 2: Fair Comparison 3: Summary")
print(f"{'='*76}\n")
all_pts=load_all_icu();assign_galaxies(all_pts)
all_pts=load_all_icu();
all_pts = [p for p in all_pts if p.get("saps_prob") is not None]
assign_galaxies(all_pts)
by_gal=defaultdict(list)
for p in all_pts:
if p["galaxy"]: by_gal[p["galaxy"]].append(p)
@@ -293,10 +314,33 @@ def main():
results = {}
# ── Full pop LOO ───────────────────────────────────────────────
print(f"\n Running full-population LOO...")
print(f"\n Running s-population LOO...")
p_full = run_loo(all_pts, all_pts, therapy_hids, by_gal, "FULL-POP")
a_td = auc_fast(p_full)
print(f" * TD full pop: AUC {a_td:.4f} n={len(p_full):,}")
print(f" * TD s-pop: AUC {a_td:.4f} n={len(p_full):,}")
# SAPS score for comparison
saps_preds_all = [{"a": p["died"], "p": p["saps_prob"]}
for p in all_pts if p.get("saps_prob") is not None]
a_td_s = auc_fast(saps_preds_all)
print(f" * SAPS-II s-pop: AUC {a_td_s:.4f} n={len(saps_preds_all):,}")
saps_preds_by_gal = {
gal: [{"a": p["died"], "p": p["saps_prob"]}
for p in gal_pts if p.get("saps_prob") is not None]
for gal, gal_pts in by_gal.items()
}
a_td_sg = auc_fast_gal(saps_preds_by_gal)
ns = str([len(gal_pts) for (gal, gal_pts) in saps_preds_by_gal.items()])
print(f" * SAPS-II by gal: AUC {a_td_sg:.4f} n={ns}")
td_preds_by_gal = defaultdict(list)
for pred in p_full:
if pred.get("g"):
td_preds_by_gal[pred["g"]].append({"a": pred["a"], "p": pred["p"]})
a_td_g = auc_fast_gal(dict(td_preds_by_gal))
ns = str([len(gal_pts) for (gal, gal_pts) in td_preds_by_gal.items()])
print(f" * TD by gal: AUC {a_td_g:.4f} n={ns}")
# ══════════════════════════════════════════════════════════════
# 1. CALIBRATION (10 Dezile)