feat: evaluate stratified by syndrome
This commit is contained in:
@@ -97,6 +97,25 @@ def auc_fast(preds):
|
|||||||
ties+=(k-j)
|
ties+=(k-j)
|
||||||
return (conc+0.5*ties)/(len(pos)*len(neg))
|
return (conc+0.5*ties)/(len(pos)*len(neg))
|
||||||
|
|
||||||
|
def auc_fast_gal(gal_preds):
|
||||||
|
"""Pooled within-stratum concordance (Σ_g conc_g) / (Σ_g n_pos_g·n_neg_g).
|
||||||
|
Equivalent to a pair-weighted average of per-galaxy AUCs."""
|
||||||
|
if not gal_preds: return 0.5
|
||||||
|
conc = 0; ties = 0; pairs = 0
|
||||||
|
for _, preds in gal_preds.items():
|
||||||
|
pos = sorted(p["p"] for p in preds if p["a"] == 1)
|
||||||
|
neg = sorted(p["p"] for p in preds if p["a"] == 0)
|
||||||
|
if not pos or not neg: continue
|
||||||
|
pairs += len(pos) * len(neg)
|
||||||
|
j = 0
|
||||||
|
for pv in pos:
|
||||||
|
while j < len(neg) and neg[j] < pv: j += 1
|
||||||
|
conc += j; k = j
|
||||||
|
while k < len(neg) and neg[k] == pv: k += 1
|
||||||
|
ties += (k - j)
|
||||||
|
if pairs == 0: return 0.5
|
||||||
|
return (conc + 0.5 * ties) / pairs
|
||||||
|
|
||||||
def compute_centroid(pts):
|
def compute_centroid(pts):
|
||||||
s=defaultdict(float);c=defaultdict(int)
|
s=defaultdict(float);c=defaultdict(int)
|
||||||
for p in pts:
|
for p in pts:
|
||||||
@@ -278,7 +297,9 @@ def main():
|
|||||||
print(f" 1: Calibration 2: Fair Comparison 3: Summary")
|
print(f" 1: Calibration 2: Fair Comparison 3: Summary")
|
||||||
print(f"{'='*76}\n")
|
print(f"{'='*76}\n")
|
||||||
|
|
||||||
all_pts=load_all_icu();assign_galaxies(all_pts)
|
all_pts=load_all_icu();
|
||||||
|
all_pts = [p for p in all_pts if p.get("saps_prob") is not None]
|
||||||
|
assign_galaxies(all_pts)
|
||||||
by_gal=defaultdict(list)
|
by_gal=defaultdict(list)
|
||||||
for p in all_pts:
|
for p in all_pts:
|
||||||
if p["galaxy"]: by_gal[p["galaxy"]].append(p)
|
if p["galaxy"]: by_gal[p["galaxy"]].append(p)
|
||||||
@@ -293,10 +314,33 @@ def main():
|
|||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
# ── Full pop LOO ───────────────────────────────────────────────
|
# ── Full pop LOO ───────────────────────────────────────────────
|
||||||
print(f"\n Running full-population LOO...")
|
print(f"\n Running s-population LOO...")
|
||||||
p_full = run_loo(all_pts, all_pts, therapy_hids, by_gal, "FULL-POP")
|
p_full = run_loo(all_pts, all_pts, therapy_hids, by_gal, "FULL-POP")
|
||||||
a_td = auc_fast(p_full)
|
a_td = auc_fast(p_full)
|
||||||
print(f" * TD full pop: AUC {a_td:.4f} n={len(p_full):,}")
|
print(f" * TD s-pop: AUC {a_td:.4f} n={len(p_full):,}")
|
||||||
|
|
||||||
|
# SAPS score for comparison
|
||||||
|
saps_preds_all = [{"a": p["died"], "p": p["saps_prob"]}
|
||||||
|
for p in all_pts if p.get("saps_prob") is not None]
|
||||||
|
a_td_s = auc_fast(saps_preds_all)
|
||||||
|
print(f" * SAPS-II s-pop: AUC {a_td_s:.4f} n={len(saps_preds_all):,}")
|
||||||
|
|
||||||
|
saps_preds_by_gal = {
|
||||||
|
gal: [{"a": p["died"], "p": p["saps_prob"]}
|
||||||
|
for p in gal_pts if p.get("saps_prob") is not None]
|
||||||
|
for gal, gal_pts in by_gal.items()
|
||||||
|
}
|
||||||
|
a_td_sg = auc_fast_gal(saps_preds_by_gal)
|
||||||
|
ns = str([len(gal_pts) for (gal, gal_pts) in saps_preds_by_gal.items()])
|
||||||
|
print(f" * SAPS-II by gal: AUC {a_td_sg:.4f} n={ns}")
|
||||||
|
|
||||||
|
td_preds_by_gal = defaultdict(list)
|
||||||
|
for pred in p_full:
|
||||||
|
if pred.get("g"):
|
||||||
|
td_preds_by_gal[pred["g"]].append({"a": pred["a"], "p": pred["p"]})
|
||||||
|
a_td_g = auc_fast_gal(dict(td_preds_by_gal))
|
||||||
|
ns = str([len(gal_pts) for (gal, gal_pts) in td_preds_by_gal.items()])
|
||||||
|
print(f" * TD by gal: AUC {a_td_g:.4f} n={ns}")
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════════
|
# ══════════════════════════════════════════════════════════════
|
||||||
# 1. CALIBRATION (10 Dezile)
|
# 1. CALIBRATION (10 Dezile)
|
||||||
|
|||||||
Reference in New Issue
Block a user