scikit-survival 0.22.0 released
I am pleased to announce the release of scikit-survival 0.22.0. The highlights for this release include
- Compatibility with scikit-learn 1.3.
- Missing value support for SurvivalTree.
- A reduced memory mode for RandomSurvivalForest, ExtraSurvivalTrees, and SurvivalTree.
- Support for predict_cumulative_hazard_function() and predict_survival_function() in Stacking.
Missing Values Support in SurvivalTree
Based on the missing value support
in scikit-learn 1.3, a SurvivalTree
can now deal with missing values if it is fit with splitter='best'
.
If the training data contained no missing values, then during prediction missing values are mapped to the child node with the most samples:
X, y = load_veterans_lung_cancer()
X_train = np.asarray(X.loc[:, ["Karnofsky_score"]], dtype=np.float32)
est = SurvivalTree(max_depth=1)
est.fit(X_train, y)
X_test = np.array([[np.nan]])
surv_fn = est.predict_survival_function(X_test, return_array=True)
mask = X_train[:, 0] > est.tree_.threshold[0]
km_x, km_y = kaplan_meier_estimator(
y[mask]["Status"], y[mask]["Survival_in_days"]
)
plt.step(km_x, km_y, where="post", linewidth=5)
plt.step(
est.unique_times_, surv_fn[0], where="post", linewidth=3, linestyle="dotted"
)
plt.ylim(0, 1)
If a tree is fit to training data with missing values, the splitter will evaluate each split with all samples with missing values going to the left or the right child node.
X, y = load_veterans_lung_cancer()
X_train = np.asarray(X.loc[:, ["Age_in_years"]], dtype=np.float32)
X_train[-50:, :] = np.nan
est = SurvivalTree(max_depth=1)
est.fit(X_train, y)
X_test = np.array([[np.nan]])
surv_fn = est.predict_survival_function(X_test, return_array=True)
mask = X_train[:, 0] > est.tree_.threshold[0]
mask |= np.isnan(X_train[:, 0])
km_x, km_y = kaplan_meier_estimator(
y[mask]["Status"], y[mask]["Survival_in_days"]
)
plt.step(km_x, km_y, where="post", linewidth=5)
plt.step(
est.unique_times_, surv_fn[0], where="post", linewidth=3, linestyle="dotted"
)
plt.ylim(0, 1)
These rules are identical to those of scikit-learn’s missing value support.
Low-memory Mode for SurvivalTree and RandomSurvivalForest
The last release already saw
performance improvments to SurvivalTree and RandomSurvivalForest.
This release adds the low_memory
option to
RandomSurvivalForest, ExtraSurvivalTrees, and SurvivalTree.
If low-memory mode is disabled, which is the default, calling predict
on a sample will require memory
in the order of unique event times in the training data, because the cumulative hazard function
is computed as an intermediate value.
If low-memory mode is enabled, then the risk score is computed directly, without computing
the cumulative hazard function. However, low-memory mode disables using
predict_cumulative_hazard_function
and predict_survival_function
.
Install
Pre-built conda packages are available for Linux, macOS, and Windows, either
via pip:
pip install scikit-survival
or via conda
conda install -c sebp scikit-survival