From 119dd8a2bb25503501464fe6d09e780913db5f6b Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 23 Feb 2024 17:19:11 +0100 Subject: [PATCH] fix colors usage in `plot.aggregated_surv_shap` when geom="importance" --- R/plot_surv_shap.R | 21 ++++++++++----------- man/plot.aggregated_surv_shap.Rd | 2 +- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/R/plot_surv_shap.R b/R/plot_surv_shap.R index 94b6df1..86728d4 100644 --- a/R/plot_surv_shap.R +++ b/R/plot_surv_shap.R @@ -97,7 +97,7 @@ plot.surv_shap <- function(x, #' @param title character, title of the plot #' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for the XXX model (n = YYY)", where XXX is the explainer label and YYY is the number of observations used for calculations #' @param max_vars maximum number of variables to be plotted (least important variables are ignored), by default 7 -#' @param colors character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue") +#' @param colors character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue"). If `geom = "importance"`, the first color will be used for the barplot, the rest for the lines. #' #' @return An object of the class `ggplot`. #' @@ -229,12 +229,19 @@ plot_shap_global_importance <- function(x, x$result <- aggregate_shap_multiple_observations(x$result, colnames(x$result[[1]]), function(x) mean(abs(x))) x$aggregate <- apply(do.call(rbind, x$aggregate), 2, function(x) mean(abs(x))) + if (!is.null(colors)){ + bar_color <- colors[1] + colors <- colors[-1] + } else { + bar_color <- "#46bac2" + } + right_plot <- plot.surv_shap( x = x, title = NULL, subtitle = NULL, max_vars = max_vars, - colors = NULL, + colors = colors, rug = rug, rug_colors = rug_colors ) + @@ -255,17 +262,9 @@ plot_shap_global_importance <- function(x, ) } - if (is.null(colors)) { - colors <- c( - low = "#9fe5bd", - mid = "#46bac2", - high = "#371ea3" - ) - } - left_plot <- with(long_df, { ggplot(long_df, aes(x = values, y = reorder(ind, values))) + - geom_col(fill = colors[2]) + + geom_col(fill = bar_color) + theme_default_survex() + labs(x = xlab_left, y = "variable") + theme(axis.title.y = element_blank()) diff --git a/man/plot.aggregated_surv_shap.Rd b/man/plot.aggregated_surv_shap.Rd index 6b3585e..9e981b9 100644 --- a/man/plot.aggregated_surv_shap.Rd +++ b/man/plot.aggregated_surv_shap.Rd @@ -27,7 +27,7 @@ \item{max_vars}{maximum number of variables to be plotted (least important variables are ignored), by default 7} -\item{colors}{character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")} +\item{colors}{character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue"). If \code{geom = "importance"}, the first color will be used for the barplot, the rest for the lines.} } \value{ An object of the class \code{ggplot}.