Files
eval-keplertwin/paper2_festung_teil3.py

576 lines
30 KiB
Python

"""
FESTUNG TEIL 3 — Die letzten drei Steine
1. CALIBRATION (10 Dezile: predicted vs observed)
2. FAIRER VERGLEICH (SAPS + LogReg auf exakt gleicher Population wie TD)
3. Positionierung (Zusammenfassung für Paper)
Benutzt die Full-Pop-Prädiktionen aus dem vorherigen Lauf.
Basilakis 2026 · chicxulub.ai
"""
import json, os, sys, math, time, random
from collections import defaultdict
# PostgreSQL connection string (libpq DSN). Override with env var.
# e.g. "host=localhost port=5432 dbname=mimic user=postgres password=..."
PG_DSN = os.environ.get("MIMIC_PG_DSN", "dbname=mimic3")
# Schema holding the stock MIMIC-III v1.3 tables (admissions, icustays,
# labevents, chartevents, inputevents_mv, inputevents_cv, prescriptions,
# diagnoses_icd, d_items, ...).
MIMIC_SCHEMA = os.environ.get("MIMIC_SCHEMA", "mimiciii")
# Schema holding the locally built derived tables (sapsii, sepsis3, ...);
# see sql/schemas.sql. Defaults to the same schema as MIMIC-III itself.
DERIVED_SCHEMA = os.environ.get("DERIVED_SCHEMA", MIMIC_SCHEMA)
# MIMIC-III stores Norepinephrine under different itemids in CareVue
# (inputevents_cv: 30047, 30120) and MetaVision (inputevents_mv: 221906).
NE_ITEMIDS_MV = [221906]
NE_ITEMIDS_CV = [30047, 30120]
SAPS_WINDOW = 10
PARAM_KEYS = ["lactate","creatinine","ph","troponin","hemoglobin",
"heart_rate","map_bp","spo2","temperature","ne_dose"]
THERAPIES = {
"vasopressin":{"drugs_input":["Vasopressin"],"drugs_rx":["Vasopressin"],"label":"Vasopressin"},
"crrt":{"drugs_input":["CRRT","CVVH","CVVHD","CVVHDF"],"label":"CRRT"},
"esmolol":{"drugs_rx":["Esmolol"],"drugs_input":["Esmolol"],"label":"Esmolol"},
"furosemide":{"drugs_rx":["Furosemide","Lasix"],"label":"Furosemide"},
"ne_high":{"ne_min":0.5,"label":"NE >0.5"},
"bicarbonate":{"drugs_rx":["Bicarbonate","Sodium Bicarbonate"],"drugs_input":["Sodium Bicarbonate"],"label":"Bicarbonate"},
"epinephrine":{"drugs_input":["Epinephrine"],"label":"Epinephrine"},
"amiodarone":{"drugs_rx":["Amiodarone"],"drugs_input":["Amiodarone"],"label":"Amiodarone"},
"transfusion":{"drugs_input":["Packed Red Blood Cells"],"label":"pRBC"},
}
SYNDROME_ICDS = {
"sepsis":{"icd_10":["A410","A411","A412","A413","A414","A418","A419","R6520","R6521"],"icd_9":["99591","99592","78552"]},
"aki":{"icd_10":["N170","N171","N172","N178","N179"],"icd_9":["5849","5845","5846","5847","5848"]},
"cardiogenic_shock":{"icd_10":["R570"],"icd_9":["78551"]},
"post_cardiac_arrest":{"icd_10":["I462","I469"],"icd_9":["4275"]},
"ards":{"icd_10":["J80"],"icd_9":["51882"]},
"acute_mi":{"icd_10":["I210","I211","I212","I213","I214","I219"],"icd_9":["41000","41001","41010","41011","41090","41091"]},
"liver_failure":{"icd_10":["K7200","K7201"],"icd_9":["5724"]},
"gi_bleeding":{"icd_10":["K920","K921","K922","K2501","K2521","K2541","K2561"],"icd_9":["5780","5781","5789","53121","53221","53321"]},
"stroke":{"icd_10":["I630","I631","I632","I633","I634","I635","I638","I639","I610","I611","I612","I619"],"icd_9":["43301","43311","43321","43331","43381","43391","431"]},
"pe":{"icd_10":["I2699","I2609","I2692","I2602"],"icd_9":["41519","41511","41512","41513"]},
"dka":{"icd_10":["E1010","E1011","E1110","E1111"],"icd_9":["25010","25011","25012","25013"]},
"heart_failure":{"icd_10":["I500","I501","I502","I503","I504","I509","I110"],"icd_9":["4280","4281","42820","42821","42830","42831","4289"]},
"pneumonia":{"icd_10":["J189","J180","J181","J188","J13","J14","J150","J151","J159"],"icd_9":["481","482","485","486"]},
"copd":{"icd_10":["J440","J441","J449"],"icd_9":["4910","4911","49120","49121","496"]},
"afib":{"icd_10":["I480","I481","I482","I4891"],"icd_9":["42731"]},
"post_cardiac_surgery":{"icd_10":["Z951","Z952","Z953","Z954"],"icd_9":["V4581","V4582","V4583"]},
}
GALAXY_PRIORITY = ["sepsis","cardiogenic_shock","post_cardiac_arrest","ards",
"acute_mi","aki","liver_failure","gi_bleeding","stroke","pe","dka",
"heart_failure","pneumonia","copd","afib","post_cardiac_surgery"]
_PG_CONN = None
def _pg_conn():
global _PG_CONN
if _PG_CONN is None or getattr(_PG_CONN, "closed", 0):
import psycopg2
_PG_CONN = psycopg2.connect(PG_DSN)
_PG_CONN.set_session(readonly=True, autocommit=True)
return _PG_CONN
def run_pg(sql):
"""Execute a read-only SQL query and return rows as list[dict]."""
import psycopg2.extras
conn = _pg_conn()
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql)
if cur.description is None:
return []
return [dict(r) for r in cur.fetchall()]
def auc_fast(preds):
if not preds: return 0.5
pos=sorted([p["p"] for p in preds if p["a"]==1])
neg=sorted([p["p"] for p in preds if p["a"]==0])
if not pos or not neg: return 0.5
conc=0;j=0;ties=0
for pv in pos:
while j<len(neg) and neg[j]<pv: j+=1
conc+=j;k=j
while k<len(neg) and neg[k]==pv: k+=1
ties+=(k-j)
return (conc+0.5*ties)/(len(pos)*len(neg))
def auc_fast_gal(gal_preds):
"""Pooled within-stratum concordance (Σ_g conc_g) / (Σ_g n_pos_g·n_neg_g).
Equivalent to a pair-weighted average of per-galaxy AUCs."""
if not gal_preds: return 0.5
conc = 0; ties = 0; pairs = 0
for _, preds in gal_preds.items():
pos = sorted(p["p"] for p in preds if p["a"] == 1)
neg = sorted(p["p"] for p in preds if p["a"] == 0)
if not pos or not neg: continue
pairs += len(pos) * len(neg)
j = 0
for pv in pos:
while j < len(neg) and neg[j] < pv: j += 1
conc += j; k = j
while k < len(neg) and neg[k] == pv: k += 1
ties += (k - j)
if pairs == 0: return 0.5
return (conc + 0.5 * ties) / pairs
def compute_centroid(pts):
s=defaultdict(float);c=defaultdict(int)
for p in pts:
for k in PARAM_KEYS:
if p.get(k) is not None: s[k]+=p[k];c[k]+=1
return {k:(s[k]/c[k] if c[k]>0 else None) for k in PARAM_KEYS}
def compute_sigma(pts):
vals=defaultdict(list)
for p in pts:
for k in PARAM_KEYS:
if p.get(k) is not None: vals[k].append(p[k])
sig={}
for k in PARAM_KEYS:
if len(vals[k])>1:
m=sum(vals[k])/len(vals[k]);v=sum((x-m)**2 for x in vals[k])/(len(vals[k])-1)
sig[k]=max(math.sqrt(v),1e-6)
else: sig[k]=1.0
return sig
def td(pv,centroid,weights):
d=0.0;n=0
for k in PARAM_KEYS:
p,c,w=pv.get(k),centroid.get(k),weights.get(k,0)
if p is not None and c is not None and w>0: d+=w*abs(p-c);n+=1
return d/n if n>0 else None
def load_all_icu():
print(" Loading ALL ICU patients...")
ne_mv = ",".join(str(i) for i in NE_ITEMIDS_MV)
ne_cv = ",".join(str(i) for i in NE_ITEMIDS_CV)
sql=f"""WITH icu_pts AS (
SELECT DISTINCT a.hadm_id,a.hospital_expire_flag AS died,s.sapsii,icu.intime,
s.sapsii_prob AS saps_prob
FROM {MIMIC_SCHEMA}.admissions a
JOIN {MIMIC_SCHEMA}.icustays icu ON a.hadm_id=icu.hadm_id
JOIN {DERIVED_SCHEMA}.sapsii s ON icu.icustay_id=s.icustay_id
WHERE s.sapsii BETWEEN 20 AND 90),
l_lac AS (SELECT le.hadm_id,MAX(le.valuenum) AS val FROM {MIMIC_SCHEMA}.labevents le JOIN icu_pts ip ON le.hadm_id=ip.hadm_id WHERE le.itemid=50813 AND le.valuenum IS NOT NULL AND le.charttime BETWEEN ip.intime AND ip.intime + INTERVAL '24 hours' GROUP BY le.hadm_id),
l_krea AS (SELECT le.hadm_id,MAX(le.valuenum) AS val FROM {MIMIC_SCHEMA}.labevents le JOIN icu_pts ip ON le.hadm_id=ip.hadm_id WHERE le.itemid=50912 AND le.valuenum IS NOT NULL AND le.charttime BETWEEN ip.intime AND ip.intime + INTERVAL '24 hours' GROUP BY le.hadm_id),
l_ph AS (SELECT le.hadm_id,MIN(le.valuenum) AS val FROM {MIMIC_SCHEMA}.labevents le JOIN icu_pts ip ON le.hadm_id=ip.hadm_id WHERE le.itemid IN (50820,50831) AND le.valuenum IS NOT NULL AND le.charttime BETWEEN ip.intime AND ip.intime + INTERVAL '24 hours' GROUP BY le.hadm_id),
l_trop AS (SELECT le.hadm_id,MAX(le.valuenum) AS val FROM {MIMIC_SCHEMA}.labevents le JOIN icu_pts ip ON le.hadm_id=ip.hadm_id WHERE le.itemid IN (51002,51003) AND le.valuenum IS NOT NULL AND le.charttime BETWEEN ip.intime AND ip.intime + INTERVAL '24 hours' GROUP BY le.hadm_id),
l_hb AS (SELECT le.hadm_id,MIN(le.valuenum) AS val FROM {MIMIC_SCHEMA}.labevents le JOIN icu_pts ip ON le.hadm_id=ip.hadm_id WHERE le.itemid=51222 AND le.valuenum IS NOT NULL AND le.charttime BETWEEN ip.intime AND ip.intime + INTERVAL '24 hours' GROUP BY le.hadm_id),
c_hr AS (SELECT ce.hadm_id,MAX(ce.valuenum) AS val FROM {MIMIC_SCHEMA}.chartevents ce JOIN icu_pts ip ON ce.hadm_id=ip.hadm_id JOIN {MIMIC_SCHEMA}.icustays icu ON ce.icustay_id=icu.icustay_id WHERE ce.itemid IN (211,220045) AND ce.valuenum BETWEEN 20 AND 250 AND ce.charttime BETWEEN icu.intime AND icu.intime + INTERVAL '24 hours' GROUP BY ce.hadm_id),
c_map AS (SELECT ce.hadm_id,MIN(ce.valuenum) AS val FROM {MIMIC_SCHEMA}.chartevents ce JOIN icu_pts ip ON ce.hadm_id=ip.hadm_id JOIN {MIMIC_SCHEMA}.icustays icu ON ce.icustay_id=icu.icustay_id WHERE ce.itemid IN (52,456,6702,220052,220181,225312) AND ce.valuenum BETWEEN 20 AND 200 AND ce.charttime BETWEEN icu.intime AND icu.intime + INTERVAL '24 hours' GROUP BY ce.hadm_id),
c_spo2 AS (SELECT ce.hadm_id,MIN(ce.valuenum) AS val FROM {MIMIC_SCHEMA}.chartevents ce JOIN icu_pts ip ON ce.hadm_id=ip.hadm_id JOIN {MIMIC_SCHEMA}.icustays icu ON ce.icustay_id=icu.icustay_id WHERE ce.itemid IN (646,220277) AND ce.valuenum BETWEEN 50 AND 100 AND ce.charttime BETWEEN icu.intime AND icu.intime + INTERVAL '24 hours' GROUP BY ce.hadm_id),
-- Temperature: pull all four MIMIC-III itemids (676/223762 nominally
-- Celsius, 678/223761 nominally Fahrenheit) and decide the unit from
-- the value itself. Plausible body temperature in C is ~28..43 and
-- in F is ~82..110; the two ranges don't overlap, so a value in the
-- F band can be safely converted to C even if it was charted under a
-- "Celsius" itemid (and vice versa). Anything outside both bands is
-- treated as sensor noise and dropped.
c_temp AS (
SELECT ce.hadm_id,
MIN(CASE
WHEN ce.valuenum BETWEEN 28 AND 43 THEN ce.valuenum
WHEN ce.valuenum BETWEEN 82 AND 110 THEN (ce.valuenum - 32.0) / 1.8
END) AS val
FROM {MIMIC_SCHEMA}.chartevents ce
JOIN icu_pts ip ON ce.hadm_id=ip.hadm_id
JOIN {MIMIC_SCHEMA}.icustays icu ON ce.icustay_id=icu.icustay_id
WHERE ce.itemid IN (676, 223762, 678, 223761)
AND ce.valuenum IS NOT NULL
AND (ce.valuenum BETWEEN 28 AND 43 OR ce.valuenum BETWEEN 82 AND 110)
AND ce.charttime BETWEEN icu.intime AND icu.intime + INTERVAL '24 hours'
GROUP BY ce.hadm_id),
ne_all AS (
SELECT ie.hadm_id, ie.icustay_id, ie.rate, ie.starttime AS evttime
FROM {MIMIC_SCHEMA}.inputevents_mv ie
WHERE ie.itemid IN ({ne_mv}) AND ie.rate>0
UNION ALL
SELECT ie.hadm_id, ie.icustay_id, ie.rate, ie.charttime AS evttime
FROM {MIMIC_SCHEMA}.inputevents_cv ie
WHERE ie.itemid IN ({ne_cv}) AND ie.rate>0),
ne AS (SELECT ie.hadm_id,MAX(ie.rate) AS val FROM ne_all ie JOIN icu_pts ip ON ie.hadm_id=ip.hadm_id JOIN {MIMIC_SCHEMA}.icustays icu ON ie.icustay_id=icu.icustay_id WHERE ie.evttime BETWEEN icu.intime AND icu.intime + INTERVAL '24 hours' GROUP BY ie.hadm_id)
SELECT ip.hadm_id,ip.died,ip.sapsii,ip.saps_prob,
ll.val AS lactate,lk.val AS creatinine,lp.val AS ph,lt.val AS troponin,lh.val AS hemoglobin,
chr_.val AS heart_rate,cma.val AS map_bp,csp.val AS spo2,cte.val AS temperature,ne.val AS ne_dose
FROM icu_pts ip
LEFT JOIN l_lac ll ON ip.hadm_id=ll.hadm_id LEFT JOIN l_krea lk ON ip.hadm_id=lk.hadm_id
LEFT JOIN l_ph lp ON ip.hadm_id=lp.hadm_id LEFT JOIN l_trop lt ON ip.hadm_id=lt.hadm_id
LEFT JOIN l_hb lh ON ip.hadm_id=lh.hadm_id LEFT JOIN c_hr chr_ ON ip.hadm_id=chr_.hadm_id
LEFT JOIN c_map cma ON ip.hadm_id=cma.hadm_id LEFT JOIN c_spo2 csp ON ip.hadm_id=csp.hadm_id
LEFT JOIN c_temp cte ON ip.hadm_id=cte.hadm_id LEFT JOIN ne ON ip.hadm_id=ne.hadm_id"""
rows=run_pg(sql)
pts=[{k:r.get(k) for k in ["hadm_id","died","sapsii","saps_prob"]+PARAM_KEYS}
for r in rows if sum(1 for k in PARAM_KEYS if r.get(k) is not None)>=3 and r.get("died") is not None]
print(f" -> {len(pts)} patients"); return pts
def assign_galaxies(pts):
print(" Assigning syndromes...")
hids=[p["hadm_id"] for p in pts];ps=defaultdict(set)
for i in range(0,len(hids),10000):
chunk=hids[i:i+10000]
# MIMIC-III v1.3 only carries ICD-9 codes (column `icd9_code`).
for r in run_pg(f"SELECT hadm_id,icd9_code FROM {MIMIC_SCHEMA}.diagnoses_icd WHERE hadm_id IN ({','.join(str(h) for h in chunk)})"):
code = r.get("icd9_code")
if code is None: continue
for sk,sd in SYNDROME_ICDS.items():
for rc in sd.get("icd_9",[]):
if code.startswith(rc): ps[r["hadm_id"]].add(sk);break
for p in pts:
p["galaxy"]=None
for g in GALAXY_PRIORITY:
if g in ps.get(p["hadm_id"],set()): p["galaxy"]=g;break
def load_therapy_hadmids(tkey):
t=THERAPIES[tkey]
if tkey=="ne_high":
ne_mv = ",".join(str(i) for i in NE_ITEMIDS_MV)
ne_cv = ",".join(str(i) for i in NE_ITEMIDS_CV)
sql = f"""
SELECT DISTINCT ie.hadm_id
FROM {MIMIC_SCHEMA}.inputevents_mv ie
JOIN {MIMIC_SCHEMA}.icustays icu ON ie.icustay_id=icu.icustay_id
WHERE ie.itemid IN ({ne_mv}) AND ie.rate>=0.5
AND ie.starttime BETWEEN icu.intime AND icu.intime + INTERVAL '24 hours'
UNION
SELECT DISTINCT ie.hadm_id
FROM {MIMIC_SCHEMA}.inputevents_cv ie
JOIN {MIMIC_SCHEMA}.icustays icu ON ie.icustay_id=icu.icustay_id
WHERE ie.itemid IN ({ne_cv}) AND ie.rate>=0.5
AND ie.charttime BETWEEN icu.intime AND icu.intime + INTERVAL '24 hours'
"""
return set(r["hadm_id"] for r in run_pg(sql))
clauses=[]
# MIMIC-III splits inputevents across MetaVision (starttime) and CareVue
# (charttime); we have to query both and UNION the hadm_ids.
for d in t.get("drugs_input",[]):
clauses.append(f"SELECT DISTINCT ie.hadm_id FROM {MIMIC_SCHEMA}.inputevents_mv ie JOIN {MIMIC_SCHEMA}.d_items di ON ie.itemid=di.itemid JOIN {MIMIC_SCHEMA}.icustays icu ON ie.icustay_id=icu.icustay_id WHERE di.label ILIKE '%{d}%' AND ie.starttime BETWEEN icu.intime AND icu.intime + INTERVAL '24 hours'")
clauses.append(f"SELECT DISTINCT ie.hadm_id FROM {MIMIC_SCHEMA}.inputevents_cv ie JOIN {MIMIC_SCHEMA}.d_items di ON ie.itemid=di.itemid JOIN {MIMIC_SCHEMA}.icustays icu ON ie.icustay_id=icu.icustay_id WHERE di.label ILIKE '%{d}%' AND ie.charttime BETWEEN icu.intime AND icu.intime + INTERVAL '24 hours'")
# MIMIC-III prescriptions uses DATE-precision `startdate` (not `starttime`).
for d in t.get("drugs_rx",[]):
clauses.append(f"SELECT DISTINCT p.hadm_id FROM {MIMIC_SCHEMA}.prescriptions p JOIN {MIMIC_SCHEMA}.icustays icu ON p.hadm_id=icu.hadm_id WHERE p.drug ILIKE '%{d}%' AND p.startdate BETWEEN icu.intime AND icu.intime + INTERVAL '24 hours'")
if not clauses: return set()
return set(r["hadm_id"] for r in run_pg(" UNION ".join(clauses)))
def run_loo(test_pts,ref_pts,therapy_hids,by_gal,label):
"""Returns list of {a, p, g, hadm_id} — includes hadm_id for fair comparison."""
global_sigma=compute_sigma(ref_pts);preds=[];t0=time.time()
for i,pat in enumerate(test_pts):
if i%5000==0 and i>0:
elapsed=time.time()-t0;eta=(len(test_pts)-i)/(i/elapsed)/60 if elapsed>0 else 0
sys.stdout.write(f"\r {label}: {i:,}/{len(test_pts):,} preds={len(preds):,} ~{eta:.0f}m");sys.stdout.flush()
therapies=[tk for tk,hids in therapy_hids.items() if pat["hadm_id"] in hids]
if not therapies: continue
g=pat.get("galaxy");saps=pat.get("sapsii",50)
for tk in therapies[:3]:
t_hids=therapy_hids[tk]
if g and g in by_gal:
slo,shi=saps-SAPS_WINDOW,saps+SAPS_WINDOW;gpts=by_gal[g]
co=[p for p in gpts if p["hadm_id"] in t_hids and p["hadm_id"]!=pat["hadm_id"] and p.get("sapsii") and slo<=p["sapsii"]<=shi]
if len(co)<10: co=[p for p in gpts if p["hadm_id"] in t_hids and p["hadm_id"]!=pat["hadm_id"]]
if len(co)<10: co=[p for p in ref_pts if p["hadm_id"] in t_hids and p["hadm_id"]!=pat["hadm_id"]]
else: co=[p for p in ref_pts if p["hadm_id"] in t_hids and p["hadm_id"]!=pat["hadm_id"]]
if len(co)<5: continue
centroid=compute_centroid(co)
if g and g in by_gal and len(by_gal.get(g,[]))>20:
w={k:1.0/s for k,s in compute_sigma(by_gal[g]).items()}
else: w={k:1.0/s for k,s in global_sigma.items()}
d=td(pat,centroid,w)
if d is None: continue
ds=sorted([v for v in [td(p,centroid,w) for p in co] if v is not None])
if not ds: continue
eps=ds[min(len(ds)-1,len(ds)//5)]
orbit=[p for p in co if td(p,centroid,w) is not None and abs(td(p,centroid,w)-d)<max(eps,0.01)]
if len(orbit)<3: continue
preds.append({"a":pat["died"],"p":sum(1 for p in orbit if p["died"])/len(orbit),
"g":g,"hadm_id":pat["hadm_id"]})
elapsed=time.time()-t0
sys.stdout.write(f"\r {label}: DONE {len(preds):,} preds {elapsed/60:.1f}min{' '*20}\n");return preds
def main():
T0=time.time()
print(f"\n{'='*76}")
print(f" FESTUNG TEIL 3 — Letzte drei Steine")
print(f" 1: Calibration 2: Fair Comparison 3: Summary")
print(f"{'='*76}\n")
all_pts=load_all_icu();
all_pts = [p for p in all_pts if p.get("saps_prob") is not None]
assign_galaxies(all_pts)
by_gal=defaultdict(list)
for p in all_pts:
if p["galaxy"]: by_gal[p["galaxy"]].append(p)
pt_idx = {p["hadm_id"]: p for p in all_pts}
print(f"\n Loading therapies...")
therapy_hids={}
for tk in THERAPIES:
therapy_hids[tk]=load_therapy_hadmids(tk)
print(f" {THERAPIES[tk]['label']:20s}: {len(therapy_hids[tk]):6d}")
results = {}
# ── Full pop LOO ───────────────────────────────────────────────
print(f"\n Running s-population LOO...")
p_full = run_loo(all_pts, all_pts, therapy_hids, by_gal, "FULL-POP")
a_td = auc_fast(p_full)
print(f" * TD s-pop: AUC {a_td:.4f} n={len(p_full):,}")
# SAPS score for comparison
saps_preds_all = [{"a": p["died"], "p": p["saps_prob"]}
for p in all_pts if p.get("saps_prob") is not None]
a_td_s = auc_fast(saps_preds_all)
print(f" * SAPS-II s-pop: AUC {a_td_s:.4f} n={len(saps_preds_all):,}")
saps_preds_by_gal = {
gal: [{"a": p["died"], "p": p["saps_prob"]}
for p in gal_pts if p.get("saps_prob") is not None]
for gal, gal_pts in by_gal.items()
}
a_td_sg = auc_fast_gal(saps_preds_by_gal)
ns = str([len(gal_pts) for (gal, gal_pts) in saps_preds_by_gal.items()])
print(f" * SAPS-II by gal: AUC {a_td_sg:.4f} n={ns}")
td_preds_by_gal = defaultdict(list)
for pred in p_full:
if pred.get("g"):
td_preds_by_gal[pred["g"]].append({"a": pred["a"], "p": pred["p"]})
a_td_g = auc_fast_gal(dict(td_preds_by_gal))
ns = str([len(gal_pts) for (gal, gal_pts) in td_preds_by_gal.items()])
print(f" * TD by gal: AUC {a_td_g:.4f} n={ns}")
# ══════════════════════════════════════════════════════════════
# 1. CALIBRATION (10 Dezile)
# ══════════════════════════════════════════════════════════════
print(f"\n{'='*76}")
print(f" 1. CALIBRATION — Predicted vs Observed (10 Dezile)")
print(f"{'='*76}\n")
sorted_preds = sorted(p_full, key=lambda x: x["p"])
n = len(sorted_preds)
bin_size = n // 10
print(f" {'Decile':>7s} {'Pred Mean':>10s} {'Obs Rate':>10s} {'n':>6s} {'Diff':>8s} Visual")
print(f" {'-'*65}")
cal_bins = []
total_brier = 0
for i in range(10):
start = i * bin_size
end = (i+1) * bin_size if i < 9 else n
bin_preds = sorted_preds[start:end]
pred_mean = sum(p["p"] for p in bin_preds) / len(bin_preds)
obs_rate = sum(p["a"] for p in bin_preds) / len(bin_preds)
diff = pred_mean - obs_rate
# Brier score contribution
for p in bin_preds:
total_brier += (p["p"] - p["a"]) ** 2
bar_pred = "#" * int(pred_mean * 40)
bar_obs = "." * int(obs_rate * 40)
print(f" {i+1:>7d} {pred_mean:10.4f} {obs_rate:10.4f} {len(bin_preds):6d} {diff:+8.4f} P:{bar_pred}")
print(f" {'':>7s} {'':>10s} {'':>10s} {'':>6s} {'':>8s} O:{bar_obs}")
cal_bins.append({"decile": i+1, "pred": round(pred_mean,4), "obs": round(obs_rate,4),
"n": len(bin_preds), "diff": round(diff,4)})
brier = total_brier / n
max_diff = max(abs(b["diff"]) for b in cal_bins)
mean_diff = sum(abs(b["diff"]) for b in cal_bins) / 10
print(f"\n Brier Score: {brier:.4f}")
print(f" Max |pred-obs|: {max_diff:.4f}")
print(f" Mean |pred-obs|: {mean_diff:.4f}")
if max_diff < 0.10:
print(f" GOOD CALIBRATION. Max deviation < 0.10.")
elif max_diff < 0.15:
print(f" ACCEPTABLE CALIBRATION. Some deviation in extreme deciles.")
else:
print(f" POOR CALIBRATION. Needs recalibration (Platt scaling or isotonic).")
# SAPS calibration for comparison
saps_preds_all = [{"a": p["died"], "p": p["saps_prob"]}
for p in all_pts if p.get("saps_prob") is not None]
saps_sorted = sorted(saps_preds_all, key=lambda x: x["p"])
saps_brier = sum((p["p"] - p["a"])**2 for p in saps_sorted) / len(saps_sorted)
print(f"\n SAPS-II Brier Score: {saps_brier:.4f}")
print(f" TD Brier Score: {brier:.4f}")
if brier < saps_brier:
print(f" TD better calibrated than SAPS-II ({brier:.4f} < {saps_brier:.4f})")
else:
print(f" SAPS-II better calibrated ({saps_brier:.4f} < {brier:.4f})")
results["calibration"] = {
"bins": cal_bins, "brier_td": round(brier,4), "brier_saps": round(saps_brier,4),
"max_diff": round(max_diff,4), "mean_diff": round(mean_diff,4)
}
# ══════════════════════════════════════════════════════════════
# 2. FAIR COMPARISON (gleiche Population)
# ══════════════════════════════════════════════════════════════
print(f"\n{'='*76}")
print(f" 2. FAIR COMPARISON — alle Modelle auf exakt gleicher Population")
print(f"{'='*76}\n")
# Get unique hadm_ids that got TD predictions
td_hadmids = set(p["hadm_id"] for p in p_full)
print(f" TD predictions from {len(td_hadmids):,} unique patients")
# For patients with multiple TD predictions, average them
td_by_patient = defaultdict(list)
for p in p_full:
td_by_patient[p["hadm_id"]].append(p["p"])
# Build matched population: only patients with TD predictions AND saps_prob
matched = []
for hid in td_hadmids:
pt = pt_idx.get(hid)
if pt and pt.get("saps_prob") is not None:
matched.append({
"hadm_id": hid,
"died": pt["died"],
"saps_prob": pt["saps_prob"],
"td_prob": sum(td_by_patient[hid]) / len(td_by_patient[hid]),
"sapsii": pt.get("sapsii"),
})
print(f" Matched population: {len(matched):,} patients (have both TD and SAPS)")
# AUC on matched population
td_matched = [{"a": p["died"], "p": p["td_prob"]} for p in matched]
saps_matched = [{"a": p["died"], "p": p["saps_prob"]} for p in matched]
a_td_m = auc_fast(td_matched)
a_saps_m = auc_fast(saps_matched)
print(f"\n ON MATCHED POPULATION ({len(matched):,} patients):")
print(f" TD: AUC {a_td_m:.4f}")
print(f" SAPS: AUC {a_saps_m:.4f}")
print(f" Delta: {a_td_m - a_saps_m:+.4f}")
# LogReg on matched population
try:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
import numpy as np
# Feature matrix for matched patients only
medians = {}
for k in PARAM_KEYS:
vals = [p[k] for p in all_pts if p.get(k) is not None]
medians[k] = sorted(vals)[len(vals)//2] if vals else 0
X_m = []
y_m = []
for p in matched:
pt = pt_idx[p["hadm_id"]]
row = [pt[k] if pt.get(k) is not None else medians[k] for k in PARAM_KEYS]
X_m.append(row)
y_m.append(p["died"])
X_m = np.array(X_m, dtype=float)
y_m = np.array(y_m, dtype=int)
pipe = Pipeline([("scaler", StandardScaler()), ("lr", LogisticRegression(max_iter=1000, random_state=42))])
y_prob_m = cross_val_predict(pipe, X_m, y_m, cv=5, method="predict_proba")[:, 1]
lr_matched = [{"a": int(y_m[i]), "p": float(y_prob_m[i])} for i in range(len(y_m))]
a_lr_m = auc_fast(lr_matched)
print(f" LogReg: AUC {a_lr_m:.4f}")
print(f"\n FAIR THREE-WAY COMPARISON (n={len(matched):,}, identical patients):")
print(f" {'Method':25s} {'AUC':>7s} {'vs TD':>8s}")
print(f" {'-'*45}")
print(f" {'Therapeutic Distance':25s} {a_td_m:7.4f} {'ref':>8s}")
print(f" {'SAPS-II':25s} {a_saps_m:7.4f} {a_saps_m-a_td_m:+8.4f}")
print(f" {'Logistic Regression':25s} {a_lr_m:7.4f} {a_lr_m-a_td_m:+8.4f}")
# Bootstrap on matched
print(f"\n Bootstrap CI on matched population (1000 resamples)...")
random.seed(42)
deltas = []
for _ in range(1000):
idx = [random.randint(0, len(matched)-1) for _ in range(len(matched))]
td_s = [td_matched[i] for i in idx]
saps_s = [saps_matched[i] for i in idx]
deltas.append(auc_fast(td_s) - auc_fast(saps_s))
deltas.sort()
ci_lo = deltas[25]
ci_hi = deltas[975]
p_val = sum(1 for d in deltas if d <= 0) / 1000
print(f" TD vs SAPS (matched): +{a_td_m-a_saps_m:.4f} (95% CI {ci_lo:+.4f} to {ci_hi:+.4f})")
print(f" p-value: {'<0.001' if p_val < 0.001 else f'{p_val:.3f}'}")
results["fair_comparison"] = {
"n_matched": len(matched),
"td_auc": round(a_td_m, 4), "saps_auc": round(a_saps_m, 4), "lr_auc": round(a_lr_m, 4),
"delta_td_saps": round(a_td_m - a_saps_m, 4),
"delta_td_lr": round(a_td_m - a_lr_m, 4),
"ci": [round(ci_lo, 4), round(ci_hi, 4)],
"p_value": round(p_val, 4) if p_val >= 0.001 else 0.001
}
except ImportError:
print(f" sklearn not installed — LogReg skipped")
results["fair_comparison"] = {
"n_matched": len(matched),
"td_auc": round(a_td_m, 4), "saps_auc": round(a_saps_m, 4),
"delta": round(a_td_m - a_saps_m, 4)
}
# ══════════════════════════════════════════════════════════════
# 3. ZUSAMMENFASSUNG — Positionierung
# ══════════════════════════════════════════════════════════════
print(f"\n{'='*76}")
print(f" 3. POSITIONIERUNG — Der eine Satz")
print(f"{'='*76}\n")
print(f" SAPS-II sagt: 'Dieser Patient hat 35% Sterberisiko.'")
print(f" LogReg sagt: 'Dieser Patient hat 28% Sterberisiko.'")
print(f" TD sagt: 'Dieser Patient ist 0.7 Einheiten vom Vasopressin-Zentroid")
print(f" und 1.2 Einheiten vom CRRT-Zentroid entfernt.")
print(f" Patienten in seinem Orbit mit Vasopressin: 25% Mortalitaet.")
print(f" Patienten in seinem Orbit mit CRRT: 55% Mortalitaet.'")
print(f"")
print(f" SAPS und LogReg liefern globales Risiko.")
print(f" TD liefert therapiespezifische Risikostruktur.")
print(f"")
print(f" Das ist der Unterschied. Das ist die Positionierung.")
print(f" Nicht: 'besserer Score'. Sondern: 'andere Information'.")
# ══════════════════════════════════════════════════════════════
# FINAL SUMMARY
# ══════════════════════════════════════════════════════════════
tm = (time.time()-T0)/60
print(f"\n{'='*76}")
print(f" FESTUNG TEIL 3 — ERGEBNIS ({tm:.0f} min)")
print(f"{'='*76}\n")
print(f" 1. Calibration:")
print(f" Brier TD: {brier:.4f} SAPS: {saps_brier:.4f}")
print(f" Max |pred-obs|: {max_diff:.4f} Mean: {mean_diff:.4f}")
print(f"")
print(f" 2. Fair Comparison ({len(matched):,} matched patients):")
print(f" TD: {a_td_m:.4f}")
print(f" SAPS: {a_saps_m:.4f} (Delta {a_td_m-a_saps_m:+.4f})")
if "lr_auc" in results.get("fair_comparison", {}):
print(f" LogReg: {results['fair_comparison']['lr_auc']:.4f} (Delta {a_td_m-results['fair_comparison']['lr_auc']:+.4f})")
print(f"")
print(f" 3. Positionierung:")
print(f" TD != besserer Score")
print(f" TD = therapiespezifische Risikostruktur")
with open("festung_teil3.json", "w") as f:
json.dump(results, f, indent=2, default=str)
print(f"\n -> festung_teil3.json saved")
print(f"\n{'='*76}\n")
if __name__ == "__main__":
main()