-
-
Notifications
You must be signed in to change notification settings - Fork 86
add fun_avg to ppc_avg functions #349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
c96c7d6
ffe723f
c647d69
23865df
27c17d6
2deeca9
1acd544
4293eb8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -469,3 +469,53 @@ grid_lines_y <- function(color = "gray50", size = 0.2) { | |
overlay_function <- function(...) { | ||
stat_function(..., inherit.aes = FALSE) | ||
} | ||
|
||
|
||
|
||
# Resolve a function name and store the expression passed in by the user | ||
#' @noRd | ||
#' @param f a function-like thing: a string naming a function, a function | ||
#' object, an anonymous function object, a formula-based lambda, and `NULL`. | ||
#' @param fallback character string providing a fallback function name | ||
#' @return the function named in `f` with an added `"tagged_expr"` attribute | ||
#' containing the expression to represent the function name and an | ||
#' `"is_anonymous_function"` attribute to flag if the expression is a call to | ||
#' `function()`. | ||
as_tagged_function <- function(f = NULL, fallback = "func") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function works as described, just one question: how come sometimes you call rlang functions with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops. Probably an artifact of me using typing |
||
qf <- enquo(f) | ||
f <- eval_tidy(qf) | ||
if (!is.null(attr(f, "tagged_expr"))) return(f) | ||
|
||
f_expr <- quo_get_expr(qf) | ||
f_fn <- f | ||
|
||
if (rlang::is_character(f)) { # f = "mean" | ||
# using sym() on the evaluated `f` means that a variable that names a | ||
# function string `x <- "mean"; as_tagged_function(x)` will be lost | ||
# but that seems okay | ||
f_expr <- rlang::sym(f) | ||
f_fn <- match.fun(f) | ||
} else if (is_null(f)) { # f = NULL | ||
f_fn <- identity | ||
f_expr <- rlang::sym(fallback) | ||
} else if (is_callable(f)) { # f = mean or f = function(x) mean(x) | ||
f_expr <- f_expr # or f = ~mean(.x) | ||
f_fn <- as_function(f) | ||
} | ||
|
||
# Setting attributes on primitive functions is deprecated, so wrap them | ||
# and then tag | ||
if (is_primitive(f_fn)) { | ||
f_fn_old <- f_fn | ||
f_factory <- function(f) { function(...) f(...) } | ||
f_fn <- f_factory(f_fn_old) | ||
} | ||
|
||
attr(f_fn, "tagged_expr") <- f_expr | ||
attr(f_fn, "is_anonymous_function") <- is_call(f_expr, name = "function") || | ||
is_formula(f_expr) | ||
f_fn | ||
} | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -11,6 +11,11 @@ | |||||
#' @template args-group | ||||||
#' @template args-facet_args | ||||||
#' @param ... Currently unused. | ||||||
#' @param stat A function or a string naming a function for computing the | ||||||
#' posterior average. In both cases, the function should take a vector input | ||||||
#' and return a scalar statistic. The function name is displayed in the | ||||||
#' axis-label, and the underlying `$rep_label` for `ppc_scatter_avg_data()` | ||||||
#' includes the function name. Defaults to `"mean"`. | ||||||
#' @param size,alpha Arguments passed to [ggplot2::geom_point()] to control the | ||||||
#' appearance of the points. | ||||||
#' @param ref_line If `TRUE` (the default) a dashed line with intercept 0 and | ||||||
|
@@ -31,10 +36,10 @@ | |||||
#' } | ||||||
#' \item{`ppc_scatter_avg()`}{ | ||||||
#' A single scatterplot of `y` against the average values of `yrep`, i.e., | ||||||
#' the points `(x,y) = (mean(yrep[, n]), y[n])`, where each `yrep[, n]` is | ||||||
#' a vector of length equal to the number of posterior draws. Unlike | ||||||
#' for `ppc_scatter()`, for `ppc_scatter_avg()` `yrep` should contain many | ||||||
#' draws (rows). | ||||||
#' the points `(x,y) = (average(yrep[, n]), y[n])`, where each `yrep[, n]` is | ||||||
#' a vector of length equal to the number of posterior draws and `average()` | ||||||
#' is summary statistic. Unlike for `ppc_scatter()`, for `ppc_scatter_avg()` | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
#' `yrep` should contain many draws (rows). | ||||||
#' } | ||||||
#' \item{`ppc_scatter_avg_grouped()`}{ | ||||||
#' The same as `ppc_scatter_avg()`, but a separate plot is generated for | ||||||
|
@@ -59,6 +64,9 @@ | |||||
#' p1 + lims | ||||||
#' p2 + lims | ||||||
#' | ||||||
#' # "average" function is customizable | ||||||
#' ppc_scatter_avg(y, yrep, stat = "median", ref_line = FALSE) | ||||||
#' | ||||||
#' # for ppc_scatter_avg_grouped the default is to allow the facets | ||||||
#' # to have different x and y axes | ||||||
#' group <- example_group_data() | ||||||
|
@@ -116,16 +124,19 @@ ppc_scatter_avg <- | |||||
function(y, | ||||||
yrep, | ||||||
..., | ||||||
stat = "mean", | ||||||
size = 2.5, | ||||||
alpha = 0.8, | ||||||
ref_line = TRUE) { | ||||||
dots <- list(...) | ||||||
stat <- as_tagged_function({{ stat }}) | ||||||
|
||||||
if (!from_grouped(dots)) { | ||||||
check_ignored_arguments(...) | ||||||
dots$group <- NULL | ||||||
} | ||||||
|
||||||
data <- ppc_scatter_avg_data(y, yrep, group = dots$group) | ||||||
data <- ppc_scatter_avg_data(y, yrep, group = dots$group, stat = stat) | ||||||
if (is.null(dots$group) && nrow(yrep) == 1) { | ||||||
inform( | ||||||
"With only 1 row in 'yrep' ppc_scatter_avg is the same as ppc_scatter." | ||||||
|
@@ -143,7 +154,7 @@ ppc_scatter_avg <- | |||||
# ppd instead of ppc (see comment in ppc_scatter) | ||||||
scale_color_ppd() + | ||||||
scale_fill_ppd() + | ||||||
labs(x = yrep_avg_label(), y = y_label()) + | ||||||
labs(x = yrep_avg_label(stat), y = y_label()) + | ||||||
bayesplot_theme_get() | ||||||
} | ||||||
|
||||||
|
@@ -155,6 +166,7 @@ ppc_scatter_avg_grouped <- | |||||
yrep, | ||||||
group, | ||||||
..., | ||||||
stat = "mean", | ||||||
facet_args = list(), | ||||||
size = 2.5, | ||||||
alpha = 0.8, | ||||||
|
@@ -184,16 +196,19 @@ ppc_scatter_data <- function(y, yrep) { | |||||
|
||||||
#' @rdname PPC-scatterplots | ||||||
#' @export | ||||||
ppc_scatter_avg_data <- function(y, yrep, group = NULL) { | ||||||
ppc_scatter_avg_data <- function(y, yrep, group = NULL, stat = "mean") { | ||||||
y <- validate_y(y) | ||||||
yrep <- validate_predictions(yrep, length(y)) | ||||||
if (!is.null(group)) { | ||||||
group <- validate_group(group, length(y)) | ||||||
} | ||||||
stat <- as_tagged_function({{ stat }}) | ||||||
|
||||||
data <- ppc_scatter_data(y = y, yrep = t(colMeans(yrep))) | ||||||
data <- ppc_scatter_data(y = y, yrep = t(apply(yrep, 2, FUN = stat))) | ||||||
data$rep_id <- NA_integer_ | ||||||
levels(data$rep_label) <- "mean(italic(y)[rep]))" | ||||||
levels(data$rep_label) <- yrep_avg_label(stat) |> | ||||||
as.expression() |> | ||||||
as.character() | ||||||
Comment on lines
+209
to
+211
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fine by me to use the base R pipe, but we do import |
||||||
|
||||||
if (!is.null(group)) { | ||||||
data <- tibble::add_column(data, | ||||||
|
@@ -206,7 +221,19 @@ ppc_scatter_avg_data <- function(y, yrep, group = NULL) { | |||||
} | ||||||
|
||||||
# internal ---------------------------------------------------------------- | ||||||
yrep_avg_label <- function() expression(paste("Average ", italic(y)[rep])) | ||||||
|
||||||
yrep_avg_label <- function(stat = NULL) { | ||||||
stat <- as_tagged_function({{ stat }}, fallback = "stat") | ||||||
e <- attr(stat, "tagged_expr") | ||||||
if (attr(stat, "is_anonymous_function")) { | ||||||
e <- sym("stat") | ||||||
} | ||||||
de <- deparse1(e) | ||||||
# dummy globals to pass R check for globals | ||||||
italic <- sym("italic") | ||||||
y <- sym("y") | ||||||
expr(paste((!!de))*(italic(y)[rep])) | ||||||
} | ||||||
|
||||||
scatter_aes <- function(...) { | ||||||
aes(x = .data$value, y = .data$y_obs, ...) | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think at this point it's been out long enough that bumping the required R version is fine.