Files
eval-keplertwin/paper3_phase5b_refined.py

619 lines
28 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
PAPER 3 — PHASE 5b: 5-TERM FORMULA + ORDINAL BEDSIDE SCORE
═════════════════════════════════════════════════════════════════════════════
Two refinements from Phase 5:
1. DROP log(1 + ne_auc). Its bootstrap CI crossed zero, and it is
collinear with I(ne_at_h24 > 0.08). Simpler 5-term formula.
2. BEDSIDE SCORE via ORDINAL BINNING (not coefficient rounding).
Each of the 5 remaining terms gets mapped to 04 points based on
clinically meaningful thresholds. SOFA-style integer score:
Lactate h24 0 if <2.5 | 2 if 2.54 | 4 if >4
Oliguria (ml/kg) 0 if ≥20 | 1 if 1020 | 2 if <10
NE at h24 0 if ≤0.08 | 3 if >0.08
HR deviation 0 if 70100| 1 if 6070/100-120 | 2 if <60/>120
Pressor MAP/NE 0 if >3000 | 1 if 10003000 | 2 if <1000
Max 13 points. Bedside-ready.
Repeats the full Phase 5 validation: CV, multi-seed, bootstrap, calibration,
subgroups — for BOTH the 5-term continuous formula AND the ordinal score.
Usage:
python paper3_phase5b_refined.py
"""
import json, os, sys, math, time, random
from collections import defaultdict
# PostgreSQL connection string (libpq DSN). Override with env var.
# e.g. "host=localhost port=5432 dbname=mimic user=postgres password=..."
PG_DSN = os.environ.get("MIMIC_PG_DSN", "dbname=mimic3")
# Schema holding the stock MIMIC-III v1.3 tables (admissions, icustays,
# labevents, chartevents, inputevents_mv, inputevents_cv, outputevents,
# patients, d_items, ...).
MIMIC_SCHEMA = os.environ.get("MIMIC_SCHEMA", "mimiciii")
# Schema holding the locally built derived tables (sapsii, sepsis3,
# norepinephrine_dose, weight_durations, ...); see sql/schemas.sql.
# Defaults to the same schema as MIMIC-III itself.
DERIVED_SCHEMA = os.environ.get("DERIVED_SCHEMA", MIMIC_SCHEMA)
H_SNAPSHOT = 24
H_PEAK_NE = 12
TRAIN_FRAC = 0.70
N_SEEDS = 10
N_FOLDS = 5
N_BOOTSTRAP = 1000
OUT_FILE = "paper3_phase5b_refined.json"
LACTATE_ID = 50813
# MAP: 52, 456, 6702 = CareVue; 220052, 220181, 225312 = MetaVision.
MAP_ITEMIDS = [52, 456, 6702, 220052, 220181, 225312]
# HR: 211 = CareVue; 220045 = MetaVision.
HR_ITEMIDS = [211, 220045]
_PG_CONN = None
def _pg_conn():
global _PG_CONN
if _PG_CONN is None or getattr(_PG_CONN, "closed", 0):
import psycopg2
_PG_CONN = psycopg2.connect(PG_DSN)
_PG_CONN.set_session(readonly=True, autocommit=True)
return _PG_CONN
def run_pg(sql, label=""):
try:
import psycopg2.extras
conn = _pg_conn()
t0 = time.time()
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql)
rows = [dict(r) for r in cur.fetchall()] if cur.description else []
print(f" {label:32s} {len(rows):>8,d} rows ({time.time()-t0:.1f}s)")
return rows
except Exception as e:
print(f"[PG ERROR] {label}: {e}"); return []
# ── Queries (PostgreSQL / MIMIC-III v1.3, SAPS Q4 pre-filtered) ────────────
#
# Notes on the port from BigQuery / MIMIC-IV:
# * `stay_id` (MIMIC-IV) is `icustay_id` in MIMIC-III; we alias to
# `stay_id` so the downstream Python is unchanged.
# * `mimiciv_3_1_icu.inputevents` (single table, mcg/kg/min) is split
# across `inputevents_mv` and `inputevents_cv` in MIMIC-III with
# different itemids and units. The `norepinephrine_dose` table built
# by sql/build_sepsis3.sql already merges both eras and normalises
# rates to mcg/kg/min, so we use that instead of the raw inputs.
# * Weight in MIMIC-IV is read from chartevents itemids 226512/224639
# (MetaVision-only). In MIMIC-III those itemids cover only the MV
# half of the cohort, so we use the `weight_durations` table built by
# sql/build_sepsis3.sql (admit + daily + neonate + echo, both eras).
# * `pat.anchor_age` (MIMIC-IV) → computed from `pat.dob` against
# `icu.intime`. MIMIC-III shifts dob backwards by ~300 years for
# patients ≥89; we cap the result at 120.
def q_cohort():
return f"""
WITH weight_first AS (
SELECT wd.icustay_id, MIN(wd.weight) AS weight_kg
FROM {DERIVED_SCHEMA}.weight_durations wd
JOIN {MIMIC_SCHEMA}.icustays icu ON icu.icustay_id = wd.icustay_id
WHERE wd.weight BETWEEN 30 AND 300
AND wd.starttime <= icu.intime + INTERVAL '24 hours'
AND wd.endtime >= icu.intime
GROUP BY wd.icustay_id
)
SELECT icu.icustay_id AS stay_id, icu.subject_id, icu.intime,
LEAST(120.0, EXTRACT(EPOCH FROM (icu.intime - pat.dob)) / 31556952.0) AS age,
pat.gender,
saps.sapsii, adm.hospital_expire_flag AS died,
COALESCE(wf.weight_kg, 75.0) AS weight_kg
FROM {DERIVED_SCHEMA}.sepsis3 s3
JOIN {MIMIC_SCHEMA}.icustays icu ON icu.icustay_id = s3.icustay_id
JOIN {MIMIC_SCHEMA}.admissions adm ON adm.hadm_id = icu.hadm_id
JOIN {MIMIC_SCHEMA}.patients pat ON pat.subject_id = icu.subject_id
LEFT JOIN {DERIVED_SCHEMA}.sapsii saps ON saps.icustay_id = icu.icustay_id
LEFT JOIN weight_first wf ON wf.icustay_id = icu.icustay_id
WHERE s3.sepsis3 = TRUE
AND EXTRACT(EPOCH FROM (icu.outtime - icu.intime)) / 3600.0 >= {H_SNAPSHOT}
AND saps.sapsii IS NOT NULL AND saps.sapsii >= 48
"""
def q_ne():
return f"""
SELECT nd.icustay_id AS stay_id,
EXTRACT(EPOCH FROM (nd.starttime - icu.intime)) / 60.0 AS start_min,
EXTRACT(EPOCH FROM (nd.endtime - icu.intime)) / 60.0 AS end_min,
nd.vaso_rate AS rate
FROM {DERIVED_SCHEMA}.norepinephrine_dose nd
JOIN {MIMIC_SCHEMA}.icustays icu ON icu.icustay_id = nd.icustay_id
JOIN {DERIVED_SCHEMA}.sepsis3 s3 ON s3.icustay_id = nd.icustay_id
JOIN {DERIVED_SCHEMA}.sapsii saps ON saps.icustay_id = nd.icustay_id
WHERE s3.sepsis3 = TRUE AND saps.sapsii >= 48
AND nd.vaso_rate > 0
AND nd.starttime BETWEEN icu.intime AND icu.intime + INTERVAL '30 hours'
"""
def q_fluid_out():
return f"""
SELECT oe.icustay_id AS stay_id, SUM(oe.value) AS fluid_out_ml
FROM {MIMIC_SCHEMA}.outputevents oe
JOIN {MIMIC_SCHEMA}.icustays icu ON icu.icustay_id = oe.icustay_id
JOIN {DERIVED_SCHEMA}.sepsis3 s3 ON s3.icustay_id = oe.icustay_id
JOIN {DERIVED_SCHEMA}.sapsii saps ON saps.icustay_id = icu.icustay_id
WHERE s3.sepsis3 = TRUE AND saps.sapsii >= 48
AND oe.value > 0
AND oe.charttime BETWEEN icu.intime AND icu.intime + INTERVAL '{H_SNAPSHOT} hours'
GROUP BY oe.icustay_id
"""
def q_vitals():
ids = ",".join(str(x) for x in MAP_ITEMIDS + HR_ITEMIDS)
return f"""
SELECT ce.icustay_id AS stay_id, ce.itemid, AVG(ce.valuenum) AS val
FROM {MIMIC_SCHEMA}.chartevents ce
JOIN {MIMIC_SCHEMA}.icustays icu ON icu.icustay_id = ce.icustay_id
JOIN {DERIVED_SCHEMA}.sepsis3 s3 ON s3.icustay_id = ce.icustay_id
JOIN {DERIVED_SCHEMA}.sapsii saps ON saps.icustay_id = icu.icustay_id
WHERE s3.sepsis3 = TRUE AND saps.sapsii >= 48
AND ce.itemid IN ({ids})
AND ce.valuenum IS NOT NULL AND ce.valuenum > 0
AND ce.charttime BETWEEN icu.intime + INTERVAL '20 hours'
AND icu.intime + INTERVAL '28 hours'
GROUP BY ce.icustay_id, ce.itemid
"""
def q_lactate():
return f"""
SELECT icu.icustay_id AS stay_id,
EXTRACT(EPOCH FROM (le.charttime - icu.intime)) / 60.0 AS offset_min,
le.valuenum AS val
FROM {MIMIC_SCHEMA}.labevents le
JOIN {MIMIC_SCHEMA}.icustays icu ON icu.hadm_id = le.hadm_id
JOIN {DERIVED_SCHEMA}.sepsis3 s3 ON s3.icustay_id = icu.icustay_id
JOIN {DERIVED_SCHEMA}.sapsii saps ON saps.icustay_id = icu.icustay_id
WHERE s3.sepsis3 = TRUE AND saps.sapsii >= 48
AND le.itemid = {LACTATE_ID}
AND le.valuenum IS NOT NULL
AND le.charttime BETWEEN icu.intime AND icu.intime + INTERVAL '30 hours'
"""
# ── Primitives ──────────────────────────────────────────────────────────────
def build_primitives(cohort, ne_rows, fout_rows, vital_rows, lac_rows):
print(f"\n[3] Building primitives...")
ne_by = defaultdict(list)
for r in ne_rows: ne_by[r["stay_id"]].append(r)
fout = {r["stay_id"]: r["fluid_out_ml"] or 0 for r in fout_rows}
vital_by = defaultdict(dict)
for r in vital_rows:
iid = r["itemid"]
key = "map" if iid in MAP_ITEMIDS else "hr"
if r["stay_id"] is not None:
cur = vital_by[r["stay_id"]].get(key)
vital_by[r["stay_id"]][key] = r["val"] if cur is None else (cur + r["val"])/2
lac_by = defaultdict(list)
for r in lac_rows: lac_by[r["stay_id"]].append(r)
prim = {}
for sid, c in cohort.items():
weight = c.get("weight_kg") or 75.0
events = ne_by.get(sid, [])
ne_h24 = 0.0
for ev in events:
sm, em, rate = ev["start_min"], ev["end_min"], ev["rate"]
if None in (sm, em, rate): continue
if sm <= H_SNAPSHOT*60 <= em and rate > ne_h24: ne_h24 = rate
lacs = sorted(lac_by.get(sid, []), key=lambda x: x["offset_min"] or 0)
lac_h24 = None
if lacs:
near = [r for r in lacs if r["offset_min"] is not None
and 18*60 <= r["offset_min"] <= 28*60]
lac_h24 = near[-1]["val"] if near else lacs[-1]["val"]
v = vital_by.get(sid, {})
map_h24 = v.get("map")
hr_h24 = v.get("hr")
pr = map_h24 / (ne_h24 + 0.01) if map_h24 is not None else None
prim[sid] = {
"ne_at_h24": ne_h24,
"fluid_out_per_kg": fout.get(sid, 0) / weight,
"lactate_h24": lac_h24,
"map_h24": map_h24, "hr_h24": hr_h24,
"pressor_resistance": pr,
}
return prim
# ── 5-term continuous formula ──────────────────────────────────────────────
def formula_features(p):
lac = p.get("lactate_h24")
fout = p.get("fluid_out_per_kg")
ne24 = p.get("ne_at_h24", 0.0) or 0.0
hr = p.get("hr_h24")
pr = p.get("pressor_resistance")
if lac is None or fout is None or hr is None or pr is None:
return None
return [
max(0, lac - 2.5), # lactate hinge
max(0, 20 - fout), # oliguria hinge
1.0 if ne24 > 0.08 else 0.0, # NE persistence
abs(hr - 85) / 20, # HR deviation
math.log(pr + 1.0), # pressor efficiency
]
FEATURE_LABELS = [
"max(0, lactate_h24 2.5)",
"max(0, 20 fluid_out_per_kg)",
"I(ne_at_h24 > 0.08)",
"|hr_h24 85| / 20",
"log(pressor_resistance + 1)",
]
# ── Ordinal bedside score (013 pts) ───────────────────────────────────────
def bedside_score(p):
lac = p.get("lactate_h24")
fout = p.get("fluid_out_per_kg")
ne24 = p.get("ne_at_h24", 0.0) or 0.0
hr = p.get("hr_h24")
pr = p.get("pressor_resistance")
if lac is None or fout is None or hr is None or pr is None:
return None, None
# Lactate (0 / 2 / 4)
if lac < 2.5: pts_lac = 0
elif lac <= 4.0: pts_lac = 2
else: pts_lac = 4
# Oliguria (0 / 1 / 2)
if fout >= 20: pts_olig = 0
elif fout >= 10: pts_olig = 1
else: pts_olig = 2
# NE persistence (0 / 3)
pts_ne = 3 if ne24 > 0.08 else 0
# HR deviation (0 / 1 / 2)
if 70 <= hr <= 100: pts_hr = 0
elif (60 <= hr < 70) or (100 < hr <= 120): pts_hr = 1
else: pts_hr = 2
# Pressor efficiency (0 / 1 / 2)
if pr > 3000: pts_pr = 0
elif pr >= 1000: pts_pr = 1
else: pts_pr = 2
total = pts_lac + pts_olig + pts_ne + pts_hr + pts_pr
breakdown = {
"lactate": pts_lac, "oliguria": pts_olig,
"ne_persist": pts_ne, "hr_dev": pts_hr, "pressor_eff": pts_pr,
}
return total, breakdown
def build_matrix(ids, primitives, cohort):
import numpy as np
X, y, scores, saps, valid = [], [], [], [], []
for sid in ids:
p = primitives.get(sid)
if p is None: continue
f = formula_features(p)
s, _ = bedside_score(p)
if f is None or s is None: continue
sap = cohort[sid].get("sapsii")
if sap is None: continue
X.append(f)
y.append(int(cohort[sid].get("died") or 0))
scores.append(s)
saps.append(float(sap))
valid.append(sid)
return np.array(X), np.array(y), np.array(scores), np.array(saps), valid
# ── Main ────────────────────────────────────────────────────────────────────
def main():
try:
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, brier_score_loss
from sklearn.model_selection import KFold
except ImportError as e:
print(f"\nERROR: {e}")
print("Install: pip install scikit-learn numpy")
sys.exit(1)
print("\n" + ""*78)
print(" PAPER 3 — PHASE 5b: 5-term formula + ordinal bedside score")
print(""*78)
print(f"\n[1] Fetching data...")
cohort_rows = run_pg(q_cohort(), "cohort")
ne_rows = run_pg(q_ne(), "NE events")
fout_rows = run_pg(q_fluid_out(), "Fluid out")
vital_rows = run_pg(q_vitals(), "Vitals h20-28")
lac_rows = run_pg(q_lactate(), "Lactate")
cohort = {r["stay_id"]: dict(r) for r in cohort_rows}
print(f"\n[2] Cohort: {len(cohort):,} SAPS Q4 sepsis-3")
primitives = build_primitives(cohort, ne_rows, fout_rows, vital_rows, lac_rows)
all_ids = [s for s in cohort if primitives.get(s)
and formula_features(primitives[s]) is not None]
print(f" usable: {len(all_ids):,}")
X_all, y_all, S_all, SAPS_all, _ = build_matrix(all_ids, primitives, cohort)
print(f" mortality: {100*y_all.mean():.1f}%")
print(f" SAPS-II: mean={SAPS_all.mean():.1f} "
f"min={SAPS_all.min():.0f} max={SAPS_all.max():.0f}")
# ══════════════════════════════════════════════════════════════════════
# [4] 5-fold CV — 5-term formula
# ══════════════════════════════════════════════════════════════════════
print(f"\n[4] 5-fold CV — 5-term continuous formula")
subject_ids = sorted(set(cohort[s]["subject_id"] for s in all_ids))
rng = np.random.default_rng(42)
rng.shuffle(subject_ids)
subj_arr = np.array(subject_ids)
kf = KFold(n_splits=N_FOLDS, shuffle=False)
fold_aucs, fold_aucs_score, fold_aucs_saps, fold_coefs = [], [], [], []
for k, (tr_idx, te_idx) in enumerate(kf.split(subj_arr)):
tr_subs = set(subj_arr[tr_idx].tolist())
tr_ids = [s for s in all_ids if cohort[s]["subject_id"] in tr_subs]
te_ids = [s for s in all_ids if cohort[s]["subject_id"] not in tr_subs]
X_tr, y_tr, S_tr, _, _ = build_matrix(tr_ids, primitives, cohort)
X_te, y_te, S_te, SAPS_te, _ = build_matrix(te_ids, primitives, cohort)
mu = X_tr.mean(0); sd = X_tr.std(0) + 1e-9
lr = LogisticRegression(C=1.0, max_iter=1000, random_state=42)
lr.fit((X_tr - mu) / sd, y_tr)
pred = lr.predict_proba((X_te - mu) / sd)[:, 1]
auc_f = roc_auc_score(y_te, pred)
auc_s = roc_auc_score(y_te, S_te) # Ordinal score AUROC
auc_saps = roc_auc_score(y_te, SAPS_te) # SAPS-II baseline AUROC
raw_beta = lr.coef_[0] / sd
raw_int = lr.intercept_[0] - sum(lr.coef_[0][i]*mu[i]/sd[i] for i in range(len(sd)))
fold_aucs.append(auc_f)
fold_aucs_score.append(auc_s)
fold_aucs_saps.append(auc_saps)
fold_coefs.append([raw_int] + list(raw_beta))
print(f" fold {k+1}: formula AUROC={auc_f:.4f} "
f"ordinal AUROC={auc_s:.4f} SAPS-II AUROC={auc_saps:.4f}")
fa = np.array(fold_aucs); fas = np.array(fold_aucs_score)
fsap = np.array(fold_aucs_saps)
print(f"\n 5-term formula CV AUROC: {fa.mean():.4f} ± {fa.std():.4f} "
f"(range {fa.min():.4f}{fa.max():.4f})")
print(f" Ordinal score CV AUROC: {fas.mean():.4f} ± {fas.std():.4f} "
f"(range {fas.min():.4f}{fas.max():.4f})")
print(f" SAPS-II CV AUROC: {fsap.mean():.4f} ± {fsap.std():.4f} "
f"(range {fsap.min():.4f}{fsap.max():.4f})")
print(f" Ordinal loss: {fa.mean()-fas.mean():+.4f}")
print(f" Δ vs SAPS-II (formula): {fa.mean()-fsap.mean():+.4f}")
print(f" Δ vs SAPS-II (ordinal): {fas.mean()-fsap.mean():+.4f}")
# SAPS-II is a fixed pre-computed score (no parameters fit here), so an
# in-cohort AUROC is not optimistic — report it on the full Kepler cohort
# for a single, directly comparable headline number.
saps_auc_overall = roc_auc_score(y_all, SAPS_all)
print(f"\n SAPS-II AUROC on full Kepler cohort (n={len(y_all):,}): "
f"{saps_auc_overall:.4f}")
fc = np.array(fold_coefs)
print(f"\n Coefficient stability (5 terms):")
names = ["intercept"] + FEATURE_LABELS
for i, name in enumerate(names):
col = fc[:, i]
flips = sum(1 for k in range(1, len(col)) if col[k]*col[k-1] < 0)
print(f" {name:35s} {col.mean():>+9.4f} ± {col.std():>7.4f} flips={flips}")
# ══════════════════════════════════════════════════════════════════════
# [5] Bootstrap CIs — 5-term formula
# ══════════════════════════════════════════════════════════════════════
print(f"\n[5] Bootstrap CIs — 5-term formula ({N_BOOTSTRAP} resamples)")
mu_all = X_all.mean(0); sd_all = X_all.std(0) + 1e-9
n = len(y_all)
boot_coefs = []
rng_b = np.random.default_rng(42)
for b in range(N_BOOTSTRAP):
idx = rng_b.integers(0, n, n)
X_b = X_all[idx]; y_b = y_all[idx]
if len(set(y_b.tolist())) < 2: continue
try:
lr = LogisticRegression(C=1.0, max_iter=500, random_state=42)
lr.fit((X_b - mu_all) / sd_all, y_b)
raw = lr.coef_[0] / sd_all
intc = lr.intercept_[0] - sum(lr.coef_[0][i]*mu_all[i]/sd_all[i] for i in range(len(sd_all)))
boot_coefs.append([intc] + list(raw))
except Exception: continue
if (b+1) % 250 == 0: print(f" {b+1}/{N_BOOTSTRAP}...")
bc = np.array(boot_coefs)
lr_full = LogisticRegression(C=1.0, max_iter=1000, random_state=42)
lr_full.fit((X_all - mu_all) / sd_all, y_all)
raw_full = lr_full.coef_[0] / sd_all
intc_full = lr_full.intercept_[0] - sum(lr_full.coef_[0][i]*mu_all[i]/sd_all[i] for i in range(len(sd_all)))
point_all = [intc_full] + list(raw_full)
print(f"\n {'term':35s} {'point':>9s} {'95% CI':>22s}")
ci_results = []
for i, name in enumerate(names):
col = bc[:, i]
lo = np.percentile(col, 2.5)
hi = np.percentile(col, 97.5)
crosses_zero = "" if (lo < 0 < hi) else ""
ci_str = f"({lo:+.4f}, {hi:+.4f}) {crosses_zero}"
print(f" {name:35s} {point_all[i]:>+9.4f} {ci_str:>22s}")
ci_results.append({"term": name, "point": float(point_all[i]),
"ci_lo": float(lo), "ci_hi": float(hi),
"crosses_zero": bool(lo < 0 < hi)})
# ══════════════════════════════════════════════════════════════════════
# [6] Ordinal score distribution + mortality per score
# ══════════════════════════════════════════════════════════════════════
print(f"\n[6] Ordinal bedside score distribution + mortality per score")
print(f" {'score':>5s} {'n':>5s} {'mortality':>10s} {'cum n':>6s}")
score_buckets = defaultdict(lambda: {"n": 0, "died": 0})
for s, y in zip(S_all, y_all):
score_buckets[int(s)]["n"] += 1
score_buckets[int(s)]["died"] += int(y)
cum_n = 0
score_rows = []
for s in sorted(score_buckets.keys()):
b = score_buckets[s]
mort = 100 * b["died"] / b["n"] if b["n"] > 0 else 0
cum_n += b["n"]
bar = "" * int(mort / 3)
print(f" {s:>5d} {b['n']:>5d} {mort:>8.1f}% {cum_n:>6d} {bar}")
score_rows.append({"score": s, "n": b["n"], "mortality_pct": mort})
# ══════════════════════════════════════════════════════════════════════
# [7] Risk bands from ordinal score (clinical cut-points)
# ══════════════════════════════════════════════════════════════════════
print(f"\n[7] Suggested risk bands (clinically meaningful cutpoints)")
# Group into LOW / MID / HIGH by score
def band(s):
if s <= 3: return "low"
if s <= 7: return "mid"
return "high"
bands = defaultdict(lambda: {"n": 0, "died": 0})
for s, y in zip(S_all, y_all):
b = band(int(s))
bands[b]["n"] += 1
bands[b]["died"] += int(y)
print(f" {'band':6s} {'range':>7s} {'n':>5s} {'mort%':>7s}")
band_results = {}
for bname in ["low", "mid", "high"]:
b = bands[bname]
if b["n"] == 0: continue
rng_str = {"low": "03", "mid": "47", "high": "8+"}[bname]
mort = 100 * b["died"] / b["n"]
print(f" {bname:6s} {rng_str:>7s} {b['n']:>5d} {mort:>6.1f}%")
band_results[bname] = {"n": b["n"], "mortality_pct": mort}
# ══════════════════════════════════════════════════════════════════════
# [8] Calibration on holdout — both formula and ordinal
# ══════════════════════════════════════════════════════════════════════
print(f"\n[8] Calibration on 30% holdout (both versions)")
subs = list(set(cohort[s]["subject_id"] for s in all_ids))
random.Random(42).shuffle(subs)
n_tr = int(len(subs) * TRAIN_FRAC)
tr_subs = set(subs[:n_tr])
tr_ids = [s for s in all_ids if cohort[s]["subject_id"] in tr_subs]
te_ids = [s for s in all_ids if cohort[s]["subject_id"] not in tr_subs]
X_tr, y_tr, _, _, _ = build_matrix(tr_ids, primitives, cohort)
X_te, y_te, S_te, _, _ = build_matrix(te_ids, primitives, cohort)
mu = X_tr.mean(0); sd = X_tr.std(0) + 1e-9
lr_cal = LogisticRegression(C=1.0, max_iter=1000, random_state=42)
lr_cal.fit((X_tr - mu) / sd, y_tr)
pred_cal = lr_cal.predict_proba((X_te - mu) / sd)[:, 1]
order = np.argsort(pred_cal)
deciles = np.array_split(order, 10)
hl_stat = 0.0
print(f" Formula deciles:")
print(f" {'d':>2s} {'n':>4s} {'pred':>7s} {'obs':>7s}")
calib_bins = []
for d, idx in enumerate(deciles):
n_d = len(idx)
p_mean = float(pred_cal[idx].mean())
obs = int(y_te[idx].sum())
exp = float(pred_cal[idx].sum())
if exp > 0 and (n_d - exp) > 0:
hl_stat += (obs - exp)**2 / exp + ((n_d - obs) - (n_d - exp))**2 / (n_d - exp)
print(f" {d+1:>2d} {n_d:>4d} {p_mean:>6.3f} {obs/n_d:>6.3f}")
calib_bins.append({"decile": d+1, "n": n_d,
"predicted": p_mean, "observed": obs/n_d})
print(f"\n Hosmer-Lemeshow χ²: {hl_stat:.2f} (critical 15.51)")
if hl_stat < 15.51:
print(f" → Well calibrated (p > 0.05)")
brier = brier_score_loss(y_te, pred_cal)
print(f" Brier: {brier:.4f}")
# ══════════════════════════════════════════════════════════════════════
# [9] FINAL HEADLINE
# ══════════════════════════════════════════════════════════════════════
print(f"\n[9] ══════════ FINAL HEADLINE (Phase 5b) ══════════")
print(f"\n Cohort: n={len(all_ids):,} sepsis-3 Q4 patients")
print(f" Mortality: {100*y_all.mean():.1f}%")
print(f"\n 5-term formula (continuous):")
print(f" 5-fold CV AUROC: {fa.mean():.4f} ± {fa.std():.4f}")
print(f" Hosmer-Lemeshow: χ² = {hl_stat:.2f} (calibrated)")
print(f" Brier score: {brier:.4f}")
print(f"\n Ordinal bedside score (013 pts):")
print(f" 5-fold CV AUROC: {fas.mean():.4f} ± {fas.std():.4f}")
print(f" AUROC loss vs continuous: {fa.mean()-fas.mean():+.4f}")
print(f"\n SAPS-II baseline (same cohort):")
print(f" Overall AUROC: {saps_auc_overall:.4f}")
print(f" 5-fold CV AUROC: {fsap.mean():.4f} ± {fsap.std():.4f}")
print(f" Δ formula SAPS-II: {fa.mean()-fsap.mean():+.4f}")
print(f" Δ ordinal SAPS-II: {fas.mean()-fsap.mean():+.4f}")
print(f"\n Risk bands (ordinal score):")
for bn in ["low", "mid", "high"]:
br = band_results.get(bn, {})
if br:
print(f" {bn:6s} ({'03' if bn=='low' else '47' if bn=='mid' else '8+':>4s}): "
f"n={br['n']:>5d} mortality={br['mortality_pct']:.1f}%")
# ── Save ────────────────────────────────────────────────────────────────
output = {
"cohort": {"n": len(all_ids), "mortality": float(y_all.mean())},
"formula_5term": {
"cv_auroc_mean": float(fa.mean()),
"cv_auroc_std": float(fa.std()),
"cv_auroc_range": [float(fa.min()), float(fa.max())],
"coefficient_cis": ci_results,
"calibration": {
"hl_chi2": float(hl_stat),
"brier": float(brier),
"deciles": calib_bins,
},
},
"ordinal_score": {
"cv_auroc_mean": float(fas.mean()),
"cv_auroc_std": float(fas.std()),
"auroc_loss_vs_continuous": float(fa.mean() - fas.mean()),
"score_distribution": score_rows,
"risk_bands": band_results,
},
"sapsii_baseline": {
"n": int(len(SAPS_all)),
"sapsii_mean": float(SAPS_all.mean()),
"sapsii_min": float(SAPS_all.min()),
"sapsii_max": float(SAPS_all.max()),
"auroc_overall": float(saps_auc_overall),
"cv_auroc_mean": float(fsap.mean()),
"cv_auroc_std": float(fsap.std()),
"cv_auroc_range": [float(fsap.min()), float(fsap.max())],
"delta_formula_minus_sapsii": float(fa.mean() - fsap.mean()),
"delta_ordinal_minus_sapsii": float(fas.mean() - fsap.mean()),
},
}
with open(OUT_FILE, "w") as f:
json.dump(output, f, indent=2, default=str)
print(f"\n → Saved: {OUT_FILE}")
print("\n" + ""*78 + "\n")
if __name__ == "__main__":
main()