Skip to content

Commit 1368e52

Browse files
lzlarrylicopybara-github
authored andcommitted
Unify the handling of tensor-valued metrics of Std with Average. In particular, this removes the ndim restriction on the Std metric.
- Average does not have such a restriction. This is not needed. - Inside pmap, if per device the output is per sample loss values (ndim=1), then the result of lax.all_gather is tensor with ndim=2 with the added dimension from the devices. So `Average` would work but `Std` would fail. PiperOrigin-RevId: 603275875
1 parent f30bc44 commit 1368e52

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+117
-71
lines changed

clu/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/asynclib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/asynclib_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/checkpoint_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/data/dataset_iterator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/data/dataset_iterator_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/deterministic_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/deterministic_data_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/internal/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/internal/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/internal/utils_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/metric_writers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/metric_writers/async_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/metric_writers/async_writer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/metric_writers/interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/metric_writers/logging_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/metric_writers/logging_writer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/metric_writers/multi_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/metric_writers/multi_writer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/metric_writers/summary_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/metric_writers/tf/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/metric_writers/tf/summary_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/metric_writers/tf/summary_writer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/metric_writers/torch_tensorboard_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/metric_writers/torch_tensorboard_writer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/metric_writers/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/metric_writers/utils_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/metrics.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -742,6 +742,27 @@ def compute(self) -> Any:
742742
return self.value
743743

744744

745+
def _broadcast_masks(values: jnp.ndarray, mask: jnp.ndarray | None):
746+
"""Checks and broadcasts mask for aggregating values."""
747+
if values.ndim == 0:
748+
values = values[None]
749+
if mask is None:
750+
mask = jnp.ones_like(values)
751+
# Leading dimensions of mask and values must match.
752+
if mask.shape[0] != values.shape[0]:
753+
raise ValueError(
754+
"Argument `mask` must have the same leading dimension as `values`. "
755+
f"Received mask of dimension {mask.shape} "
756+
f"and values of dimension {values.shape}."
757+
)
758+
# Broadcast mask to the same number of dimensions as values.
759+
if mask.ndim < values.ndim:
760+
mask = jnp.expand_dims(mask, axis=tuple(np.arange(mask.ndim, values.ndim)))
761+
mask = mask.astype(bool)
762+
utils.check_param(mask, dtype=bool, ndim=values.ndim)
763+
return values, mask
764+
765+
745766
@flax.struct.dataclass
746767
class Average(Metric):
747768
"""Computes the average of a scalar or a batch of tensors.
@@ -769,26 +790,14 @@ def empty(cls) -> Average:
769790
def from_model_output(
770791
cls, values: jnp.ndarray, mask: jnp.ndarray | None = None, **_
771792
) -> Average:
772-
if values.ndim == 0:
773-
values = values[None]
774-
if mask is None:
775-
mask = jnp.ones_like(values)
776-
# Leading dimensions of mask and values must match.
777-
if mask.shape[0] != values.shape[0]:
778-
raise ValueError(
779-
f"Argument `mask` must have the same leading dimension as `values`. "
780-
f"Received mask of dimension {mask.shape} "
781-
f"and values of dimension {values.shape}.")
782-
# Broadcast mask to the same number of dimensions as values.
783-
if mask.ndim < values.ndim:
784-
mask = jnp.expand_dims(
785-
mask, axis=tuple(np.arange(mask.ndim, values.ndim)))
786-
mask = mask.astype(bool)
787-
utils.check_param(mask, dtype=bool, ndim=values.ndim)
793+
values, mask = _broadcast_masks(values, mask)
788794
return cls(
789795
total=jnp.where(mask, values, jnp.zeros_like(values)).sum(),
790-
count=jnp.where(mask, jnp.ones_like(values, dtype=jnp.int32),
791-
jnp.zeros_like(values, dtype=jnp.int32)).sum(),
796+
count=jnp.where(
797+
mask,
798+
jnp.ones_like(values, dtype=jnp.int32),
799+
jnp.zeros_like(values, dtype=jnp.int32),
800+
).sum(),
792801
)
793802

794803
def merge(self, other: Average) -> Average:
@@ -804,9 +813,10 @@ def compute(self) -> Any:
804813

805814
@flax.struct.dataclass
806815
class Std(Metric):
807-
"""Computes the standard deviation of a scalar or a batch of scalars.
816+
"""Computes the standard deviation of a scalar or a batch of tensors.
808817
809-
See also documentation of `Metric`.
818+
The result is always a single scalar. See also the documentation of `Average`
819+
for the mask handling.
810820
"""
811821

812822
total: jnp.ndarray
@@ -824,17 +834,15 @@ def empty(cls) -> Std:
824834
def from_model_output(
825835
cls, values: jnp.ndarray, mask: jnp.ndarray | None = None, **_
826836
) -> Std:
827-
if values.ndim == 0:
828-
values = values[None]
829-
utils.check_param(values, ndim=1)
830-
if mask is None:
831-
mask = jnp.ones(values.shape[0], dtype=jnp.int32)
837+
values, mask = _broadcast_masks(values, mask)
832838
return cls(
833839
total=jnp.where(mask, values, jnp.zeros_like(values)).sum(),
834-
sum_of_squares=jnp.where(
835-
mask, values**2, jnp.zeros_like(values)
840+
sum_of_squares=jnp.where(mask, values**2, jnp.zeros_like(values)).sum(),
841+
count=jnp.where(
842+
mask,
843+
jnp.ones_like(values, dtype=jnp.int32),
844+
jnp.zeros_like(values, dtype=jnp.int32),
836845
).sum(),
837-
count=mask.sum(),
838846
)
839847

840848
def merge(self, other: Std) -> Std:

clu/metrics_test.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -526,6 +526,44 @@ def merge_collection(model_output, collection):
526526
# If it does have a weak type the second call will cause a re-trace
527527
collection = merge_collection(model_output, collection)
528528

529+
@parameterized.product(
530+
value_mask_pair=[
531+
(1, None),
532+
([1, 2, 3], None),
533+
([1, 2, 3], [True, True, False]),
534+
([[1, 2], [2, 3], [3, 4]], None),
535+
([[1, 2], [2, 3], [3, 4]], [False, True, True]),
536+
(
537+
[[1, 2], [2, 3], [3, 4]],
538+
[[False, True], [True, True], [True, True]],
539+
),
540+
([[[1, 2], [2, 3]], [[2, 1], [3, 4]], [[3, 1], [4, 1]]], None),
541+
(
542+
[[[1, 2], [2, 3]], [[2, 1], [3, 4]], [[3, 1], [4, 1]]],
543+
[False, True, True],
544+
),
545+
],
546+
metric_np_equivalent_pair=[
547+
(metrics.Average, jnp.mean),
548+
(metrics.Std, jnp.std),
549+
],
550+
)
551+
def test_tensor_aggregation_metrics_with_masks(
552+
self, value_mask_pair, metric_np_equivalent_pair
553+
):
554+
values, mask = value_mask_pair
555+
metric, np_equivalent = metric_np_equivalent_pair
556+
values = jnp.asarray(values)
557+
masked = values
558+
if mask is not None:
559+
mask = jnp.asarray(mask)
560+
masked = values[mask]
561+
expected = np_equivalent(masked)
562+
563+
result = metric.from_model_output(values, mask=mask).compute()
564+
# The lower precision is needed for the lower precision jitted version.
565+
chex.assert_trees_all_close(result, expected, atol=1e-4, rtol=1e-4)
566+
529567

530568
if __name__ == "__main__":
531569
absltest.main()

clu/parameter_overview.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/parameter_overview_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/periodic_actions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/periodic_actions_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/platform/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/platform/interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/platform/local.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/preprocess_spec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/preprocess_spec_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/profiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

clu/values.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The CLU Authors.
1+
# Copyright 2024 The CLU Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

0 commit comments

Comments
 (0)