Skip to content

Commit e1cf737

Browse files
Vishal Batchucopybara-github
authored andcommitted
Update CLU metrics to support dictionary returns for computed metric results to allow for additional flexibility
PiperOrigin-RevId: 745477695
1 parent 7baf9cc commit e1cf737

File tree

1 file changed

+34
-18
lines changed

1 file changed

+34
-18
lines changed

clu/metrics.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ def _assert_same_shape(a: jnp.ndarray, b: jnp.ndarray):
9292

9393

9494
M = TypeVar("M", bound="Metric")
95+
R = TypeVar("R", jnp.ndarray, dict[str, jnp.ndarray])
96+
V = TypeVar("V", clu.values.Value, dict[str, clu.values.Value])
9597

9698

9799
class Metric:
@@ -160,7 +162,7 @@ def merge(self: M, other: M) -> M:
160162
def _reduce_merge(self: M, other: M) -> M:
161163
return self.merge(other)
162164

163-
def compute(self) -> jnp.ndarray:
165+
def compute(self) -> R:
164166
"""Computes final metrics from intermediate values."""
165167
raise NotImplementedError("Must override compute()")
166168

@@ -169,9 +171,13 @@ def empty(cls: type[M]) -> M:
169171
"""Returns an empty instance (i.e. `.merge(Metric.empty())` is a no-op)."""
170172
raise NotImplementedError("Must override empty()")
171173

172-
def compute_value(self) -> clu.values.Value:
173-
"""Wraps compute() and returns a values.Value."""
174-
return clu.values.Scalar(self.compute())
174+
def compute_value(self) -> V:
175+
"""Wraps compute() and returns a values.Value or dict of values.Value."""
176+
result = self.compute()
177+
if isinstance(result, dict):
178+
return {k: clu.values.Scalar(v) for k, v in result.items()}
179+
else:
180+
return clu.values.Scalar(result)
175181

176182
def reduce(self: M) -> M:
177183
"""Reduces the metric along it first axis by calling `_reduce_merge()`.
@@ -623,22 +629,32 @@ def reduce(self: C) -> C:
623629
})
624630

625631
def compute(self) -> dict[str, jnp.ndarray]:
626-
"""Returns a dictionary mapping metric field name to `Metric.compute()`."""
627-
_check_reduction_counter_ndim(self._reduction_counter)
628-
return {
629-
metric_name: metric.compute()
630-
for metric_name, metric in vars(self).items()
631-
if metric_name != "_reduction_counter"
632-
}
632+
"""Returns a dictionary mapping metrics to their computed values."""
633+
metric_results = {}
634+
for metric_name, metric in vars(self).items():
635+
if metric_name != "_reduction_counter":
636+
metric_result = metric.compute()
637+
if isinstance(metric_result, dict):
638+
metric_results.update(
639+
{f"{metric_name}/{k}": v for k, v in metric_result.items()}
640+
)
641+
else:
642+
metric_results[metric_name] = metric_result
643+
return metric_results
633644

634645
def compute_values(self) -> dict[str, clu.values.Value]:
635-
"""Computes metrics and returns them as clu.values.Value."""
636-
_check_reduction_counter_ndim(self._reduction_counter)
637-
return {
638-
metric_name: metric.compute_value()
639-
for metric_name, metric in vars(self).items()
640-
if metric_name != "_reduction_counter"
641-
}
646+
"""Computes metrics and returns them as clu_values.Value."""
647+
metric_results = {}
648+
for metric_name, metric in vars(self).items():
649+
if metric_name != "_reduction_counter":
650+
metric_result = metric.compute_value()
651+
if isinstance(metric_result, dict):
652+
metric_results.update(
653+
{f"{metric_name}/{k}": v for k, v in metric_result.items()}
654+
)
655+
else:
656+
metric_results[metric_name] = metric_result
657+
return metric_results
642658

643659
def unreplicate(self: C) -> C:
644660
"""Short-hand for `flax.jax_utils.unreplicate(self)`.

0 commit comments

Comments
 (0)