@@ -92,6 +92,8 @@ def _assert_same_shape(a: jnp.ndarray, b: jnp.ndarray):
92
92
93
93
94
94
M = TypeVar ("M" , bound = "Metric" )
95
+ R = TypeVar ("R" , jnp .ndarray , dict [str , jnp .ndarray ])
96
+ V = TypeVar ("V" , clu .values .Value , dict [str , clu .values .Value ])
95
97
96
98
97
99
class Metric :
@@ -160,7 +162,7 @@ def merge(self: M, other: M) -> M:
160
162
def _reduce_merge (self : M , other : M ) -> M :
161
163
return self .merge (other )
162
164
163
- def compute (self ) -> jnp . ndarray :
165
+ def compute (self ) -> R :
164
166
"""Computes final metrics from intermediate values."""
165
167
raise NotImplementedError ("Must override compute()" )
166
168
@@ -169,9 +171,13 @@ def empty(cls: type[M]) -> M:
169
171
"""Returns an empty instance (i.e. `.merge(Metric.empty())` is a no-op)."""
170
172
raise NotImplementedError ("Must override empty()" )
171
173
172
- def compute_value (self ) -> clu .values .Value :
173
- """Wraps compute() and returns a values.Value."""
174
- return clu .values .Scalar (self .compute ())
174
+ def compute_value (self ) -> V :
175
+ """Wraps compute() and returns a values.Value or dict of values.Value."""
176
+ result = self .compute ()
177
+ if isinstance (result , dict ):
178
+ return {k : clu .values .Scalar (v ) for k , v in result .items ()}
179
+ else :
180
+ return clu .values .Scalar (result )
175
181
176
182
def reduce (self : M ) -> M :
177
183
"""Reduces the metric along it first axis by calling `_reduce_merge()`.
@@ -623,22 +629,32 @@ def reduce(self: C) -> C:
623
629
})
624
630
625
631
def compute (self ) -> dict [str , jnp .ndarray ]:
626
- """Returns a dictionary mapping metric field name to `Metric.compute()`."""
627
- _check_reduction_counter_ndim (self ._reduction_counter )
628
- return {
629
- metric_name : metric .compute ()
630
- for metric_name , metric in vars (self ).items ()
631
- if metric_name != "_reduction_counter"
632
- }
632
+ """Returns a dictionary mapping metrics to their computed values."""
633
+ metric_results = {}
634
+ for metric_name , metric in vars (self ).items ():
635
+ if metric_name != "_reduction_counter" :
636
+ metric_result = metric .compute ()
637
+ if isinstance (metric_result , dict ):
638
+ metric_results .update (
639
+ {f"{ metric_name } /{ k } " : v for k , v in metric_result .items ()}
640
+ )
641
+ else :
642
+ metric_results [metric_name ] = metric_result
643
+ return metric_results
633
644
634
645
def compute_values (self ) -> dict [str , clu .values .Value ]:
635
- """Computes metrics and returns them as clu.values.Value."""
636
- _check_reduction_counter_ndim (self ._reduction_counter )
637
- return {
638
- metric_name : metric .compute_value ()
639
- for metric_name , metric in vars (self ).items ()
640
- if metric_name != "_reduction_counter"
641
- }
646
+ """Computes metrics and returns them as clu_values.Value."""
647
+ metric_results = {}
648
+ for metric_name , metric in vars (self ).items ():
649
+ if metric_name != "_reduction_counter" :
650
+ metric_result = metric .compute_value ()
651
+ if isinstance (metric_result , dict ):
652
+ metric_results .update (
653
+ {f"{ metric_name } /{ k } " : v for k , v in metric_result .items ()}
654
+ )
655
+ else :
656
+ metric_results [metric_name ] = metric_result
657
+ return metric_results
642
658
643
659
def unreplicate (self : C ) -> C :
644
660
"""Short-hand for `flax.jax_utils.unreplicate(self)`.
0 commit comments