ICU Preprocessing Pipeline with ReciPies¶
This notebook demonstrates a complete end-to-end preprocessing pipeline for ICU time-series data using ReciPies (see https://github.com/rvandewater/YAIB for more info). We'll cover:
- Data Loading: Loading dynamic measurements, static features, and outcomes from parquet files
- Train/Test Split: Proper group-level splitting to prevent data leakage
- Multi-Step Pipeline:
- Missing value imputation (forward fill + zero fill)
- Feature scaling (standardization)
- Historical feature engineering (rolling mean and max)
- Custom domain-specific features
- Baking the Data: Applying the preprocessing pipeline to both training and test sets
- Model Training: Using the preprocessed data to train a machine learning model
The pipeline uses Polars for high-performance data processing, with ReciPies handling all preprocessing steps while maintaining column role information throughout the transformation pipeline.
1. Load ICU Data¶
We start by loading the ICU demo data, which consists of three components:
- Dynamic data: Time-varying measurements (vitals, lab values) recorded at regular intervals
- Static data: Patient-level features that don't change over time (age, sex, height, weight)
- Outcome data: The target variable we want to predict (mortality at 24 hours)
Let's examine the structure of each dataset.
import numpy as np
import polars as pl
from sklearn.preprocessing import LabelEncoder
from sklearn.impute import SimpleImputer
from IPython.display import display
from recipies import Ingredients, Recipe
from recipies.selector import all_predictors, all_numeric_predictors, has_role, has_type, all_of
from recipies.step import StepImputeFill, StepHistorical, StepScale, StepFunction, Accumulator, StepSklearn
dynamic_data = pl.read_parquet("../../examples/icu_demo_data/mortality24/eicu_demo/dyn.parquet")
static_data = pl.read_parquet("../../examples/icu_demo_data/mortality24/eicu_demo/sta.parquet")
outcome = pl.read_parquet("../../examples/icu_demo_data/mortality24/eicu_demo/outc.parquet")
print("Columns:")
print(f"dynamic: {dynamic_data.columns}")
print(f"static: {static_data.columns}")
print(f"outcome: {outcome.columns}")
print("Shapes:")
print(f"dynamic: {dynamic_data.shape}")
print(f"static: {static_data.shape}")
print(f"outcome: {outcome.shape}")
print("Heads:")
display(dynamic_data.head())
display(static_data.head())
display(outcome.head())
Columns: dynamic: ['stay_id', 'time', 'alb', 'alp', 'alt', 'ast', 'be', 'bicar', 'bili', 'bili_dir', 'bnd', 'bun', 'ca', 'cai', 'ck', 'ckmb', 'cl', 'crea', 'crp', 'dbp', 'fgn', 'fio2', 'glu', 'hgb', 'hr', 'inr_pt', 'k', 'lact', 'lymph', 'map', 'mch', 'mchc', 'mcv', 'methb', 'mg', 'na', 'neut', 'o2sat', 'pco2', 'ph', 'phos', 'plt', 'po2', 'ptt', 'resp', 'sbp', 'temp', 'tnt', 'urine', 'wbc'] static: ['stay_id', 'age', 'sex', 'height', 'weight'] outcome: ['stay_id', 'label'] Shapes: dynamic: (34175, 50) static: (1367, 5) outcome: (1367, 2) Heads:
| stay_id | time | alb | alp | alt | ast | be | bicar | bili | bili_dir | bnd | bun | ca | cai | ck | ckmb | cl | crea | crp | dbp | fgn | fio2 | glu | hgb | hr | inr_pt | k | lact | lymph | map | mch | mchc | mcv | methb | mg | na | neut | o2sat | pco2 | ph | phos | plt | po2 | ptt | resp | sbp | temp | tnt | urine | wbc |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| i32 | duration[ms] | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 |
| 141765 | 0ms | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | 87.0 | null | null | null | null | 83.0 | null | null | null | null | 108.0 | null | null | null | null | null | null | null | 96.0 | null | null | null | null | null | null | 18.0 | 142.0 | 36.711111 | null | null | null |
| 141765 | 1h | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | 70.0 | null | null | null | null | 80.0 | null | null | null | null | 99.0 | null | null | null | null | null | null | null | 96.0 | null | null | null | null | null | null | 23.5 | 144.0 | 36.894444 | null | null | null |
| 141765 | 2h | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | 67.0 | null | null | null | null | 77.0 | null | null | null | null | 97.0 | null | null | null | null | null | null | null | 96.0 | null | null | null | null | null | null | 23.0 | 139.0 | null | null | null | null |
| 141765 | 3h | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | 67.0 | null | null | null | null | 79.0 | null | null | null | null | 99.0 | null | null | null | null | null | null | null | 96.0 | null | null | null | null | null | null | 24.0 | 133.0 | null | null | null | null |
| 141765 | 4h | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | 67.0 | null | null | null | null | 76.0 | null | null | null | null | 89.0 | null | null | null | null | null | null | null | 96.0 | null | null | null | null | null | null | 25.0 | 120.0 | null | null | null | null |
| stay_id | age | sex | height | weight |
|---|---|---|---|---|
| i32 | f64 | str | f64 | f64 |
| 141765 | 87.0 | "Female" | 157.5 | 46.5 |
| 147784 | 60.0 | "Female" | 154.9 | 95.6 |
| 151179 | 59.0 | "Female" | 149.9 | null |
| 151867 | 44.0 | "Male" | 172.7 | null |
| 151900 | 66.0 | "Female" | 165.1 | 86.8 |
| stay_id | label |
|---|---|
| i32 | i32 |
| 141765 | 0 |
| 147784 | 0 |
| 151179 | 0 |
| 151867 | 0 |
| 151900 | 0 |
2. Train/Test Split¶
Critical: We perform a group-level split at the stay_id level. This ensures that all records for a given patient stay are assigned to either the training or test set, preventing data leakage where information from the test set could leak into the training process.
We use an 80/20 split stratified by stay_id:
# Train/test split at the stay_id level (group-level split)
# This ensures all records for a given stay go to either train or test
# Get unique stay_ids
unique_stays = outcome.select("stay_id").unique().sample(fraction=1.0, seed=42)
n_train = int(len(unique_stays) * 0.8)
train_stay_ids = unique_stays.head(n_train)["stay_id"].to_list()
test_stay_ids = unique_stays.tail(len(unique_stays) - n_train)["stay_id"].to_list()
# Split dynamic, static, and outcome data
dynamic_train = dynamic_data.filter(pl.col("stay_id").is_in(train_stay_ids))
dynamic_test = dynamic_data.filter(pl.col("stay_id").is_in(test_stay_ids))
static_train = static_data.filter(pl.col("stay_id").is_in(train_stay_ids))
static_test = static_data.filter(pl.col("stay_id").is_in(test_stay_ids))
outcome_train = outcome.filter(pl.col("stay_id").is_in(train_stay_ids))
outcome_test = outcome.filter(pl.col("stay_id").is_in(test_stay_ids))
# Join train data
df_train = dynamic_train.join(static_train, on="stay_id", how="left")
df_train = df_train.join(outcome_train.select(["stay_id", "label"]), on="stay_id", how="left")
# Join test data
df_test = dynamic_test.join(static_test, on="stay_id", how="left")
df_test = df_test.join(outcome_test.select(["stay_id", "label"]), on="stay_id", how="left")
print(f"Train: {len(df_train)} rows, {len(train_stay_ids)} stays")
print(f"Test: {len(df_test)} rows, {len(test_stay_ids)} stays")
Train: 27325 rows, 1093 stays Test: 6850 rows, 274 stays
# Quick check: verify we have the expected columns after joining
print(f"Train dataframe columns: {len(df_train.columns)}")
print(f"Test dataframe columns: {len(df_test.columns)}")
Train dataframe columns: 55 Test dataframe columns: 55
3. Build Preprocessing Pipeline¶
Now we'll create a comprehensive preprocessing pipeline using ReciPies. The pipeline includes:
- Role Assignment: Define which columns are outcomes, predictors, groups (
stay_id), and sequences (time) - Imputation: Forward fill followed by zero fill for any remaining missing values
- Feature Scaling: Standardize numeric predictors (mean=0, std=1)
- Historical Features: Create rolling mean and max aggregations over time within each stay
- Custom Features: Add domain-specific features like heart rate to temperature ratio
The key advantage of ReciPies is that all transformations maintain column role information, ensuring proper handling of grouped time-series data.
# Initialize Ingredients
ing = Ingredients(df_train)
# Define and build the recipe
rec = Recipe(
ing,
outcomes=["label"],
predictors=[c for c in ing.columns if c not in {"label", "stay_id", "time"}],
groups=["stay_id"],
sequences=["time"],
)
# Impute missing values forward (pre-resample)
rec.add_step(StepImputeFill(sel=all_predictors(), strategy="forward"))
rec.add_step(StepImputeFill(sel=all_predictors(), strategy="zero"))
# Scale numeric predictors at the end (after imputation)
rec.add_step(StepScale(sel=all_numeric_predictors(), with_mean=True, with_std=True))
# Add a custom domain feature (example: hr/temp ratio) via StepFunction
def add_custom_features(ingr: Ingredients, columns) -> Ingredients:
df_ = ingr.get_df()
if all(col in df_.columns for col in ["hr", "temp"]):
df_ = df_.with_columns((pl.col("hr") / pl.col("temp")).alias("hr_temp_ratio"))
ingr.set_df(df_)
ingr.update_role("hr_temp_ratio", "predictor")
return ingr
rec.add_step(StepFunction(sel=has_role(["predictor"]), function=add_custom_features))
# Label encode categorical features
types = ["String", "Object", "Categorical"]
rec.add_step(StepSklearn(SimpleImputer(missing_values=np.nan, strategy="most_frequent"), sel=has_type(types)))
rec.add_step(StepSklearn(LabelEncoder(), sel=has_type(types), columnwise=True))
original_predictors = all_of(list(all_numeric_predictors()(ing))) # Capture the fixed list of original numeric predictors
# Historical features
rec.add_step(StepHistorical(sel=original_predictors, fun=Accumulator.MEAN, suffix="_mean_hist"))
rec.add_step(StepHistorical(sel=original_predictors, fun=Accumulator.MIN, suffix="_min_hist"))
rec.add_step(StepHistorical(sel=original_predictors, fun=Accumulator.MAX, suffix="_max_hist"))
rec.add_step(StepHistorical(sel=original_predictors, fun=Accumulator.VAR, suffix="_var_hist"))
rec.add_step
# Prep and bake (fit and transform) the training data
train_baked = rec.prep()
display(train_baked.head())
print(train_baked.columns)
print(len(train_baked.columns))
sel: roles: ['predictor']
| stay_id | time | alb | alp | alt | ast | be | bicar | bili | bili_dir | bnd | bun | ca | cai | ck | ckmb | cl | crea | crp | dbp | fgn | fio2 | glu | hgb | hr | inr_pt | k | lact | lymph | map | mch | mchc | mcv | methb | mg | na | neut | … | cl_var_hist | crea_var_hist | crp_var_hist | dbp_var_hist | fgn_var_hist | fio2_var_hist | glu_var_hist | hgb_var_hist | hr_var_hist | inr_pt_var_hist | k_var_hist | lact_var_hist | lymph_var_hist | map_var_hist | mch_var_hist | mchc_var_hist | mcv_var_hist | methb_var_hist | mg_var_hist | na_var_hist | neut_var_hist | o2sat_var_hist | pco2_var_hist | ph_var_hist | phos_var_hist | plt_var_hist | po2_var_hist | ptt_var_hist | resp_var_hist | sbp_var_hist | temp_var_hist | tnt_var_hist | urine_var_hist | wbc_var_hist | age_var_hist | height_var_hist | weight_var_hist |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| i32 | duration[ms] | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | … | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 |
| 141765 | 0ms | -0.702261 | -0.506281 | -0.144356 | -0.11373 | 0.022402 | -1.191386 | -0.239623 | -0.07978 | -0.157898 | -0.829268 | -1.302892 | -0.154937 | -0.071739 | -0.08439 | -1.342402 | -0.662689 | -0.054799 | 1.267354 | -0.172039 | -0.776926 | -1.399734 | -1.221039 | 0.01027 | -0.507669 | -1.364243 | -0.322496 | -0.529953 | 1.32194 | -1.123966 | -1.173064 | -1.169936 | -0.232168 | -0.766994 | -1.396505 | -0.69738 | … | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null |
| 141765 | 1h | -0.702261 | -0.506281 | -0.144356 | -0.11373 | 0.022402 | -1.191386 | -0.239623 | -0.07978 | -0.157898 | -0.829268 | -1.302892 | -0.154937 | -0.071739 | -0.08439 | -1.342402 | -0.662689 | -0.054799 | 0.392267 | -0.172039 | -0.776926 | -1.399734 | -1.221039 | -0.113637 | -0.507669 | -1.364243 | -0.322496 | -0.529953 | 0.925237 | -1.123966 | -1.173064 | -1.169936 | -0.232168 | -0.766994 | -1.396505 | -0.69738 | … | 0.0 | 0.0 | 0.0 | 0.382889 | 0.0 | 0.0 | 0.0 | 0.0 | 0.007677 | 0.0 | 0.0 | 0.0 | 0.0 | 0.078687 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.214599 | 0.001843 | 0.000166 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 141765 | 2h | -0.702261 | -0.506281 | -0.144356 | -0.11373 | 0.022402 | -1.191386 | -0.239623 | -0.07978 | -0.157898 | -0.829268 | -1.302892 | -0.154937 | -0.071739 | -0.08439 | -1.342402 | -0.662689 | -0.054799 | 0.237839 | -0.172039 | -0.776926 | -1.399734 | -1.221039 | -0.237545 | -0.507669 | -1.364243 | -0.322496 | -0.529953 | 0.837081 | -1.123966 | -1.173064 | -1.169936 | -0.232168 | -0.766994 | -1.396505 | -0.69738 | … | 0.0 | 0.0 | 0.0 | 0.308254 | 0.0 | 0.0 | 0.0 | 0.0 | 0.015353 | 0.0 | 0.0 | 0.0 | 0.0 | 0.066706 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.131242 | 0.005837 | 0.000111 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 141765 | 3h | -0.702261 | -0.506281 | -0.144356 | -0.11373 | 0.022402 | -1.191386 | -0.239623 | -0.07978 | -0.157898 | -0.829268 | -1.302892 | -0.154937 | -0.071739 | -0.08439 | -1.342402 | -0.662689 | -0.054799 | 0.237839 | -0.172039 | -0.776926 | -1.399734 | -1.221039 | -0.15494 | -0.507669 | -1.364243 | -0.322496 | -0.529953 | 0.925237 | -1.123966 | -1.173064 | -1.169936 | -0.232168 | -0.766994 | -1.396505 | -0.69738 | … | 0.0 | 0.0 | 0.0 | 0.244439 | 0.0 | 0.0 | 0.0 | 0.0 | 0.010662 | 0.0 | 0.0 | 0.0 | 0.0 | 0.047115 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.109664 | 0.021198 | 0.000083 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 141765 | 4h | -0.702261 | -0.506281 | -0.144356 | -0.11373 | 0.022402 | -1.191386 | -0.239623 | -0.07978 | -0.157898 | -0.829268 | -1.302892 | -0.154937 | -0.071739 | -0.08439 | -1.342402 | -0.662689 | -0.054799 | 0.237839 | -0.172039 | -0.776926 | -1.399734 | -1.221039 | -0.278847 | -0.507669 | -1.364243 | -0.322496 | -0.529953 | 0.484455 | -1.123966 | -1.173064 | -1.169936 | -0.232168 | -0.766994 | -1.396505 | -0.69738 | … | 0.0 | 0.0 | 0.0 | 0.200851 | 0.0 | 0.0 | 0.0 | 0.0 | 0.012794 | 0.0 | 0.0 | 0.0 | 0.0 | 0.088984 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.105703 | 0.085989 | 0.000066 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
['stay_id', 'time', 'alb', 'alp', 'alt', 'ast', 'be', 'bicar', 'bili', 'bili_dir', 'bnd', 'bun', 'ca', 'cai', 'ck', 'ckmb', 'cl', 'crea', 'crp', 'dbp', 'fgn', 'fio2', 'glu', 'hgb', 'hr', 'inr_pt', 'k', 'lact', 'lymph', 'map', 'mch', 'mchc', 'mcv', 'methb', 'mg', 'na', 'neut', 'o2sat', 'pco2', 'ph', 'phos', 'plt', 'po2', 'ptt', 'resp', 'sbp', 'temp', 'tnt', 'urine', 'wbc', 'age', 'sex', 'height', 'weight', 'label', 'hr_temp_ratio', 'alb_mean_hist', 'alp_mean_hist', 'alt_mean_hist', 'ast_mean_hist', 'be_mean_hist', 'bicar_mean_hist', 'bili_mean_hist', 'bili_dir_mean_hist', 'bnd_mean_hist', 'bun_mean_hist', 'ca_mean_hist', 'cai_mean_hist', 'ck_mean_hist', 'ckmb_mean_hist', 'cl_mean_hist', 'crea_mean_hist', 'crp_mean_hist', 'dbp_mean_hist', 'fgn_mean_hist', 'fio2_mean_hist', 'glu_mean_hist', 'hgb_mean_hist', 'hr_mean_hist', 'inr_pt_mean_hist', 'k_mean_hist', 'lact_mean_hist', 'lymph_mean_hist', 'map_mean_hist', 'mch_mean_hist', 'mchc_mean_hist', 'mcv_mean_hist', 'methb_mean_hist', 'mg_mean_hist', 'na_mean_hist', 'neut_mean_hist', 'o2sat_mean_hist', 'pco2_mean_hist', 'ph_mean_hist', 'phos_mean_hist', 'plt_mean_hist', 'po2_mean_hist', 'ptt_mean_hist', 'resp_mean_hist', 'sbp_mean_hist', 'temp_mean_hist', 'tnt_mean_hist', 'urine_mean_hist', 'wbc_mean_hist', 'age_mean_hist', 'height_mean_hist', 'weight_mean_hist', 'alb_min_hist', 'alp_min_hist', 'alt_min_hist', 'ast_min_hist', 'be_min_hist', 'bicar_min_hist', 'bili_min_hist', 'bili_dir_min_hist', 'bnd_min_hist', 'bun_min_hist', 'ca_min_hist', 'cai_min_hist', 'ck_min_hist', 'ckmb_min_hist', 'cl_min_hist', 'crea_min_hist', 'crp_min_hist', 'dbp_min_hist', 'fgn_min_hist', 'fio2_min_hist', 'glu_min_hist', 'hgb_min_hist', 'hr_min_hist', 'inr_pt_min_hist', 'k_min_hist', 'lact_min_hist', 'lymph_min_hist', 'map_min_hist', 'mch_min_hist', 'mchc_min_hist', 'mcv_min_hist', 'methb_min_hist', 'mg_min_hist', 'na_min_hist', 'neut_min_hist', 'o2sat_min_hist', 'pco2_min_hist', 'ph_min_hist', 'phos_min_hist', 'plt_min_hist', 'po2_min_hist', 'ptt_min_hist', 'resp_min_hist', 'sbp_min_hist', 'temp_min_hist', 'tnt_min_hist', 'urine_min_hist', 'wbc_min_hist', 'age_min_hist', 'height_min_hist', 'weight_min_hist', 'alb_max_hist', 'alp_max_hist', 'alt_max_hist', 'ast_max_hist', 'be_max_hist', 'bicar_max_hist', 'bili_max_hist', 'bili_dir_max_hist', 'bnd_max_hist', 'bun_max_hist', 'ca_max_hist', 'cai_max_hist', 'ck_max_hist', 'ckmb_max_hist', 'cl_max_hist', 'crea_max_hist', 'crp_max_hist', 'dbp_max_hist', 'fgn_max_hist', 'fio2_max_hist', 'glu_max_hist', 'hgb_max_hist', 'hr_max_hist', 'inr_pt_max_hist', 'k_max_hist', 'lact_max_hist', 'lymph_max_hist', 'map_max_hist', 'mch_max_hist', 'mchc_max_hist', 'mcv_max_hist', 'methb_max_hist', 'mg_max_hist', 'na_max_hist', 'neut_max_hist', 'o2sat_max_hist', 'pco2_max_hist', 'ph_max_hist', 'phos_max_hist', 'plt_max_hist', 'po2_max_hist', 'ptt_max_hist', 'resp_max_hist', 'sbp_max_hist', 'temp_max_hist', 'tnt_max_hist', 'urine_max_hist', 'wbc_max_hist', 'age_max_hist', 'height_max_hist', 'weight_max_hist', 'alb_var_hist', 'alp_var_hist', 'alt_var_hist', 'ast_var_hist', 'be_var_hist', 'bicar_var_hist', 'bili_var_hist', 'bili_dir_var_hist', 'bnd_var_hist', 'bun_var_hist', 'ca_var_hist', 'cai_var_hist', 'ck_var_hist', 'ckmb_var_hist', 'cl_var_hist', 'crea_var_hist', 'crp_var_hist', 'dbp_var_hist', 'fgn_var_hist', 'fio2_var_hist', 'glu_var_hist', 'hgb_var_hist', 'hr_var_hist', 'inr_pt_var_hist', 'k_var_hist', 'lact_var_hist', 'lymph_var_hist', 'map_var_hist', 'mch_var_hist', 'mchc_var_hist', 'mcv_var_hist', 'methb_var_hist', 'mg_var_hist', 'na_var_hist', 'neut_var_hist', 'o2sat_var_hist', 'pco2_var_hist', 'ph_var_hist', 'phos_var_hist', 'plt_var_hist', 'po2_var_hist', 'ptt_var_hist', 'resp_var_hist', 'sbp_var_hist', 'temp_var_hist', 'tnt_var_hist', 'urine_var_hist', 'wbc_var_hist', 'age_var_hist', 'height_var_hist', 'weight_var_hist'] 260
4. Apply Pipeline to Test Data¶
Once the recipe is fitted on the training data using prep(), we can apply the same transformations to the test data using bake(). This ensures:
- No data leakage: Test data statistics are never used to fit the pipeline
- Consistent transformations: The same preprocessing steps are applied identically to both datasets
- Reproducibility: The fitted recipe can be saved and reused on new data
The bake() method applies all fitted transformations without refitting, ensuring the test set is processed identically to how the training set was processed.
test_baked = rec.bake(df_test)
display(test_baked.head())
print(test_baked.columns)
print(len(test_baked.columns))
| stay_id | time | alb | alp | alt | ast | be | bicar | bili | bili_dir | bnd | bun | ca | cai | ck | ckmb | cl | crea | crp | dbp | fgn | fio2 | glu | hgb | hr | inr_pt | k | lact | lymph | map | mch | mchc | mcv | methb | mg | na | neut | … | cl_var_hist | crea_var_hist | crp_var_hist | dbp_var_hist | fgn_var_hist | fio2_var_hist | glu_var_hist | hgb_var_hist | hr_var_hist | inr_pt_var_hist | k_var_hist | lact_var_hist | lymph_var_hist | map_var_hist | mch_var_hist | mchc_var_hist | mcv_var_hist | methb_var_hist | mg_var_hist | na_var_hist | neut_var_hist | o2sat_var_hist | pco2_var_hist | ph_var_hist | phos_var_hist | plt_var_hist | po2_var_hist | ptt_var_hist | resp_var_hist | sbp_var_hist | temp_var_hist | tnt_var_hist | urine_var_hist | wbc_var_hist | age_var_hist | height_var_hist | weight_var_hist |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| i32 | duration[ms] | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | … | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 |
| 157016 | 0ms | -0.702261 | -0.506281 | -0.144356 | -0.11373 | 0.022402 | -1.191386 | -0.239623 | -0.07978 | -0.157898 | -0.829268 | -1.302892 | -0.154937 | -0.071739 | -0.08439 | -1.342402 | -0.662689 | -0.054799 | 1.370305 | -0.172039 | -0.776926 | -1.399734 | -1.221039 | 0.258085 | -0.507669 | -1.364243 | -0.322496 | -0.529953 | 1.277862 | -1.123966 | -1.173064 | -1.169936 | -0.232168 | -0.766994 | -1.396505 | -0.69738 | … | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null | null |
| 157016 | 1h | -0.702261 | -0.506281 | -0.144356 | -0.11373 | 0.022402 | -1.191386 | -0.239623 | -0.07978 | -0.157898 | -0.829268 | -1.302892 | -0.154937 | -0.071739 | -0.08439 | -1.342402 | -0.662689 | -0.054799 | 1.112927 | -0.172039 | -0.776926 | -1.399734 | -1.221039 | 0.237434 | -0.507669 | -1.364243 | -0.322496 | -0.529953 | 0.969315 | -1.123966 | -1.173064 | -1.169936 | -0.232168 | -0.766994 | -1.396505 | -0.69738 | … | 0.0 | 0.0 | 0.0 | 0.033122 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000213 | 0.0 | 0.0 | 0.0 | 0.0 | 0.047601 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.001423 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.113507 | 0.01947 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 157016 | 2h | -0.702261 | -0.506281 | -0.144356 | -0.11373 | 0.022402 | -1.191386 | -0.239623 | -0.07978 | -0.157898 | -0.829268 | -1.302892 | -0.154937 | -0.071739 | -0.08439 | -1.342402 | -0.662689 | -0.054799 | 1.164402 | -0.172039 | -0.776926 | -1.399734 | -1.221039 | -0.340801 | -0.507669 | -1.364243 | -0.322496 | -0.529953 | 1.123588 | -1.123966 | -1.173064 | -1.169936 | -0.232168 | -0.766994 | -1.396505 | -0.69738 | … | 0.0 | 0.0 | 0.0 | 0.018548 | 0.0 | 0.0 | 0.0 | 0.0 | 0.115574 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0238 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000948 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.288496 | 0.010215 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 157016 | 3h | -0.702261 | -0.506281 | -0.144356 | -0.11373 | 0.022402 | -1.191386 | -0.239623 | -0.07978 | -0.157898 | -0.829268 | -1.302892 | -0.154937 | -0.071739 | -0.08439 | -1.342402 | -0.662689 | -0.054799 | 1.524732 | -0.172039 | -0.776926 | -1.399734 | -1.221039 | -0.361452 | -0.507669 | -1.364243 | -0.322496 | -0.529953 | 1.410096 | -1.123966 | -1.173064 | -1.169936 | -0.232168 | -0.766994 | -1.396505 | -0.69738 | … | 0.0 | 0.0 | 0.0 | 0.036213 | 0.0 | 0.0 | 0.0 | 0.0 | 0.119697 | 0.0 | 0.0 | 0.0 | 0.0 | 0.036389 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000711 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.269578 | 0.019201 | 0.001646 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 157016 | 4h | 0.842301 | 2.159296 | -0.030123 | -0.047428 | 0.022402 | 0.725248 | -0.070302 | -0.07978 | -0.157898 | -0.550159 | 0.854627 | -0.154937 | -0.071739 | -0.08439 | 0.768036 | -0.25071 | -0.054799 | 1.524732 | -0.172039 | -0.776926 | -0.288906 | -1.221039 | -0.051684 | -0.507669 | 0.684495 | -0.322496 | -0.529953 | 1.410096 | -1.123966 | -1.173064 | -1.169936 | -0.232168 | -0.766994 | 0.735023 | -0.69738 | … | 0.89079 | 0.033945 | 0.0 | 0.037891 | 0.0 | 0.0 | 0.246788 | 0.0 | 0.089773 | 0.0 | 0.839465 | 0.0 | 0.0 | 0.036526 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.908683 | 0.0 | 0.000569 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.236945 | 0.019977 | 0.001975 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
['stay_id', 'time', 'alb', 'alp', 'alt', 'ast', 'be', 'bicar', 'bili', 'bili_dir', 'bnd', 'bun', 'ca', 'cai', 'ck', 'ckmb', 'cl', 'crea', 'crp', 'dbp', 'fgn', 'fio2', 'glu', 'hgb', 'hr', 'inr_pt', 'k', 'lact', 'lymph', 'map', 'mch', 'mchc', 'mcv', 'methb', 'mg', 'na', 'neut', 'o2sat', 'pco2', 'ph', 'phos', 'plt', 'po2', 'ptt', 'resp', 'sbp', 'temp', 'tnt', 'urine', 'wbc', 'age', 'sex', 'height', 'weight', 'label', 'hr_temp_ratio', 'alb_mean_hist', 'alp_mean_hist', 'alt_mean_hist', 'ast_mean_hist', 'be_mean_hist', 'bicar_mean_hist', 'bili_mean_hist', 'bili_dir_mean_hist', 'bnd_mean_hist', 'bun_mean_hist', 'ca_mean_hist', 'cai_mean_hist', 'ck_mean_hist', 'ckmb_mean_hist', 'cl_mean_hist', 'crea_mean_hist', 'crp_mean_hist', 'dbp_mean_hist', 'fgn_mean_hist', 'fio2_mean_hist', 'glu_mean_hist', 'hgb_mean_hist', 'hr_mean_hist', 'inr_pt_mean_hist', 'k_mean_hist', 'lact_mean_hist', 'lymph_mean_hist', 'map_mean_hist', 'mch_mean_hist', 'mchc_mean_hist', 'mcv_mean_hist', 'methb_mean_hist', 'mg_mean_hist', 'na_mean_hist', 'neut_mean_hist', 'o2sat_mean_hist', 'pco2_mean_hist', 'ph_mean_hist', 'phos_mean_hist', 'plt_mean_hist', 'po2_mean_hist', 'ptt_mean_hist', 'resp_mean_hist', 'sbp_mean_hist', 'temp_mean_hist', 'tnt_mean_hist', 'urine_mean_hist', 'wbc_mean_hist', 'age_mean_hist', 'height_mean_hist', 'weight_mean_hist', 'alb_min_hist', 'alp_min_hist', 'alt_min_hist', 'ast_min_hist', 'be_min_hist', 'bicar_min_hist', 'bili_min_hist', 'bili_dir_min_hist', 'bnd_min_hist', 'bun_min_hist', 'ca_min_hist', 'cai_min_hist', 'ck_min_hist', 'ckmb_min_hist', 'cl_min_hist', 'crea_min_hist', 'crp_min_hist', 'dbp_min_hist', 'fgn_min_hist', 'fio2_min_hist', 'glu_min_hist', 'hgb_min_hist', 'hr_min_hist', 'inr_pt_min_hist', 'k_min_hist', 'lact_min_hist', 'lymph_min_hist', 'map_min_hist', 'mch_min_hist', 'mchc_min_hist', 'mcv_min_hist', 'methb_min_hist', 'mg_min_hist', 'na_min_hist', 'neut_min_hist', 'o2sat_min_hist', 'pco2_min_hist', 'ph_min_hist', 'phos_min_hist', 'plt_min_hist', 'po2_min_hist', 'ptt_min_hist', 'resp_min_hist', 'sbp_min_hist', 'temp_min_hist', 'tnt_min_hist', 'urine_min_hist', 'wbc_min_hist', 'age_min_hist', 'height_min_hist', 'weight_min_hist', 'alb_max_hist', 'alp_max_hist', 'alt_max_hist', 'ast_max_hist', 'be_max_hist', 'bicar_max_hist', 'bili_max_hist', 'bili_dir_max_hist', 'bnd_max_hist', 'bun_max_hist', 'ca_max_hist', 'cai_max_hist', 'ck_max_hist', 'ckmb_max_hist', 'cl_max_hist', 'crea_max_hist', 'crp_max_hist', 'dbp_max_hist', 'fgn_max_hist', 'fio2_max_hist', 'glu_max_hist', 'hgb_max_hist', 'hr_max_hist', 'inr_pt_max_hist', 'k_max_hist', 'lact_max_hist', 'lymph_max_hist', 'map_max_hist', 'mch_max_hist', 'mchc_max_hist', 'mcv_max_hist', 'methb_max_hist', 'mg_max_hist', 'na_max_hist', 'neut_max_hist', 'o2sat_max_hist', 'pco2_max_hist', 'ph_max_hist', 'phos_max_hist', 'plt_max_hist', 'po2_max_hist', 'ptt_max_hist', 'resp_max_hist', 'sbp_max_hist', 'temp_max_hist', 'tnt_max_hist', 'urine_max_hist', 'wbc_max_hist', 'age_max_hist', 'height_max_hist', 'weight_max_hist', 'alb_var_hist', 'alp_var_hist', 'alt_var_hist', 'ast_var_hist', 'be_var_hist', 'bicar_var_hist', 'bili_var_hist', 'bili_dir_var_hist', 'bnd_var_hist', 'bun_var_hist', 'ca_var_hist', 'cai_var_hist', 'ck_var_hist', 'ckmb_var_hist', 'cl_var_hist', 'crea_var_hist', 'crp_var_hist', 'dbp_var_hist', 'fgn_var_hist', 'fio2_var_hist', 'glu_var_hist', 'hgb_var_hist', 'hr_var_hist', 'inr_pt_var_hist', 'k_var_hist', 'lact_var_hist', 'lymph_var_hist', 'map_var_hist', 'mch_var_hist', 'mchc_var_hist', 'mcv_var_hist', 'methb_var_hist', 'mg_var_hist', 'na_var_hist', 'neut_var_hist', 'o2sat_var_hist', 'pco2_var_hist', 'ph_var_hist', 'phos_var_hist', 'plt_var_hist', 'po2_var_hist', 'ptt_var_hist', 'resp_var_hist', 'sbp_var_hist', 'temp_var_hist', 'tnt_var_hist', 'urine_var_hist', 'wbc_var_hist', 'age_var_hist', 'height_var_hist', 'weight_var_hist'] 260
5. Train a Machine Learning Model¶
With our preprocessed data ready, we can now train a machine learning model. The preprocessed dataframes contain:
- All original features (scaled and imputed)
- Historical aggregated features
- One-hot encoded categorical variables
- Custom domain features (e.g., hr/temp ratio)
For demonstration, we'll use a simple logistic regression model, but you can use any scikit-learn compatible model or more advanced methods like XGBoost, LightGBM, or neural networks.
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, classification_report
import numpy as np
# Extract features and labels
# Exclude outcome, group, and sequence columns from features
feature_cols = [c for c in train_baked.columns if c not in ["label", "stay_id", "time"]]
X_train = train_baked.select(feature_cols).to_numpy()
y_train = train_baked.select("label").to_numpy().ravel()
X_test = test_baked.select(feature_cols).to_numpy()
y_test = test_baked.select("label").to_numpy().ravel()
# Handle any remaining NaN values (should be minimal after preprocessing)
X_train = np.nan_to_num(X_train, nan=0.0)
X_test = np.nan_to_num(X_test, nan=0.0)
print(f"Training set: {X_train.shape[0]} samples, {X_train.shape[1]} features")
print(f"Test set: {X_test.shape[0]} samples, {X_test.shape[1]} features")
print(f"Class distribution (train): {np.bincount(y_train)}")
print(f"Class distribution (test): {np.bincount(y_test)}")
# Train model
model = LogisticRegression(max_iter=1000, random_state=42, class_weight="balanced")
model.fit(X_train, y_train)
# Predictions
y_train_pred = model.predict_proba(X_train)[:, 1]
y_test_pred = model.predict_proba(X_test)[:, 1]
# Evaluate
train_auc = roc_auc_score(y_train, y_train_pred)
test_auc = roc_auc_score(y_test, y_test_pred)
print("\nModel Performance:")
print(f"Train AUC: {train_auc:.4f}")
print(f"Test AUC: {test_auc:.4f}")
print("\nClassification Report (Test Set):")
print(classification_report(y_test, model.predict(X_test), target_names=["No Mortality", "Mortality"]))
Training set: 27325 samples, 257 features Test set: 6850 samples, 257 features Class distribution (train): [25825 1500] Class distribution (test): [6600 250]
Model Performance:
Train AUC: 0.9463
Test AUC: 0.5448
Classification Report (Test Set):
precision recall f1-score support
No Mortality 0.97 0.81 0.88 6600
Mortality 0.07 0.36 0.11 250
accuracy 0.79 6850
macro avg 0.52 0.58 0.50 6850
weighted avg 0.94 0.79 0.85 6850