Files
eval-keplertwin/paper3_phase5b_icd_sepsis_prevalence.py
2026-05-05 19:35:11 +02:00

448 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
ICD-coded sepsis prevalence — inclusion vs exclusion cohort
═════════════════════════════════════════════════════════════════════════════
Companion analysis for paper3_phase5b_refined.py.
The Phase 5b cohort SQL (`q_cohort()`) keeps an ICU stay only when ALL of:
1. sepsis3.sepsis3 = TRUE (Sepsis-3 derived flag)
2. ICU length-of-stay ≥ 24h (H_SNAPSHOT)
3. sapsii IS NOT NULL AND sapsii ≥ 48 (SAPS-II Q4)
This script computes the prevalence of *explicit* ICD-coded sepsis on:
(a) the INCLUSION cohort — stays that satisfy all three filters,
(b) the EXCLUSION cohort — every other ICU stay in MIMIC-III (fails
at least one of the three filters above), and
(c) ALL ICU STAYS — the full mimiciii.icustays universe
(= inclusion exclusion).
Note: stay-level totals partition cleanly (incl + excl = all), but at the
admission and subject level a single hadm_id / subject can have ICU stays
in both buckets, so the "all" row is computed via SQL GROUPING SETS rather
than by summing.
ICD-coded sepsis is evaluated at the hospital-admission level (a stay is
"ICD-sepsis +" if its parent hadm_id carries any of the codes below):
- Explicit sepsis (ICD-9): 995.91, 995.92, 785.52
(matches paper2_festung_teil3.py SYNDROME_ICDS["sepsis"]["icd_9"])
- Angus septicemia (ICD-9): 038.* (any 038-prefixed code)
- Any of the above (union)
For each cohort × definition combination we report:
n positives, prevalence, Wilson 95% CI.
The inclusion vs exclusion difference is reported with a normal-approx
95% CI and a Pearson χ² statistic (no scipy dependency).
Usage:
python paper3_phase5b_icd_sepsis_prevalence.py
"""
import json, math, os, sys, time
# Reuse the same env-var contract as paper3_phase5b_refined.py.
PG_DSN = os.environ.get("MIMIC_PG_DSN", "dbname=mimic3")
MIMIC_SCHEMA = os.environ.get("MIMIC_SCHEMA", "mimiciii")
DERIVED_SCHEMA = os.environ.get("DERIVED_SCHEMA", MIMIC_SCHEMA)
H_SNAPSHOT = 24 # ICU LOS threshold, hours (matches paper3 phase 5b)
SAPSII_MIN = 48 # SAPS-II Q4 cutoff (matches paper3 phase 5b)
OUT_FILE = "paper3_phase5b_icd_sepsis_prevalence.json"
# Explicit sepsis codes (ICD-9, MIMIC-III stores them WITHOUT decimal point):
# 995.91 → '99591' Sepsis
# 995.92 → '99592' Severe sepsis
# 785.52 → '78552' Septic shock
EXPLICIT_SEPSIS_ICD9 = ("99591", "99592", "78552")
# Angus-style broad septicemia bucket: any ICD-9 starting with 038.
SEPTICEMIA_PREFIX = "038"
_PG_CONN = None
def _pg_conn():
global _PG_CONN
if _PG_CONN is None or getattr(_PG_CONN, "closed", 0):
import psycopg2
_PG_CONN = psycopg2.connect(PG_DSN)
_PG_CONN.set_session(readonly=True, autocommit=True)
return _PG_CONN
def run_pg(sql, label=""):
import psycopg2.extras
conn = _pg_conn()
t0 = time.time()
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql)
rows = [dict(r) for r in cur.fetchall()] if cur.description else []
print(f" {label:40s} {len(rows):>8,d} rows ({time.time()-t0:.1f}s)")
return rows
# ── SQL ─────────────────────────────────────────────────────────────────────
#
# One pass: classify every ICU stay as "inclusion" or "exclusion" using the
# Phase 5b filter, then left-join the two ICD code sets at the hadm_id level
# and aggregate. This mirrors q_cohort() exactly for the inclusion bucket
# (sepsis3 = TRUE AND LOS ≥ 24h AND sapsii ≥ 48), and treats every other
# ICU stay in mimiciii.icustays as exclusion. GROUPING SETS adds a third
# row (cohort = NULL → 'all') aggregated over the full ICU universe so that
# admission- and subject-level distinct counts are correct (a single hadm_id
# may straddle both buckets, so we cannot just sum incl + excl).
def q_prevalence():
explicit = ",".join(f"'{c}'" for c in EXPLICIT_SEPSIS_ICD9)
return f"""
WITH icu AS (
SELECT icu.icustay_id,
icu.hadm_id,
icu.subject_id,
EXTRACT(EPOCH FROM (icu.outtime - icu.intime)) / 3600.0 AS los_h,
COALESCE(s3.sepsis3, FALSE) AS is_sepsis3,
saps.sapsii AS sapsii
FROM {MIMIC_SCHEMA}.icustays icu
LEFT JOIN {DERIVED_SCHEMA}.sepsis3 s3 ON s3.icustay_id = icu.icustay_id
LEFT JOIN {DERIVED_SCHEMA}.sapsii saps ON saps.icustay_id = icu.icustay_id
WHERE icu.icustay_id IS NOT NULL
AND icu.hadm_id IS NOT NULL
),
classified AS (
SELECT icustay_id, hadm_id, subject_id,
is_sepsis3, los_h, sapsii,
CASE WHEN is_sepsis3 = TRUE
AND los_h >= {H_SNAPSHOT}
AND sapsii IS NOT NULL
AND sapsii >= {SAPSII_MIN}
THEN 'inclusion' ELSE 'exclusion' END AS cohort
FROM icu
),
explicit_sepsis AS (
SELECT DISTINCT hadm_id
FROM {MIMIC_SCHEMA}.diagnoses_icd
WHERE icd9_code IN ({explicit})
),
septicemia AS (
SELECT DISTINCT hadm_id
FROM {MIMIC_SCHEMA}.diagnoses_icd
WHERE icd9_code LIKE '{SEPTICEMIA_PREFIX}%%'
)
SELECT
COALESCE(c.cohort, 'all') AS cohort,
COUNT(*) AS n_stays,
COUNT(DISTINCT c.hadm_id) AS n_admissions,
COUNT(DISTINCT c.subject_id) AS n_subjects,
SUM(CASE WHEN e.hadm_id IS NOT NULL
THEN 1 ELSE 0 END) AS n_stays_explicit,
COUNT(DISTINCT CASE WHEN e.hadm_id IS NOT NULL
THEN c.hadm_id END) AS n_adm_explicit,
SUM(CASE WHEN s.hadm_id IS NOT NULL
THEN 1 ELSE 0 END) AS n_stays_septicemia,
COUNT(DISTINCT CASE WHEN s.hadm_id IS NOT NULL
THEN c.hadm_id END) AS n_adm_septicemia,
SUM(CASE WHEN e.hadm_id IS NOT NULL
OR s.hadm_id IS NOT NULL
THEN 1 ELSE 0 END) AS n_stays_any,
COUNT(DISTINCT CASE WHEN e.hadm_id IS NOT NULL
OR s.hadm_id IS NOT NULL
THEN c.hadm_id END) AS n_adm_any
FROM classified c
LEFT JOIN explicit_sepsis e ON e.hadm_id = c.hadm_id
LEFT JOIN septicemia s ON s.hadm_id = c.hadm_id
GROUP BY GROUPING SETS ((c.cohort), ())
"""
def q_exclusion_breakdown():
"""How many excluded stays fail each individual filter (non-exclusive
counts; an excluded stay can violate >1 criterion)."""
return f"""
WITH icu AS (
SELECT icu.icustay_id,
EXTRACT(EPOCH FROM (icu.outtime - icu.intime)) / 3600.0 AS los_h,
COALESCE(s3.sepsis3, FALSE) AS is_sepsis3,
saps.sapsii AS sapsii
FROM {MIMIC_SCHEMA}.icustays icu
LEFT JOIN {DERIVED_SCHEMA}.sepsis3 s3 ON s3.icustay_id = icu.icustay_id
LEFT JOIN {DERIVED_SCHEMA}.sapsii saps ON saps.icustay_id = icu.icustay_id
WHERE icu.icustay_id IS NOT NULL
AND icu.hadm_id IS NOT NULL
)
SELECT
COUNT(*) AS n_total,
SUM(CASE WHEN is_sepsis3 = FALSE
THEN 1 ELSE 0 END) AS n_not_sepsis3,
SUM(CASE WHEN los_h < {H_SNAPSHOT}
THEN 1 ELSE 0 END) AS n_los_short,
SUM(CASE WHEN sapsii IS NULL
THEN 1 ELSE 0 END) AS n_sapsii_null,
SUM(CASE WHEN sapsii IS NOT NULL AND sapsii < {SAPSII_MIN}
THEN 1 ELSE 0 END) AS n_sapsii_below
FROM icu
"""
def q_icd_sepsis_waterfall():
"""Mutually-exclusive waterfall, restricted to ICD-coded sepsis stays
only, showing which inclusion gate eliminated them. Uses the union
definition (explicit sepsis codes 038.* septicemia)."""
explicit = ",".join(f"'{c}'" for c in EXPLICIT_SEPSIS_ICD9)
return f"""
WITH icu AS (
SELECT icu.icustay_id, icu.hadm_id,
EXTRACT(EPOCH FROM (icu.outtime - icu.intime)) / 3600.0 AS los_h,
COALESCE(s3.sepsis3, FALSE) AS is_sepsis3,
saps.sapsii AS sapsii
FROM {MIMIC_SCHEMA}.icustays icu
LEFT JOIN {DERIVED_SCHEMA}.sepsis3 s3 ON s3.icustay_id = icu.icustay_id
LEFT JOIN {DERIVED_SCHEMA}.sapsii saps ON saps.icustay_id = icu.icustay_id
WHERE icu.icustay_id IS NOT NULL
AND icu.hadm_id IS NOT NULL
),
icd_pos AS (
SELECT DISTINCT hadm_id
FROM {MIMIC_SCHEMA}.diagnoses_icd
WHERE icd9_code IN ({explicit})
OR icd9_code LIKE '{SEPTICEMIA_PREFIX}%%'
),
icd_stays AS (
SELECT i.* FROM icu i JOIN icd_pos x ON x.hadm_id = i.hadm_id
)
SELECT
COUNT(*) AS n_total_stays,
COUNT(DISTINCT hadm_id) AS n_total_adm,
-- Waterfall: each stay is counted exactly once, in the order of the
-- inclusion filter (sepsis3 → LOS → sapsii NULL → sapsii < 48 → pass).
SUM(CASE WHEN NOT is_sepsis3
THEN 1 ELSE 0 END) AS n_fail_sepsis3,
SUM(CASE WHEN is_sepsis3 AND los_h < {H_SNAPSHOT}
THEN 1 ELSE 0 END) AS n_fail_los,
SUM(CASE WHEN is_sepsis3 AND los_h >= {H_SNAPSHOT}
AND sapsii IS NULL
THEN 1 ELSE 0 END) AS n_fail_sapsii_null,
SUM(CASE WHEN is_sepsis3 AND los_h >= {H_SNAPSHOT}
AND sapsii IS NOT NULL
AND sapsii < {SAPSII_MIN}
THEN 1 ELSE 0 END) AS n_fail_sapsii_below,
SUM(CASE WHEN is_sepsis3 AND los_h >= {H_SNAPSHOT}
AND sapsii IS NOT NULL
AND sapsii >= {SAPSII_MIN}
THEN 1 ELSE 0 END) AS n_pass
FROM icd_stays
"""
# ── Stats helpers (no scipy) ────────────────────────────────────────────────
def wilson_ci(k, n, z=1.959963984540054):
"""Wilson score 95% CI for a binomial proportion. Returns (lo, hi)."""
if n <= 0:
return (float("nan"), float("nan"))
p = k / n
denom = 1.0 + z*z/n
centre = (p + z*z/(2.0*n)) / denom
half = (z * math.sqrt(p*(1.0 - p)/n + z*z/(4.0*n*n))) / denom
return (max(0.0, centre - half), min(1.0, centre + half))
def diff_ci(k1, n1, k2, n2, z=1.959963984540054):
"""Normal-approx 95% CI for (p1 p2). Returns (delta, lo, hi)."""
if n1 <= 0 or n2 <= 0:
return (float("nan"),) * 3
p1, p2 = k1 / n1, k2 / n2
se = math.sqrt(p1*(1.0 - p1)/n1 + p2*(1.0 - p2)/n2)
d = p1 - p2
return (d, d - z*se, d + z*se)
def chi2_2x2(k1, n1, k2, n2):
"""Pearson χ² for the 2×2 table (sepsis± × cohort). Returns (chi2, dof=1).
Critical value at p=0.05 is 3.841."""
a, b = k1, n1 - k1 # incl: sepsis+, sepsis
c, d = k2, n2 - k2 # excl: sepsis+, sepsis
n = n1 + n2
if n == 0: return (float("nan"), 1)
row1, row2 = a + b, c + d
col1, col2 = a + c, b + d
chi2 = 0.0
for obs, r, col in ((a, row1, col1), (b, row1, col2),
(c, row2, col1), (d, row2, col2)):
exp = r * col / n
if exp > 0:
chi2 += (obs - exp) ** 2 / exp
return (chi2, 1)
def fmt_pct(p): return f"{100.0*p:5.2f}%"
def fmt_ci(lo,hi): return f"[{100.0*lo:5.2f}, {100.0*hi:5.2f}]"
# ── Main ────────────────────────────────────────────────────────────────────
def main():
print("\n" + ""*78)
print(" ICD-coded sepsis prevalence — Phase 5b inclusion vs exclusion")
print(""*78)
print(f"\n PG DSN: {PG_DSN}")
print(f" MIMIC schema: {MIMIC_SCHEMA}")
print(f" Derived schema: {DERIVED_SCHEMA}")
print(f" Inclusion: sepsis3=TRUE AND LOS≥{H_SNAPSHOT}h AND SAPS-II≥{SAPSII_MIN}")
print(f" Explicit ICD-9: {', '.join(EXPLICIT_SEPSIS_ICD9)} "
f"(995.91 / 995.92 / 785.52)")
print(f" Septicemia: ICD-9 {SEPTICEMIA_PREFIX}.*")
print(f"\n[1] Querying MIMIC-III...")
rows = run_pg(q_prevalence(), "cohort × ICD prevalence")
bkdwn = run_pg(q_exclusion_breakdown(), "exclusion breakdown")
wfall = run_pg(q_icd_sepsis_waterfall(), "ICD-sepsis waterfall")
if not rows:
print("\n[ERROR] no rows returned. Check PG_DSN / schema permissions.")
sys.exit(1)
by = {r["cohort"]: r for r in rows}
incl = by.get("inclusion", {})
excl = by.get("exclusion", {})
allc = by.get("all", {})
INT_KEYS = ("n_stays","n_admissions","n_subjects",
"n_stays_explicit","n_adm_explicit",
"n_stays_septicemia","n_adm_septicemia",
"n_stays_any","n_adm_any")
for c in (incl, excl, allc):
for k in INT_KEYS:
c[k] = int(c.get(k) or 0)
COHORTS = (("inclusion", incl), ("exclusion", excl), ("all", allc))
# ── [2] Cohort sizes ──────────────────────────────────────────────────
print(f"\n[2] Cohort sizes")
print(f" {'cohort':12s} {'stays':>8s} {'admissions':>11s} {'subjects':>9s}")
for label, c in COHORTS:
print(f" {label:12s} {c['n_stays']:>8,d} "
f"{c['n_admissions']:>11,d} {c['n_subjects']:>9,d}")
# ── [3] Why an ICU stay was excluded ──────────────────────────────────
if bkdwn:
b = bkdwn[0]
print(f"\n[3] Exclusion breakdown (non-exclusive: a stay can fail >1 filter)")
n_total = int(b['n_total'] or 0)
for lbl, k in (("not Sepsis-3", "n_not_sepsis3"),
(f"ICU LOS < {H_SNAPSHOT}h", "n_los_short"),
("SAPS-II is NULL", "n_sapsii_null"),
(f"SAPS-II < {SAPSII_MIN}", "n_sapsii_below")):
n = int(b[k] or 0)
pct = 100.0*n/n_total if n_total else 0.0
print(f" {lbl:24s} {n:>8,d} ({pct:5.2f}% of all ICU stays)")
print(f" {'all ICU stays':24s} {n_total:>8,d}")
# ── [3b] ICD-sepsis-positive waterfall ────────────────────────────────
# Diagnostic: of the ICD-coded sepsis stays (explicit 038.*), which
# inclusion gate eliminated them? Mutually exclusive: each stay is
# counted in the FIRST gate that would reject it, walking in the
# inclusion-filter order (sepsis3 → LOS → SAPS-II NULL → SAPS-II<48).
if wfall:
w = wfall[0]
n_w = int(w["n_total_stays"] or 0)
n_adm = int(w["n_total_adm"] or 0)
print(f"\n[3b] ICD-sepsis-positive waterfall (mutually exclusive,"
f" inclusion-filter order)")
print(f" {'gate':28s} {'n stays':>9s} {'pct':>6s} cumulative")
cum = 0
steps = (
("rejected: not Sepsis-3", "n_fail_sepsis3"),
(f"rejected: LOS < {H_SNAPSHOT}h", "n_fail_los"),
("rejected: SAPS-II is NULL", "n_fail_sapsii_null"),
(f"rejected: SAPS-II < {SAPSII_MIN}", "n_fail_sapsii_below"),
(f"PASS (= inclusion)", "n_pass"),
)
for lbl, kn in steps:
n = int(w[kn] or 0)
cum += n
pct = 100.0*n/n_w if n_w else 0.0
print(f" {lbl:28s} {n:>9,d} {pct:5.2f}% {cum:>9,d}")
print(f" {'TOTAL ICD-sepsis stays':28s} {n_w:>9,d} "
f"({n_adm:,} admissions)")
# ── [4] Prevalence per definition ─────────────────────────────────────
DEFS = (
("Explicit sepsis (995.91 / 995.92 / 785.52)",
"n_stays_explicit", "n_adm_explicit"),
("Angus septicemia (038.*)",
"n_stays_septicemia", "n_adm_septicemia"),
("Any of the above (union)",
"n_stays_any", "n_adm_any"),
)
def _table(title, denom_key_n, denom_key_k):
"""Render the prevalence table for one denominator (stays or
admissions) and append rows to `results[bucket]`."""
print(f"\n{title}")
print(f" {'definition':45s} {'cohort':10s} "
f"{'n+':>7s} {'N':>7s} {'prev':>7s} {'95% CI (Wilson)':>18s}")
out = []
for name, sk, ak in DEFS:
kkey = sk if denom_key_k == "n_stays" else ak
for label, c in COHORTS:
k_, n_ = c[kkey], c[denom_key_n]
p = k_/n_ if n_ else float("nan")
lo, hi = wilson_ci(k_, n_)
print(f" {name:45s} {label:10s} "
f"{k_:>7,d} {n_:>7,d} {fmt_pct(p):>7s} "
f"{fmt_ci(lo,hi):>18s}")
out.append({"definition": name, "cohort": label,
"k": k_, "n": n_, "prevalence": p,
"ci_lo": lo, "ci_hi": hi})
# Inclusion vs exclusion comparison (the "all" row is just a
# weighted average of the two so a Δ against it isn't meaningful).
k1, n1 = incl[kkey], incl[denom_key_n]
k2, n2 = excl[kkey], excl[denom_key_n]
d, dlo, dhi = diff_ci(k1, n1, k2, n2)
chi2, dof = chi2_2x2(k1, n1, k2, n2)
sig = "p<0.05" if chi2 > 3.841 else "n.s."
print(f" {' Δ (incl excl)':45s} {'':10s} "
f"{'':>7s} {'':>7s} {fmt_pct(d):>7s} "
f"{fmt_ci(dlo,dhi):>18s} χ²={chi2:6.2f} ({sig})")
out.append({"definition": name, "cohort": "delta_incl_minus_excl",
"delta": d, "ci_lo": dlo, "ci_hi": dhi,
"chi2": chi2, "dof": dof})
return out
results = {
"by_stays": _table("[4] ICD-coded sepsis prevalence (denominator = ICU STAYS)",
"n_stays", "n_stays"),
"by_admissions": _table("[5] ICD-coded sepsis prevalence (denominator = ADMISSIONS)",
"n_admissions", "n_admissions"),
}
# ── Save ────────────────────────────────────────────────────────────
output = {
"filters": {
"h_snapshot_hours": H_SNAPSHOT,
"sapsii_min": SAPSII_MIN,
"explicit_icd9": list(EXPLICIT_SEPSIS_ICD9),
"septicemia_prefix": SEPTICEMIA_PREFIX,
},
"cohorts": {
"inclusion": {k: incl[k] for k in INT_KEYS},
"exclusion": {k: excl[k] for k in INT_KEYS},
"all": {k: allc[k] for k in INT_KEYS},
},
"exclusion_breakdown": (bkdwn[0] if bkdwn else None),
"icd_sepsis_waterfall": (wfall[0] if wfall else None),
"results": results,
}
with open(OUT_FILE, "w") as f:
json.dump(output, f, indent=2, default=str)
print(f"\n → Saved: {OUT_FILE}")
print("\n" + ""*78 + "\n")
if __name__ == "__main__":
main()