From 7e5c81c0f1aaab54c2bb508655859613525baae4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Mon, 6 Nov 2023 07:31:41 +0100 Subject: [PATCH 1/6] Add silverman kernel to survlime --- R/predict_parts.R | 3 ++- R/surv_lime.R | 16 ++++++++++++-- man/predict_parts.surv_explainer.Rd | 3 ++- man/surv_lime.Rd | 3 ++- tests/testthat/test-predict_parts.R | 34 +++++++++++++++++++++++++++++ 5 files changed, 54 insertions(+), 5 deletions(-) diff --git a/R/predict_parts.R b/R/predict_parts.R index 10c84386..4745e714 100644 --- a/R/predict_parts.R +++ b/R/predict_parts.R @@ -20,7 +20,7 @@ #' * for `survlime` #' * `N` - a positive integer, number of observations generated in the neighbourhood #' * `distance_metric` - character, name of the distance metric to be used, only `"euclidean"` is implemented -#' * `kernel_width` - a numeric, parameter used for calculating weights, by default it's `sqrt(ncol(data)*0.75)` +#' * `kernel_width` - a numeric or `"silverman"`, parameter used for calculating weights, by default it's `sqrt(ncol(data)\*0.75)`. If `"silverman"` the kernel width is calculated using the method proposed by Silverman and used in the Python implementation \[3\]. #' * `sampling_method` - character, name of the method of generating neighbourhood, only `"gaussian"` is implemented #' * `sample_around_instance` - logical, if the neighbourhood should be generated with the new observation as the center (default), or should the mean of the whole dataset be used as the center #' * `max_iter` - a numeric, maximal number of iteration for the optimization problem @@ -34,6 +34,7 @@ #' @section References: #' - \[1\] Krzyziński, Mateusz, et al. ["SurvSHAP(t): Time-dependent explanations of machine learning survival models."](https://www.sciencedirect.com/science/article/pii/S0950705122013302) Knowledge-Based Systems 262 (2023): 110234 #' - \[2\] Kovalev, Maxim S., et al. ["SurvLIME: A method for explaining machine learning survival models."](https://www.sciencedirect.com/science/article/pii/S0950705120304044?casa_token=6e9cyk_ji3AAAAAA:tbqo33MsZvNC9nrSGabZdLfPtZTsvsvZTHYQCM2aEhumLI5D46U7ovhr37EaYUhmKZrw45JzDhg) Knowledge-Based Systems 203 (2020): 106164. +#' - \[3\] Pachón-García, Cristian, et al. ["SurvLIMEpy: A Python package implementing SurvLIME."](https://www.sciencedirect.com/science/article/pii/S095741742302122X) Expert Systems with Applications 237 (2024): 121620. #' #' @examples #' \donttest{ diff --git a/R/surv_lime.R b/R/surv_lime.R index 3cb7fb9d..6b0d67ee 100644 --- a/R/surv_lime.R +++ b/R/surv_lime.R @@ -6,7 +6,7 @@ #' @param ... additional parameters, passed to internal functions #' @param N a positive integer, number of observations generated in the neighbourhood #' @param distance_metric character, name of the distance metric to be used, only `"euclidean"` is implemented -#' @param kernel_width a numeric, parameter used for calculating weights, by default it's `sqrt(ncol(data)*0.75)` +#' @param kernel_width a numeric or `"silverman"`, parameter used for calculating weights, by default it's `sqrt(ncol(data)*0.75)`. If `"silverman"` the kernel width is calculated using the method proposed by Silverman and used in the Python implementation \[2\]. #' @param sampling_method character, name of the method of generating neighbourhood, only `"gaussian"` is implemented #' @param sample_around_instance logical, if the neighbourhood should be generated with the new observation as the center (default), or should the mean of the whole dataset be used as the center #' @param max_iter a numeric, maximal number of iteration for the optimization problem @@ -17,7 +17,7 @@ #' #' @section References: #' - \[1\] Kovalev, Maxim S., et al. ["SurvLIME: A method for explaining machine learning survival models."](https://www.sciencedirect.com/science/article/pii/S0950705120304044?casa_token=6e9cyk_ji3AAAAAA:tbqo33MsZvNC9nrSGabZdLfPtZTsvsvZTHYQCM2aEhumLI5D46U7ovhr37EaYUhmKZrw45JzDhg) Knowledge-Based Systems 203 (2020): 106164. -#' +#' - \[2\] Pachón-García, Cristian, et al. ["SurvLIMEpy: A Python package implementing SurvLIME."](https://www.sciencedirect.com/science/article/pii/S095741742302122X) Expert Systems with Applications 237 (2024): 121620. #' @keywords internal #' @importFrom stats optim surv_lime <- function(explainer, new_observation, @@ -60,6 +60,18 @@ surv_lime <- function(explainer, new_observation, if (is.null(kernel_width)) kernel_width <- sqrt(ncol(scaled_data)) * 0.75 + if (is.character(kernel_width)) { + if (kernel_width == "silverman") { + n <- nrow(scaled_data) + p <- ncol(scaled_data) + + kernel_width <- (4/(n*(p+2)))^(1/(p+4)) + + } else { + stop("`kernel_width` must be either NULL, numeric or \"silverman\"") + } + } + weights <- sqrt(exp(-(distances^2) / (kernel_width^2))) na_est <- survival::basehaz(survival::coxph(explainer$y ~ 1)) diff --git a/man/predict_parts.surv_explainer.Rd b/man/predict_parts.surv_explainer.Rd index c79249e5..c7c9a54c 100644 --- a/man/predict_parts.surv_explainer.Rd +++ b/man/predict_parts.surv_explainer.Rd @@ -47,7 +47,7 @@ There are additional parameters that are passed to internal functions \itemize{ \item \code{N} - a positive integer, number of observations generated in the neighbourhood \item \code{distance_metric} - character, name of the distance metric to be used, only \code{"euclidean"} is implemented -\item \code{kernel_width} - a numeric, parameter used for calculating weights, by default it's \code{sqrt(ncol(data)*0.75)} +\item \code{kernel_width} - a numeric or \code{"silverman"}, parameter used for calculating weights, by default it's \verb{sqrt(ncol(data)\\*0.75)}. If \code{"silverman"} the kernel width is calculated using the method proposed by Silverman and used in the Python implementation [3]. \item \code{sampling_method} - character, name of the method of generating neighbourhood, only \code{"gaussian"} is implemented \item \code{sample_around_instance} - logical, if the neighbourhood should be generated with the new observation as the center (default), or should the mean of the whole dataset be used as the center \item \code{max_iter} - a numeric, maximal number of iteration for the optimization problem @@ -68,6 +68,7 @@ There are additional parameters that are passed to internal functions \itemize{ \item [1] Krzyziński, Mateusz, et al. \href{https://www.sciencedirect.com/science/article/pii/S0950705122013302}{"SurvSHAP(t): Time-dependent explanations of machine learning survival models."} Knowledge-Based Systems 262 (2023): 110234 \item [2] Kovalev, Maxim S., et al. \href{https://www.sciencedirect.com/science/article/pii/S0950705120304044?casa_token=6e9cyk_ji3AAAAAA:tbqo33MsZvNC9nrSGabZdLfPtZTsvsvZTHYQCM2aEhumLI5D46U7ovhr37EaYUhmKZrw45JzDhg}{"SurvLIME: A method for explaining machine learning survival models."} Knowledge-Based Systems 203 (2020): 106164. +\item [3] Pachón-García, Cristian, et al. \href{https://www.sciencedirect.com/science/article/pii/S095741742302122X}{"SurvLIMEpy: A Python package implementing SurvLIME."} Expert Systems with Applications 237 (2024): 121620. } } diff --git a/man/surv_lime.Rd b/man/surv_lime.Rd index ec93a668..35b13442 100644 --- a/man/surv_lime.Rd +++ b/man/surv_lime.Rd @@ -29,7 +29,7 @@ surv_lime( \item{distance_metric}{character, name of the distance metric to be used, only \code{"euclidean"} is implemented} -\item{kernel_width}{a numeric, parameter used for calculating weights, by default it's \code{sqrt(ncol(data)*0.75)}} +\item{kernel_width}{a numeric or \code{"silverman"}, parameter used for calculating weights, by default it's \code{sqrt(ncol(data)*0.75)}. If \code{"silverman"} the kernel width is calculated using the method proposed by Silverman and used in the Python implementation [2].} \item{sampling_method}{character, name of the method of generating neighbourhood, only \code{"gaussian"} is implemented} @@ -51,6 +51,7 @@ Helper functions for \code{predict_parts.R} \itemize{ \item [1] Kovalev, Maxim S., et al. \href{https://www.sciencedirect.com/science/article/pii/S0950705120304044?casa_token=6e9cyk_ji3AAAAAA:tbqo33MsZvNC9nrSGabZdLfPtZTsvsvZTHYQCM2aEhumLI5D46U7ovhr37EaYUhmKZrw45JzDhg}{"SurvLIME: A method for explaining machine learning survival models."} Knowledge-Based Systems 203 (2020): 106164. +\item [2] Pachón-García, Cristian, et al. \href{https://www.sciencedirect.com/science/article/pii/S095741742302122X}{"SurvLIMEpy: A Python package implementing SurvLIME."} Expert Systems with Applications 237 (2024): 121620. } } diff --git a/tests/testthat/test-predict_parts.R b/tests/testthat/test-predict_parts.R index 6e8c733f..c7928543 100644 --- a/tests/testthat/test-predict_parts.R +++ b/tests/testthat/test-predict_parts.R @@ -182,6 +182,40 @@ test_that("survlime explanations work", { }) + +test_that("survlime silverman kernel works", { + + veteran <- survival::veteran + + cph <- survival::coxph(survival::Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) + rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) + rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) + + cph_exp <- explain(cph, verbose = FALSE) + + cph_survlime <- predict_parts(cph_exp, new_observation = veteran[1, -c(3, 4)], type = "survlime", kernel_width = "silverman") + + expect_error(predict_parts(cph_exp, new_observation = veteran[1, -c(3, 4)], type = "survlime", kernel_width = "nonexistent")) + # error on to few columns + expect_error(predict_parts(rsf_src_exp, new_observation = veteran[1, -c(1, 2 ,3, 4)], type = "survlime")) + + plot(cph_survlime, type = "coefficients") + plot(cph_survlime, type = "local_importance") + plot(cph_survlime, show_survival_function = FALSE) + + expect_error(plot(cph_survlime, type = "nonexistent")) + + expect_s3_class(cph_survlime, c("predict_parts_survival", "surv_lime")) + + expect_gte(length(cph_survlime$result), ncol(cph_exp$data)) + + expect_setequal(cph_survlime$black_box_sf_times, cph_exp$times) + + expect_output(print(cph_survlime)) + +}) + + test_that("default DALEX::predict_parts is ok", { veteran <- survival::veteran From a6292ac34f53a0278163027ffaafecd5858ffac9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Mon, 6 Nov 2023 07:34:08 +0100 Subject: [PATCH 2/6] Bump version, update news --- DESCRIPTION | 2 +- NEWS.md | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index a958d5fc..35d0eed6 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: survex Title: Explainable Machine Learning in Survival Analysis -Version: 1.2.0 +Version: 1.2.0.9001 Authors@R: c( person("Mikołaj", "Spytek", email = "mikolajspytek@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-7111-2286")), diff --git a/NEWS.md b/NEWS.md index 9329a134..627ea39e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,6 @@ +# survex (development) +* added a new estimation method for the `kernel_width` parameter for the SurvLIME method. + # survex 1.2.0 * added new `calculation_method` for `surv_shap()` called `"treeshap"` that uses the `treeshap` package ([#75](https://github.com/ModelOriented/survex/issues/75)) * enable to calculate SurvSHAP(t) explanations based on subsample of the explainer's data From 6682a31b531444d81a8300ac8540f60817d94313 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 6 Nov 2023 11:07:37 +0100 Subject: [PATCH 3/6] use cov matrix with kernel width for sampling --- R/surv_lime.R | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/R/surv_lime.R b/R/surv_lime.R index 6b0d67ee..31b2cc3c 100644 --- a/R/surv_lime.R +++ b/R/surv_lime.R @@ -43,7 +43,8 @@ surv_lime <- function(explainer, new_observation, N, categorical_variables, sampling_method, - sample_around_instance + sample_around_instance, + kernel_width ) @@ -122,7 +123,8 @@ generate_neighbourhood <- function(data_org, n_samples = 100, categorical_variables = NULL, sampling_method = "gaussian", - sample_around_instance = TRUE) { + sample_around_instance = TRUE, + kernel_width = NULL) { # change categorical_variables to column names if (is.numeric(categorical_variables)) categorical_variables <- colnames(data_org)[categorical_variables] @@ -131,6 +133,13 @@ generate_neighbourhood <- function(data_org, categorical_variables <- unique(c(additional_categorical_variables, factor_variables)) data_row <- data_row[colnames(data_org)] + if (is.character(kernel_width) && kernel_width == "silverman"){ + p <- ncol(data_org) + b <- (4/(n_samples*(p+2)))^(1/(p+4)) + } else { + b <- 1 + } + feature_frequencies <- list(length(categorical_variables)) scaled_data <- scale(data_org[, !colnames(data_org) %in% categorical_variables]) @@ -154,9 +163,9 @@ generate_neighbourhood <- function(data_org, if (sample_around_instance) { to_add <- data_row[, !colnames(data_row) %in% categorical_variables] - data <- data %*% diag(sc) + to_add[col(data)] + data <- data %*% (b * diag(sc)) + to_add[col(data)] } else { - data <- data %*% diag(sc) + me[col(data)] + data <- data %*% (b * diag(sc)) + me[col(data)] } data <- as.data.frame(data) From ffc33d6482088c349dbb0a8f66e21218bf0bc709 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 6 Nov 2023 11:09:11 +0100 Subject: [PATCH 4/6] fix typo --- R/predict_parts.R | 2 +- man/predict_parts.surv_explainer.Rd | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/predict_parts.R b/R/predict_parts.R index 4745e714..16bfa2f7 100644 --- a/R/predict_parts.R +++ b/R/predict_parts.R @@ -20,7 +20,7 @@ #' * for `survlime` #' * `N` - a positive integer, number of observations generated in the neighbourhood #' * `distance_metric` - character, name of the distance metric to be used, only `"euclidean"` is implemented -#' * `kernel_width` - a numeric or `"silverman"`, parameter used for calculating weights, by default it's `sqrt(ncol(data)\*0.75)`. If `"silverman"` the kernel width is calculated using the method proposed by Silverman and used in the Python implementation \[3\]. +#' * `kernel_width` - a numeric or `"silverman"`, parameter used for calculating weights, by default it's `sqrt(ncol(data)*0.75)`. If `"silverman"` the kernel width is calculated using the method proposed by Silverman and used in the Python implementation \[3\]. #' * `sampling_method` - character, name of the method of generating neighbourhood, only `"gaussian"` is implemented #' * `sample_around_instance` - logical, if the neighbourhood should be generated with the new observation as the center (default), or should the mean of the whole dataset be used as the center #' * `max_iter` - a numeric, maximal number of iteration for the optimization problem diff --git a/man/predict_parts.surv_explainer.Rd b/man/predict_parts.surv_explainer.Rd index c7c9a54c..80809412 100644 --- a/man/predict_parts.surv_explainer.Rd +++ b/man/predict_parts.surv_explainer.Rd @@ -47,7 +47,7 @@ There are additional parameters that are passed to internal functions \itemize{ \item \code{N} - a positive integer, number of observations generated in the neighbourhood \item \code{distance_metric} - character, name of the distance metric to be used, only \code{"euclidean"} is implemented -\item \code{kernel_width} - a numeric or \code{"silverman"}, parameter used for calculating weights, by default it's \verb{sqrt(ncol(data)\\*0.75)}. If \code{"silverman"} the kernel width is calculated using the method proposed by Silverman and used in the Python implementation [3]. +\item \code{kernel_width} - a numeric or \code{"silverman"}, parameter used for calculating weights, by default it's \code{sqrt(ncol(data)*0.75)}. If \code{"silverman"} the kernel width is calculated using the method proposed by Silverman and used in the Python implementation [3]. \item \code{sampling_method} - character, name of the method of generating neighbourhood, only \code{"gaussian"} is implemented \item \code{sample_around_instance} - logical, if the neighbourhood should be generated with the new observation as the center (default), or should the mean of the whole dataset be used as the center \item \code{max_iter} - a numeric, maximal number of iteration for the optimization problem From c820fe961ec9c6cdb335918a1dddefb6eb20836c Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 6 Nov 2023 11:23:42 +0100 Subject: [PATCH 5/6] change kernel_width description --- R/predict_parts.R | 2 +- R/surv_lime.R | 2 +- man/predict_parts.surv_explainer.Rd | 2 +- man/surv_lime.Rd | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/R/predict_parts.R b/R/predict_parts.R index 16bfa2f7..d8ce0cc4 100644 --- a/R/predict_parts.R +++ b/R/predict_parts.R @@ -20,7 +20,7 @@ #' * for `survlime` #' * `N` - a positive integer, number of observations generated in the neighbourhood #' * `distance_metric` - character, name of the distance metric to be used, only `"euclidean"` is implemented -#' * `kernel_width` - a numeric or `"silverman"`, parameter used for calculating weights, by default it's `sqrt(ncol(data)*0.75)`. If `"silverman"` the kernel width is calculated using the method proposed by Silverman and used in the Python implementation \[3\]. +#' * `kernel_width` - a numeric or `"silverman"`, parameter used for calculating weights, by default it's `sqrt(ncol(data)*0.75)`. If `"silverman"` the kernel width is calculated using the method proposed by Silverman and used in the SurvLIMEpy Python package. #' * `sampling_method` - character, name of the method of generating neighbourhood, only `"gaussian"` is implemented #' * `sample_around_instance` - logical, if the neighbourhood should be generated with the new observation as the center (default), or should the mean of the whole dataset be used as the center #' * `max_iter` - a numeric, maximal number of iteration for the optimization problem diff --git a/R/surv_lime.R b/R/surv_lime.R index 31b2cc3c..b9e05e30 100644 --- a/R/surv_lime.R +++ b/R/surv_lime.R @@ -6,7 +6,7 @@ #' @param ... additional parameters, passed to internal functions #' @param N a positive integer, number of observations generated in the neighbourhood #' @param distance_metric character, name of the distance metric to be used, only `"euclidean"` is implemented -#' @param kernel_width a numeric or `"silverman"`, parameter used for calculating weights, by default it's `sqrt(ncol(data)*0.75)`. If `"silverman"` the kernel width is calculated using the method proposed by Silverman and used in the Python implementation \[2\]. +#' @param kernel_width a numeric or `"silverman"`, parameter used for calculating weights, by default it's `sqrt(ncol(data)*0.75)`. If `"silverman"` the kernel width is calculated using the method proposed by Silverman and used in the SurvLIMEpy Python package. #' @param sampling_method character, name of the method of generating neighbourhood, only `"gaussian"` is implemented #' @param sample_around_instance logical, if the neighbourhood should be generated with the new observation as the center (default), or should the mean of the whole dataset be used as the center #' @param max_iter a numeric, maximal number of iteration for the optimization problem diff --git a/man/predict_parts.surv_explainer.Rd b/man/predict_parts.surv_explainer.Rd index 80809412..58071d79 100644 --- a/man/predict_parts.surv_explainer.Rd +++ b/man/predict_parts.surv_explainer.Rd @@ -47,7 +47,7 @@ There are additional parameters that are passed to internal functions \itemize{ \item \code{N} - a positive integer, number of observations generated in the neighbourhood \item \code{distance_metric} - character, name of the distance metric to be used, only \code{"euclidean"} is implemented -\item \code{kernel_width} - a numeric or \code{"silverman"}, parameter used for calculating weights, by default it's \code{sqrt(ncol(data)*0.75)}. If \code{"silverman"} the kernel width is calculated using the method proposed by Silverman and used in the Python implementation [3]. +\item \code{kernel_width} - a numeric or \code{"silverman"}, parameter used for calculating weights, by default it's \code{sqrt(ncol(data)*0.75)}. If \code{"silverman"} the kernel width is calculated using the method proposed by Silverman and used in the SurvLIMEpy Python package. \item \code{sampling_method} - character, name of the method of generating neighbourhood, only \code{"gaussian"} is implemented \item \code{sample_around_instance} - logical, if the neighbourhood should be generated with the new observation as the center (default), or should the mean of the whole dataset be used as the center \item \code{max_iter} - a numeric, maximal number of iteration for the optimization problem diff --git a/man/surv_lime.Rd b/man/surv_lime.Rd index 35b13442..e64d1339 100644 --- a/man/surv_lime.Rd +++ b/man/surv_lime.Rd @@ -29,7 +29,7 @@ surv_lime( \item{distance_metric}{character, name of the distance metric to be used, only \code{"euclidean"} is implemented} -\item{kernel_width}{a numeric or \code{"silverman"}, parameter used for calculating weights, by default it's \code{sqrt(ncol(data)*0.75)}. If \code{"silverman"} the kernel width is calculated using the method proposed by Silverman and used in the Python implementation [2].} +\item{kernel_width}{a numeric or \code{"silverman"}, parameter used for calculating weights, by default it's \code{sqrt(ncol(data)*0.75)}. If \code{"silverman"} the kernel width is calculated using the method proposed by Silverman and used in the SurvLIMEpy Python package.} \item{sampling_method}{character, name of the method of generating neighbourhood, only \code{"gaussian"} is implemented} From 2b826856e13df7d35923ec1220b66e1a04d623c4 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 6 Nov 2023 11:24:25 +0100 Subject: [PATCH 6/6] `verbose=FALSE` for treeshap tests --- tests/testthat/test-model_survshap.R | 3 ++- tests/testthat/test-predict_parts.R | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test-model_survshap.R b/tests/testthat/test-model_survshap.R index 4ee3be45..42b61b5d 100644 --- a/tests/testthat/test-model_survshap.R +++ b/tests/testthat/test-model_survshap.R @@ -82,7 +82,8 @@ test_that("global survshap explanations with treeshap work for ranger", { new_observation = new_obs, y_true = survival::Surv(veteran$time[1:40], veteran$status[1:40]), aggregation_method = "mean_absolute", - calculation_method = "treeshap" + calculation_method = "treeshap", + verbose = FALSE ) plot(ranger_global_survshap_tree) diff --git a/tests/testthat/test-predict_parts.R b/tests/testthat/test-predict_parts.R index c7928543..6c02780a 100644 --- a/tests/testthat/test-predict_parts.R +++ b/tests/testthat/test-predict_parts.R @@ -75,7 +75,8 @@ test_that("local survshap explanations with treeshap work for ranger", { new_obs, y_true = c(veteran$time[2], veteran$status[2]), aggregation_method = "mean_absolute", - calculation_method = "treeshap" + calculation_method = "treeshap", + verbose = FALSE ) plot(parts_ranger)