initial: Therapeutic Distance orig script

This commit is contained in:
2026-05-05 09:53:51 +02:00
commit 88c02e8e81

453
paper2_festung_teil3.py Normal file
View File

@@ -0,0 +1,453 @@
"""
FESTUNG TEIL 3 — Die letzten drei Steine
1. CALIBRATION (10 Dezile: predicted vs observed)
2. FAIRER VERGLEICH (SAPS + LogReg auf exakt gleicher Population wie TD)
3. Positionierung (Zusammenfassung für Paper)
Benutzt die Full-Pop-Prädiktionen aus dem vorherigen Lauf.
Basilakis 2026 · chicxulub.ai
"""
import json, sys, math, time, random
from collections import defaultdict
BQ_PROJECT = "goddard-gap"
DATA_PROJECT = "physionet-data"
NE_ITEMID = 221906
SAPS_WINDOW = 10
PARAM_KEYS = ["lactate","creatinine","ph","troponin","hemoglobin",
"heart_rate","map_bp","spo2","temperature","ne_dose"]
THERAPIES = {
"vasopressin":{"drugs_input":["Vasopressin"],"drugs_rx":["Vasopressin"],"label":"Vasopressin"},
"crrt":{"drugs_input":["CRRT","CVVH","CVVHD","CVVHDF"],"label":"CRRT"},
"esmolol":{"drugs_rx":["Esmolol"],"drugs_input":["Esmolol"],"label":"Esmolol"},
"furosemide":{"drugs_rx":["Furosemide","Lasix"],"label":"Furosemide"},
"ne_high":{"ne_min":0.5,"label":"NE >0.5"},
"bicarbonate":{"drugs_rx":["Bicarbonate","Sodium Bicarbonate"],"drugs_input":["Sodium Bicarbonate"],"label":"Bicarbonate"},
"epinephrine":{"drugs_input":["Epinephrine"],"label":"Epinephrine"},
"amiodarone":{"drugs_rx":["Amiodarone"],"drugs_input":["Amiodarone"],"label":"Amiodarone"},
"transfusion":{"drugs_input":["Packed Red Blood Cells"],"label":"pRBC"},
}
SYNDROME_ICDS = {
"sepsis":{"icd_10":["A410","A411","A412","A413","A414","A418","A419","R6520","R6521"],"icd_9":["99591","99592","78552"]},
"aki":{"icd_10":["N170","N171","N172","N178","N179"],"icd_9":["5849","5845","5846","5847","5848"]},
"cardiogenic_shock":{"icd_10":["R570"],"icd_9":["78551"]},
"post_cardiac_arrest":{"icd_10":["I462","I469"],"icd_9":["4275"]},
"ards":{"icd_10":["J80"],"icd_9":["51882"]},
"acute_mi":{"icd_10":["I210","I211","I212","I213","I214","I219"],"icd_9":["41000","41001","41010","41011","41090","41091"]},
"liver_failure":{"icd_10":["K7200","K7201"],"icd_9":["5724"]},
"gi_bleeding":{"icd_10":["K920","K921","K922","K2501","K2521","K2541","K2561"],"icd_9":["5780","5781","5789","53121","53221","53321"]},
"stroke":{"icd_10":["I630","I631","I632","I633","I634","I635","I638","I639","I610","I611","I612","I619"],"icd_9":["43301","43311","43321","43331","43381","43391","431"]},
"pe":{"icd_10":["I2699","I2609","I2692","I2602"],"icd_9":["41519","41511","41512","41513"]},
"dka":{"icd_10":["E1010","E1011","E1110","E1111"],"icd_9":["25010","25011","25012","25013"]},
"heart_failure":{"icd_10":["I500","I501","I502","I503","I504","I509","I110"],"icd_9":["4280","4281","42820","42821","42830","42831","4289"]},
"pneumonia":{"icd_10":["J189","J180","J181","J188","J13","J14","J150","J151","J159"],"icd_9":["481","482","485","486"]},
"copd":{"icd_10":["J440","J441","J449"],"icd_9":["4910","4911","49120","49121","496"]},
"afib":{"icd_10":["I480","I481","I482","I4891"],"icd_9":["42731"]},
"post_cardiac_surgery":{"icd_10":["Z951","Z952","Z953","Z954"],"icd_9":["V4581","V4582","V4583"]},
}
GALAXY_PRIORITY = ["sepsis","cardiogenic_shock","post_cardiac_arrest","ards",
"acute_mi","aki","liver_failure","gi_bleeding","stroke","pe","dka",
"heart_failure","pneumonia","copd","afib","post_cardiac_surgery"]
def run_bq(sql):
from google.cloud import bigquery
client = bigquery.Client(project=BQ_PROJECT)
return [dict(r.items()) for r in client.query(sql).result()]
def auc_fast(preds):
if not preds: return 0.5
pos=sorted([p["p"] for p in preds if p["a"]==1])
neg=sorted([p["p"] for p in preds if p["a"]==0])
if not pos or not neg: return 0.5
conc=0;j=0;ties=0
for pv in pos:
while j<len(neg) and neg[j]<pv: j+=1
conc+=j;k=j
while k<len(neg) and neg[k]==pv: k+=1
ties+=(k-j)
return (conc+0.5*ties)/(len(pos)*len(neg))
def compute_centroid(pts):
s=defaultdict(float);c=defaultdict(int)
for p in pts:
for k in PARAM_KEYS:
if p.get(k) is not None: s[k]+=p[k];c[k]+=1
return {k:(s[k]/c[k] if c[k]>0 else None) for k in PARAM_KEYS}
def compute_sigma(pts):
vals=defaultdict(list)
for p in pts:
for k in PARAM_KEYS:
if p.get(k) is not None: vals[k].append(p[k])
sig={}
for k in PARAM_KEYS:
if len(vals[k])>1:
m=sum(vals[k])/len(vals[k]);v=sum((x-m)**2 for x in vals[k])/(len(vals[k])-1)
sig[k]=max(math.sqrt(v),1e-6)
else: sig[k]=1.0
return sig
def td(pv,centroid,weights):
d=0.0;n=0
for k in PARAM_KEYS:
p,c,w=pv.get(k),centroid.get(k),weights.get(k,0)
if p is not None and c is not None and w>0: d+=w*abs(p-c);n+=1
return d/n if n>0 else None
def load_all_icu():
print(" Loading ALL ICU patients...")
sql=f"""WITH icu_pts AS (
SELECT DISTINCT a.hadm_id,a.hospital_expire_flag AS died,s.sapsii,icu.intime,
s.sapsii_prob AS saps_prob
FROM `{DATA_PROJECT}.mimiciv_3_1_hosp.admissions` a
JOIN `{DATA_PROJECT}.mimiciv_3_1_icu.icustays` icu ON a.hadm_id=icu.hadm_id
JOIN `{DATA_PROJECT}.mimiciv_3_1_derived.sapsii` s ON icu.stay_id=s.stay_id
WHERE s.sapsii BETWEEN 20 AND 90),
l_lac AS (SELECT le.hadm_id,MAX(le.valuenum) AS val FROM `{DATA_PROJECT}.mimiciv_3_1_hosp.labevents` le JOIN icu_pts ip ON le.hadm_id=ip.hadm_id WHERE le.itemid=50813 AND le.valuenum IS NOT NULL AND le.charttime BETWEEN ip.intime AND TIMESTAMP_ADD(ip.intime,INTERVAL 24 HOUR) GROUP BY le.hadm_id),
l_krea AS (SELECT le.hadm_id,MAX(le.valuenum) AS val FROM `{DATA_PROJECT}.mimiciv_3_1_hosp.labevents` le JOIN icu_pts ip ON le.hadm_id=ip.hadm_id WHERE le.itemid=50912 AND le.valuenum IS NOT NULL AND le.charttime BETWEEN ip.intime AND TIMESTAMP_ADD(ip.intime,INTERVAL 24 HOUR) GROUP BY le.hadm_id),
l_ph AS (SELECT le.hadm_id,MIN(le.valuenum) AS val FROM `{DATA_PROJECT}.mimiciv_3_1_hosp.labevents` le JOIN icu_pts ip ON le.hadm_id=ip.hadm_id WHERE le.itemid IN (50820,50831) AND le.valuenum IS NOT NULL AND le.charttime BETWEEN ip.intime AND TIMESTAMP_ADD(ip.intime,INTERVAL 24 HOUR) GROUP BY le.hadm_id),
l_trop AS (SELECT le.hadm_id,MAX(le.valuenum) AS val FROM `{DATA_PROJECT}.mimiciv_3_1_hosp.labevents` le JOIN icu_pts ip ON le.hadm_id=ip.hadm_id WHERE le.itemid IN (51002,51003) AND le.valuenum IS NOT NULL AND le.charttime BETWEEN ip.intime AND TIMESTAMP_ADD(ip.intime,INTERVAL 24 HOUR) GROUP BY le.hadm_id),
l_hb AS (SELECT le.hadm_id,MIN(le.valuenum) AS val FROM `{DATA_PROJECT}.mimiciv_3_1_hosp.labevents` le JOIN icu_pts ip ON le.hadm_id=ip.hadm_id WHERE le.itemid=51222 AND le.valuenum IS NOT NULL AND le.charttime BETWEEN ip.intime AND TIMESTAMP_ADD(ip.intime,INTERVAL 24 HOUR) GROUP BY le.hadm_id),
c_hr AS (SELECT ce.hadm_id,MAX(ce.valuenum) AS val FROM `{DATA_PROJECT}.mimiciv_3_1_icu.chartevents` ce JOIN icu_pts ip ON ce.hadm_id=ip.hadm_id JOIN `{DATA_PROJECT}.mimiciv_3_1_icu.icustays` icu ON ce.stay_id=icu.stay_id WHERE ce.itemid=220045 AND ce.valuenum BETWEEN 20 AND 250 AND ce.charttime BETWEEN icu.intime AND TIMESTAMP_ADD(icu.intime,INTERVAL 24 HOUR) GROUP BY ce.hadm_id),
c_map AS (SELECT ce.hadm_id,MIN(ce.valuenum) AS val FROM `{DATA_PROJECT}.mimiciv_3_1_icu.chartevents` ce JOIN icu_pts ip ON ce.hadm_id=ip.hadm_id JOIN `{DATA_PROJECT}.mimiciv_3_1_icu.icustays` icu ON ce.stay_id=icu.stay_id WHERE ce.itemid IN (220052,220181,225312) AND ce.valuenum BETWEEN 20 AND 200 AND ce.charttime BETWEEN icu.intime AND TIMESTAMP_ADD(icu.intime,INTERVAL 24 HOUR) GROUP BY ce.hadm_id),
c_spo2 AS (SELECT ce.hadm_id,MIN(ce.valuenum) AS val FROM `{DATA_PROJECT}.mimiciv_3_1_icu.chartevents` ce JOIN icu_pts ip ON ce.hadm_id=ip.hadm_id JOIN `{DATA_PROJECT}.mimiciv_3_1_icu.icustays` icu ON ce.stay_id=icu.stay_id WHERE ce.itemid=220277 AND ce.valuenum BETWEEN 50 AND 100 AND ce.charttime BETWEEN icu.intime AND TIMESTAMP_ADD(icu.intime,INTERVAL 24 HOUR) GROUP BY ce.hadm_id),
c_temp AS (SELECT ce.hadm_id,MIN(ce.valuenum) AS val FROM `{DATA_PROJECT}.mimiciv_3_1_icu.chartevents` ce JOIN icu_pts ip ON ce.hadm_id=ip.hadm_id JOIN `{DATA_PROJECT}.mimiciv_3_1_icu.icustays` icu ON ce.stay_id=icu.stay_id WHERE ce.itemid=223762 AND ce.valuenum BETWEEN 28 AND 43 AND ce.charttime BETWEEN icu.intime AND TIMESTAMP_ADD(icu.intime,INTERVAL 24 HOUR) GROUP BY ce.hadm_id),
ne AS (SELECT ie.hadm_id,MAX(ie.rate) AS val FROM `{DATA_PROJECT}.mimiciv_3_1_icu.inputevents` ie JOIN icu_pts ip ON ie.hadm_id=ip.hadm_id JOIN `{DATA_PROJECT}.mimiciv_3_1_icu.icustays` icu ON ie.stay_id=icu.stay_id WHERE ie.itemid={NE_ITEMID} AND ie.rate>0 AND ie.starttime BETWEEN icu.intime AND TIMESTAMP_ADD(icu.intime,INTERVAL 24 HOUR) GROUP BY ie.hadm_id)
SELECT ip.hadm_id,ip.died,ip.sapsii,ip.saps_prob,
ll.val AS lactate,lk.val AS creatinine,lp.val AS ph,lt.val AS troponin,lh.val AS hemoglobin,
chr.val AS heart_rate,cma.val AS map_bp,csp.val AS spo2,cte.val AS temperature,ne.val AS ne_dose
FROM icu_pts ip
LEFT JOIN l_lac ll ON ip.hadm_id=ll.hadm_id LEFT JOIN l_krea lk ON ip.hadm_id=lk.hadm_id
LEFT JOIN l_ph lp ON ip.hadm_id=lp.hadm_id LEFT JOIN l_trop lt ON ip.hadm_id=lt.hadm_id
LEFT JOIN l_hb lh ON ip.hadm_id=lh.hadm_id LEFT JOIN c_hr chr ON ip.hadm_id=chr.hadm_id
LEFT JOIN c_map cma ON ip.hadm_id=cma.hadm_id LEFT JOIN c_spo2 csp ON ip.hadm_id=csp.hadm_id
LEFT JOIN c_temp cte ON ip.hadm_id=cte.hadm_id LEFT JOIN ne ON ip.hadm_id=ne.hadm_id"""
rows=run_bq(sql)
pts=[{k:r.get(k) for k in ["hadm_id","died","sapsii","saps_prob"]+PARAM_KEYS}
for r in rows if sum(1 for k in PARAM_KEYS if r.get(k) is not None)>=3 and r.get("died") is not None]
print(f" -> {len(pts)} patients"); return pts
def assign_galaxies(pts):
print(" Assigning syndromes...")
hids=[p["hadm_id"] for p in pts];ps=defaultdict(set)
for i in range(0,len(hids),10000):
chunk=hids[i:i+10000]
for r in run_bq(f"SELECT hadm_id,icd_code,icd_version FROM `{DATA_PROJECT}.mimiciv_3_1_hosp.diagnoses_icd` WHERE hadm_id IN ({','.join(str(h) for h in chunk)})"):
for sk,sd in SYNDROME_ICDS.items():
for rc in sd.get(f"icd_{r['icd_version']}",[]):
if r["icd_code"].startswith(rc): ps[r["hadm_id"]].add(sk);break
for p in pts:
p["galaxy"]=None
for g in GALAXY_PRIORITY:
if g in ps.get(p["hadm_id"],set()): p["galaxy"]=g;break
def load_therapy_hadmids(tkey):
t=THERAPIES[tkey]
if tkey=="ne_high":
return set(r["hadm_id"] for r in run_bq(f"SELECT DISTINCT ie.hadm_id FROM `{DATA_PROJECT}.mimiciv_3_1_icu.inputevents` ie JOIN `{DATA_PROJECT}.mimiciv_3_1_icu.icustays` icu ON ie.stay_id=icu.stay_id WHERE ie.itemid={NE_ITEMID} AND ie.rate>=0.5 AND ie.starttime BETWEEN icu.intime AND TIMESTAMP_ADD(icu.intime,INTERVAL 24 HOUR)"))
clauses=[]
for d in t.get("drugs_input",[]):
clauses.append(f"SELECT DISTINCT ie.hadm_id FROM `{DATA_PROJECT}.mimiciv_3_1_icu.inputevents` ie JOIN `{DATA_PROJECT}.mimiciv_3_1_icu.d_items` di ON ie.itemid=di.itemid JOIN `{DATA_PROJECT}.mimiciv_3_1_icu.icustays` icu ON ie.stay_id=icu.stay_id WHERE di.label LIKE '%{d}%' AND ie.starttime BETWEEN icu.intime AND TIMESTAMP_ADD(icu.intime,INTERVAL 24 HOUR)")
for d in t.get("drugs_rx",[]):
clauses.append(f"SELECT DISTINCT p.hadm_id FROM `{DATA_PROJECT}.mimiciv_3_1_hosp.prescriptions` p JOIN `{DATA_PROJECT}.mimiciv_3_1_icu.icustays` icu ON p.hadm_id=icu.hadm_id WHERE p.drug LIKE '%{d}%' AND p.starttime BETWEEN icu.intime AND TIMESTAMP_ADD(icu.intime,INTERVAL 24 HOUR)")
if not clauses: return set()
return set(r["hadm_id"] for r in run_bq(" UNION DISTINCT ".join(clauses)))
def run_loo(test_pts,ref_pts,therapy_hids,by_gal,label):
"""Returns list of {a, p, g, hadm_id} — includes hadm_id for fair comparison."""
global_sigma=compute_sigma(ref_pts);preds=[];t0=time.time()
for i,pat in enumerate(test_pts):
if i%5000==0 and i>0:
elapsed=time.time()-t0;eta=(len(test_pts)-i)/(i/elapsed)/60 if elapsed>0 else 0
sys.stdout.write(f"\r {label}: {i:,}/{len(test_pts):,} preds={len(preds):,} ~{eta:.0f}m");sys.stdout.flush()
therapies=[tk for tk,hids in therapy_hids.items() if pat["hadm_id"] in hids]
if not therapies: continue
g=pat.get("galaxy");saps=pat.get("sapsii",50)
for tk in therapies[:3]:
t_hids=therapy_hids[tk]
if g and g in by_gal:
slo,shi=saps-SAPS_WINDOW,saps+SAPS_WINDOW;gpts=by_gal[g]
co=[p for p in gpts if p["hadm_id"] in t_hids and p["hadm_id"]!=pat["hadm_id"] and p.get("sapsii") and slo<=p["sapsii"]<=shi]
if len(co)<10: co=[p for p in gpts if p["hadm_id"] in t_hids and p["hadm_id"]!=pat["hadm_id"]]
if len(co)<10: co=[p for p in ref_pts if p["hadm_id"] in t_hids and p["hadm_id"]!=pat["hadm_id"]]
else: co=[p for p in ref_pts if p["hadm_id"] in t_hids and p["hadm_id"]!=pat["hadm_id"]]
if len(co)<5: continue
centroid=compute_centroid(co)
if g and g in by_gal and len(by_gal.get(g,[]))>20:
w={k:1.0/s for k,s in compute_sigma(by_gal[g]).items()}
else: w={k:1.0/s for k,s in global_sigma.items()}
d=td(pat,centroid,w)
if d is None: continue
ds=sorted([v for v in [td(p,centroid,w) for p in co] if v is not None])
if not ds: continue
eps=ds[min(len(ds)-1,len(ds)//5)]
orbit=[p for p in co if td(p,centroid,w) is not None and abs(td(p,centroid,w)-d)<max(eps,0.01)]
if len(orbit)<3: continue
preds.append({"a":pat["died"],"p":sum(1 for p in orbit if p["died"])/len(orbit),
"g":g,"hadm_id":pat["hadm_id"]})
elapsed=time.time()-t0
sys.stdout.write(f"\r {label}: DONE {len(preds):,} preds {elapsed/60:.1f}min{' '*20}\n");return preds
def main():
T0=time.time()
print(f"\n{'='*76}")
print(f" FESTUNG TEIL 3 — Letzte drei Steine")
print(f" 1: Calibration 2: Fair Comparison 3: Summary")
print(f"{'='*76}\n")
all_pts=load_all_icu();assign_galaxies(all_pts)
by_gal=defaultdict(list)
for p in all_pts:
if p["galaxy"]: by_gal[p["galaxy"]].append(p)
pt_idx = {p["hadm_id"]: p for p in all_pts}
print(f"\n Loading therapies...")
therapy_hids={}
for tk in THERAPIES:
therapy_hids[tk]=load_therapy_hadmids(tk)
print(f" {THERAPIES[tk]['label']:20s}: {len(therapy_hids[tk]):6d}")
results = {}
# ── Full pop LOO ───────────────────────────────────────────────
print(f"\n Running full-population LOO...")
p_full = run_loo(all_pts, all_pts, therapy_hids, by_gal, "FULL-POP")
a_td = auc_fast(p_full)
print(f" * TD full pop: AUC {a_td:.4f} n={len(p_full):,}")
# ══════════════════════════════════════════════════════════════
# 1. CALIBRATION (10 Dezile)
# ══════════════════════════════════════════════════════════════
print(f"\n{'='*76}")
print(f" 1. CALIBRATION — Predicted vs Observed (10 Dezile)")
print(f"{'='*76}\n")
sorted_preds = sorted(p_full, key=lambda x: x["p"])
n = len(sorted_preds)
bin_size = n // 10
print(f" {'Decile':>7s} {'Pred Mean':>10s} {'Obs Rate':>10s} {'n':>6s} {'Diff':>8s} Visual")
print(f" {'-'*65}")
cal_bins = []
total_brier = 0
for i in range(10):
start = i * bin_size
end = (i+1) * bin_size if i < 9 else n
bin_preds = sorted_preds[start:end]
pred_mean = sum(p["p"] for p in bin_preds) / len(bin_preds)
obs_rate = sum(p["a"] for p in bin_preds) / len(bin_preds)
diff = pred_mean - obs_rate
# Brier score contribution
for p in bin_preds:
total_brier += (p["p"] - p["a"]) ** 2
bar_pred = "#" * int(pred_mean * 40)
bar_obs = "." * int(obs_rate * 40)
print(f" {i+1:>7d} {pred_mean:10.4f} {obs_rate:10.4f} {len(bin_preds):6d} {diff:+8.4f} P:{bar_pred}")
print(f" {'':>7s} {'':>10s} {'':>10s} {'':>6s} {'':>8s} O:{bar_obs}")
cal_bins.append({"decile": i+1, "pred": round(pred_mean,4), "obs": round(obs_rate,4),
"n": len(bin_preds), "diff": round(diff,4)})
brier = total_brier / n
max_diff = max(abs(b["diff"]) for b in cal_bins)
mean_diff = sum(abs(b["diff"]) for b in cal_bins) / 10
print(f"\n Brier Score: {brier:.4f}")
print(f" Max |pred-obs|: {max_diff:.4f}")
print(f" Mean |pred-obs|: {mean_diff:.4f}")
if max_diff < 0.10:
print(f" GOOD CALIBRATION. Max deviation < 0.10.")
elif max_diff < 0.15:
print(f" ACCEPTABLE CALIBRATION. Some deviation in extreme deciles.")
else:
print(f" POOR CALIBRATION. Needs recalibration (Platt scaling or isotonic).")
# SAPS calibration for comparison
saps_preds_all = [{"a": p["died"], "p": p["saps_prob"]}
for p in all_pts if p.get("saps_prob") is not None]
saps_sorted = sorted(saps_preds_all, key=lambda x: x["p"])
saps_brier = sum((p["p"] - p["a"])**2 for p in saps_sorted) / len(saps_sorted)
print(f"\n SAPS-II Brier Score: {saps_brier:.4f}")
print(f" TD Brier Score: {brier:.4f}")
if brier < saps_brier:
print(f" TD better calibrated than SAPS-II ({brier:.4f} < {saps_brier:.4f})")
else:
print(f" SAPS-II better calibrated ({saps_brier:.4f} < {brier:.4f})")
results["calibration"] = {
"bins": cal_bins, "brier_td": round(brier,4), "brier_saps": round(saps_brier,4),
"max_diff": round(max_diff,4), "mean_diff": round(mean_diff,4)
}
# ══════════════════════════════════════════════════════════════
# 2. FAIR COMPARISON (gleiche Population)
# ══════════════════════════════════════════════════════════════
print(f"\n{'='*76}")
print(f" 2. FAIR COMPARISON — alle Modelle auf exakt gleicher Population")
print(f"{'='*76}\n")
# Get unique hadm_ids that got TD predictions
td_hadmids = set(p["hadm_id"] for p in p_full)
print(f" TD predictions from {len(td_hadmids):,} unique patients")
# For patients with multiple TD predictions, average them
td_by_patient = defaultdict(list)
for p in p_full:
td_by_patient[p["hadm_id"]].append(p["p"])
# Build matched population: only patients with TD predictions AND saps_prob
matched = []
for hid in td_hadmids:
pt = pt_idx.get(hid)
if pt and pt.get("saps_prob") is not None:
matched.append({
"hadm_id": hid,
"died": pt["died"],
"saps_prob": pt["saps_prob"],
"td_prob": sum(td_by_patient[hid]) / len(td_by_patient[hid]),
"sapsii": pt.get("sapsii"),
})
print(f" Matched population: {len(matched):,} patients (have both TD and SAPS)")
# AUC on matched population
td_matched = [{"a": p["died"], "p": p["td_prob"]} for p in matched]
saps_matched = [{"a": p["died"], "p": p["saps_prob"]} for p in matched]
a_td_m = auc_fast(td_matched)
a_saps_m = auc_fast(saps_matched)
print(f"\n ON MATCHED POPULATION ({len(matched):,} patients):")
print(f" TD: AUC {a_td_m:.4f}")
print(f" SAPS: AUC {a_saps_m:.4f}")
print(f" Delta: {a_td_m - a_saps_m:+.4f}")
# LogReg on matched population
try:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
import numpy as np
# Feature matrix for matched patients only
medians = {}
for k in PARAM_KEYS:
vals = [p[k] for p in all_pts if p.get(k) is not None]
medians[k] = sorted(vals)[len(vals)//2] if vals else 0
X_m = []
y_m = []
for p in matched:
pt = pt_idx[p["hadm_id"]]
row = [pt[k] if pt.get(k) is not None else medians[k] for k in PARAM_KEYS]
X_m.append(row)
y_m.append(p["died"])
X_m = np.array(X_m, dtype=float)
y_m = np.array(y_m, dtype=int)
pipe = Pipeline([("scaler", StandardScaler()), ("lr", LogisticRegression(max_iter=1000, random_state=42))])
y_prob_m = cross_val_predict(pipe, X_m, y_m, cv=5, method="predict_proba")[:, 1]
lr_matched = [{"a": int(y_m[i]), "p": float(y_prob_m[i])} for i in range(len(y_m))]
a_lr_m = auc_fast(lr_matched)
print(f" LogReg: AUC {a_lr_m:.4f}")
print(f"\n FAIR THREE-WAY COMPARISON (n={len(matched):,}, identical patients):")
print(f" {'Method':25s} {'AUC':>7s} {'vs TD':>8s}")
print(f" {'-'*45}")
print(f" {'Therapeutic Distance':25s} {a_td_m:7.4f} {'ref':>8s}")
print(f" {'SAPS-II':25s} {a_saps_m:7.4f} {a_saps_m-a_td_m:+8.4f}")
print(f" {'Logistic Regression':25s} {a_lr_m:7.4f} {a_lr_m-a_td_m:+8.4f}")
# Bootstrap on matched
print(f"\n Bootstrap CI on matched population (1000 resamples)...")
random.seed(42)
deltas = []
for _ in range(1000):
idx = [random.randint(0, len(matched)-1) for _ in range(len(matched))]
td_s = [td_matched[i] for i in idx]
saps_s = [saps_matched[i] for i in idx]
deltas.append(auc_fast(td_s) - auc_fast(saps_s))
deltas.sort()
ci_lo = deltas[25]
ci_hi = deltas[975]
p_val = sum(1 for d in deltas if d <= 0) / 1000
print(f" TD vs SAPS (matched): +{a_td_m-a_saps_m:.4f} (95% CI {ci_lo:+.4f} to {ci_hi:+.4f})")
print(f" p-value: {'<0.001' if p_val < 0.001 else f'{p_val:.3f}'}")
results["fair_comparison"] = {
"n_matched": len(matched),
"td_auc": round(a_td_m, 4), "saps_auc": round(a_saps_m, 4), "lr_auc": round(a_lr_m, 4),
"delta_td_saps": round(a_td_m - a_saps_m, 4),
"delta_td_lr": round(a_td_m - a_lr_m, 4),
"ci": [round(ci_lo, 4), round(ci_hi, 4)],
"p_value": round(p_val, 4) if p_val >= 0.001 else 0.001
}
except ImportError:
print(f" sklearn not installed — LogReg skipped")
results["fair_comparison"] = {
"n_matched": len(matched),
"td_auc": round(a_td_m, 4), "saps_auc": round(a_saps_m, 4),
"delta": round(a_td_m - a_saps_m, 4)
}
# ══════════════════════════════════════════════════════════════
# 3. ZUSAMMENFASSUNG — Positionierung
# ══════════════════════════════════════════════════════════════
print(f"\n{'='*76}")
print(f" 3. POSITIONIERUNG — Der eine Satz")
print(f"{'='*76}\n")
print(f" SAPS-II sagt: 'Dieser Patient hat 35% Sterberisiko.'")
print(f" LogReg sagt: 'Dieser Patient hat 28% Sterberisiko.'")
print(f" TD sagt: 'Dieser Patient ist 0.7 Einheiten vom Vasopressin-Zentroid")
print(f" und 1.2 Einheiten vom CRRT-Zentroid entfernt.")
print(f" Patienten in seinem Orbit mit Vasopressin: 25% Mortalitaet.")
print(f" Patienten in seinem Orbit mit CRRT: 55% Mortalitaet.'")
print(f"")
print(f" SAPS und LogReg liefern globales Risiko.")
print(f" TD liefert therapiespezifische Risikostruktur.")
print(f"")
print(f" Das ist der Unterschied. Das ist die Positionierung.")
print(f" Nicht: 'besserer Score'. Sondern: 'andere Information'.")
# ══════════════════════════════════════════════════════════════
# FINAL SUMMARY
# ══════════════════════════════════════════════════════════════
tm = (time.time()-T0)/60
print(f"\n{'='*76}")
print(f" FESTUNG TEIL 3 — ERGEBNIS ({tm:.0f} min)")
print(f"{'='*76}\n")
print(f" 1. Calibration:")
print(f" Brier TD: {brier:.4f} SAPS: {saps_brier:.4f}")
print(f" Max |pred-obs|: {max_diff:.4f} Mean: {mean_diff:.4f}")
print(f"")
print(f" 2. Fair Comparison ({len(matched):,} matched patients):")
print(f" TD: {a_td_m:.4f}")
print(f" SAPS: {a_saps_m:.4f} (Delta {a_td_m-a_saps_m:+.4f})")
if "lr_auc" in results.get("fair_comparison", {}):
print(f" LogReg: {results['fair_comparison']['lr_auc']:.4f} (Delta {a_td_m-results['fair_comparison']['lr_auc']:+.4f})")
print(f"")
print(f" 3. Positionierung:")
print(f" TD != besserer Score")
print(f" TD = therapiespezifische Risikostruktur")
with open("festung_teil3.json", "w") as f:
json.dump(results, f, indent=2, default=str)
print(f"\n -> festung_teil3.json saved")
print(f"\n{'='*76}\n")
if __name__ == "__main__":
main()