Skip to content

Commit 574d4c9

Browse files
jpuigcervercopybara-github
authored andcommitted
Allow paramter_overview to work with JAX ShapeDtypeStruct.
PiperOrigin-RevId: 595317017
1 parent 8a01f21 commit 574d4c9

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

clu/parameter_overview.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _count_parameters(params: _ParamsContainer) -> int:
8383
def _parameters_size(params: _ParamsContainer) -> int:
8484
"""Returns total size (bytes) for the module or parameter dictionary."""
8585
params = flatten_dict(params)
86-
return sum(v.nbytes for v in params.values())
86+
return sum(np.prod(v.shape) * v.dtype.itemsize for v in params.values())
8787

8888

8989
def count_parameters(params: _ParamsContainer) -> int:

0 commit comments

Comments
 (0)