From 1e904713bde78cbf869e6400fc6bd74ccc261749 Mon Sep 17 00:00:00 2001 From: David Madl Date: Tue, 5 May 2026 13:55:32 +0200 Subject: [PATCH] feat: evaluate stratified by syndrome --- paper2_festung_teil3.py | 50 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 47 insertions(+), 3 deletions(-) diff --git a/paper2_festung_teil3.py b/paper2_festung_teil3.py index d32b8b1..b5e4bf9 100644 --- a/paper2_festung_teil3.py +++ b/paper2_festung_teil3.py @@ -97,6 +97,25 @@ def auc_fast(preds): ties+=(k-j) return (conc+0.5*ties)/(len(pos)*len(neg)) +def auc_fast_gal(gal_preds): + """Pooled within-stratum concordance (Σ_g conc_g) / (Σ_g n_pos_g·n_neg_g). + Equivalent to a pair-weighted average of per-galaxy AUCs.""" + if not gal_preds: return 0.5 + conc = 0; ties = 0; pairs = 0 + for _, preds in gal_preds.items(): + pos = sorted(p["p"] for p in preds if p["a"] == 1) + neg = sorted(p["p"] for p in preds if p["a"] == 0) + if not pos or not neg: continue + pairs += len(pos) * len(neg) + j = 0 + for pv in pos: + while j < len(neg) and neg[j] < pv: j += 1 + conc += j; k = j + while k < len(neg) and neg[k] == pv: k += 1 + ties += (k - j) + if pairs == 0: return 0.5 + return (conc + 0.5 * ties) / pairs + def compute_centroid(pts): s=defaultdict(float);c=defaultdict(int) for p in pts: @@ -278,7 +297,9 @@ def main(): print(f" 1: Calibration 2: Fair Comparison 3: Summary") print(f"{'='*76}\n") - all_pts=load_all_icu();assign_galaxies(all_pts) + all_pts=load_all_icu(); + all_pts = [p for p in all_pts if p.get("saps_prob") is not None] + assign_galaxies(all_pts) by_gal=defaultdict(list) for p in all_pts: if p["galaxy"]: by_gal[p["galaxy"]].append(p) @@ -293,10 +314,33 @@ def main(): results = {} # ── Full pop LOO ─────────────────────────────────────────────── - print(f"\n Running full-population LOO...") + print(f"\n Running s-population LOO...") p_full = run_loo(all_pts, all_pts, therapy_hids, by_gal, "FULL-POP") a_td = auc_fast(p_full) - print(f" * TD full pop: AUC {a_td:.4f} n={len(p_full):,}") + print(f" * TD s-pop: AUC {a_td:.4f} n={len(p_full):,}") + + # SAPS score for comparison + saps_preds_all = [{"a": p["died"], "p": p["saps_prob"]} + for p in all_pts if p.get("saps_prob") is not None] + a_td_s = auc_fast(saps_preds_all) + print(f" * SAPS-II s-pop: AUC {a_td_s:.4f} n={len(saps_preds_all):,}") + + saps_preds_by_gal = { + gal: [{"a": p["died"], "p": p["saps_prob"]} + for p in gal_pts if p.get("saps_prob") is not None] + for gal, gal_pts in by_gal.items() + } + a_td_sg = auc_fast_gal(saps_preds_by_gal) + ns = str([len(gal_pts) for (gal, gal_pts) in saps_preds_by_gal.items()]) + print(f" * SAPS-II by gal: AUC {a_td_sg:.4f} n={ns}") + + td_preds_by_gal = defaultdict(list) + for pred in p_full: + if pred.get("g"): + td_preds_by_gal[pred["g"]].append({"a": pred["a"], "p": pred["p"]}) + a_td_g = auc_fast_gal(dict(td_preds_by_gal)) + ns = str([len(gal_pts) for (gal, gal_pts) in td_preds_by_gal.items()]) + print(f" * TD by gal: AUC {a_td_g:.4f} n={ns}") # ══════════════════════════════════════════════════════════════ # 1. CALIBRATION (10 Dezile)