532 lines
28 KiB
Python
532 lines
28 KiB
Python
"""
|
|
FESTUNG TEIL 3 — Die letzten drei Steine
|
|
|
|
1. CALIBRATION (10 Dezile: predicted vs observed)
|
|
2. FAIRER VERGLEICH (SAPS + LogReg auf exakt gleicher Population wie TD)
|
|
3. Positionierung (Zusammenfassung für Paper)
|
|
|
|
Benutzt die Full-Pop-Prädiktionen aus dem vorherigen Lauf.
|
|
Basilakis 2026 · chicxulub.ai
|
|
"""
|
|
|
|
import json, os, sys, math, time, random
|
|
from collections import defaultdict
|
|
|
|
# PostgreSQL connection string (libpq DSN). Override with env var.
|
|
# e.g. "host=localhost port=5432 dbname=mimic user=postgres password=..."
|
|
PG_DSN = os.environ.get("MIMIC_PG_DSN", "dbname=mimic3")
|
|
# Schema holding the stock MIMIC-III v1.3 tables (admissions, icustays,
|
|
# labevents, chartevents, inputevents_mv, inputevents_cv, prescriptions,
|
|
# diagnoses_icd, d_items, ...).
|
|
MIMIC_SCHEMA = os.environ.get("MIMIC_SCHEMA", "mimiciii")
|
|
# Schema holding the locally built derived tables (sapsii, sepsis3, ...);
|
|
# see sql/schemas.sql. Defaults to the same schema as MIMIC-III itself.
|
|
DERIVED_SCHEMA = os.environ.get("DERIVED_SCHEMA", MIMIC_SCHEMA)
|
|
|
|
# MIMIC-III stores Norepinephrine under different itemids in CareVue
|
|
# (inputevents_cv: 30047, 30120) and MetaVision (inputevents_mv: 221906).
|
|
NE_ITEMIDS_MV = [221906]
|
|
NE_ITEMIDS_CV = [30047, 30120]
|
|
|
|
SAPS_WINDOW = 10
|
|
PARAM_KEYS = ["lactate","creatinine","ph","troponin","hemoglobin",
|
|
"heart_rate","map_bp","spo2","temperature","ne_dose"]
|
|
|
|
THERAPIES = {
|
|
"vasopressin":{"drugs_input":["Vasopressin"],"drugs_rx":["Vasopressin"],"label":"Vasopressin"},
|
|
"crrt":{"drugs_input":["CRRT","CVVH","CVVHD","CVVHDF"],"label":"CRRT"},
|
|
"esmolol":{"drugs_rx":["Esmolol"],"drugs_input":["Esmolol"],"label":"Esmolol"},
|
|
"furosemide":{"drugs_rx":["Furosemide","Lasix"],"label":"Furosemide"},
|
|
"ne_high":{"ne_min":0.5,"label":"NE >0.5"},
|
|
"bicarbonate":{"drugs_rx":["Bicarbonate","Sodium Bicarbonate"],"drugs_input":["Sodium Bicarbonate"],"label":"Bicarbonate"},
|
|
"epinephrine":{"drugs_input":["Epinephrine"],"label":"Epinephrine"},
|
|
"amiodarone":{"drugs_rx":["Amiodarone"],"drugs_input":["Amiodarone"],"label":"Amiodarone"},
|
|
"transfusion":{"drugs_input":["Packed Red Blood Cells"],"label":"pRBC"},
|
|
}
|
|
SYNDROME_ICDS = {
|
|
"sepsis":{"icd_10":["A410","A411","A412","A413","A414","A418","A419","R6520","R6521"],"icd_9":["99591","99592","78552"]},
|
|
"aki":{"icd_10":["N170","N171","N172","N178","N179"],"icd_9":["5849","5845","5846","5847","5848"]},
|
|
"cardiogenic_shock":{"icd_10":["R570"],"icd_9":["78551"]},
|
|
"post_cardiac_arrest":{"icd_10":["I462","I469"],"icd_9":["4275"]},
|
|
"ards":{"icd_10":["J80"],"icd_9":["51882"]},
|
|
"acute_mi":{"icd_10":["I210","I211","I212","I213","I214","I219"],"icd_9":["41000","41001","41010","41011","41090","41091"]},
|
|
"liver_failure":{"icd_10":["K7200","K7201"],"icd_9":["5724"]},
|
|
"gi_bleeding":{"icd_10":["K920","K921","K922","K2501","K2521","K2541","K2561"],"icd_9":["5780","5781","5789","53121","53221","53321"]},
|
|
"stroke":{"icd_10":["I630","I631","I632","I633","I634","I635","I638","I639","I610","I611","I612","I619"],"icd_9":["43301","43311","43321","43331","43381","43391","431"]},
|
|
"pe":{"icd_10":["I2699","I2609","I2692","I2602"],"icd_9":["41519","41511","41512","41513"]},
|
|
"dka":{"icd_10":["E1010","E1011","E1110","E1111"],"icd_9":["25010","25011","25012","25013"]},
|
|
"heart_failure":{"icd_10":["I500","I501","I502","I503","I504","I509","I110"],"icd_9":["4280","4281","42820","42821","42830","42831","4289"]},
|
|
"pneumonia":{"icd_10":["J189","J180","J181","J188","J13","J14","J150","J151","J159"],"icd_9":["481","482","485","486"]},
|
|
"copd":{"icd_10":["J440","J441","J449"],"icd_9":["4910","4911","49120","49121","496"]},
|
|
"afib":{"icd_10":["I480","I481","I482","I4891"],"icd_9":["42731"]},
|
|
"post_cardiac_surgery":{"icd_10":["Z951","Z952","Z953","Z954"],"icd_9":["V4581","V4582","V4583"]},
|
|
}
|
|
GALAXY_PRIORITY = ["sepsis","cardiogenic_shock","post_cardiac_arrest","ards",
|
|
"acute_mi","aki","liver_failure","gi_bleeding","stroke","pe","dka",
|
|
"heart_failure","pneumonia","copd","afib","post_cardiac_surgery"]
|
|
|
|
_PG_CONN = None
|
|
def _pg_conn():
|
|
global _PG_CONN
|
|
if _PG_CONN is None or getattr(_PG_CONN, "closed", 0):
|
|
import psycopg2
|
|
_PG_CONN = psycopg2.connect(PG_DSN)
|
|
_PG_CONN.set_session(readonly=True, autocommit=True)
|
|
return _PG_CONN
|
|
|
|
def run_pg(sql):
|
|
"""Execute a read-only SQL query and return rows as list[dict]."""
|
|
import psycopg2.extras
|
|
conn = _pg_conn()
|
|
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
|
cur.execute(sql)
|
|
if cur.description is None:
|
|
return []
|
|
return [dict(r) for r in cur.fetchall()]
|
|
|
|
def auc_fast(preds):
|
|
if not preds: return 0.5
|
|
pos=sorted([p["p"] for p in preds if p["a"]==1])
|
|
neg=sorted([p["p"] for p in preds if p["a"]==0])
|
|
if not pos or not neg: return 0.5
|
|
conc=0;j=0;ties=0
|
|
for pv in pos:
|
|
while j<len(neg) and neg[j]<pv: j+=1
|
|
conc+=j;k=j
|
|
while k<len(neg) and neg[k]==pv: k+=1
|
|
ties+=(k-j)
|
|
return (conc+0.5*ties)/(len(pos)*len(neg))
|
|
|
|
def compute_centroid(pts):
|
|
s=defaultdict(float);c=defaultdict(int)
|
|
for p in pts:
|
|
for k in PARAM_KEYS:
|
|
if p.get(k) is not None: s[k]+=p[k];c[k]+=1
|
|
return {k:(s[k]/c[k] if c[k]>0 else None) for k in PARAM_KEYS}
|
|
|
|
def compute_sigma(pts):
|
|
vals=defaultdict(list)
|
|
for p in pts:
|
|
for k in PARAM_KEYS:
|
|
if p.get(k) is not None: vals[k].append(p[k])
|
|
sig={}
|
|
for k in PARAM_KEYS:
|
|
if len(vals[k])>1:
|
|
m=sum(vals[k])/len(vals[k]);v=sum((x-m)**2 for x in vals[k])/(len(vals[k])-1)
|
|
sig[k]=max(math.sqrt(v),1e-6)
|
|
else: sig[k]=1.0
|
|
return sig
|
|
|
|
def td(pv,centroid,weights):
|
|
d=0.0;n=0
|
|
for k in PARAM_KEYS:
|
|
p,c,w=pv.get(k),centroid.get(k),weights.get(k,0)
|
|
if p is not None and c is not None and w>0: d+=w*abs(p-c);n+=1
|
|
return d/n if n>0 else None
|
|
|
|
def load_all_icu():
|
|
print(" Loading ALL ICU patients...")
|
|
ne_mv = ",".join(str(i) for i in NE_ITEMIDS_MV)
|
|
ne_cv = ",".join(str(i) for i in NE_ITEMIDS_CV)
|
|
sql=f"""WITH icu_pts AS (
|
|
SELECT DISTINCT a.hadm_id,a.hospital_expire_flag AS died,s.sapsii,icu.intime,
|
|
s.sapsii_prob AS saps_prob
|
|
FROM {MIMIC_SCHEMA}.admissions a
|
|
JOIN {MIMIC_SCHEMA}.icustays icu ON a.hadm_id=icu.hadm_id
|
|
JOIN {DERIVED_SCHEMA}.sapsii s ON icu.icustay_id=s.icustay_id
|
|
WHERE s.sapsii BETWEEN 20 AND 90),
|
|
l_lac AS (SELECT le.hadm_id,MAX(le.valuenum) AS val FROM {MIMIC_SCHEMA}.labevents le JOIN icu_pts ip ON le.hadm_id=ip.hadm_id WHERE le.itemid=50813 AND le.valuenum IS NOT NULL AND le.charttime BETWEEN ip.intime AND ip.intime + INTERVAL '24 hours' GROUP BY le.hadm_id),
|
|
l_krea AS (SELECT le.hadm_id,MAX(le.valuenum) AS val FROM {MIMIC_SCHEMA}.labevents le JOIN icu_pts ip ON le.hadm_id=ip.hadm_id WHERE le.itemid=50912 AND le.valuenum IS NOT NULL AND le.charttime BETWEEN ip.intime AND ip.intime + INTERVAL '24 hours' GROUP BY le.hadm_id),
|
|
l_ph AS (SELECT le.hadm_id,MIN(le.valuenum) AS val FROM {MIMIC_SCHEMA}.labevents le JOIN icu_pts ip ON le.hadm_id=ip.hadm_id WHERE le.itemid IN (50820,50831) AND le.valuenum IS NOT NULL AND le.charttime BETWEEN ip.intime AND ip.intime + INTERVAL '24 hours' GROUP BY le.hadm_id),
|
|
l_trop AS (SELECT le.hadm_id,MAX(le.valuenum) AS val FROM {MIMIC_SCHEMA}.labevents le JOIN icu_pts ip ON le.hadm_id=ip.hadm_id WHERE le.itemid IN (51002,51003) AND le.valuenum IS NOT NULL AND le.charttime BETWEEN ip.intime AND ip.intime + INTERVAL '24 hours' GROUP BY le.hadm_id),
|
|
l_hb AS (SELECT le.hadm_id,MIN(le.valuenum) AS val FROM {MIMIC_SCHEMA}.labevents le JOIN icu_pts ip ON le.hadm_id=ip.hadm_id WHERE le.itemid=51222 AND le.valuenum IS NOT NULL AND le.charttime BETWEEN ip.intime AND ip.intime + INTERVAL '24 hours' GROUP BY le.hadm_id),
|
|
c_hr AS (SELECT ce.hadm_id,MAX(ce.valuenum) AS val FROM {MIMIC_SCHEMA}.chartevents ce JOIN icu_pts ip ON ce.hadm_id=ip.hadm_id JOIN {MIMIC_SCHEMA}.icustays icu ON ce.icustay_id=icu.icustay_id WHERE ce.itemid IN (211,220045) AND ce.valuenum BETWEEN 20 AND 250 AND ce.charttime BETWEEN icu.intime AND icu.intime + INTERVAL '24 hours' GROUP BY ce.hadm_id),
|
|
c_map AS (SELECT ce.hadm_id,MIN(ce.valuenum) AS val FROM {MIMIC_SCHEMA}.chartevents ce JOIN icu_pts ip ON ce.hadm_id=ip.hadm_id JOIN {MIMIC_SCHEMA}.icustays icu ON ce.icustay_id=icu.icustay_id WHERE ce.itemid IN (52,456,6702,220052,220181,225312) AND ce.valuenum BETWEEN 20 AND 200 AND ce.charttime BETWEEN icu.intime AND icu.intime + INTERVAL '24 hours' GROUP BY ce.hadm_id),
|
|
c_spo2 AS (SELECT ce.hadm_id,MIN(ce.valuenum) AS val FROM {MIMIC_SCHEMA}.chartevents ce JOIN icu_pts ip ON ce.hadm_id=ip.hadm_id JOIN {MIMIC_SCHEMA}.icustays icu ON ce.icustay_id=icu.icustay_id WHERE ce.itemid IN (646,220277) AND ce.valuenum BETWEEN 50 AND 100 AND ce.charttime BETWEEN icu.intime AND icu.intime + INTERVAL '24 hours' GROUP BY ce.hadm_id),
|
|
-- Temperature: pull all four MIMIC-III itemids (676/223762 nominally
|
|
-- Celsius, 678/223761 nominally Fahrenheit) and decide the unit from
|
|
-- the value itself. Plausible body temperature in C is ~28..43 and
|
|
-- in F is ~82..110; the two ranges don't overlap, so a value in the
|
|
-- F band can be safely converted to C even if it was charted under a
|
|
-- "Celsius" itemid (and vice versa). Anything outside both bands is
|
|
-- treated as sensor noise and dropped.
|
|
c_temp AS (
|
|
SELECT ce.hadm_id,
|
|
MIN(CASE
|
|
WHEN ce.valuenum BETWEEN 28 AND 43 THEN ce.valuenum
|
|
WHEN ce.valuenum BETWEEN 82 AND 110 THEN (ce.valuenum - 32.0) / 1.8
|
|
END) AS val
|
|
FROM {MIMIC_SCHEMA}.chartevents ce
|
|
JOIN icu_pts ip ON ce.hadm_id=ip.hadm_id
|
|
JOIN {MIMIC_SCHEMA}.icustays icu ON ce.icustay_id=icu.icustay_id
|
|
WHERE ce.itemid IN (676, 223762, 678, 223761)
|
|
AND ce.valuenum IS NOT NULL
|
|
AND (ce.valuenum BETWEEN 28 AND 43 OR ce.valuenum BETWEEN 82 AND 110)
|
|
AND ce.charttime BETWEEN icu.intime AND icu.intime + INTERVAL '24 hours'
|
|
GROUP BY ce.hadm_id),
|
|
ne_all AS (
|
|
SELECT ie.hadm_id, ie.icustay_id, ie.rate, ie.starttime AS evttime
|
|
FROM {MIMIC_SCHEMA}.inputevents_mv ie
|
|
WHERE ie.itemid IN ({ne_mv}) AND ie.rate>0
|
|
UNION ALL
|
|
SELECT ie.hadm_id, ie.icustay_id, ie.rate, ie.charttime AS evttime
|
|
FROM {MIMIC_SCHEMA}.inputevents_cv ie
|
|
WHERE ie.itemid IN ({ne_cv}) AND ie.rate>0),
|
|
ne AS (SELECT ie.hadm_id,MAX(ie.rate) AS val FROM ne_all ie JOIN icu_pts ip ON ie.hadm_id=ip.hadm_id JOIN {MIMIC_SCHEMA}.icustays icu ON ie.icustay_id=icu.icustay_id WHERE ie.evttime BETWEEN icu.intime AND icu.intime + INTERVAL '24 hours' GROUP BY ie.hadm_id)
|
|
SELECT ip.hadm_id,ip.died,ip.sapsii,ip.saps_prob,
|
|
ll.val AS lactate,lk.val AS creatinine,lp.val AS ph,lt.val AS troponin,lh.val AS hemoglobin,
|
|
chr_.val AS heart_rate,cma.val AS map_bp,csp.val AS spo2,cte.val AS temperature,ne.val AS ne_dose
|
|
FROM icu_pts ip
|
|
LEFT JOIN l_lac ll ON ip.hadm_id=ll.hadm_id LEFT JOIN l_krea lk ON ip.hadm_id=lk.hadm_id
|
|
LEFT JOIN l_ph lp ON ip.hadm_id=lp.hadm_id LEFT JOIN l_trop lt ON ip.hadm_id=lt.hadm_id
|
|
LEFT JOIN l_hb lh ON ip.hadm_id=lh.hadm_id LEFT JOIN c_hr chr_ ON ip.hadm_id=chr_.hadm_id
|
|
LEFT JOIN c_map cma ON ip.hadm_id=cma.hadm_id LEFT JOIN c_spo2 csp ON ip.hadm_id=csp.hadm_id
|
|
LEFT JOIN c_temp cte ON ip.hadm_id=cte.hadm_id LEFT JOIN ne ON ip.hadm_id=ne.hadm_id"""
|
|
rows=run_pg(sql)
|
|
pts=[{k:r.get(k) for k in ["hadm_id","died","sapsii","saps_prob"]+PARAM_KEYS}
|
|
for r in rows if sum(1 for k in PARAM_KEYS if r.get(k) is not None)>=3 and r.get("died") is not None]
|
|
print(f" -> {len(pts)} patients"); return pts
|
|
|
|
def assign_galaxies(pts):
|
|
print(" Assigning syndromes...")
|
|
hids=[p["hadm_id"] for p in pts];ps=defaultdict(set)
|
|
for i in range(0,len(hids),10000):
|
|
chunk=hids[i:i+10000]
|
|
# MIMIC-III v1.3 only carries ICD-9 codes (column `icd9_code`).
|
|
for r in run_pg(f"SELECT hadm_id,icd9_code FROM {MIMIC_SCHEMA}.diagnoses_icd WHERE hadm_id IN ({','.join(str(h) for h in chunk)})"):
|
|
code = r.get("icd9_code")
|
|
if code is None: continue
|
|
for sk,sd in SYNDROME_ICDS.items():
|
|
for rc in sd.get("icd_9",[]):
|
|
if code.startswith(rc): ps[r["hadm_id"]].add(sk);break
|
|
for p in pts:
|
|
p["galaxy"]=None
|
|
for g in GALAXY_PRIORITY:
|
|
if g in ps.get(p["hadm_id"],set()): p["galaxy"]=g;break
|
|
|
|
def load_therapy_hadmids(tkey):
|
|
t=THERAPIES[tkey]
|
|
if tkey=="ne_high":
|
|
ne_mv = ",".join(str(i) for i in NE_ITEMIDS_MV)
|
|
ne_cv = ",".join(str(i) for i in NE_ITEMIDS_CV)
|
|
sql = f"""
|
|
SELECT DISTINCT ie.hadm_id
|
|
FROM {MIMIC_SCHEMA}.inputevents_mv ie
|
|
JOIN {MIMIC_SCHEMA}.icustays icu ON ie.icustay_id=icu.icustay_id
|
|
WHERE ie.itemid IN ({ne_mv}) AND ie.rate>=0.5
|
|
AND ie.starttime BETWEEN icu.intime AND icu.intime + INTERVAL '24 hours'
|
|
UNION
|
|
SELECT DISTINCT ie.hadm_id
|
|
FROM {MIMIC_SCHEMA}.inputevents_cv ie
|
|
JOIN {MIMIC_SCHEMA}.icustays icu ON ie.icustay_id=icu.icustay_id
|
|
WHERE ie.itemid IN ({ne_cv}) AND ie.rate>=0.5
|
|
AND ie.charttime BETWEEN icu.intime AND icu.intime + INTERVAL '24 hours'
|
|
"""
|
|
return set(r["hadm_id"] for r in run_pg(sql))
|
|
clauses=[]
|
|
# MIMIC-III splits inputevents across MetaVision (starttime) and CareVue
|
|
# (charttime); we have to query both and UNION the hadm_ids.
|
|
for d in t.get("drugs_input",[]):
|
|
clauses.append(f"SELECT DISTINCT ie.hadm_id FROM {MIMIC_SCHEMA}.inputevents_mv ie JOIN {MIMIC_SCHEMA}.d_items di ON ie.itemid=di.itemid JOIN {MIMIC_SCHEMA}.icustays icu ON ie.icustay_id=icu.icustay_id WHERE di.label ILIKE '%{d}%' AND ie.starttime BETWEEN icu.intime AND icu.intime + INTERVAL '24 hours'")
|
|
clauses.append(f"SELECT DISTINCT ie.hadm_id FROM {MIMIC_SCHEMA}.inputevents_cv ie JOIN {MIMIC_SCHEMA}.d_items di ON ie.itemid=di.itemid JOIN {MIMIC_SCHEMA}.icustays icu ON ie.icustay_id=icu.icustay_id WHERE di.label ILIKE '%{d}%' AND ie.charttime BETWEEN icu.intime AND icu.intime + INTERVAL '24 hours'")
|
|
# MIMIC-III prescriptions uses DATE-precision `startdate` (not `starttime`).
|
|
for d in t.get("drugs_rx",[]):
|
|
clauses.append(f"SELECT DISTINCT p.hadm_id FROM {MIMIC_SCHEMA}.prescriptions p JOIN {MIMIC_SCHEMA}.icustays icu ON p.hadm_id=icu.hadm_id WHERE p.drug ILIKE '%{d}%' AND p.startdate BETWEEN icu.intime AND icu.intime + INTERVAL '24 hours'")
|
|
if not clauses: return set()
|
|
return set(r["hadm_id"] for r in run_pg(" UNION ".join(clauses)))
|
|
|
|
def run_loo(test_pts,ref_pts,therapy_hids,by_gal,label):
|
|
"""Returns list of {a, p, g, hadm_id} — includes hadm_id for fair comparison."""
|
|
global_sigma=compute_sigma(ref_pts);preds=[];t0=time.time()
|
|
for i,pat in enumerate(test_pts):
|
|
if i%5000==0 and i>0:
|
|
elapsed=time.time()-t0;eta=(len(test_pts)-i)/(i/elapsed)/60 if elapsed>0 else 0
|
|
sys.stdout.write(f"\r {label}: {i:,}/{len(test_pts):,} preds={len(preds):,} ~{eta:.0f}m");sys.stdout.flush()
|
|
therapies=[tk for tk,hids in therapy_hids.items() if pat["hadm_id"] in hids]
|
|
if not therapies: continue
|
|
g=pat.get("galaxy");saps=pat.get("sapsii",50)
|
|
for tk in therapies[:3]:
|
|
t_hids=therapy_hids[tk]
|
|
if g and g in by_gal:
|
|
slo,shi=saps-SAPS_WINDOW,saps+SAPS_WINDOW;gpts=by_gal[g]
|
|
co=[p for p in gpts if p["hadm_id"] in t_hids and p["hadm_id"]!=pat["hadm_id"] and p.get("sapsii") and slo<=p["sapsii"]<=shi]
|
|
if len(co)<10: co=[p for p in gpts if p["hadm_id"] in t_hids and p["hadm_id"]!=pat["hadm_id"]]
|
|
if len(co)<10: co=[p for p in ref_pts if p["hadm_id"] in t_hids and p["hadm_id"]!=pat["hadm_id"]]
|
|
else: co=[p for p in ref_pts if p["hadm_id"] in t_hids and p["hadm_id"]!=pat["hadm_id"]]
|
|
if len(co)<5: continue
|
|
centroid=compute_centroid(co)
|
|
if g and g in by_gal and len(by_gal.get(g,[]))>20:
|
|
w={k:1.0/s for k,s in compute_sigma(by_gal[g]).items()}
|
|
else: w={k:1.0/s for k,s in global_sigma.items()}
|
|
d=td(pat,centroid,w)
|
|
if d is None: continue
|
|
ds=sorted([v for v in [td(p,centroid,w) for p in co] if v is not None])
|
|
if not ds: continue
|
|
eps=ds[min(len(ds)-1,len(ds)//5)]
|
|
orbit=[p for p in co if td(p,centroid,w) is not None and abs(td(p,centroid,w)-d)<max(eps,0.01)]
|
|
if len(orbit)<3: continue
|
|
preds.append({"a":pat["died"],"p":sum(1 for p in orbit if p["died"])/len(orbit),
|
|
"g":g,"hadm_id":pat["hadm_id"]})
|
|
elapsed=time.time()-t0
|
|
sys.stdout.write(f"\r {label}: DONE {len(preds):,} preds {elapsed/60:.1f}min{' '*20}\n");return preds
|
|
|
|
|
|
def main():
|
|
T0=time.time()
|
|
print(f"\n{'='*76}")
|
|
print(f" FESTUNG TEIL 3 — Letzte drei Steine")
|
|
print(f" 1: Calibration 2: Fair Comparison 3: Summary")
|
|
print(f"{'='*76}\n")
|
|
|
|
all_pts=load_all_icu();assign_galaxies(all_pts)
|
|
by_gal=defaultdict(list)
|
|
for p in all_pts:
|
|
if p["galaxy"]: by_gal[p["galaxy"]].append(p)
|
|
pt_idx = {p["hadm_id"]: p for p in all_pts}
|
|
|
|
print(f"\n Loading therapies...")
|
|
therapy_hids={}
|
|
for tk in THERAPIES:
|
|
therapy_hids[tk]=load_therapy_hadmids(tk)
|
|
print(f" {THERAPIES[tk]['label']:20s}: {len(therapy_hids[tk]):6d}")
|
|
|
|
results = {}
|
|
|
|
# ── Full pop LOO ───────────────────────────────────────────────
|
|
print(f"\n Running full-population LOO...")
|
|
p_full = run_loo(all_pts, all_pts, therapy_hids, by_gal, "FULL-POP")
|
|
a_td = auc_fast(p_full)
|
|
print(f" * TD full pop: AUC {a_td:.4f} n={len(p_full):,}")
|
|
|
|
# ══════════════════════════════════════════════════════════════
|
|
# 1. CALIBRATION (10 Dezile)
|
|
# ══════════════════════════════════════════════════════════════
|
|
print(f"\n{'='*76}")
|
|
print(f" 1. CALIBRATION — Predicted vs Observed (10 Dezile)")
|
|
print(f"{'='*76}\n")
|
|
|
|
sorted_preds = sorted(p_full, key=lambda x: x["p"])
|
|
n = len(sorted_preds)
|
|
bin_size = n // 10
|
|
|
|
print(f" {'Decile':>7s} {'Pred Mean':>10s} {'Obs Rate':>10s} {'n':>6s} {'Diff':>8s} Visual")
|
|
print(f" {'-'*65}")
|
|
|
|
cal_bins = []
|
|
total_brier = 0
|
|
for i in range(10):
|
|
start = i * bin_size
|
|
end = (i+1) * bin_size if i < 9 else n
|
|
bin_preds = sorted_preds[start:end]
|
|
pred_mean = sum(p["p"] for p in bin_preds) / len(bin_preds)
|
|
obs_rate = sum(p["a"] for p in bin_preds) / len(bin_preds)
|
|
diff = pred_mean - obs_rate
|
|
|
|
# Brier score contribution
|
|
for p in bin_preds:
|
|
total_brier += (p["p"] - p["a"]) ** 2
|
|
|
|
bar_pred = "#" * int(pred_mean * 40)
|
|
bar_obs = "." * int(obs_rate * 40)
|
|
print(f" {i+1:>7d} {pred_mean:10.4f} {obs_rate:10.4f} {len(bin_preds):6d} {diff:+8.4f} P:{bar_pred}")
|
|
print(f" {'':>7s} {'':>10s} {'':>10s} {'':>6s} {'':>8s} O:{bar_obs}")
|
|
|
|
cal_bins.append({"decile": i+1, "pred": round(pred_mean,4), "obs": round(obs_rate,4),
|
|
"n": len(bin_preds), "diff": round(diff,4)})
|
|
|
|
brier = total_brier / n
|
|
max_diff = max(abs(b["diff"]) for b in cal_bins)
|
|
mean_diff = sum(abs(b["diff"]) for b in cal_bins) / 10
|
|
|
|
print(f"\n Brier Score: {brier:.4f}")
|
|
print(f" Max |pred-obs|: {max_diff:.4f}")
|
|
print(f" Mean |pred-obs|: {mean_diff:.4f}")
|
|
|
|
if max_diff < 0.10:
|
|
print(f" GOOD CALIBRATION. Max deviation < 0.10.")
|
|
elif max_diff < 0.15:
|
|
print(f" ACCEPTABLE CALIBRATION. Some deviation in extreme deciles.")
|
|
else:
|
|
print(f" POOR CALIBRATION. Needs recalibration (Platt scaling or isotonic).")
|
|
|
|
# SAPS calibration for comparison
|
|
saps_preds_all = [{"a": p["died"], "p": p["saps_prob"]}
|
|
for p in all_pts if p.get("saps_prob") is not None]
|
|
saps_sorted = sorted(saps_preds_all, key=lambda x: x["p"])
|
|
saps_brier = sum((p["p"] - p["a"])**2 for p in saps_sorted) / len(saps_sorted)
|
|
print(f"\n SAPS-II Brier Score: {saps_brier:.4f}")
|
|
print(f" TD Brier Score: {brier:.4f}")
|
|
if brier < saps_brier:
|
|
print(f" TD better calibrated than SAPS-II ({brier:.4f} < {saps_brier:.4f})")
|
|
else:
|
|
print(f" SAPS-II better calibrated ({saps_brier:.4f} < {brier:.4f})")
|
|
|
|
results["calibration"] = {
|
|
"bins": cal_bins, "brier_td": round(brier,4), "brier_saps": round(saps_brier,4),
|
|
"max_diff": round(max_diff,4), "mean_diff": round(mean_diff,4)
|
|
}
|
|
|
|
# ══════════════════════════════════════════════════════════════
|
|
# 2. FAIR COMPARISON (gleiche Population)
|
|
# ══════════════════════════════════════════════════════════════
|
|
print(f"\n{'='*76}")
|
|
print(f" 2. FAIR COMPARISON — alle Modelle auf exakt gleicher Population")
|
|
print(f"{'='*76}\n")
|
|
|
|
# Get unique hadm_ids that got TD predictions
|
|
td_hadmids = set(p["hadm_id"] for p in p_full)
|
|
print(f" TD predictions from {len(td_hadmids):,} unique patients")
|
|
|
|
# For patients with multiple TD predictions, average them
|
|
td_by_patient = defaultdict(list)
|
|
for p in p_full:
|
|
td_by_patient[p["hadm_id"]].append(p["p"])
|
|
|
|
# Build matched population: only patients with TD predictions AND saps_prob
|
|
matched = []
|
|
for hid in td_hadmids:
|
|
pt = pt_idx.get(hid)
|
|
if pt and pt.get("saps_prob") is not None:
|
|
matched.append({
|
|
"hadm_id": hid,
|
|
"died": pt["died"],
|
|
"saps_prob": pt["saps_prob"],
|
|
"td_prob": sum(td_by_patient[hid]) / len(td_by_patient[hid]),
|
|
"sapsii": pt.get("sapsii"),
|
|
})
|
|
|
|
print(f" Matched population: {len(matched):,} patients (have both TD and SAPS)")
|
|
|
|
# AUC on matched population
|
|
td_matched = [{"a": p["died"], "p": p["td_prob"]} for p in matched]
|
|
saps_matched = [{"a": p["died"], "p": p["saps_prob"]} for p in matched]
|
|
a_td_m = auc_fast(td_matched)
|
|
a_saps_m = auc_fast(saps_matched)
|
|
|
|
print(f"\n ON MATCHED POPULATION ({len(matched):,} patients):")
|
|
print(f" TD: AUC {a_td_m:.4f}")
|
|
print(f" SAPS: AUC {a_saps_m:.4f}")
|
|
print(f" Delta: {a_td_m - a_saps_m:+.4f}")
|
|
|
|
# LogReg on matched population
|
|
try:
|
|
from sklearn.linear_model import LogisticRegression
|
|
from sklearn.model_selection import cross_val_predict
|
|
from sklearn.preprocessing import StandardScaler
|
|
from sklearn.pipeline import Pipeline
|
|
import numpy as np
|
|
|
|
# Feature matrix for matched patients only
|
|
medians = {}
|
|
for k in PARAM_KEYS:
|
|
vals = [p[k] for p in all_pts if p.get(k) is not None]
|
|
medians[k] = sorted(vals)[len(vals)//2] if vals else 0
|
|
|
|
X_m = []
|
|
y_m = []
|
|
for p in matched:
|
|
pt = pt_idx[p["hadm_id"]]
|
|
row = [pt[k] if pt.get(k) is not None else medians[k] for k in PARAM_KEYS]
|
|
X_m.append(row)
|
|
y_m.append(p["died"])
|
|
|
|
X_m = np.array(X_m, dtype=float)
|
|
y_m = np.array(y_m, dtype=int)
|
|
|
|
pipe = Pipeline([("scaler", StandardScaler()), ("lr", LogisticRegression(max_iter=1000, random_state=42))])
|
|
y_prob_m = cross_val_predict(pipe, X_m, y_m, cv=5, method="predict_proba")[:, 1]
|
|
|
|
lr_matched = [{"a": int(y_m[i]), "p": float(y_prob_m[i])} for i in range(len(y_m))]
|
|
a_lr_m = auc_fast(lr_matched)
|
|
|
|
print(f" LogReg: AUC {a_lr_m:.4f}")
|
|
print(f"\n FAIR THREE-WAY COMPARISON (n={len(matched):,}, identical patients):")
|
|
print(f" {'Method':25s} {'AUC':>7s} {'vs TD':>8s}")
|
|
print(f" {'-'*45}")
|
|
print(f" {'Therapeutic Distance':25s} {a_td_m:7.4f} {'ref':>8s}")
|
|
print(f" {'SAPS-II':25s} {a_saps_m:7.4f} {a_saps_m-a_td_m:+8.4f}")
|
|
print(f" {'Logistic Regression':25s} {a_lr_m:7.4f} {a_lr_m-a_td_m:+8.4f}")
|
|
|
|
# Bootstrap on matched
|
|
print(f"\n Bootstrap CI on matched population (1000 resamples)...")
|
|
random.seed(42)
|
|
deltas = []
|
|
for _ in range(1000):
|
|
idx = [random.randint(0, len(matched)-1) for _ in range(len(matched))]
|
|
td_s = [td_matched[i] for i in idx]
|
|
saps_s = [saps_matched[i] for i in idx]
|
|
deltas.append(auc_fast(td_s) - auc_fast(saps_s))
|
|
deltas.sort()
|
|
ci_lo = deltas[25]
|
|
ci_hi = deltas[975]
|
|
p_val = sum(1 for d in deltas if d <= 0) / 1000
|
|
|
|
print(f" TD vs SAPS (matched): +{a_td_m-a_saps_m:.4f} (95% CI {ci_lo:+.4f} to {ci_hi:+.4f})")
|
|
print(f" p-value: {'<0.001' if p_val < 0.001 else f'{p_val:.3f}'}")
|
|
|
|
results["fair_comparison"] = {
|
|
"n_matched": len(matched),
|
|
"td_auc": round(a_td_m, 4), "saps_auc": round(a_saps_m, 4), "lr_auc": round(a_lr_m, 4),
|
|
"delta_td_saps": round(a_td_m - a_saps_m, 4),
|
|
"delta_td_lr": round(a_td_m - a_lr_m, 4),
|
|
"ci": [round(ci_lo, 4), round(ci_hi, 4)],
|
|
"p_value": round(p_val, 4) if p_val >= 0.001 else 0.001
|
|
}
|
|
|
|
except ImportError:
|
|
print(f" sklearn not installed — LogReg skipped")
|
|
results["fair_comparison"] = {
|
|
"n_matched": len(matched),
|
|
"td_auc": round(a_td_m, 4), "saps_auc": round(a_saps_m, 4),
|
|
"delta": round(a_td_m - a_saps_m, 4)
|
|
}
|
|
|
|
# ══════════════════════════════════════════════════════════════
|
|
# 3. ZUSAMMENFASSUNG — Positionierung
|
|
# ══════════════════════════════════════════════════════════════
|
|
print(f"\n{'='*76}")
|
|
print(f" 3. POSITIONIERUNG — Der eine Satz")
|
|
print(f"{'='*76}\n")
|
|
|
|
print(f" SAPS-II sagt: 'Dieser Patient hat 35% Sterberisiko.'")
|
|
print(f" LogReg sagt: 'Dieser Patient hat 28% Sterberisiko.'")
|
|
print(f" TD sagt: 'Dieser Patient ist 0.7 Einheiten vom Vasopressin-Zentroid")
|
|
print(f" und 1.2 Einheiten vom CRRT-Zentroid entfernt.")
|
|
print(f" Patienten in seinem Orbit mit Vasopressin: 25% Mortalitaet.")
|
|
print(f" Patienten in seinem Orbit mit CRRT: 55% Mortalitaet.'")
|
|
print(f"")
|
|
print(f" SAPS und LogReg liefern globales Risiko.")
|
|
print(f" TD liefert therapiespezifische Risikostruktur.")
|
|
print(f"")
|
|
print(f" Das ist der Unterschied. Das ist die Positionierung.")
|
|
print(f" Nicht: 'besserer Score'. Sondern: 'andere Information'.")
|
|
|
|
# ══════════════════════════════════════════════════════════════
|
|
# FINAL SUMMARY
|
|
# ══════════════════════════════════════════════════════════════
|
|
tm = (time.time()-T0)/60
|
|
print(f"\n{'='*76}")
|
|
print(f" FESTUNG TEIL 3 — ERGEBNIS ({tm:.0f} min)")
|
|
print(f"{'='*76}\n")
|
|
|
|
print(f" 1. Calibration:")
|
|
print(f" Brier TD: {brier:.4f} SAPS: {saps_brier:.4f}")
|
|
print(f" Max |pred-obs|: {max_diff:.4f} Mean: {mean_diff:.4f}")
|
|
print(f"")
|
|
print(f" 2. Fair Comparison ({len(matched):,} matched patients):")
|
|
print(f" TD: {a_td_m:.4f}")
|
|
print(f" SAPS: {a_saps_m:.4f} (Delta {a_td_m-a_saps_m:+.4f})")
|
|
if "lr_auc" in results.get("fair_comparison", {}):
|
|
print(f" LogReg: {results['fair_comparison']['lr_auc']:.4f} (Delta {a_td_m-results['fair_comparison']['lr_auc']:+.4f})")
|
|
print(f"")
|
|
print(f" 3. Positionierung:")
|
|
print(f" TD != besserer Score")
|
|
print(f" TD = therapiespezifische Risikostruktur")
|
|
|
|
with open("festung_teil3.json", "w") as f:
|
|
json.dump(results, f, indent=2, default=str)
|
|
print(f"\n -> festung_teil3.json saved")
|
|
print(f"\n{'='*76}\n")
|
|
|
|
if __name__ == "__main__":
|
|
main() |