Skip to content

add fun_avg to ppc_avg functions #349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Title: Plotting for Bayesian Models
Version: 1.12.0.9000
Date: 2025-04-09
Authors@R: c(person("Jonah", "Gabry", role = c("aut", "cre"), email = "jsg2201@columbia.edu"),
person("Tristan", "Mahr", role = "aut"),
person("Tristan", "Mahr", role = "aut", comment = c(ORCID = "0000-0002-8890-5116")),
person("Paul-Christian", "Bürkner", role = "ctb"),
person("Martin", "Modrák", role = "ctb"),
person("Malcolm", "Barrett", role = "ctb"),
Expand All @@ -26,7 +26,7 @@ URL: https://mc-stan.org/bayesplot/
BugReports: https://github.com/stan-dev/bayesplot/issues/
SystemRequirements: pandoc (>= 1.12.3), pandoc-citeproc
Depends:
R (>= 3.1.0)
R (>= 4.1.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think at this point it's been out long enough that bumping the required R version is fine.

Imports:
dplyr (>= 0.8.0),
ggplot2 (>= 3.4.0),
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

* Add possibility for left-truncation to `ppc_km_overlay()` and `ppc_km_overlay_grouped()` by @Sakuski
* Added `ppc_loo_pit_ecdf()` by @TeemuSailynoja
* PPC "avg" functions (`ppc_scatter_avg()`, `ppc_error_scatter_avg()`, etc.) gain a `stat` argument to set the averaging function. (Suggestion of #348, @kruschke).
* `ppc_error_scatter_avg_vs_x(x = some_expression)` labels the *x* axis with `some_expression`.

# bayesplot 1.12.0

Expand Down
50 changes: 50 additions & 0 deletions R/bayesplot-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -469,3 +469,53 @@ grid_lines_y <- function(color = "gray50", size = 0.2) {
overlay_function <- function(...) {
stat_function(..., inherit.aes = FALSE)
}



# Resolve a function name and store the expression passed in by the user
#' @noRd
#' @param f a function-like thing: a string naming a function, a function
#' object, an anonymous function object, a formula-based lambda, and `NULL`.
#' @param fallback character string providing a fallback function name
#' @return the function named in `f` with an added `"tagged_expr"` attribute
#' containing the expression to represent the function name and an
#' `"is_anonymous_function"` attribute to flag if the expression is a call to
#' `function()`.
as_tagged_function <- function(f = NULL, fallback = "func") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function works as described, just one question: how come sometimes you call rlang functions with rlang::foo() and other times just foo()? I think either way works since we have @import rlang in bayesplot-package.R, just curious if your choices were intentional.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. Probably an artifact of me using typing rlang::[TAB] to use autocomplete to find the right rlang function.

qf <- enquo(f)
f <- eval_tidy(qf)
if (!is.null(attr(f, "tagged_expr"))) return(f)

f_expr <- quo_get_expr(qf)
f_fn <- f

if (rlang::is_character(f)) { # f = "mean"
# using sym() on the evaluated `f` means that a variable that names a
# function string `x <- "mean"; as_tagged_function(x)` will be lost
# but that seems okay
f_expr <- rlang::sym(f)
f_fn <- match.fun(f)
} else if (is_null(f)) { # f = NULL
f_fn <- identity
f_expr <- rlang::sym(fallback)
} else if (is_callable(f)) { # f = mean or f = function(x) mean(x)
f_expr <- f_expr # or f = ~mean(.x)
f_fn <- as_function(f)
}

# Setting attributes on primitive functions is deprecated, so wrap them
# and then tag
if (is_primitive(f_fn)) {
f_fn_old <- f_fn
f_factory <- function(f) { function(...) f(...) }
f_fn <- f_factory(f_fn_old)
}

attr(f_fn, "tagged_expr") <- f_expr
attr(f_fn, "is_anonymous_function") <- is_call(f_expr, name = "function") ||
is_formula(f_expr)
f_fn
}



3 changes: 1 addition & 2 deletions R/bayesplot-package.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#' **bayesplot**: Plotting for Bayesian Models
#'
#' @docType package
#' @name bayesplot-package
#' @aliases bayesplot
#'
Expand Down Expand Up @@ -96,7 +95,7 @@
#' ppd_hist(ypred[1:8, ])
#' }
#'
NULL
"_PACKAGE"


# internal ----------------------------------------------------------------
Expand Down
85 changes: 58 additions & 27 deletions R/ppc-errors.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
#' @template args-group
#' @template args-facet_args
#' @param ... Currently unused.
#' @param stat A function or a string naming a function for computing the
#' posterior average. In both cases, the function should take a vector input and
#' return a scalar statistic. The function name is displayed in the axis-label.
#' Defaults to `"mean"`.
#' @param size,alpha For scatterplots, arguments passed to
#' [ggplot2::geom_point()] to control the appearance of the points. For the
#' binned error plot, arguments controlling the size of the outline and
Expand Down Expand Up @@ -209,21 +213,26 @@ ppc_error_scatter_avg <-
function(y,
yrep,
...,
stat = "mean",
size = 2.5,
alpha = 0.8) {
check_ignored_arguments(...)

y <- validate_y(y)
yrep <- validate_predictions(yrep, length(y))
errors <- compute_errors(y, yrep)

stat <- as_tagged_function({{ stat }})

ppc_scatter_avg(
y = y,
yrep = errors,
size = size,
alpha = alpha,
ref_line = FALSE
ref_line = FALSE,
stat = stat
) +
labs(x = error_avg_label(), y = y_label())
labs(x = error_avg_label(stat), y = y_label())
}


Expand All @@ -234,13 +243,16 @@ ppc_error_scatter_avg_grouped <-
yrep,
group,
...,
stat = "mean",
facet_args = list(),
size = 2.5,
alpha = 0.8) {
check_ignored_arguments(...)

y <- validate_y(y)
yrep <- validate_predictions(yrep, length(y))
stat <- as_tagged_function({{ stat }})

errors <- compute_errors(y, yrep)
ppc_scatter_avg_grouped(
y = y,
Expand All @@ -249,9 +261,10 @@ ppc_error_scatter_avg_grouped <-
size = size,
alpha = alpha,
facet_args = facet_args,
ref_line = FALSE
ref_line = FALSE,
stat = stat
) +
labs(x = error_avg_label(), y = y_label())
labs(x = error_avg_label(stat), y = y_label())
}


Expand All @@ -260,29 +273,37 @@ ppc_error_scatter_avg_grouped <-
#' @param x A numeric vector the same length as `y` to use as the x-axis
#' variable.
#'
ppc_error_scatter_avg_vs_x <-
function(y,
yrep,
x,
...,
size = 2.5,
alpha = 0.8) {
check_ignored_arguments(...)
ppc_error_scatter_avg_vs_x <- function(
y,
yrep,
x,
...,
stat = "mean",
size = 2.5,
alpha = 0.8
) {
check_ignored_arguments(...)

y <- validate_y(y)
yrep <- validate_predictions(yrep, length(y))
x <- validate_x(x, y)
errors <- compute_errors(y, yrep)
ppc_scatter_avg(
y = x,
yrep = errors,
size = size,
alpha = alpha,
ref_line = FALSE
y <- validate_y(y)
yrep <- validate_predictions(yrep, length(y))
qx <- enquo(x)
x <- validate_x(x, y)
stat <- as_tagged_function({{ stat }})
errors <- compute_errors(y, yrep)
ppc_scatter_avg(
y = x,
yrep = errors,
size = size,
alpha = alpha,
ref_line = FALSE,
stat = stat
) +
labs(
x = error_avg_label(stat),
y = as_label((qx))
) +
labs(x = error_avg_label(), y = expression(italic(x))) +
coord_flip()
}
coord_flip()
}


#' @rdname PPC-errors
Expand Down Expand Up @@ -414,8 +435,18 @@ error_hist_facets <-
error_label <- function() {
expression(italic(y) - italic(y)[rep])
}
error_avg_label <- function() {
expression(paste("Average ", italic(y) - italic(y)[rep]))

error_avg_label <- function(stat = NULL) {
stat <- as_tagged_function({{ stat }}, fallback = "stat")
e <- attr(stat, "tagged_expr")
if (attr(stat, "is_anonymous_function")) {
e <- sym("stat")
}
de <- deparse1(e)
# dummy globals to pass R check for globals
italic <- sym("italic")
y <- sym("y")
expr(paste((!!de))*(italic(y) - italic(y)[rep]))
}


Expand Down
47 changes: 37 additions & 10 deletions R/ppc-scatterplots.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
#' @template args-group
#' @template args-facet_args
#' @param ... Currently unused.
#' @param stat A function or a string naming a function for computing the
#' posterior average. In both cases, the function should take a vector input
#' and return a scalar statistic. The function name is displayed in the
#' axis-label, and the underlying `$rep_label` for `ppc_scatter_avg_data()`
#' includes the function name. Defaults to `"mean"`.
#' @param size,alpha Arguments passed to [ggplot2::geom_point()] to control the
#' appearance of the points.
#' @param ref_line If `TRUE` (the default) a dashed line with intercept 0 and
Expand All @@ -31,10 +36,10 @@
#' }
#' \item{`ppc_scatter_avg()`}{
#' A single scatterplot of `y` against the average values of `yrep`, i.e.,
#' the points `(x,y) = (mean(yrep[, n]), y[n])`, where each `yrep[, n]` is
#' a vector of length equal to the number of posterior draws. Unlike
#' for `ppc_scatter()`, for `ppc_scatter_avg()` `yrep` should contain many
#' draws (rows).
#' the points `(x,y) = (average(yrep[, n]), y[n])`, where each `yrep[, n]` is
#' a vector of length equal to the number of posterior draws and `average()`
#' is summary statistic. Unlike for `ppc_scatter()`, for `ppc_scatter_avg()`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#' is summary statistic. Unlike for `ppc_scatter()`, for `ppc_scatter_avg()`
#' is a summary statistic. Unlike for `ppc_scatter()`, for `ppc_scatter_avg()`

#' `yrep` should contain many draws (rows).
#' }
#' \item{`ppc_scatter_avg_grouped()`}{
#' The same as `ppc_scatter_avg()`, but a separate plot is generated for
Expand All @@ -59,6 +64,9 @@
#' p1 + lims
#' p2 + lims
#'
#' # "average" function is customizable
#' ppc_scatter_avg(y, yrep, stat = "median", ref_line = FALSE)
#'
#' # for ppc_scatter_avg_grouped the default is to allow the facets
#' # to have different x and y axes
#' group <- example_group_data()
Expand Down Expand Up @@ -116,16 +124,19 @@ ppc_scatter_avg <-
function(y,
yrep,
...,
stat = "mean",
size = 2.5,
alpha = 0.8,
ref_line = TRUE) {
dots <- list(...)
stat <- as_tagged_function({{ stat }})

if (!from_grouped(dots)) {
check_ignored_arguments(...)
dots$group <- NULL
}

data <- ppc_scatter_avg_data(y, yrep, group = dots$group)
data <- ppc_scatter_avg_data(y, yrep, group = dots$group, stat = stat)
if (is.null(dots$group) && nrow(yrep) == 1) {
inform(
"With only 1 row in 'yrep' ppc_scatter_avg is the same as ppc_scatter."
Expand All @@ -143,7 +154,7 @@ ppc_scatter_avg <-
# ppd instead of ppc (see comment in ppc_scatter)
scale_color_ppd() +
scale_fill_ppd() +
labs(x = yrep_avg_label(), y = y_label()) +
labs(x = yrep_avg_label(stat), y = y_label()) +
bayesplot_theme_get()
}

Expand All @@ -155,6 +166,7 @@ ppc_scatter_avg_grouped <-
yrep,
group,
...,
stat = "mean",
facet_args = list(),
size = 2.5,
alpha = 0.8,
Expand Down Expand Up @@ -184,16 +196,19 @@ ppc_scatter_data <- function(y, yrep) {

#' @rdname PPC-scatterplots
#' @export
ppc_scatter_avg_data <- function(y, yrep, group = NULL) {
ppc_scatter_avg_data <- function(y, yrep, group = NULL, stat = "mean") {
y <- validate_y(y)
yrep <- validate_predictions(yrep, length(y))
if (!is.null(group)) {
group <- validate_group(group, length(y))
}
stat <- as_tagged_function({{ stat }})

data <- ppc_scatter_data(y = y, yrep = t(colMeans(yrep)))
data <- ppc_scatter_data(y = y, yrep = t(apply(yrep, 2, FUN = stat)))
data$rep_id <- NA_integer_
levels(data$rep_label) <- "mean(italic(y)[rep]))"
levels(data$rep_label) <- yrep_avg_label(stat) |>
as.expression() |>
as.character()
Comment on lines +209 to +211
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine by me to use the base R pipe, but we do import %>% from dplyr so could also use that.


if (!is.null(group)) {
data <- tibble::add_column(data,
Expand All @@ -206,7 +221,19 @@ ppc_scatter_avg_data <- function(y, yrep, group = NULL) {
}

# internal ----------------------------------------------------------------
yrep_avg_label <- function() expression(paste("Average ", italic(y)[rep]))

yrep_avg_label <- function(stat = NULL) {
stat <- as_tagged_function({{ stat }}, fallback = "stat")
e <- attr(stat, "tagged_expr")
if (attr(stat, "is_anonymous_function")) {
e <- sym("stat")
}
de <- deparse1(e)
# dummy globals to pass R check for globals
italic <- sym("italic")
y <- sym("y")
expr(paste((!!de))*(italic(y)[rep]))
}

scatter_aes <- function(...) {
aes(x = .data$value, y = .data$y_obs, ...)
Expand Down
Loading