diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 718aebb3..72605848 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.7.2 +++++ +* :pr:`165`: support for task text-to-image * :pr:`162`: improves graphs rendering for historical data 0.7.1 diff --git a/_doc/api/tasks/index.rst b/_doc/api/tasks/index.rst index fb606d1a..de27f4cd 100644 --- a/_doc/api/tasks/index.rst +++ b/_doc/api/tasks/index.rst @@ -46,6 +46,7 @@ Or: summarization text_classification text_generation + text_to_image text2text_generation zero_shot_image_classification diff --git a/_doc/api/tasks/text_to_image.rst b/_doc/api/tasks/text_to_image.rst new file mode 100644 index 00000000..dcc1936e --- /dev/null +++ b/_doc/api/tasks/text_to_image.rst @@ -0,0 +1,7 @@ + +onnx_diagnostic.tasks.text_to_image +=================================== + +.. automodule:: onnx_diagnostic.tasks.text_to_image + :members: + :no-undoc-members: diff --git a/_doc/api/torch_export_patches/index.rst b/_doc/api/torch_export_patches/index.rst index 4493ffec..23f136a5 100644 --- a/_doc/api/torch_export_patches/index.rst +++ b/_doc/api/torch_export_patches/index.rst @@ -8,6 +8,7 @@ onnx_diagnostic.torch_export_patches eval/index onnx_export_errors onnx_export_serialization + onnx_export_serialization_impl patches/index patch_expressions patch_inputs diff --git a/_doc/api/torch_export_patches/onnx_export_serialization_impl.rst b/_doc/api/torch_export_patches/onnx_export_serialization_impl.rst new file mode 100644 index 00000000..22a94d27 --- /dev/null +++ b/_doc/api/torch_export_patches/onnx_export_serialization_impl.rst @@ -0,0 +1,7 @@ + +onnx_diagnostic.torch_export_patches.onnx_export_serialization_impl +=================================================================== + +.. automodule:: onnx_diagnostic.torch_export_patches.onnx_export_serialization_impl + :members: + :no-undoc-members: diff --git a/_doc/conf.py b/_doc/conf.py index e8f010d5..c4b312c1 100644 --- a/_doc/conf.py +++ b/_doc/conf.py @@ -90,6 +90,8 @@ def linkcode_resolve(domain, info): "https://sdpython.github.io/doc/experimental-experiment/dev/", None, ), + # Not a sphinx documentation + # "diffusers": ("https://huggingface.co/docs/diffusers/index", None), "matplotlib": ("https://matplotlib.org/stable/", None), "numpy": ("https://numpy.org/doc/stable", None), "onnx": ("https://onnx.ai/onnx/", None), @@ -104,6 +106,8 @@ def linkcode_resolve(domain, info): "sklearn": ("https://scikit-learn.org/stable/", None), "skl2onnx": ("https://onnx.ai/sklearn-onnx/", None), "torch": ("https://pytorch.org/docs/main/", None), + # Not a sphinx documentation + # "transformers": ("https://huggingface.co/docs/transformers/index", None), } # Check intersphinx reference targets exist @@ -116,6 +120,7 @@ def linkcode_resolve(domain, info): ("py:class", "True"), ("py:class", "Argument"), ("py:class", "default=sklearn.utils.metadata_routing.UNCHANGED"), + ("py:class", "diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput"), ("py:class", "ModelProto"), ("py:class", "Model"), ("py:class", "Module"), diff --git a/_doc/patches.rst b/_doc/patches.rst index 7b6cf622..bb589431 100644 --- a/_doc/patches.rst +++ b/_doc/patches.rst @@ -113,7 +113,7 @@ Here is the list of supported caches: import onnx_diagnostic.torch_export_patches.onnx_export_serialization as p - print("\n".join(sorted(p.serialization_functions()))) + print("\n".join(sorted(t.__name__ for t in p.serialization_functions()))) .. _l-control-flow-rewriting: diff --git a/_doc/status/patches_coverage.rst b/_doc/status/patches_coverage.rst index f35c066b..1a3ac8b7 100644 --- a/_doc/status/patches_coverage.rst +++ b/_doc/status/patches_coverage.rst @@ -14,7 +14,7 @@ The following code shows the list of serialized classes in transformers. import onnx_diagnostic.torch_export_patches.onnx_export_serialization as p - print('\n'.join(sorted(p.serialization_functions()))) + print('\n'.join(sorted(t.__name__ for t in p.serialization_functions()))) Patched Classes =============== diff --git a/_unittests/ut_tasks/test_tasks_image_classification.py b/_unittests/ut_tasks/test_tasks_image_classification.py index e9856d10..0bf6b97c 100644 --- a/_unittests/ut_tasks/test_tasks_image_classification.py +++ b/_unittests/ut_tasks/test_tasks_image_classification.py @@ -6,7 +6,7 @@ from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str -class TestTasks(ExtTestCase): +class TestTasksImageClassification(ExtTestCase): @hide_stdout() def test_image_classification(self): mid = "hf-internal-testing/tiny-random-BeitForImageClassification" diff --git a/_unittests/ut_tasks/test_tasks_image_text_to_text.py b/_unittests/ut_tasks/test_tasks_image_text_to_text.py index 173d628c..b51db78f 100644 --- a/_unittests/ut_tasks/test_tasks_image_text_to_text.py +++ b/_unittests/ut_tasks/test_tasks_image_text_to_text.py @@ -11,7 +11,7 @@ from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str -class TestTasks(ExtTestCase): +class TestTasksImageTextToText(ExtTestCase): @hide_stdout() @requires_transformers("4.52") @requires_torch("2.7.99") diff --git a/_unittests/ut_tasks/test_tasks_object_detection.py b/_unittests/ut_tasks/test_tasks_object_detection.py index 2429d5a9..7e21f7df 100644 --- a/_unittests/ut_tasks/test_tasks_object_detection.py +++ b/_unittests/ut_tasks/test_tasks_object_detection.py @@ -6,7 +6,7 @@ from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str -class TestTasks(ExtTestCase): +class TestTasksObjectDetection(ExtTestCase): @hide_stdout() def test_object_detection(self): mid = "hustvl/yolos-tiny" diff --git a/_unittests/ut_tasks/test_tasks_text_to_image.py b/_unittests/ut_tasks/test_tasks_text_to_image.py new file mode 100644 index 00000000..57d6609b --- /dev/null +++ b/_unittests/ut_tasks/test_tasks_text_to_image.py @@ -0,0 +1,35 @@ +import unittest +import torch +from onnx_diagnostic.ext_test_case import ( + ExtTestCase, + hide_stdout, + requires_transformers, + requires_torch, +) +from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs +from onnx_diagnostic.torch_export_patches import torch_export_patches +from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str + + +class TestTasksTextToTimage(ExtTestCase): + @hide_stdout() + @requires_transformers("4.52") + @requires_torch("2.7.99") + def test_text_to_image(self): + mid = "diffusers/tiny-torch-full-checker" + data = get_untrained_model_with_inputs( + mid, verbose=1, add_second_input=True, subfolder="unet" + ) + self.assertEqual(data["task"], "text-to-image") + self.assertIn((data["size"], data["n_weights"]), [(5708048, 1427012)]) + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + model(**inputs) + model(**data["inputs2"]) + with torch_export_patches(patch_transformers=True, verbose=10, stop_if_static=1): + torch.export.export( + model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_tasks/test_tasks_zero_shot_image_classification.py b/_unittests/ut_tasks/test_tasks_zero_shot_image_classification.py index bbfb34ab..7381eac2 100644 --- a/_unittests/ut_tasks/test_tasks_zero_shot_image_classification.py +++ b/_unittests/ut_tasks/test_tasks_zero_shot_image_classification.py @@ -6,7 +6,7 @@ from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str -class TestTasks(ExtTestCase): +class TestTasksZeroShotImageClassification(ExtTestCase): @requires_torch("2.7.99") @hide_stdout() def test_zero_shot_image_classification(self): diff --git a/_unittests/ut_tasks/try_tasks.py b/_unittests/ut_tasks/try_tasks.py index 6d6df11f..69522267 100644 --- a/_unittests/ut_tasks/try_tasks.py +++ b/_unittests/ut_tasks/try_tasks.py @@ -569,6 +569,27 @@ def test_object_detection(self): f"{round(score.item(), 3)} at location {box}" ) + @never_test() + def test_text_to_image(self): + # clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k test_text_to_image + import torch + from diffusers import StableDiffusionPipeline + + model_id = "diffusers/tiny-torch-full-checker" # "stabilityai/stable-diffusion-2" + pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to( + "cuda" + ) + + prompt = "a photo of an astronaut riding a horse on mars and on jupyter" + print() + with steal_forward(pipe.unet, with_min_max=True): + image = pipe(prompt).images[0] + print("-- output", self.string_type(image, with_shape=True, with_min_max=True)) + # stolen forward for class UNet2DConditionModel -- iteration 44 + # sample=T10s2x4x96x96[-3.7734375,4.359375:A-0.043463995395642184] + # time_step=T7s=101 + # encoder_hidden_states:T10s2x77x1024[-6.58203125,13.0234375:A-0.16780663634440257] + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_export_patches/test_patch_serialization.py b/_unittests/ut_torch_export_patches/test_patch_serialization.py index 851627a5..fc7beaa4 100644 --- a/_unittests/ut_torch_export_patches/test_patch_serialization.py +++ b/_unittests/ut_torch_export_patches/test_patch_serialization.py @@ -1,7 +1,12 @@ import unittest import torch from transformers.modeling_outputs import BaseModelOutput -from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_torch +from onnx_diagnostic.ext_test_case import ( + ExtTestCase, + ignore_warnings, + requires_torch, + requires_diffusers, +) from onnx_diagnostic.helpers.cache_helper import ( make_encoder_decoder_cache, make_dynamic_cache, @@ -212,6 +217,63 @@ def test_sliding_window_cache_flatten(self): self.string_type(cache2, with_shape=True, with_min_max=True), ) + @ignore_warnings(UserWarning) + @requires_diffusers("0.30") + def test_unet_2d_condition_output(self): + from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput + + bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4))) + self.assertEqual(bo.__class__.__name__, "UNet2DConditionOutput") + bo2 = torch_deepcopy([bo]) + self.assertIsInstance(bo2, list) + self.assertEqual( + "UNet2DConditionOutput(sample:T1s4x4x4)", + self.string_type(bo, with_shape=True), + ) + + with torch_export_patches(): + # internal function + bo2 = torch_deepcopy([bo]) + self.assertIsInstance(bo2, list) + self.assertEqual(bo2[0].__class__.__name__, "UNet2DConditionOutput") + self.assertEqualAny([bo], bo2) + self.assertEqual( + "UNet2DConditionOutput(sample:T1s4x4x4)", + self.string_type(bo, with_shape=True), + ) + + # serialization + flat, _spec = torch.utils._pytree.tree_flatten(bo) + self.assertEqual( + "#1[T1s4x4x4]", + self.string_type(flat, with_shape=True), + ) + bo2 = torch.utils._pytree.tree_unflatten(flat, _spec) + self.assertEqual( + self.string_type(bo, with_shape=True, with_min_max=True), + self.string_type(bo2, with_shape=True, with_min_max=True), + ) + + # flatten_unflatten + flat, _spec = torch.utils._pytree.tree_flatten(bo) + unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True) + self.assertIsInstance(unflat, dict) + self.assertEqual(list(unflat), ["sample"]) + + # export + class Model(torch.nn.Module): + def forward(self, cache): + return cache.sample[0] + + bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4))) + model = Model() + model(bo) + DYN = torch.export.Dim.DYNAMIC + ds = [{0: DYN}] + + with torch_export_patches(): + torch.export.export(model, (bo,), dynamic_shapes=(ds,)) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/helpers/config_helper.py b/onnx_diagnostic/helpers/config_helper.py index c22340ab..bd4f987d 100644 --- a/onnx_diagnostic/helpers/config_helper.py +++ b/onnx_diagnostic/helpers/config_helper.py @@ -43,7 +43,10 @@ def update_config(config: Any, mkwargs: Dict[str, Any]): else: update_config(getattr(config, k), v) continue - setattr(config, k, v) + if type(config) is dict: + config[k] = v + else: + setattr(config, k, v) def _pick(config, *atts, exceptions: Optional[Dict[str, Callable]] = None): @@ -66,6 +69,18 @@ def _pick(config, *atts, exceptions: Optional[Dict[str, Callable]] = None): raise AssertionError(f"Unable to find any of these {atts!r} in {config}") +def pick(config, name: str, default_value: Any) -> Any: + """ + Returns the value of a attribute if config has it + otherwise the default value. + """ + if not config: + return default_value + if type(config) is dict: + return config.get(name, default_value) + return getattr(config, name, default_value) + + @functools.cache def config_class_from_architecture(arch: str, exc: bool = False) -> Optional[type]: """ diff --git a/onnx_diagnostic/tasks/__init__.py b/onnx_diagnostic/tasks/__init__.py index 405b9c80..1f14cc9e 100644 --- a/onnx_diagnostic/tasks/__init__.py +++ b/onnx_diagnostic/tasks/__init__.py @@ -11,6 +11,7 @@ summarization, text_classification, text_generation, + text_to_image, text2text_generation, zero_shot_image_classification, ) @@ -27,6 +28,7 @@ summarization, text_classification, text_generation, + text_to_image, text2text_generation, zero_shot_image_classification, ] diff --git a/onnx_diagnostic/tasks/text_to_image.py b/onnx_diagnostic/tasks/text_to_image.py new file mode 100644 index 00000000..983d9bec --- /dev/null +++ b/onnx_diagnostic/tasks/text_to_image.py @@ -0,0 +1,91 @@ +from typing import Any, Callable, Dict, Optional, Tuple +import torch +from ..helpers.config_helper import update_config, check_hasattr, pick + +__TASK__ = "text-to-image" + + +def reduce_model_config(config: Any) -> Dict[str, Any]: + """Reduces a model size.""" + check_hasattr(config, "sample_size", "cross_attention_dim") + kwargs = dict( + sample_size=min(config["sample_size"], 32), + cross_attention_dim=min(config["cross_attention_dim"], 64), + ) + update_config(config, kwargs) + return kwargs + + +def get_inputs( + model: torch.nn.Module, + config: Optional[Any], + batch_size: int, + sequence_length: int, + cache_length: int, + in_channels: int, + sample_size: int, + cross_attention_dim: int, + add_second_input: bool = False, + **kwargs, # unused +): + """ + Generates inputs for task ``text-to-image``. + Example: + + :: + + sample:T10s2x4x96x96[-3.7734375,4.359375:A-0.043463995395642184] + timestep:T7s=101 + encoder_hidden_states:T10s2x77x1024[-6.58203125,13.0234375:A-0.16780663634440257] + """ + assert ( + "cls_cache" not in kwargs + ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." + batch = "batch" + shapes = { + "sample": {0: batch}, + "timestep": {}, + "encoder_hidden_states": {0: batch, 1: "encoder_length"}, + } + inputs = dict( + sample=torch.randn((batch_size, sequence_length, sample_size, sample_size)).to( + torch.float32 + ), + timestep=torch.tensor([101], dtype=torch.int64), + encoder_hidden_states=torch.randn( + (batch_size, sequence_length, cross_attention_dim) + ).to(torch.float32), + ) + res = dict(inputs=inputs, dynamic_shapes=shapes) + if add_second_input: + res["inputs2"] = get_inputs( + model=model, + config=config, + batch_size=batch_size + 1, + sequence_length=sequence_length, + cache_length=cache_length + 1, + in_channels=in_channels, + sample_size=sample_size, + cross_attention_dim=cross_attention_dim, + **kwargs, + )["inputs"] + return res + + +def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: + """ + Inputs kwargs. + + If the configuration is None, the function selects typical dimensions. + """ + if config is not None: + check_hasattr(config, "sample_size", "cross_attention_dim", "in_channels") + kwargs = dict( + batch_size=2, + sequence_length=pick(config, "in_channels", 4), + cache_length=77, + in_channels=pick(config, "in_channels", 4), + sample_size=pick(config, "sample_size", 32), + cross_attention_dim=pick(config, "cross_attention_dim", 64), + ) + return kwargs, get_inputs diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index ebbed95f..4f216367 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -1,5 +1,5 @@ import pprint -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, Optional, Set import packaging.version as pv import optree import torch @@ -11,9 +11,8 @@ SlidingWindowCache, StaticCache, ) -from transformers.modeling_outputs import BaseModelOutput + from ..helpers import string_type -from ..helpers.cache_helper import make_static_cache PATCH_OF_PATCHES: Set[Any] = set() @@ -40,10 +39,12 @@ def register_class_serialization( :return: registered or not """ if cls is not None and cls in torch.utils._pytree.SUPPORTED_NODES: + if verbose and cls is not None: + print(f"[register_class_serialization] already registered {cls.__name__}") return False if verbose: - print(f"[register_cache_serialization] register {cls}") + print(f"[register_class_serialization] ---------- register {cls.__name__}") torch.utils._pytree.register_pytree_node( cls, f_flatten, @@ -54,8 +55,8 @@ def register_class_serialization( if pv.Version(torch.__version__) < pv.Version("2.7"): if verbose: print( - f"[register_cache_serialization] " - f"register {cls} for torch=={torch.__version__}" + f"[register_class_serialization] " + f"---------- register {cls.__name__} for torch=={torch.__version__}" ) torch.fx._pytree.register_pytree_flatten_spec(cls, lambda x, _: f_flatten(x)[0]) @@ -77,6 +78,10 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: Registers many classes with :func:`register_class_serialization`. Returns information needed to undo the registration. """ + from .onnx_export_serialization_impl import WRONG_REGISTRATIONS + + registration_functions = serialization_functions(verbose=verbose) + # DynamicCache serialization is different in transformers and does not # play way with torch.export.export. # see test test_export_dynamic_cache_cat with NOBYPASS=1 @@ -85,63 +90,65 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: # torch.fx._pytree.register_pytree_flatten_spec( # DynamicCache, _flatten_dynamic_cache_for_fx) # so we remove it anyway - if ( - DynamicCache in torch.utils._pytree.SUPPORTED_NODES - and DynamicCache not in PATCH_OF_PATCHES - # and pv.Version(torch.__version__) < pv.Version("2.7") - and pv.Version(transformers.__version__) >= pv.Version("4.50") - ): - if verbose: - print( - f"[_fix_registration] DynamicCache is unregistered and " - f"registered first for transformers=={transformers.__version__}" - ) - unregister(DynamicCache, verbose=verbose) - register_class_serialization( - DynamicCache, - flatten_dynamic_cache, - unflatten_dynamic_cache, - flatten_with_keys_dynamic_cache, - # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), - verbose=verbose, - ) - if verbose: - print("[_fix_registration] DynamicCache done.") - # To avoid doing it multiple times. - PATCH_OF_PATCHES.add(DynamicCache) - # BaseModelOutput serialization is incomplete. # It does not include dynamic shapes mapping. - if ( - BaseModelOutput in torch.utils._pytree.SUPPORTED_NODES - and BaseModelOutput not in PATCH_OF_PATCHES - ): - if verbose: - print( - f"[_fix_registration] BaseModelOutput is unregistered and " - f"registered first for transformers=={transformers.__version__}" + for cls, version in WRONG_REGISTRATIONS.items(): + if ( + cls in torch.utils._pytree.SUPPORTED_NODES + and cls not in PATCH_OF_PATCHES + # and pv.Version(torch.__version__) < pv.Version("2.7") + and ( + version is None or pv.Version(transformers.__version__) >= pv.Version(version) ) - unregister(BaseModelOutput, verbose=verbose) - register_class_serialization( - BaseModelOutput, - flatten_base_model_output, - unflatten_base_model_output, - flatten_with_keys_base_model_output, - verbose=verbose, - ) - if verbose: - print("[_fix_registration] BaseModelOutput done.") - - # To avoid doing it multiple times. - PATCH_OF_PATCHES.add(BaseModelOutput) - - return serialization_functions(verbose=verbose) - - -def serialization_functions(verbose: int = 0) -> Dict[str, Union[Callable, int]]: + ): + assert cls in registration_functions, ( + f"{cls} has no registration functions mapped to it, " + f"available options are {list(registration_functions)}" + ) + if verbose: + print( + f"[_fix_registration] {cls.__name__} is unregistered and " + f"registered first" + ) + unregister_class_serialization(cls, verbose=verbose) + registration_functions[cls](verbose=verbose) # type: ignore[arg-type, call-arg] + if verbose: + print(f"[_fix_registration] {cls.__name__} done.") + # To avoid doing it multiple times. + PATCH_OF_PATCHES.add(cls) + + # classes with no registration at all. + done = {} + for k, v in registration_functions.items(): + done[k] = v(verbose=verbose) # type: ignore[arg-type, call-arg] + return done + + +def serialization_functions(verbose: int = 0) -> Dict[type, Callable[[int], bool]]: """Returns the list of serialization functions.""" - return dict( - DynamicCache=register_class_serialization( + from .onnx_export_serialization_impl import ( + SUPPORTED_DATACLASSES, + _lower_name_with_, + __dict__ as all_functions, + flatten_dynamic_cache, + unflatten_dynamic_cache, + flatten_with_keys_dynamic_cache, + flatten_mamba_cache, + unflatten_mamba_cache, + flatten_with_keys_mamba_cache, + flatten_encoder_decoder_cache, + unflatten_encoder_decoder_cache, + flatten_with_keys_encoder_decoder_cache, + flatten_sliding_window_cache, + unflatten_sliding_window_cache, + flatten_with_keys_sliding_window_cache, + flatten_static_cache, + unflatten_static_cache, + flatten_with_keys_static_cache, + ) + + transformers_classes = { + DynamicCache: lambda verbose=verbose: register_class_serialization( DynamicCache, flatten_dynamic_cache, unflatten_dynamic_cache, @@ -149,45 +156,53 @@ def serialization_functions(verbose: int = 0) -> Dict[str, Union[Callable, int]] # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), verbose=verbose, ), - MambaCache=register_class_serialization( + MambaCache: lambda verbose=verbose: register_class_serialization( MambaCache, flatten_mamba_cache, unflatten_mamba_cache, flatten_with_keys_mamba_cache, verbose=verbose, ), - EncoderDecoderCache=register_class_serialization( + EncoderDecoderCache: lambda verbose=verbose: register_class_serialization( EncoderDecoderCache, flatten_encoder_decoder_cache, unflatten_encoder_decoder_cache, flatten_with_keys_encoder_decoder_cache, verbose=verbose, ), - BaseModelOutput=register_class_serialization( - BaseModelOutput, - flatten_base_model_output, - unflatten_base_model_output, - flatten_with_keys_base_model_output, - verbose=verbose, - ), - SlidingWindowCache=register_class_serialization( + SlidingWindowCache: lambda verbose=verbose: register_class_serialization( SlidingWindowCache, flatten_sliding_window_cache, unflatten_sliding_window_cache, flatten_with_keys_sliding_window_cache, verbose=verbose, ), - StaticCache=register_class_serialization( + StaticCache: lambda verbose=verbose: register_class_serialization( StaticCache, flatten_static_cache, unflatten_static_cache, flatten_with_keys_static_cache, verbose=verbose, ), - ) + } + for cls in SUPPORTED_DATACLASSES: + lname = _lower_name_with_(cls.__name__) + assert ( + f"flatten_{lname}" in all_functions + ), f"Unable to find function 'flatten_{lname}' in {sorted(all_functions)}" + transformers_classes[cls] = ( + lambda verbose=verbose, _ln=lname, cls=cls, _al=all_functions: register_class_serialization( # noqa: E501 + cls, + _al[f"flatten_{_ln}"], + _al[f"unflatten_{_ln}"], + _al[f"flatten_with_keys_{_ln}"], + verbose=verbose, + ) + ) + return transformers_classes -def unregister(cls: type, verbose: int = 0): +def unregister_class_serialization(cls: type, verbose: int = 0): """Undo the registration.""" # torch.utils._pytree._deregister_pytree_flatten_spec(cls) if cls in torch.fx._pytree.SUPPORTED_NODES: @@ -217,264 +232,7 @@ def unregister(cls: type, verbose: int = 0): def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0): """Undo all registrations.""" - for cls in [MambaCache, DynamicCache, EncoderDecoderCache, BaseModelOutput]: + cls_ensemble = {MambaCache, DynamicCache, EncoderDecoderCache} | set(undo) + for cls in cls_ensemble: if undo.get(cls.__name__, False): - unregister(cls, verbose) - - -############ -# MambaCache -############ - - -def flatten_mamba_cache( - mamba_cache: MambaCache, -) -> Tuple[List[Any], torch.utils._pytree.Context]: - """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" - flat = [ - ("conv_states", mamba_cache.conv_states), - ("ssm_states", mamba_cache.ssm_states), - ] - return [f[1] for f in flat], [f[0] for f in flat] - - -def unflatten_mamba_cache( - values: List[Any], context: torch.utils._pytree.Context, output_type=None -) -> MambaCache: - """Restores a :class:`transformers.cache_utils.MambaCache` from python objects.""" - conv_states, ssm_states = values - - class _config: - def __init__(self): - if isinstance(conv_states, list): - self.intermediate_size = conv_states[0].shape[1] - self.state_size = ssm_states[0].shape[2] - self.conv_kernel = conv_states[0].shape[2] - self.num_hidden_layers = len(conv_states) - else: - self.intermediate_size = conv_states.shape[2] - self.state_size = ssm_states.shape[3] - self.conv_kernel = conv_states.shape[3] - self.num_hidden_layers = conv_states.shape[0] - - cache = MambaCache( - _config(), - max_batch_size=1, - dtype=values[-1][0].dtype, - device="cpu" if values[-1][0].get_device() < 0 else "cuda", - ) - values = dict(zip(context, values)) - for k, v in values.items(): - setattr(cache, k, v) - return cache - - -def flatten_with_keys_mamba_cache(cache: MambaCache) -> Tuple[ - List[Tuple[torch.utils._pytree.KeyEntry, Any]], - torch.utils._pytree.Context, -]: - """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" - values, context = flatten_mamba_cache(cache) - return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context - - -############## -# DynamicCache -############## - - -def flatten_dynamic_cache( - dynamic_cache: DynamicCache, -) -> Tuple[List[Any], torch.utils._pytree.Context]: - """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" - if hasattr(transformers.cache_utils, "_flatten_dynamic_cache"): - return transformers.cache_utils._flatten_dynamic_cache(dynamic_cache) - flat = [("key_cache", dynamic_cache.key_cache), ("value_cache", dynamic_cache.value_cache)] - return [f[1] for f in flat], [f[0] for f in flat] - - -def flatten_with_keys_dynamic_cache( - dynamic_cache: DynamicCache, -) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: - """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" - if hasattr(transformers.cache_utils, "_flatten_with_keys_dynamic_cache"): - return transformers.cache_utils._flatten_with_keys_dynamic_cache(dynamic_cache) - values, context = flatten_dynamic_cache(dynamic_cache) - return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context - - -def unflatten_dynamic_cache( - values: List[Any], context: torch.utils._pytree.Context, output_type=None -) -> DynamicCache: - """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects.""" - if hasattr(transformers.cache_utils, "_unflatten_dynamic_cache"): - assert output_type is None, f"output_type={output_type} not supported" - return transformers.cache_utils._unflatten_dynamic_cache(values, context) - - cache = transformers.cache_utils.DynamicCache() - values = dict(zip(context, values)) - for k, v in values.items(): - setattr(cache, k, v) - return cache - - -############# -# StaticCache -############# - - -def flatten_static_cache( - cache: StaticCache, -) -> Tuple[List[Any], torch.utils._pytree.Context]: - """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects.""" - flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)] - return [f[1] for f in flat], [f[0] for f in flat] - - -def flatten_with_keys_static_cache( - cache: StaticCache, -) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: - """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects.""" - values, context = flatten_static_cache(cache) - return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context - - -def unflatten_static_cache( - values: List[Any], context: torch.utils._pytree.Context, output_type=None -) -> StaticCache: - """Restores a :class:`transformers.cache_utils.StaticCache` from python objects.""" - return make_static_cache(list(zip(values[0], values[1]))) - - -#################### -# SlidingWindowCache -#################### - - -def flatten_sliding_window_cache( - cache: SlidingWindowCache, -) -> Tuple[List[Any], torch.utils._pytree.Context]: - """ - Serializes a :class:`transformers.cache_utils.SlidingWindowCache` - with python objects. - """ - flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)] - return [f[1] for f in flat], [f[0] for f in flat] - - -def flatten_with_keys_sliding_window_cache( - cache: SlidingWindowCache, -) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: - """ - Serializes a :class:`transformers.cache_utils.SlidingWindowCache` - with python objects. - """ - values, context = flatten_sliding_window_cache(cache) - return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context - - -def unflatten_sliding_window_cache( - values: List[Any], context: torch.utils._pytree.Context, output_type=None -) -> SlidingWindowCache: - """Restores a :class:`transformers.cache_utils.SlidingWindowCache` from python objects.""" - key_cache, value_cache = values - - class _config: - def __init__(self): - self.head_dim = key_cache[0].shape[-1] - self.num_attention_heads = key_cache[0].shape[1] - self.num_hidden_layers = len(key_cache) - self.sliding_window = key_cache[0].shape[2] - - cache = SlidingWindowCache( - _config(), - max_batch_size=key_cache[0].shape[0], - max_cache_len=key_cache[0].shape[2], # sligding window - device=key_cache[0].device, - dtype=key_cache[0].dtype, - ) - - values = dict(zip(context, values)) - for k, v in values.items(): - setattr(cache, k, v) - return cache - - -##################### -# EncoderDecoderCache -##################### - - -def flatten_encoder_decoder_cache( - ec_cache: EncoderDecoderCache, -) -> Tuple[List[Any], torch.utils._pytree.Context]: - """ - Serializes a :class:`transformers.cache_utils.EncoderDecoderCache` - with python objects. - """ - dictionary = { - "self_attention_cache": ec_cache.self_attention_cache, - "cross_attention_cache": ec_cache.cross_attention_cache, - } - return torch.utils._pytree._dict_flatten(dictionary) - - -def flatten_with_keys_encoder_decoder_cache(ec_cache: EncoderDecoderCache) -> Tuple[ - List[Tuple[torch.utils._pytree.KeyEntry, Any]], - torch.utils._pytree.Context, -]: - """ - Serializes a :class:`transformers.cache_utils.EncoderDecoderCache` - with python objects. - """ - dictionary = { - "self_attention_cache": ec_cache.self_attention_cache, - "cross_attention_cache": ec_cache.cross_attention_cache, - } - return torch.utils._pytree._dict_flatten_with_keys(dictionary) - - -def unflatten_encoder_decoder_cache( - values: List[Any], context: torch.utils._pytree.Context, output_type=None -) -> EncoderDecoderCache: - """Restores a :class:`transformers.cache_utils.EncoderDecoderCache` from python objects.""" - dictionary = torch.utils._pytree._dict_unflatten(values, context) - return EncoderDecoderCache(**dictionary) - - -################# -# BaseModelOutput -################# - - -def flatten_base_model_output( - bo: BaseModelOutput, -) -> Tuple[List[Any], torch.utils._pytree.Context]: - """ - Serializes a :class:`transformers.modeling_outputs.BaseModelOutput` - with python objects. - """ - return list(bo.values()), list(bo.keys()) - - -def flatten_with_keys_base_model_output( - bo: BaseModelOutput, -) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: - """ - Serializes a :class:`transformers.modeling_outputs.BaseModelOutput` - with python objects. - """ - values, context = flatten_base_model_output(bo) - return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context - - -def unflatten_base_model_output( - values: List[Any], - context: torch.utils._pytree.Context, - output_type=None, -) -> BaseModelOutput: - """ - Restores a :class:`transformers.modeling_outputs.BaseModelOutput` - from python objects. - """ - return BaseModelOutput(**dict(zip(context, values))) + unregister_class_serialization(cls, verbose) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization_impl.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization_impl.py new file mode 100644 index 00000000..32a7415c --- /dev/null +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization_impl.py @@ -0,0 +1,327 @@ +import re +from typing import Any, Callable, Dict, List, Optional, Tuple +import torch +import transformers +from transformers.cache_utils import ( + DynamicCache, + MambaCache, + EncoderDecoderCache, + SlidingWindowCache, + StaticCache, +) +from transformers.modeling_outputs import BaseModelOutput + +try: + from diffusers.models.autoencoders.vae import DecoderOutput, EncoderOutput + from diffusers.models.unets.unet_1d import UNet1DOutput + from diffusers.models.unets.unet_2d import UNet2DOutput + from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput + from diffusers.models.unets.unet_3d_condition import UNet3DConditionOutput +except ImportError as e: + try: + import diffusers + except ImportError: + diffusers = None + DecoderOutput, EncoderOutput = None, None + UNet1DOutput, UNet2DOutput = None, None + UNet2DConditionOutput, UNet3DConditionOutput = None, None + if diffusers: + raise e + +from ..helpers.cache_helper import make_static_cache + + +def _make_wrong_registrations() -> Dict[str, Optional[str]]: + res = { + DynamicCache: "4.50", + BaseModelOutput: None, + } + for c in [UNet2DConditionOutput]: + if c is not None: + res[c] = None + return res + + +SUPPORTED_DATACLASSES = set() +WRONG_REGISTRATIONS = _make_wrong_registrations() + + +############ +# MambaCache +############ + + +def flatten_mamba_cache( + mamba_cache: MambaCache, +) -> Tuple[List[Any], torch.utils._pytree.Context]: + """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" + flat = [ + ("conv_states", mamba_cache.conv_states), + ("ssm_states", mamba_cache.ssm_states), + ] + return [f[1] for f in flat], [f[0] for f in flat] + + +def unflatten_mamba_cache( + values: List[Any], context: torch.utils._pytree.Context, output_type=None +) -> MambaCache: + """Restores a :class:`transformers.cache_utils.MambaCache` from python objects.""" + conv_states, ssm_states = values + + class _config: + def __init__(self): + if isinstance(conv_states, list): + self.intermediate_size = conv_states[0].shape[1] + self.state_size = ssm_states[0].shape[2] + self.conv_kernel = conv_states[0].shape[2] + self.num_hidden_layers = len(conv_states) + else: + self.intermediate_size = conv_states.shape[2] + self.state_size = ssm_states.shape[3] + self.conv_kernel = conv_states.shape[3] + self.num_hidden_layers = conv_states.shape[0] + + cache = MambaCache( + _config(), + max_batch_size=1, + dtype=values[-1][0].dtype, + device="cpu" if values[-1][0].get_device() < 0 else "cuda", + ) + values = dict(zip(context, values)) + for k, v in values.items(): + setattr(cache, k, v) + return cache + + +def flatten_with_keys_mamba_cache(cache: MambaCache) -> Tuple[ + List[Tuple[torch.utils._pytree.KeyEntry, Any]], + torch.utils._pytree.Context, +]: + """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" + values, context = flatten_mamba_cache(cache) + return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context + + +############## +# DynamicCache +############## + + +def flatten_dynamic_cache( + dynamic_cache: DynamicCache, +) -> Tuple[List[Any], torch.utils._pytree.Context]: + """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" + if hasattr(transformers.cache_utils, "_flatten_dynamic_cache"): + return transformers.cache_utils._flatten_dynamic_cache(dynamic_cache) + flat = [("key_cache", dynamic_cache.key_cache), ("value_cache", dynamic_cache.value_cache)] + return [f[1] for f in flat], [f[0] for f in flat] + + +def flatten_with_keys_dynamic_cache( + dynamic_cache: DynamicCache, +) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: + """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" + if hasattr(transformers.cache_utils, "_flatten_with_keys_dynamic_cache"): + return transformers.cache_utils._flatten_with_keys_dynamic_cache(dynamic_cache) + values, context = flatten_dynamic_cache(dynamic_cache) + return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context + + +def unflatten_dynamic_cache( + values: List[Any], context: torch.utils._pytree.Context, output_type=None +) -> DynamicCache: + """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects.""" + if hasattr(transformers.cache_utils, "_unflatten_dynamic_cache"): + assert output_type is None, f"output_type={output_type} not supported" + return transformers.cache_utils._unflatten_dynamic_cache(values, context) + + cache = transformers.cache_utils.DynamicCache() + values = dict(zip(context, values)) + for k, v in values.items(): + setattr(cache, k, v) + return cache + + +############# +# StaticCache +############# + + +def flatten_static_cache( + cache: StaticCache, +) -> Tuple[List[Any], torch.utils._pytree.Context]: + """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects.""" + flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)] + return [f[1] for f in flat], [f[0] for f in flat] + + +def flatten_with_keys_static_cache( + cache: StaticCache, +) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: + """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects.""" + values, context = flatten_static_cache(cache) + return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context + + +def unflatten_static_cache( + values: List[Any], context: torch.utils._pytree.Context, output_type=None +) -> StaticCache: + """Restores a :class:`transformers.cache_utils.StaticCache` from python objects.""" + return make_static_cache(list(zip(values[0], values[1]))) + + +#################### +# SlidingWindowCache +#################### + + +def flatten_sliding_window_cache( + cache: SlidingWindowCache, +) -> Tuple[List[Any], torch.utils._pytree.Context]: + """ + Serializes a :class:`transformers.cache_utils.SlidingWindowCache` + with python objects. + """ + flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)] + return [f[1] for f in flat], [f[0] for f in flat] + + +def flatten_with_keys_sliding_window_cache( + cache: SlidingWindowCache, +) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: + """ + Serializes a :class:`transformers.cache_utils.SlidingWindowCache` + with python objects. + """ + values, context = flatten_sliding_window_cache(cache) + return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context + + +def unflatten_sliding_window_cache( + values: List[Any], context: torch.utils._pytree.Context, output_type=None +) -> SlidingWindowCache: + """Restores a :class:`transformers.cache_utils.SlidingWindowCache` from python objects.""" + key_cache, value_cache = values + + class _config: + def __init__(self): + self.head_dim = key_cache[0].shape[-1] + self.num_attention_heads = key_cache[0].shape[1] + self.num_hidden_layers = len(key_cache) + self.sliding_window = key_cache[0].shape[2] + + cache = SlidingWindowCache( + _config(), + max_batch_size=key_cache[0].shape[0], + max_cache_len=key_cache[0].shape[2], # sligding window + device=key_cache[0].device, + dtype=key_cache[0].dtype, + ) + + values = dict(zip(context, values)) + for k, v in values.items(): + setattr(cache, k, v) + return cache + + +##################### +# EncoderDecoderCache +##################### + + +def flatten_encoder_decoder_cache( + ec_cache: EncoderDecoderCache, +) -> Tuple[List[Any], torch.utils._pytree.Context]: + """ + Serializes a :class:`transformers.cache_utils.EncoderDecoderCache` + with python objects. + """ + dictionary = { + "self_attention_cache": ec_cache.self_attention_cache, + "cross_attention_cache": ec_cache.cross_attention_cache, + } + return torch.utils._pytree._dict_flatten(dictionary) + + +def flatten_with_keys_encoder_decoder_cache(ec_cache: EncoderDecoderCache) -> Tuple[ + List[Tuple[torch.utils._pytree.KeyEntry, Any]], + torch.utils._pytree.Context, +]: + """ + Serializes a :class:`transformers.cache_utils.EncoderDecoderCache` + with python objects. + """ + dictionary = { + "self_attention_cache": ec_cache.self_attention_cache, + "cross_attention_cache": ec_cache.cross_attention_cache, + } + return torch.utils._pytree._dict_flatten_with_keys(dictionary) + + +def unflatten_encoder_decoder_cache( + values: List[Any], context: torch.utils._pytree.Context, output_type=None +) -> EncoderDecoderCache: + """Restores a :class:`transformers.cache_utils.EncoderDecoderCache` from python objects.""" + dictionary = torch.utils._pytree._dict_unflatten(values, context) + return EncoderDecoderCache(**dictionary) + + +############# +# dataclasses +############# + + +def _lower_name_with_(name): + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + + +def make_serialization_function_for_dataclass(cls) -> Tuple[Callable, Callable, Callable]: + """ + Automatically creates serialization function for a class decorated with + ``dataclasses.dataclass``. + """ + + def flatten_cls(obj: cls) -> Tuple[List[Any], torch.utils._pytree.Context]: + """Serializes a ``%s`` with python objects.""" + return list(obj.values()), list(obj.keys()) + + def flatten_with_keys_cls( + obj: cls, + ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: + """Serializes a ``%s`` with python objects with keys.""" + values, context = list(obj.values()), list(obj.keys()) + return [ + (torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values) + ], context + + def unflatten_cls( + values: List[Any], context: torch.utils._pytree.Context, output_type=None + ) -> cls: + """Restores an instance of ``%s`` from python objects.""" + return cls(**dict(zip(context, values))) + + name = _lower_name_with_(cls.__name__) + flatten_cls.__name__ = f"flatten_{name}" + flatten_with_keys_cls.__name__ = f"flatten_with_keys_{name}" + unflatten_cls.__name__ = f"unflatten_{name}" + flatten_cls.__doc__ = flatten_cls.__doc__ % cls.__name__ + flatten_with_keys_cls.__doc__ = flatten_with_keys_cls.__doc__ % cls.__name__ + unflatten_cls.__doc__ = unflatten_cls.__doc__ % cls.__name__ + SUPPORTED_DATACLASSES.add(cls) + return flatten_cls, flatten_with_keys_cls, unflatten_cls + + +( + flatten_base_model_output, + flatten_with_keys_base_model_output, + unflatten_base_model_output, +) = make_serialization_function_for_dataclass(BaseModelOutput) + + +if UNet2DConditionOutput is not None: + ( + flatten_u_net2_d_condition_output, + flatten_with_keys_u_net2_d_condition_output, + unflatten_u_net2_d_condition_output, + ) = make_serialization_function_for_dataclass(UNet2DConditionOutput) diff --git a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py index ee826e82..d671c508 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +++ b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py @@ -4302,3 +4302,31 @@ def _ccached_microsoft_phi_35_mini_instruct(): "vocab_size": 32064, } ) + + +def _ccached_diffusers_tiny_torch_full_checker_unet(): + "diffusers/tiny-torch-full-checker/unet" + return { + "_class_name": "UNet2DConditionModel", + "_diffusers_version": "0.8.0", + "_name_or_path": "https://huggingface.co/diffusers/tiny-torch-full-checker/blob/main/unet/config.json", + "act_fn": "silu", + "attention_head_dim": 8, + "block_out_channels": [32, 64], + "center_input_sample": false, + "cross_attention_dim": 32, + "down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D"], + "downsample_padding": 1, + "dual_cross_attention": false, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "out_channels": 4, + "sample_size": 32, + "up_block_types": ["CrossAttnUpBlock2D", "UpBlock2D"], + "use_linear_projection": false, + } diff --git a/onnx_diagnostic/torch_models/hghub/model_inputs.py b/onnx_diagnostic/torch_models/hghub/model_inputs.py index 1f8d89ed..6db64040 100644 --- a/onnx_diagnostic/torch_models/hghub/model_inputs.py +++ b/onnx_diagnostic/torch_models/hghub/model_inputs.py @@ -145,12 +145,19 @@ def get_untrained_model_with_inputs( f"{config._attn_implementation!r}" # type: ignore[union-attr] ) + if type(config) is dict and "_diffusers_version" in config: + import diffusers + + package_source = diffusers + else: + package_source = transformers + if use_pretrained: model = transformers.AutoModel.from_pretrained(model_id, **mkwargs) else: if archs is not None: try: - model = getattr(transformers, archs[0])(config) + cls_model = getattr(package_source, archs[0]) except AttributeError as e: # The code of the models is not in transformers but in the # repository of the model. We need to download it. @@ -174,10 +181,12 @@ def get_untrained_model_with_inputs( f"[get_untrained_model_with_inputs] from folder " f"{os.path.split(pyfiles[0])[0]!r}" ) - cls = transformers.dynamic_module_utils.get_class_from_dynamic_module( - cls_name, pretrained_model_name_or_path=os.path.split(pyfiles[0])[0] + cls_model = ( + transformers.dynamic_module_utils.get_class_from_dynamic_module( + cls_name, + pretrained_model_name_or_path=os.path.split(pyfiles[0])[0], + ) ) - model = cls(config) else: raise AttributeError( f"Unable to find class 'tranformers.{archs[0]}'. " @@ -191,6 +200,16 @@ def get_untrained_model_with_inputs( f"and use_pretrained=True." ) + try: + if type(config) is dict: + model = cls_model(**config) + else: + model = cls_model(config) + except RuntimeError as e: + raise RuntimeError( + f"Unable to instantiate class {cls_model.__name__} with\n{config}" + ) from e + # input kwargs kwargs, fct = random_input_kwargs(config, task) if verbose: diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index c27f744f..03011e4b 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -538,9 +538,13 @@ def validate_model( if summary["model_module"] in sys.modules: summary["model_file"] = str(sys.modules[summary["model_module"]].__file__) # type: ignore[index] summary["model_config_class"] = data["configuration"].__class__.__name__ - summary["model_config"] = str(shrink_config(data["configuration"].to_dict())).replace( - " ", "" - ) + summary["model_config"] = str( + shrink_config( + data["configuration"] + if type(data["configuration"]) is dict + else data["configuration"].to_dict() + ) + ).replace(" ", "") summary["model_id"] = model_id if verbose: