Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: support SurvSHAP computation with {treeshap} #85

Merged
merged 37 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a796a99
feat: adding support for treeshap calculation of survshap for ranger …
kapsner Apr 4, 2023
d500fd6
fix: added contraint that new-observation for predict surv_shap has e…
kapsner Apr 5, 2023
44f56db
chore: updated description
kapsner Apr 5, 2023
bb47009
chore: merged multirow-support into main
kapsner Apr 5, 2023
b25c468
chore: comment for clarifying implementation of treeshap
kapsner Apr 5, 2023
dc7d33c
chore: moved treeshap and kernelshap to suggests and added respective…
kapsner Apr 6, 2023
b8135ee
fix: fixed issues when providing matrix to kernelshap
kapsner Apr 6, 2023
de65c4b
chore: merged feat-global-survshap into feat-treeshap
kapsner Apr 6, 2023
ad886bb
feat: fully functional treeshap integration
kapsner Apr 6, 2023
d2c16a6
feat: code adaptions to treeshap computation using pre-difned surviva…
kapsner Apr 8, 2023
1329254
refactor: implemented quality checks and now explicitly control row o…
kapsner Apr 9, 2023
60e3b91
chore: rounding time-points for row names to a maximum of 2 digits
kapsner Apr 9, 2023
5b1a810
chore: more informative error messages in stopifnot statements
kapsner Apr 13, 2023
f88f794
chore: merge into main
kapsner Jul 24, 2023
6cfe2fb
chore: merged latest global-survshap devs into feat-branch
kapsner Jul 24, 2023
6a4c926
feat: adaptions of treeshap feature branch to new global survshap
kapsner Jul 25, 2023
bf5222d
chore: merged dev-global-survshap into treeshap feature branch
kapsner Aug 6, 2023
b08f4d5
chore: merged upstream changes into kapsner-fork
kapsner Aug 31, 2023
c251f3b
chore: merged latest upstream changes into treeshap branch
kapsner Aug 31, 2023
15c689c
chore: updated description, namespace and typos
kapsner Aug 31, 2023
a9269f0
chore: updated messages
kapsner Aug 31, 2023
0e8c85c
fix: added missing output_type arguments to switch of calculation met…
kapsner Aug 31, 2023
40a076e
fix: new_observation now as.matrix for kernelshap
kapsner Aug 31, 2023
a496a67
fix: removed data.frame conversion of explainer in kernelshap
kapsner Aug 31, 2023
23e57da
fix: another try to fix kernelshap data, now X and bg_X as data.frame
kapsner Aug 31, 2023
ec6d20e
Merge branch 'ModelOriented:main' into main
kapsner Sep 4, 2023
5013883
Merge pull request #2 from kapsner/main
kapsner Sep 4, 2023
2661ef8
Merge branch 'ModelOriented:main' into main
kapsner Sep 5, 2023
160053d
Merge pull request #3 from kapsner/main
kapsner Sep 5, 2023
5481675
Merge pull request #4 from ModelOriented/main
kapsner Sep 7, 2023
7752820
Merge branch 'feat-treeshap' into main
kapsner Sep 7, 2023
6b0e337
Merge pull request #5 from kapsner/main
kapsner Sep 7, 2023
468a2d3
Merge branch 'main' into feat-treeshap
krzyzinskim Oct 2, 2023
f61aad5
Update surv_shap.R description
krzyzinskim Oct 2, 2023
a3a8311
remove treeshap from Remotes
krzyzinskim Oct 2, 2023
885a27b
Update surv_shap.R
krzyzinskim Oct 2, 2023
1d275eb
Update test-predict_parts.R
krzyzinskim Oct 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Authors@R:
person("Mateusz", "Krzyziński", role = c("aut"), comment = c(ORCID = "0000-0001-6143-488X")),
person("Sophie", "Langbein", role = c("aut")),
person("Hubert", "Baniecki", role = c("aut"), comment = c(ORCID = "0000-0001-6661-5364")),
person("Lorenz A.", "Kapsner", role = c("ctb"), comment = c(ORCID = "0000-0003-1866-860X")),
person("Przemyslaw", "Biecek", role = c("aut"), comment = c(ORCID = "0000-0001-8423-1823"))
)
Description: Survival analysis models are commonly used in medicine and other areas. Many of them
Expand All @@ -25,6 +26,7 @@ Imports:
DALEX (>= 2.2.1),
ggplot2 (>= 3.4.0),
kernelshap,
treeshap,
pec,
survival,
patchwork
Expand Down
142 changes: 128 additions & 14 deletions R/surv_shap.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#' @param output_type a character, either `"survival"` or `"chf"`. Determines which type of prediction should be used for explanations.
#' @param ... additional parameters, passed to internal functions
#' @param y_true a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting
#' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements) or `"exact_kernel"` for exact Kernel SHAP estimation
#' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements), `"exact_kernel"` for exact Kernel SHAP estimation, or `"treeshap"` for use of `treeshap` library (efficient implementation to compute SHAP values for tree-based models).
#' @param aggregation_method a character, either `"integral"`, `"integral_absolute"`, `"mean_absolute"`, `"max_absolute"`, or `"sum_of_squares"`
#'
#' @return A list, containing the calculated SurvSHAP(t) results in the `result` field
Expand All @@ -19,10 +19,15 @@ surv_shap <- function(explainer,
output_type,
...,
y_true = NULL,
calculation_method = "kernelshap",
aggregation_method = "integral") {
calculation_method = c("kernelshap", "exact_kernel", "treeshap"),
aggregation_method = c("integral", "mean_absolute", "max_absolute", "sum_of_squares")
) {
calculation_method <- match.arg(calculation_method)
aggregation_method <- match.arg(aggregation_method)

# make this code work for multiple observations
stopifnot(
"`y_true` must be either a matrix with one per observation in `new_observation` or a vector of length == 2" = ifelse(
"`y_true` must be either a matrix with one row per observation in `new_observation` or a vector of length == 2" = ifelse(
!is.null(y_true),
ifelse(
is.matrix(y_true),
Expand All @@ -33,14 +38,40 @@ surv_shap <- function(explainer,
)
)

if (calculation_method == "kernelshap") {
if (!requireNamespace("kernelshap", quietly = TRUE)) {
stop(
paste0(
"Package \"kernelshap\" must be installed to use ",
"'calculation_method = \"kernelshap\"'."
),
call. = FALSE
)
}
}
if (calculation_method == "treeshap") {
if (!requireNamespace("treeshap", quietly = TRUE)) {
stop(
paste0(
"Package \"treeshap\" must be installed to use ",
"'calculation_method = \"treeshap\"'."
),
call. = FALSE
)
}
}

test_explainer(explainer, "surv_shap", has_data = TRUE, has_y = TRUE, has_survival = TRUE)

# make this code also work for 1-row matrix
col_index <- which(colnames(new_observation) %in% colnames(explainer$data))
if (is.matrix(new_observation) && nrow(new_observation) == 1) {
new_observation <- as.matrix(t(new_observation[, col_index]))
new_observation <- data.frame(as.matrix(t(new_observation[, col_index])))
} else {
new_observation <- new_observation[, col_index]
if (!inherits(new_observation, "data.frame")) {
new_observation <- data.frame(new_observation)
}
}

if (ncol(explainer$data) != ncol(new_observation)) {
Expand All @@ -59,14 +90,25 @@ surv_shap <- function(explainer,
}
}

if (calculation_method == "treeshap") {
if (!inherits(explainer$model, "ranger")) {
stop("Calculation method `treeshap` is currently only implemented for `ranger` survival models.")
}
}

res <- list()
res$eval_times <- explainer$times
# to display final object correctly, when is.matrix(new_observation) == TRUE
res$variable_values <- as.data.frame(new_observation)
res$result <- switch(calculation_method,
"exact_kernel" = use_exact_shap(explainer, new_observation, output_type, ...),
"kernelshap" = use_kernelshap(explainer, new_observation, output_type, ...),
stop("Only `exact_kernel` and `kernelshap` calculation methods are implemented")
"exact_kernel" = use_exact_shap(explainer, new_observation, output_type, ...),
"kernelshap" = use_kernelshap(explainer, new_observation, output_type, ...),
"treeshap" = use_treeshap(explainer, new_observation, ...),
stop("Only `exact_kernel`, `kernelshap` and `treeshap` calculation methods are implemented"))
# quality-check here
stopifnot(
"Number of rows of SurvSHAP table are not identical with length(eval_times)" =
nrow(res$result) == length(res$eval_times)
)

if (!is.null(y_true)) res$y_true <- c(y_true_time = y_true_time, y_true_ind = y_true_ind)
Expand All @@ -86,7 +128,7 @@ surv_shap <- function(explainer,
return(res)
}

use_exact_shap <- function(explainer, new_observation, output_type, observation_aggregation_method, ...) {
use_exact_shap <- function(explainer, new_observation, output_type, ...) {
shap_values <- sapply(
X = as.character(seq_len(nrow(new_observation))),
FUN = function(i) {
Expand Down Expand Up @@ -123,11 +165,8 @@ shap_kernel <- function(explainer, new_observation, output_type, ...) {
timestamps
)



shap_values <- as.data.frame(shap_values, row.names = colnames(explainer$data))
colnames(shap_values) <- paste("t=", timestamps, sep = "")

return(t(shap_values))
}

Expand Down Expand Up @@ -204,19 +243,31 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_
times = explainer$times
)
}
}

stopifnot(
"new_observation must be a data.frame" = inherits(
new_observation, "data.frame")
)

# get explainer data to be able to make class checks and transformations
explainer_data <- explainer$data
# ensure that classes of explainer$data and new_observation are equal
if (!inherits(explainer_data, "data.frame")) {
explainer_data <- data.frame(explainer_data)
}

shap_values <- sapply(
X = as.character(seq_len(nrow(new_observation))),
FUN = function(i) {
tmp_res <- kernelshap::kernelshap(
object = explainer$model,
X = new_observation[as.integer(i), ],
bg_X = explainer$data,
X = new_observation[as.integer(i), ], # data.frame
bg_X = explainer_data, # data.frame
pred_fun = predfun,
verbose = FALSE
)
# kernelshap-test: is.matrix(X) == is.matrix(bg_X) should evaluate to `TRUE`
tmp_shap_values <- data.frame(t(sapply(tmp_res$S, cbind)))
colnames(tmp_shap_values) <- colnames(tmp_res$X)
rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "")
Expand All @@ -229,6 +280,69 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_
return(shap_values)
}

use_treeshap <- function(explainer, new_observation, ...){

stopifnot(
"new_observation must be a data.frame" = inherits(
new_observation, "data.frame")
)

# init unify_append_args
unify_append_args <- list()

if (inherits(explainer$model, "ranger")) {
# UNIFY_FUN to prepare code for easy Integration of other ml algorithms
# that are supported by treeshap
UNIFY_FUN <- treeshap::ranger_surv.unify
unify_append_args <- list(type = "survival", times = explainer$times)
} else {
stop("Support for `treeshap` is currently only implemented for `ranger`.")
}

unify_args <- list(
rf_model = explainer$model,
data = explainer$data
)

if (length(unify_append_args) > 0) {
unify_args <- c(unify_args, unify_append_args)
}

tmp_unified <- do.call(UNIFY_FUN, unify_args)

shap_values <- sapply(
X = as.character(seq_len(nrow(new_observation))),
FUN = function(i) {
tmp_res <- do.call(
rbind,
lapply(
tmp_unified,
function(m) {
new_obs_mat <- new_observation[as.integer(i), ]
# ensure that matrix has expected dimensions; as.integer is
# necessary for valid comparison with "identical"
stopifnot(identical(dim(new_obs_mat), as.integer(c(1L, ncol(new_observation)))))
treeshap::treeshap(
unified_model = m,
x = new_obs_mat
)$shaps
}
)
)

tmp_shap_values <- data.frame(tmp_res)
colnames(tmp_shap_values) <- colnames(tmp_res)
rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "")
tmp_shap_values
},
USE.NAMES = TRUE,
simplify = FALSE
)

return(shap_values)

}

#' @keywords internal
aggregate_shap_multiple_observations <- function(shap_res_list, feature_names, aggregation_function) {
if (length(shap_res_list) > 1) {
Expand Down
3 changes: 2 additions & 1 deletion man/model_survshap.surv_explainer.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions man/surv_shap.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions tests/testthat/test-model_survshap.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

# create objects here so that they do not have to be created redundantly
veteran <- survival::veteran
rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE)
Expand Down Expand Up @@ -68,3 +69,25 @@ test_that("global survshap explanations with kernelshap work for coxph, using ex
expect_equal(length(cph_global_survshap$eval_times), length(cph_exp$times))
expect_true(all(names(cph_global_survshap$variable_values) == colnames(cph_exp$data)))
})

# testing if matrix works as input
rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE)

test_that("global survshap explanations with treeshap work for ranger", {

new_obs <- model.matrix(~ -1 + ., veteran[1:40, setdiff(colnames(veteran), c("time", "status"))])
ranger_global_survshap_tree <- model_survshap(
rsf_ranger_exp_matrix,
new_observation = new_obs,
y_true = survival::Surv(veteran$time[1:40], veteran$status[1:40]),
aggregation_method = "mean_absolute",
calculation_method = "treeshap"
)
plot(ranger_global_survshap_tree)

expect_s3_class(ranger_global_survshap_tree, c("aggregated_surv_shap", "surv_shap"))
expect_equal(length(ranger_global_survshap_tree$eval_times), length(rsf_ranger_exp_matrix$times))
expect_true(all(names(ranger_global_survshap_tree$variable_values) == colnames(rsf_ranger_exp_matrix$data)))

})
42 changes: 41 additions & 1 deletion tests/testthat/test-predict_parts.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ test_that("survshap explanations work", {
rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran)

cph_exp <- explain(cph, verbose = FALSE)
rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE)
rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = survival::Surv(veteran$time, veteran$status), verbose = FALSE)
rsf_src_exp <- explain(rsf_src, verbose = FALSE)

parts_cph <- predict_parts(cph_exp, veteran[1, !colnames(veteran) %in% c("time", "status")], y_true = matrix(c(100, 1), ncol = 2), aggregation_method = "sum_of_squares")
Expand All @@ -19,6 +19,19 @@ test_that("survshap explanations work", {
parts_ranger <- predict_parts(rsf_ranger_exp, veteran[2, !colnames(veteran) %in% c("time", "status")], y_true = c(100, 1), aggregation_method = "mean_absolute")
plot(parts_ranger)

# test ranger with kernelshap when using a matrix as input for data and new observation
rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE)
new_obs <- model.matrix(~ -1 + ., veteran[2, !colnames(veteran) %in% c("time", "status")])
parts_ranger_kernelshap <- predict_parts(
rsf_ranger_exp_matrix,
new_observation = new_obs,
y_true = c(100, 1),
aggregation_method = "mean_absolute",
calculation_method = "kernelshap"
)
plot(parts_ranger_kernelshap)

parts_src <- predict_parts(rsf_src_exp, veteran[3, !colnames(veteran) %in% c("time", "status")])
plot(parts_src)

Expand Down Expand Up @@ -46,6 +59,29 @@ test_that("survshap explanations work", {
expect_error(predict_parts(cph_exp, veteran[1, ], calculation_method = "nonexistent"))
expect_error(predict_parts(cph_exp, veteran[1, c(1, 1, 1, 1, 1)], calculation_method = "nonexistent"))

})

test_that("local survshap explanations with treeshap work for ranger", {

veteran <- survival::veteran

rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE)


new_obs <- data.frame(model.matrix(~ -1 + ., veteran[2, setdiff(colnames(veteran), c("time", "status"))]))
parts_ranger <- model_survshap(
rsf_ranger_exp_matrix,
new_obs,
y_true = c(veteran$time[2], veteran$status[2]),
aggregation_method = "mean_absolute",
calculation_method = "treeshap"
)
plot(parts_ranger)

expect_s3_class(parts_ranger, c("predict_parts_survival", "surv_shap"))
expect_equal(nrow(parts_ranger$result), length(rsf_ranger_exp_matrix$times))
expect_true(all(colnames(parts_ranger$result) == colnames(rsf_ranger_exp_matrix$data)))

})

Expand All @@ -67,6 +103,10 @@ test_that("survshap explanations with output_type = 'chf' work", {
plot(parts_cph, rug = "censors")
plot(parts_cph, rug = "none")

# test global exact
parts_cph_glob <- predict_parts(cph_exp, veteran[1:3, !colnames(veteran) %in% c("time", "status")], y_true = as.matrix(veteran[1:3, c("time", "status")]), calculation_method = "exact_kernel", aggregation_method = "max_absolute", output_type = "chf")
plot(parts_cph_glob)

parts_ranger <- predict_parts(rsf_ranger_exp, veteran[2, !colnames(veteran) %in% c("time", "status")], y_true = c(100, 1), aggregation_method = "mean_absolute", output_type = "chf")
plot(parts_ranger)

Expand Down