scikit-survival 0.21.0 released
Today marks the release of scikit-survival 0.21.0. This release features some exciting new features and significant performance improvements:
- Pointwise confidence intervals for the Kaplan-Meier estimator.
- Early stopping in GradientBoostingSurvivalAnalysis.
- Improved performance of fitting SurvivalTree and RandomSurvivalForest.
- Reduced memory footprint of concordance_index_censored.
Pointwise Confidence Intervals for the Kaplan-Meier Estimator
kaplan_meier_estimator()
can now estimate pointwise confidence intervals by specifying the conf_type
parameter.
import matplotlib.pyplot as plt
from sksurv.datasets import load_veterans_lung_cancer
from sksurv.nonparametric import kaplan_meier_estimator
_, y = load_veterans_lung_cancer()
time, survival_prob, conf_int = kaplan_meier_estimator(
y["Status"], y["Survival_in_days"], conf_type="log-log"
)
plt.step(time, survival_prob, where="post")
plt.fill_between(time, conf_int[0], conf_int[1], alpha=0.25, step="post")
plt.ylim(0, 1)
plt.ylabel("est. probability of survival $\hat{S}(t)$")
plt.xlabel("time $t$")
Early Stopping in GradientBoostingSurvivalAnalysis
Early stopping enables us to determine when the model is sufficiently complex.
This is usually done by continuously evaluating the model on held-out data.
For GradientBoostingSurvivalAnalysis,
the easiest way to achieve this is by setting n_iter_no_change
and
optionally validation_fraction
(defaults to 0.1).
from sksurv.datasets import load_whas500
from sksurv.ensemble import GradientBoostingSurvivalAnalysis
X, y = load_whas500()
model = GradientBoostingSurvivalAnalysis(
n_estimators=1000, max_depth=2, subsample=0.8, n_iter_no_change=10, random_state=0,
)
model.fit(X, y)
print(model.n_estimators_)
In this example, model.n_estimators_
indicates that fitting stopped after 73 iterations,
instead of the maximum 1000 iterations.
Alternatively, one can provide a custom callback function to the
fit
method. If the callback returns True
, training is stopped.
model = GradientBoostingSurvivalAnalysis(
n_estimators=1000, max_depth=2, subsample=0.8, random_state=0,
)
def early_stopping_monitor(iteration, model, args):
"""Stop training if there was no improvement in the last 10 iterations"""
start = max(0, iteration - 10)
end = iteration + 1
oob_improvement = model.oob_improvement_[start:end]
return all(oob_improvement < 0)
model.fit(X, y, monitor=early_stopping_monitor)
print(model.n_estimators_)
In the example above, early stopping is determined by checking
the last 10 entries of the oob_improvement_
attribute.
It contains the improvement in loss on the out-of-bag samples
relative to the previous iteration.
This requires setting subsample
to a value smaller 1, here 0.8.
Using this approach, training stopped after 114 iterations.
Improved Performance of SurvivalTree and RandomSurvivalForest
Another exciting feature of scikit-survival 0.21.0 is due to a re-write of the training routine of SurvivalTree. This results in roughly 3x faster training times.
The plot above compares the time required to fit a single SurvivalTree on data with 25 features and varying number of samples. The performance difference becomes notable for data with 1000 samples and above.
Note that this improvement also speeds-up fitting RandomSurvivalForest and ExtraSurvivalTrees.
Improved concordance index
Another performance improvement is due to Christine Poerschke who significantly reduced the memory footprint of concordance_index_censored(). With scikit-survival 0.21.0, memory usage scales linear, instead of quadratic, in the number of samples, making performance evaluation on large datasets much more manageable.
For a full list of changes in scikit-survival 0.21.0, please see the release notes.
Install
Pre-built conda packages are available for Linux, macOS (Intel), and Windows, either
via pip:
pip install scikit-survival
or via conda
conda install -c sebp scikit-survival