Skip to content

Commit

Permalink
Merge pull request #89 from ModelOriented/survlime-kernel
Browse files Browse the repository at this point in the history
Add Silverman kernel width estimation to the  `surv_lime()` function
  • Loading branch information
krzyzinskim authored Nov 6, 2023
2 parents 20bb14e + 2b82685 commit 49a378a
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 12 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: survex
Title: Explainable Machine Learning in Survival Analysis
Version: 1.2.0
Version: 1.2.0.9001
Authors@R:
c(
person("Mikołaj", "Spytek", email = "[email protected]", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-7111-2286")),
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# survex (development)
* added a new estimation method for the `kernel_width` parameter for the SurvLIME method.

# survex 1.2.0
* added new `calculation_method` for `surv_shap()` called `"treeshap"` that uses the `treeshap` package ([#75](https://github.com/ModelOriented/survex/issues/75))
* enable to calculate SurvSHAP(t) explanations based on subsample of the explainer's data
Expand Down
3 changes: 2 additions & 1 deletion R/predict_parts.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#' * for `survlime`
#' * `N` - a positive integer, number of observations generated in the neighbourhood
#' * `distance_metric` - character, name of the distance metric to be used, only `"euclidean"` is implemented
#' * `kernel_width` - a numeric, parameter used for calculating weights, by default it's `sqrt(ncol(data)*0.75)`
#' * `kernel_width` - a numeric or `"silverman"`, parameter used for calculating weights, by default it's `sqrt(ncol(data)*0.75)`. If `"silverman"` the kernel width is calculated using the method proposed by Silverman and used in the SurvLIMEpy Python package.
#' * `sampling_method` - character, name of the method of generating neighbourhood, only `"gaussian"` is implemented
#' * `sample_around_instance` - logical, if the neighbourhood should be generated with the new observation as the center (default), or should the mean of the whole dataset be used as the center
#' * `max_iter` - a numeric, maximal number of iteration for the optimization problem
Expand All @@ -34,6 +34,7 @@
#' @section References:
#' - \[1\] Krzyziński, Mateusz, et al. ["SurvSHAP(t): Time-dependent explanations of machine learning survival models."](https://www.sciencedirect.com/science/article/pii/S0950705122013302) Knowledge-Based Systems 262 (2023): 110234
#' - \[2\] Kovalev, Maxim S., et al. ["SurvLIME: A method for explaining machine learning survival models."](https://www.sciencedirect.com/science/article/pii/S0950705120304044?casa_token=6e9cyk_ji3AAAAAA:tbqo33MsZvNC9nrSGabZdLfPtZTsvsvZTHYQCM2aEhumLI5D46U7ovhr37EaYUhmKZrw45JzDhg) Knowledge-Based Systems 203 (2020): 106164.
#' - \[3\] Pachón-García, Cristian, et al. ["SurvLIMEpy: A Python package implementing SurvLIME."](https://www.sciencedirect.com/science/article/pii/S095741742302122X) Expert Systems with Applications 237 (2024): 121620.
#'
#' @examples
#' \donttest{
Expand Down
33 changes: 27 additions & 6 deletions R/surv_lime.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#' @param ... additional parameters, passed to internal functions
#' @param N a positive integer, number of observations generated in the neighbourhood
#' @param distance_metric character, name of the distance metric to be used, only `"euclidean"` is implemented
#' @param kernel_width a numeric, parameter used for calculating weights, by default it's `sqrt(ncol(data)*0.75)`
#' @param kernel_width a numeric or `"silverman"`, parameter used for calculating weights, by default it's `sqrt(ncol(data)*0.75)`. If `"silverman"` the kernel width is calculated using the method proposed by Silverman and used in the SurvLIMEpy Python package.
#' @param sampling_method character, name of the method of generating neighbourhood, only `"gaussian"` is implemented
#' @param sample_around_instance logical, if the neighbourhood should be generated with the new observation as the center (default), or should the mean of the whole dataset be used as the center
#' @param max_iter a numeric, maximal number of iteration for the optimization problem
Expand All @@ -17,7 +17,7 @@
#'
#' @section References:
#' - \[1\] Kovalev, Maxim S., et al. ["SurvLIME: A method for explaining machine learning survival models."](https://www.sciencedirect.com/science/article/pii/S0950705120304044?casa_token=6e9cyk_ji3AAAAAA:tbqo33MsZvNC9nrSGabZdLfPtZTsvsvZTHYQCM2aEhumLI5D46U7ovhr37EaYUhmKZrw45JzDhg) Knowledge-Based Systems 203 (2020): 106164.
#'
#' - \[2\] Pachón-García, Cristian, et al. ["SurvLIMEpy: A Python package implementing SurvLIME."](https://www.sciencedirect.com/science/article/pii/S095741742302122X) Expert Systems with Applications 237 (2024): 121620.
#' @keywords internal
#' @importFrom stats optim
surv_lime <- function(explainer, new_observation,
Expand All @@ -43,7 +43,8 @@ surv_lime <- function(explainer, new_observation,
N,
categorical_variables,
sampling_method,
sample_around_instance
sample_around_instance,
kernel_width
)


Expand All @@ -60,6 +61,18 @@ surv_lime <- function(explainer, new_observation,

if (is.null(kernel_width)) kernel_width <- sqrt(ncol(scaled_data)) * 0.75

if (is.character(kernel_width)) {
if (kernel_width == "silverman") {
n <- nrow(scaled_data)
p <- ncol(scaled_data)

kernel_width <- (4/(n*(p+2)))^(1/(p+4))

} else {
stop("`kernel_width` must be either NULL, numeric or \"silverman\"")
}
}

weights <- sqrt(exp(-(distances^2) / (kernel_width^2)))
na_est <- survival::basehaz(survival::coxph(explainer$y ~ 1))

Expand Down Expand Up @@ -110,7 +123,8 @@ generate_neighbourhood <- function(data_org,
n_samples = 100,
categorical_variables = NULL,
sampling_method = "gaussian",
sample_around_instance = TRUE) {
sample_around_instance = TRUE,
kernel_width = NULL) {

# change categorical_variables to column names
if (is.numeric(categorical_variables)) categorical_variables <- colnames(data_org)[categorical_variables]
Expand All @@ -119,6 +133,13 @@ generate_neighbourhood <- function(data_org,
categorical_variables <- unique(c(additional_categorical_variables, factor_variables))
data_row <- data_row[colnames(data_org)]

if (is.character(kernel_width) && kernel_width == "silverman"){
p <- ncol(data_org)
b <- (4/(n_samples*(p+2)))^(1/(p+4))
} else {
b <- 1
}

feature_frequencies <- list(length(categorical_variables))
scaled_data <- scale(data_org[, !colnames(data_org) %in% categorical_variables])

Expand All @@ -142,9 +163,9 @@ generate_neighbourhood <- function(data_org,

if (sample_around_instance) {
to_add <- data_row[, !colnames(data_row) %in% categorical_variables]
data <- data %*% diag(sc) + to_add[col(data)]
data <- data %*% (b * diag(sc)) + to_add[col(data)]
} else {
data <- data %*% diag(sc) + me[col(data)]
data <- data %*% (b * diag(sc)) + me[col(data)]
}

data <- as.data.frame(data)
Expand Down
3 changes: 2 additions & 1 deletion man/predict_parts.surv_explainer.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/surv_lime.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion tests/testthat/test-model_survshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ test_that("global survshap explanations with treeshap work for ranger", {
new_observation = new_obs,
y_true = survival::Surv(veteran$time[1:40], veteran$status[1:40]),
aggregation_method = "mean_absolute",
calculation_method = "treeshap"
calculation_method = "treeshap",
verbose = FALSE
)
plot(ranger_global_survshap_tree)

Expand Down
37 changes: 36 additions & 1 deletion tests/testthat/test-predict_parts.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ test_that("local survshap explanations with treeshap work for ranger", {
new_obs,
y_true = c(veteran$time[2], veteran$status[2]),
aggregation_method = "mean_absolute",
calculation_method = "treeshap"
calculation_method = "treeshap",
verbose = FALSE
)
plot(parts_ranger)

Expand Down Expand Up @@ -182,6 +183,40 @@ test_that("survlime explanations work", {

})


test_that("survlime silverman kernel works", {

veteran <- survival::veteran

cph <- survival::coxph(survival::Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE)
rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran)

cph_exp <- explain(cph, verbose = FALSE)

cph_survlime <- predict_parts(cph_exp, new_observation = veteran[1, -c(3, 4)], type = "survlime", kernel_width = "silverman")

expect_error(predict_parts(cph_exp, new_observation = veteran[1, -c(3, 4)], type = "survlime", kernel_width = "nonexistent"))
# error on to few columns
expect_error(predict_parts(rsf_src_exp, new_observation = veteran[1, -c(1, 2 ,3, 4)], type = "survlime"))

plot(cph_survlime, type = "coefficients")
plot(cph_survlime, type = "local_importance")
plot(cph_survlime, show_survival_function = FALSE)

expect_error(plot(cph_survlime, type = "nonexistent"))

expect_s3_class(cph_survlime, c("predict_parts_survival", "surv_lime"))

expect_gte(length(cph_survlime$result), ncol(cph_exp$data))

expect_setequal(cph_survlime$black_box_sf_times, cph_exp$times)

expect_output(print(cph_survlime))

})


test_that("default DALEX::predict_parts is ok", {

veteran <- survival::veteran
Expand Down

0 comments on commit 49a378a

Please sign in to comment.