Skip to content

Commit

Permalink
fix colors usage in plot.aggregated_surv_shap when geom="importance"
Browse files Browse the repository at this point in the history
  • Loading branch information
krzyzinskim committed Feb 23, 2024
1 parent e2a76ed commit 119dd8a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
21 changes: 10 additions & 11 deletions R/plot_surv_shap.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ plot.surv_shap <- function(x,
#' @param title character, title of the plot
#' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for the XXX model (n = YYY)", where XXX is the explainer label and YYY is the number of observations used for calculations
#' @param max_vars maximum number of variables to be plotted (least important variables are ignored), by default 7
#' @param colors character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")
#' @param colors character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue"). If `geom = "importance"`, the first color will be used for the barplot, the rest for the lines.
#'
#' @return An object of the class `ggplot`.
#'
Expand Down Expand Up @@ -229,12 +229,19 @@ plot_shap_global_importance <- function(x,
x$result <- aggregate_shap_multiple_observations(x$result, colnames(x$result[[1]]), function(x) mean(abs(x)))
x$aggregate <- apply(do.call(rbind, x$aggregate), 2, function(x) mean(abs(x)))

if (!is.null(colors)){
bar_color <- colors[1]
colors <- colors[-1]
} else {
bar_color <- "#46bac2"
}

right_plot <- plot.surv_shap(
x = x,
title = NULL,
subtitle = NULL,
max_vars = max_vars,
colors = NULL,
colors = colors,
rug = rug,
rug_colors = rug_colors
) +
Expand All @@ -255,17 +262,9 @@ plot_shap_global_importance <- function(x,
)
}

if (is.null(colors)) {
colors <- c(
low = "#9fe5bd",
mid = "#46bac2",
high = "#371ea3"
)
}

left_plot <- with(long_df, {
ggplot(long_df, aes(x = values, y = reorder(ind, values))) +
geom_col(fill = colors[2]) +
geom_col(fill = bar_color) +
theme_default_survex() +
labs(x = xlab_left, y = "variable") +
theme(axis.title.y = element_blank())
Expand Down
2 changes: 1 addition & 1 deletion man/plot.aggregated_surv_shap.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 119dd8a

Please sign in to comment.