From c30f91229cd57f73b84b0b8b30269ff6c1a27da9 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 24 Jun 2025 14:48:03 +0200 Subject: [PATCH 1/9] Support for text-to-image --- .../test_tasks_image_classification.py | 2 +- .../ut_tasks/test_tasks_image_text_to_text.py | 2 +- .../ut_tasks/test_tasks_object_detection.py | 2 +- .../ut_tasks/test_tasks_text_to_image.py | 35 +++++++ ...st_tasks_zero_shot_image_classification.py | 2 +- _unittests/ut_tasks/try_tasks.py | 21 +++++ onnx_diagnostic/helpers/config_helper.py | 5 +- onnx_diagnostic/tasks/__init__.py | 2 + onnx_diagnostic/tasks/text_to_image.py | 91 +++++++++++++++++++ .../hghub/hub_data_cached_configs.py | 28 ++++++ .../torch_models/hghub/model_inputs.py | 22 ++++- onnx_diagnostic/torch_models/validate.py | 10 +- 12 files changed, 210 insertions(+), 12 deletions(-) create mode 100644 _unittests/ut_tasks/test_tasks_text_to_image.py create mode 100644 onnx_diagnostic/tasks/text_to_image.py diff --git a/_unittests/ut_tasks/test_tasks_image_classification.py b/_unittests/ut_tasks/test_tasks_image_classification.py index e9856d10..0bf6b97c 100644 --- a/_unittests/ut_tasks/test_tasks_image_classification.py +++ b/_unittests/ut_tasks/test_tasks_image_classification.py @@ -6,7 +6,7 @@ from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str -class TestTasks(ExtTestCase): +class TestTasksImageClassification(ExtTestCase): @hide_stdout() def test_image_classification(self): mid = "hf-internal-testing/tiny-random-BeitForImageClassification" diff --git a/_unittests/ut_tasks/test_tasks_image_text_to_text.py b/_unittests/ut_tasks/test_tasks_image_text_to_text.py index 173d628c..b51db78f 100644 --- a/_unittests/ut_tasks/test_tasks_image_text_to_text.py +++ b/_unittests/ut_tasks/test_tasks_image_text_to_text.py @@ -11,7 +11,7 @@ from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str -class TestTasks(ExtTestCase): +class TestTasksImageTextToText(ExtTestCase): @hide_stdout() @requires_transformers("4.52") @requires_torch("2.7.99") diff --git a/_unittests/ut_tasks/test_tasks_object_detection.py b/_unittests/ut_tasks/test_tasks_object_detection.py index 2429d5a9..7e21f7df 100644 --- a/_unittests/ut_tasks/test_tasks_object_detection.py +++ b/_unittests/ut_tasks/test_tasks_object_detection.py @@ -6,7 +6,7 @@ from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str -class TestTasks(ExtTestCase): +class TestTasksObjectDetection(ExtTestCase): @hide_stdout() def test_object_detection(self): mid = "hustvl/yolos-tiny" diff --git a/_unittests/ut_tasks/test_tasks_text_to_image.py b/_unittests/ut_tasks/test_tasks_text_to_image.py new file mode 100644 index 00000000..57d6609b --- /dev/null +++ b/_unittests/ut_tasks/test_tasks_text_to_image.py @@ -0,0 +1,35 @@ +import unittest +import torch +from onnx_diagnostic.ext_test_case import ( + ExtTestCase, + hide_stdout, + requires_transformers, + requires_torch, +) +from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs +from onnx_diagnostic.torch_export_patches import torch_export_patches +from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str + + +class TestTasksTextToTimage(ExtTestCase): + @hide_stdout() + @requires_transformers("4.52") + @requires_torch("2.7.99") + def test_text_to_image(self): + mid = "diffusers/tiny-torch-full-checker" + data = get_untrained_model_with_inputs( + mid, verbose=1, add_second_input=True, subfolder="unet" + ) + self.assertEqual(data["task"], "text-to-image") + self.assertIn((data["size"], data["n_weights"]), [(5708048, 1427012)]) + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + model(**inputs) + model(**data["inputs2"]) + with torch_export_patches(patch_transformers=True, verbose=10, stop_if_static=1): + torch.export.export( + model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_tasks/test_tasks_zero_shot_image_classification.py b/_unittests/ut_tasks/test_tasks_zero_shot_image_classification.py index bbfb34ab..7381eac2 100644 --- a/_unittests/ut_tasks/test_tasks_zero_shot_image_classification.py +++ b/_unittests/ut_tasks/test_tasks_zero_shot_image_classification.py @@ -6,7 +6,7 @@ from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str -class TestTasks(ExtTestCase): +class TestTasksZeroShotImageClassification(ExtTestCase): @requires_torch("2.7.99") @hide_stdout() def test_zero_shot_image_classification(self): diff --git a/_unittests/ut_tasks/try_tasks.py b/_unittests/ut_tasks/try_tasks.py index 6d6df11f..69522267 100644 --- a/_unittests/ut_tasks/try_tasks.py +++ b/_unittests/ut_tasks/try_tasks.py @@ -569,6 +569,27 @@ def test_object_detection(self): f"{round(score.item(), 3)} at location {box}" ) + @never_test() + def test_text_to_image(self): + # clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k test_text_to_image + import torch + from diffusers import StableDiffusionPipeline + + model_id = "diffusers/tiny-torch-full-checker" # "stabilityai/stable-diffusion-2" + pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to( + "cuda" + ) + + prompt = "a photo of an astronaut riding a horse on mars and on jupyter" + print() + with steal_forward(pipe.unet, with_min_max=True): + image = pipe(prompt).images[0] + print("-- output", self.string_type(image, with_shape=True, with_min_max=True)) + # stolen forward for class UNet2DConditionModel -- iteration 44 + # sample=T10s2x4x96x96[-3.7734375,4.359375:A-0.043463995395642184] + # time_step=T7s=101 + # encoder_hidden_states:T10s2x77x1024[-6.58203125,13.0234375:A-0.16780663634440257] + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/helpers/config_helper.py b/onnx_diagnostic/helpers/config_helper.py index c22340ab..938fbccc 100644 --- a/onnx_diagnostic/helpers/config_helper.py +++ b/onnx_diagnostic/helpers/config_helper.py @@ -43,7 +43,10 @@ def update_config(config: Any, mkwargs: Dict[str, Any]): else: update_config(getattr(config, k), v) continue - setattr(config, k, v) + if type(config) is dict: + config[k] = v + else: + setattr(config, k, v) def _pick(config, *atts, exceptions: Optional[Dict[str, Callable]] = None): diff --git a/onnx_diagnostic/tasks/__init__.py b/onnx_diagnostic/tasks/__init__.py index 405b9c80..1f14cc9e 100644 --- a/onnx_diagnostic/tasks/__init__.py +++ b/onnx_diagnostic/tasks/__init__.py @@ -11,6 +11,7 @@ summarization, text_classification, text_generation, + text_to_image, text2text_generation, zero_shot_image_classification, ) @@ -27,6 +28,7 @@ summarization, text_classification, text_generation, + text_to_image, text2text_generation, zero_shot_image_classification, ] diff --git a/onnx_diagnostic/tasks/text_to_image.py b/onnx_diagnostic/tasks/text_to_image.py new file mode 100644 index 00000000..791de235 --- /dev/null +++ b/onnx_diagnostic/tasks/text_to_image.py @@ -0,0 +1,91 @@ +from typing import Any, Callable, Dict, Optional, Tuple +import torch +from ..helpers.config_helper import update_config, check_hasattr + +__TASK__ = "text-to-image" + + +def reduce_model_config(config: Any) -> Dict[str, Any]: + """Reduces a model size.""" + check_hasattr(config, "sample_size", "cross_attention_dim") + kwargs = dict( + sample_size=min(config["sample_size"], 32), + cross_attention_dim=min(config["cross_attention_dim"], 64), + ) + update_config(config, kwargs) + return kwargs + + +def get_inputs( + model: torch.nn.Module, + config: Optional[Any], + batch_size: int, + sequence_length: int, + cache_length: int, + in_channels: int, + sample_size: int, + cross_attention_dim: int, + add_second_input: bool = False, + **kwargs, # unused +): + """ + Generates inputs for task ``text-to-image``. + Example: + + :: + + sample:T10s2x4x96x96[-3.7734375,4.359375:A-0.043463995395642184] + timestep:T7s=101 + encoder_hidden_states:T10s2x77x1024[-6.58203125,13.0234375:A-0.16780663634440257] + """ + assert ( + "cls_cache" not in kwargs + ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." + batch = torch.export.Dim("batch", min=1, max=1024) + shapes = { + "sample": {0: batch}, + "timestep": {}, + "encoder_hidden_states": {0: batch, 1: "encoder_length"}, + } + inputs = dict( + sample=torch.randn((batch_size, sequence_length, sample_size, sample_size)).to( + torch.float32 + ), + timestep=torch.tensor([101], dtype=torch.int64), + encoder_hidden_states=torch.randn( + (batch_size, sequence_length, cross_attention_dim) + ).to(torch.float32), + ) + res = dict(inputs=inputs, dynamic_shapes=shapes) + if add_second_input: + res["inputs2"] = get_inputs( + model=model, + config=config, + batch_size=batch_size + 1, + sequence_length=sequence_length, + cache_length=cache_length + 1, + in_channels=in_channels, + sample_size=sample_size, + cross_attention_dim=cross_attention_dim, + **kwargs, + )["inputs"] + return res + + +def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: + """ + Inputs kwargs. + + If the configuration is None, the function selects typical dimensions. + """ + if config is not None: + check_hasattr(config, "sample_size", "cross_attention_dim", "in_channels") + kwargs = dict( + batch_size=2, + sequence_length=config["in_channels"], + cache_length=77, + in_channels=config["in_channels"], + sample_size=config["sample_size"], + cross_attention_dim=config["cross_attention_dim"], + ) + return kwargs, get_inputs diff --git a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py index ee826e82..d671c508 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +++ b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py @@ -4302,3 +4302,31 @@ def _ccached_microsoft_phi_35_mini_instruct(): "vocab_size": 32064, } ) + + +def _ccached_diffusers_tiny_torch_full_checker_unet(): + "diffusers/tiny-torch-full-checker/unet" + return { + "_class_name": "UNet2DConditionModel", + "_diffusers_version": "0.8.0", + "_name_or_path": "https://huggingface.co/diffusers/tiny-torch-full-checker/blob/main/unet/config.json", + "act_fn": "silu", + "attention_head_dim": 8, + "block_out_channels": [32, 64], + "center_input_sample": false, + "cross_attention_dim": 32, + "down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D"], + "downsample_padding": 1, + "dual_cross_attention": false, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "out_channels": 4, + "sample_size": 32, + "up_block_types": ["CrossAttnUpBlock2D", "UpBlock2D"], + "use_linear_projection": false, + } diff --git a/onnx_diagnostic/torch_models/hghub/model_inputs.py b/onnx_diagnostic/torch_models/hghub/model_inputs.py index 1f8d89ed..33426788 100644 --- a/onnx_diagnostic/torch_models/hghub/model_inputs.py +++ b/onnx_diagnostic/torch_models/hghub/model_inputs.py @@ -145,12 +145,19 @@ def get_untrained_model_with_inputs( f"{config._attn_implementation!r}" # type: ignore[union-attr] ) + if type(config) is dict and "_diffusers_version" in config: + import diffusers + + package_source = diffusers + else: + package_source = transformers + if use_pretrained: model = transformers.AutoModel.from_pretrained(model_id, **mkwargs) else: if archs is not None: try: - model = getattr(transformers, archs[0])(config) + cls_model = getattr(package_source, archs[0]) except AttributeError as e: # The code of the models is not in transformers but in the # repository of the model. We need to download it. @@ -174,10 +181,12 @@ def get_untrained_model_with_inputs( f"[get_untrained_model_with_inputs] from folder " f"{os.path.split(pyfiles[0])[0]!r}" ) - cls = transformers.dynamic_module_utils.get_class_from_dynamic_module( - cls_name, pretrained_model_name_or_path=os.path.split(pyfiles[0])[0] + cls_model = ( + transformers.dynamic_module_utils.get_class_from_dynamic_module( + cls_name, + pretrained_model_name_or_path=os.path.split(pyfiles[0])[0], + ) ) - model = cls(config) else: raise AttributeError( f"Unable to find class 'tranformers.{archs[0]}'. " @@ -191,6 +200,11 @@ def get_untrained_model_with_inputs( f"and use_pretrained=True." ) + if type(config) is dict: + model = cls_model(**config) + else: + model = cls_model(config) + # input kwargs kwargs, fct = random_input_kwargs(config, task) if verbose: diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index c27f744f..4fbcf2f2 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -538,9 +538,13 @@ def validate_model( if summary["model_module"] in sys.modules: summary["model_file"] = str(sys.modules[summary["model_module"]].__file__) # type: ignore[index] summary["model_config_class"] = data["configuration"].__class__.__name__ - summary["model_config"] = str(shrink_config(data["configuration"].to_dict())).replace( - " ", "" - ) + summary["model_config"] = str( + shrink_config( + data["configuration"] + if type(data["configuration"]) + else data["configuration"].to_dict() + ) + ).replace(" ", "") summary["model_id"] = model_id if verbose: From 4c6cebd87b5e920c3c3fabbafcef9099a911bb68 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 24 Jun 2025 15:03:02 +0200 Subject: [PATCH 2/9] doc --- CHANGELOGS.rst | 1 + _doc/api/tasks/index.rst | 1 + _doc/api/tasks/text_to_image.rst | 7 +++++++ onnx_diagnostic/tasks/text_to_image.py | 2 +- onnx_diagnostic/torch_models/validate.py | 2 +- 5 files changed, 11 insertions(+), 2 deletions(-) create mode 100644 _doc/api/tasks/text_to_image.rst diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 718aebb3..72605848 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.7.2 +++++ +* :pr:`165`: support for task text-to-image * :pr:`162`: improves graphs rendering for historical data 0.7.1 diff --git a/_doc/api/tasks/index.rst b/_doc/api/tasks/index.rst index fb606d1a..de27f4cd 100644 --- a/_doc/api/tasks/index.rst +++ b/_doc/api/tasks/index.rst @@ -46,6 +46,7 @@ Or: summarization text_classification text_generation + text_to_image text2text_generation zero_shot_image_classification diff --git a/_doc/api/tasks/text_to_image.rst b/_doc/api/tasks/text_to_image.rst new file mode 100644 index 00000000..dcc1936e --- /dev/null +++ b/_doc/api/tasks/text_to_image.rst @@ -0,0 +1,7 @@ + +onnx_diagnostic.tasks.text_to_image +=================================== + +.. automodule:: onnx_diagnostic.tasks.text_to_image + :members: + :no-undoc-members: diff --git a/onnx_diagnostic/tasks/text_to_image.py b/onnx_diagnostic/tasks/text_to_image.py index 791de235..916500e9 100644 --- a/onnx_diagnostic/tasks/text_to_image.py +++ b/onnx_diagnostic/tasks/text_to_image.py @@ -41,7 +41,7 @@ def get_inputs( assert ( "cls_cache" not in kwargs ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." - batch = torch.export.Dim("batch", min=1, max=1024) + batch = "batch" shapes = { "sample": {0: batch}, "timestep": {}, diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index 4fbcf2f2..03011e4b 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -541,7 +541,7 @@ def validate_model( summary["model_config"] = str( shrink_config( data["configuration"] - if type(data["configuration"]) + if type(data["configuration"]) is dict else data["configuration"].to_dict() ) ).replace(" ", "") From b23031b1f3547128885c2338984304e646a0884b Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 25 Jun 2025 08:16:07 +0200 Subject: [PATCH 3/9] add pick --- onnx_diagnostic/helpers/config_helper.py | 12 ++++++++++++ onnx_diagnostic/tasks/text_to_image.py | 10 +++++----- onnx_diagnostic/torch_models/hghub/model_inputs.py | 13 +++++++++---- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/onnx_diagnostic/helpers/config_helper.py b/onnx_diagnostic/helpers/config_helper.py index 938fbccc..3db0675d 100644 --- a/onnx_diagnostic/helpers/config_helper.py +++ b/onnx_diagnostic/helpers/config_helper.py @@ -69,6 +69,18 @@ def _pick(config, *atts, exceptions: Optional[Dict[str, Callable]] = None): raise AssertionError(f"Unable to find any of these {atts!r} in {config}") +def pick(config, name: str, default_value: Any) -> Any: + """ + Returns the vlaue of a attribute if config has it + otherwise the default value. + """ + if not config: + return default_value + if type(config) is dict: + return config.get(name, default_value) + return getattr(config, name, default_value) + + @functools.cache def config_class_from_architecture(arch: str, exc: bool = False) -> Optional[type]: """ diff --git a/onnx_diagnostic/tasks/text_to_image.py b/onnx_diagnostic/tasks/text_to_image.py index 916500e9..983d9bec 100644 --- a/onnx_diagnostic/tasks/text_to_image.py +++ b/onnx_diagnostic/tasks/text_to_image.py @@ -1,6 +1,6 @@ from typing import Any, Callable, Dict, Optional, Tuple import torch -from ..helpers.config_helper import update_config, check_hasattr +from ..helpers.config_helper import update_config, check_hasattr, pick __TASK__ = "text-to-image" @@ -82,10 +82,10 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: check_hasattr(config, "sample_size", "cross_attention_dim", "in_channels") kwargs = dict( batch_size=2, - sequence_length=config["in_channels"], + sequence_length=pick(config, "in_channels", 4), cache_length=77, - in_channels=config["in_channels"], - sample_size=config["sample_size"], - cross_attention_dim=config["cross_attention_dim"], + in_channels=pick(config, "in_channels", 4), + sample_size=pick(config, "sample_size", 32), + cross_attention_dim=pick(config, "cross_attention_dim", 64), ) return kwargs, get_inputs diff --git a/onnx_diagnostic/torch_models/hghub/model_inputs.py b/onnx_diagnostic/torch_models/hghub/model_inputs.py index 33426788..6db64040 100644 --- a/onnx_diagnostic/torch_models/hghub/model_inputs.py +++ b/onnx_diagnostic/torch_models/hghub/model_inputs.py @@ -200,10 +200,15 @@ def get_untrained_model_with_inputs( f"and use_pretrained=True." ) - if type(config) is dict: - model = cls_model(**config) - else: - model = cls_model(config) + try: + if type(config) is dict: + model = cls_model(**config) + else: + model = cls_model(config) + except RuntimeError as e: + raise RuntimeError( + f"Unable to instantiate class {cls_model.__name__} with\n{config}" + ) from e # input kwargs kwargs, fct = random_input_kwargs(config, task) From 87fa333b12871adfed3a6229f39331fa880f5bc8 Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 25 Jun 2025 10:03:01 +0200 Subject: [PATCH 4/9] fix issues --- _doc/conf.py | 2 + .../onnx_export_serialization.py | 189 ++++++++++++------ 2 files changed, 126 insertions(+), 65 deletions(-) diff --git a/_doc/conf.py b/_doc/conf.py index e8f010d5..8bc36bf0 100644 --- a/_doc/conf.py +++ b/_doc/conf.py @@ -90,6 +90,7 @@ def linkcode_resolve(domain, info): "https://sdpython.github.io/doc/experimental-experiment/dev/", None, ), + "diffusers": ("https://huggingface.co/docs/diffusers/index", None), "matplotlib": ("https://matplotlib.org/stable/", None), "numpy": ("https://numpy.org/doc/stable", None), "onnx": ("https://onnx.ai/onnx/", None), @@ -104,6 +105,7 @@ def linkcode_resolve(domain, info): "sklearn": ("https://scikit-learn.org/stable/", None), "skl2onnx": ("https://onnx.ai/sklearn-onnx/", None), "torch": ("https://pytorch.org/docs/main/", None), + "transformers": ("https://huggingface.co/docs/transformers/index", None), } # Check intersphinx reference targets exist diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index ebbed95f..34341378 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -12,11 +12,34 @@ StaticCache, ) from transformers.modeling_outputs import BaseModelOutput + +try: + from diffusers.models.autoencoders.vae import DecoderOutput, EncoderOutput + from diffusers.models.unets.unet_1d import UNet1DOutput + from diffusers.models.unets.unet_2d import UNet2DOutput + from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput + from diffusers.models.unets.unet_3d_condition import UNet3DConditionOutput +except ImportError as e: + try: + import diffusers + except ImportError: + diffusers = None + DecoderOutput, EncoderOutput = None, None + UNet1DOutput, UNet2DOutput = None, None + UNet2DConditionOutput, UNet3DConditionOutput = None, None + if diffusers: + raise e + from ..helpers import string_type from ..helpers.cache_helper import make_static_cache PATCH_OF_PATCHES: Set[Any] = set() +WRONG_REGISTRATIONS: Dict[str, str] = { + DynamicCache: "4.50", + BaseModelOutput: None, + UNet2DConditionOutput: None, +} def register_class_serialization( @@ -40,10 +63,12 @@ def register_class_serialization( :return: registered or not """ if cls is not None and cls in torch.utils._pytree.SUPPORTED_NODES: + if verbose and cls is not None: + print(f"[register_class_serialization] already registered {cls.__name__}") return False if verbose: - print(f"[register_cache_serialization] register {cls}") + print(f"[register_class_serialization] ---------- register {cls.__name__}") torch.utils._pytree.register_pytree_node( cls, f_flatten, @@ -54,8 +79,8 @@ def register_class_serialization( if pv.Version(torch.__version__) < pv.Version("2.7"): if verbose: print( - f"[register_cache_serialization] " - f"register {cls} for torch=={torch.__version__}" + f"[register_class_serialization] " + f"---------- register {cls.__name__} for torch=={torch.__version__}" ) torch.fx._pytree.register_pytree_flatten_spec(cls, lambda x, _: f_flatten(x)[0]) @@ -77,6 +102,8 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: Registers many classes with :func:`register_class_serialization`. Returns information needed to undo the registration. """ + registration_functions = serialization_functions(verbose=verbose) + # DynamicCache serialization is different in transformers and does not # play way with torch.export.export. # see test test_export_dynamic_cache_cat with NOBYPASS=1 @@ -85,63 +112,44 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: # torch.fx._pytree.register_pytree_flatten_spec( # DynamicCache, _flatten_dynamic_cache_for_fx) # so we remove it anyway - if ( - DynamicCache in torch.utils._pytree.SUPPORTED_NODES - and DynamicCache not in PATCH_OF_PATCHES - # and pv.Version(torch.__version__) < pv.Version("2.7") - and pv.Version(transformers.__version__) >= pv.Version("4.50") - ): - if verbose: - print( - f"[_fix_registration] DynamicCache is unregistered and " - f"registered first for transformers=={transformers.__version__}" - ) - unregister(DynamicCache, verbose=verbose) - register_class_serialization( - DynamicCache, - flatten_dynamic_cache, - unflatten_dynamic_cache, - flatten_with_keys_dynamic_cache, - # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), - verbose=verbose, - ) - if verbose: - print("[_fix_registration] DynamicCache done.") - # To avoid doing it multiple times. - PATCH_OF_PATCHES.add(DynamicCache) - # BaseModelOutput serialization is incomplete. # It does not include dynamic shapes mapping. - if ( - BaseModelOutput in torch.utils._pytree.SUPPORTED_NODES - and BaseModelOutput not in PATCH_OF_PATCHES - ): - if verbose: - print( - f"[_fix_registration] BaseModelOutput is unregistered and " - f"registered first for transformers=={transformers.__version__}" + for cls, version in WRONG_REGISTRATIONS.items(): + if ( + cls in torch.utils._pytree.SUPPORTED_NODES + and cls not in PATCH_OF_PATCHES + # and pv.Version(torch.__version__) < pv.Version("2.7") + and ( + version is None or pv.Version(transformers.__version__) >= pv.Version(version) ) - unregister(BaseModelOutput, verbose=verbose) - register_class_serialization( - BaseModelOutput, - flatten_base_model_output, - unflatten_base_model_output, - flatten_with_keys_base_model_output, - verbose=verbose, - ) - if verbose: - print("[_fix_registration] BaseModelOutput done.") - - # To avoid doing it multiple times. - PATCH_OF_PATCHES.add(BaseModelOutput) - - return serialization_functions(verbose=verbose) - - -def serialization_functions(verbose: int = 0) -> Dict[str, Union[Callable, int]]: + ): + assert cls in registration_functions, ( + f"{cls} has no registration functions mapped to it, " + f"available {sorted(registration_functions)}" + ) + if verbose: + print( + f"[_fix_registration] {cls.__name__} is unregistered and " + f"registered first" + ) + unregister_class_serialization(cls, verbose=verbose) + registration_functions[cls](verbose=verbose) + if verbose: + print(f"[_fix_registration] {cls.__name__} done.") + # To avoid doing it multiple times. + PATCH_OF_PATCHES.add(cls) + + # classes with no registration at all. + done = {} + for k, v in registration_functions.items(): + done[k] = v(verbose=verbose) + return done + + +def serialization_functions(verbose: int = 0) -> Dict[type, Union[Callable[[], bool], int]]: """Returns the list of serialization functions.""" - return dict( - DynamicCache=register_class_serialization( + transformers_classes = { + DynamicCache: lambda verbose=verbose: register_class_serialization( DynamicCache, flatten_dynamic_cache, unflatten_dynamic_cache, @@ -149,45 +157,57 @@ def serialization_functions(verbose: int = 0) -> Dict[str, Union[Callable, int]] # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), verbose=verbose, ), - MambaCache=register_class_serialization( + MambaCache: lambda verbose=verbose: register_class_serialization( MambaCache, flatten_mamba_cache, unflatten_mamba_cache, flatten_with_keys_mamba_cache, verbose=verbose, ), - EncoderDecoderCache=register_class_serialization( + EncoderDecoderCache: lambda verbose=verbose: register_class_serialization( EncoderDecoderCache, flatten_encoder_decoder_cache, unflatten_encoder_decoder_cache, flatten_with_keys_encoder_decoder_cache, verbose=verbose, ), - BaseModelOutput=register_class_serialization( + BaseModelOutput: lambda verbose=verbose: register_class_serialization( BaseModelOutput, flatten_base_model_output, unflatten_base_model_output, flatten_with_keys_base_model_output, verbose=verbose, ), - SlidingWindowCache=register_class_serialization( + SlidingWindowCache: lambda verbose=verbose: register_class_serialization( SlidingWindowCache, flatten_sliding_window_cache, unflatten_sliding_window_cache, flatten_with_keys_sliding_window_cache, verbose=verbose, ), - StaticCache=register_class_serialization( + StaticCache: lambda verbose=verbose: register_class_serialization( StaticCache, flatten_static_cache, unflatten_static_cache, flatten_with_keys_static_cache, verbose=verbose, ), - ) + } + if UNet2DConditionOutput: + diffusers_classes = { + UNet2DConditionOutput: lambda verbose=verbose: register_class_serialization( + UNet2DConditionOutput, + flatten_unet_2d_condition_output, + unflatten_unet_2d_condition_output, + flatten_with_keys_unet_2d_condition_output, + verbose=verbose, + ) + } + transformers_classes.update(diffusers_classes) + return transformers_classes -def unregister(cls: type, verbose: int = 0): +def unregister_class_serialization(cls: type, verbose: int = 0): """Undo the registration.""" # torch.utils._pytree._deregister_pytree_flatten_spec(cls) if cls in torch.fx._pytree.SUPPORTED_NODES: @@ -217,9 +237,10 @@ def unregister(cls: type, verbose: int = 0): def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0): """Undo all registrations.""" - for cls in [MambaCache, DynamicCache, EncoderDecoderCache, BaseModelOutput]: + cls_ensemble = {MambaCache, DynamicCache, EncoderDecoderCache, BaseModelOutput} | set(undo) + for cls in cls_ensemble: if undo.get(cls.__name__, False): - unregister(cls, verbose) + unregister_class_serialization(cls, verbose) ############ @@ -478,3 +499,41 @@ def unflatten_base_model_output( from python objects. """ return BaseModelOutput(**dict(zip(context, values))) + + +####################### +# UNet2DConditionOutput +####################### + + +def flatten_unet_2d_condition_output( + obj: UNet2DConditionOutput, +) -> Tuple[List[Any], torch.utils._pytree.Context]: + """ + Serializes a :class:`diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput` + with python objects. + """ + return list(obj.values()), list(obj.keys()) + + +def flatten_with_keys_unet_2d_condition_output( + obj: UNet2DConditionOutput, +) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: + """ + Serializes a :class:`diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput` + with python objects. + """ + values, context = flatten_unet_2d_condition_output(obj) + return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context + + +def unflatten_unet_2d_condition_output( + values: List[Any], + context: torch.utils._pytree.Context, + output_type=None, +) -> UNet2DConditionOutput: + """ + Restores a :class:`diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput` + from python objects. + """ + return UNet2DConditionOutput(**dict(zip(context, values))) From 7be71c150e18ed6eb9987f189b969ca7aea68ed1 Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 25 Jun 2025 10:39:38 +0200 Subject: [PATCH 5/9] fix issues --- _doc/conf.py | 7 ++- _doc/patches.rst | 2 +- _doc/status/patches_coverage.rst | 2 +- .../test_patch_serialization.py | 57 ++++++++++++++++++- onnx_diagnostic/helpers/config_helper.py | 2 +- .../onnx_export_serialization.py | 6 +- 6 files changed, 67 insertions(+), 9 deletions(-) diff --git a/_doc/conf.py b/_doc/conf.py index 8bc36bf0..c4b312c1 100644 --- a/_doc/conf.py +++ b/_doc/conf.py @@ -90,7 +90,8 @@ def linkcode_resolve(domain, info): "https://sdpython.github.io/doc/experimental-experiment/dev/", None, ), - "diffusers": ("https://huggingface.co/docs/diffusers/index", None), + # Not a sphinx documentation + # "diffusers": ("https://huggingface.co/docs/diffusers/index", None), "matplotlib": ("https://matplotlib.org/stable/", None), "numpy": ("https://numpy.org/doc/stable", None), "onnx": ("https://onnx.ai/onnx/", None), @@ -105,7 +106,8 @@ def linkcode_resolve(domain, info): "sklearn": ("https://scikit-learn.org/stable/", None), "skl2onnx": ("https://onnx.ai/sklearn-onnx/", None), "torch": ("https://pytorch.org/docs/main/", None), - "transformers": ("https://huggingface.co/docs/transformers/index", None), + # Not a sphinx documentation + # "transformers": ("https://huggingface.co/docs/transformers/index", None), } # Check intersphinx reference targets exist @@ -118,6 +120,7 @@ def linkcode_resolve(domain, info): ("py:class", "True"), ("py:class", "Argument"), ("py:class", "default=sklearn.utils.metadata_routing.UNCHANGED"), + ("py:class", "diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput"), ("py:class", "ModelProto"), ("py:class", "Model"), ("py:class", "Module"), diff --git a/_doc/patches.rst b/_doc/patches.rst index 7b6cf622..bb589431 100644 --- a/_doc/patches.rst +++ b/_doc/patches.rst @@ -113,7 +113,7 @@ Here is the list of supported caches: import onnx_diagnostic.torch_export_patches.onnx_export_serialization as p - print("\n".join(sorted(p.serialization_functions()))) + print("\n".join(sorted(t.__name__ for t in p.serialization_functions()))) .. _l-control-flow-rewriting: diff --git a/_doc/status/patches_coverage.rst b/_doc/status/patches_coverage.rst index f35c066b..1a3ac8b7 100644 --- a/_doc/status/patches_coverage.rst +++ b/_doc/status/patches_coverage.rst @@ -14,7 +14,7 @@ The following code shows the list of serialized classes in transformers. import onnx_diagnostic.torch_export_patches.onnx_export_serialization as p - print('\n'.join(sorted(p.serialization_functions()))) + print('\n'.join(sorted(t.__name__ for t in p.serialization_functions()))) Patched Classes =============== diff --git a/_unittests/ut_torch_export_patches/test_patch_serialization.py b/_unittests/ut_torch_export_patches/test_patch_serialization.py index 851627a5..4bdb1eca 100644 --- a/_unittests/ut_torch_export_patches/test_patch_serialization.py +++ b/_unittests/ut_torch_export_patches/test_patch_serialization.py @@ -1,7 +1,12 @@ import unittest import torch from transformers.modeling_outputs import BaseModelOutput -from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_torch +from onnx_diagnostic.ext_test_case import ( + ExtTestCase, + ignore_warnings, + requires_torch, + requires_diffusers, +) from onnx_diagnostic.helpers.cache_helper import ( make_encoder_decoder_cache, make_dynamic_cache, @@ -212,6 +217,56 @@ def test_sliding_window_cache_flatten(self): self.string_type(cache2, with_shape=True, with_min_max=True), ) + @ignore_warnings(UserWarning) + @requires_diffusers("0.30") + def test_unet_2d_condition_output(self): + from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput + + bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4))) + self.assertEqual(bo.__class__.__name__, "UNet2DConditionOutput") + with torch_export_patches(): + # internal function + bo2 = torch_deepcopy([bo]) + self.assertIsInstance(bo2, list) + self.assertEqual(bo2[0].__class__.__name__, "UNet2DConditionOutput") + self.assertEqualAny([bo], bo2) + self.assertEqual( + "UNet2DConditionOutput(sample:T1s4x4x4)", + self.string_type(bo, with_shape=True), + ) + + # serialization + flat, _spec = torch.utils._pytree.tree_flatten(bo) + self.assertEqual( + "#1[T1s4x4x4]", + self.string_type(flat, with_shape=True), + ) + bo2 = torch.utils._pytree.tree_unflatten(flat, _spec) + self.assertEqual( + self.string_type(bo, with_shape=True, with_min_max=True), + self.string_type(bo2, with_shape=True, with_min_max=True), + ) + + # flatten_unflatten + flat, _spec = torch.utils._pytree.tree_flatten(bo) + unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True) + self.assertIsInstance(unflat, dict) + self.assertEqual(list(unflat), ["sample"]) + + # export + class Model(torch.nn.Module): + def forward(self, cache): + return cache.sample[0] + + bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4))) + model = Model() + model(bo) + DYN = torch.export.Dim.DYNAMIC + ds = [{0: DYN}] + + with torch_export_patches(): + torch.export.export(model, (bo,), dynamic_shapes=(ds,)) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/helpers/config_helper.py b/onnx_diagnostic/helpers/config_helper.py index 3db0675d..bd4f987d 100644 --- a/onnx_diagnostic/helpers/config_helper.py +++ b/onnx_diagnostic/helpers/config_helper.py @@ -71,7 +71,7 @@ def _pick(config, *atts, exceptions: Optional[Dict[str, Callable]] = None): def pick(config, name: str, default_value: Any) -> Any: """ - Returns the vlaue of a attribute if config has it + Returns the value of a attribute if config has it otherwise the default value. """ if not config: diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index 34341378..7a4e8272 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -35,7 +35,7 @@ PATCH_OF_PATCHES: Set[Any] = set() -WRONG_REGISTRATIONS: Dict[str, str] = { +WRONG_REGISTRATIONS: Dict[str, Optional[str]] = { DynamicCache: "4.50", BaseModelOutput: None, UNet2DConditionOutput: None, @@ -125,7 +125,7 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: ): assert cls in registration_functions, ( f"{cls} has no registration functions mapped to it, " - f"available {sorted(registration_functions)}" + f"available options are {list(registration_functions)}" ) if verbose: print( @@ -146,7 +146,7 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: return done -def serialization_functions(verbose: int = 0) -> Dict[type, Union[Callable[[], bool], int]]: +def serialization_functions(verbose: int = 0) -> Dict[type, Union[Callable[[int], bool], int]]: """Returns the list of serialization functions.""" transformers_classes = { DynamicCache: lambda verbose=verbose: register_class_serialization( From 118f7d067a3d0a19084fcfcb670531be26edb0b7 Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 25 Jun 2025 11:26:55 +0200 Subject: [PATCH 6/9] type --- .../torch_export_patches/onnx_export_serialization.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index 7a4e8272..0c7c1f98 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -1,5 +1,5 @@ import pprint -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple import packaging.version as pv import optree import torch @@ -133,7 +133,7 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: f"registered first" ) unregister_class_serialization(cls, verbose=verbose) - registration_functions[cls](verbose=verbose) + registration_functions[cls](verbose=verbose) # type: ignore[arg-type] if verbose: print(f"[_fix_registration] {cls.__name__} done.") # To avoid doing it multiple times. @@ -142,11 +142,11 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: # classes with no registration at all. done = {} for k, v in registration_functions.items(): - done[k] = v(verbose=verbose) + done[k] = v(verbose=verbose) # type: ignore[arg-type] return done -def serialization_functions(verbose: int = 0) -> Dict[type, Union[Callable[[int], bool], int]]: +def serialization_functions(verbose: int = 0) -> Dict[type, Callable[[int], bool]]: """Returns the list of serialization functions.""" transformers_classes = { DynamicCache: lambda verbose=verbose: register_class_serialization( From 58b8f640c59e4b3e5689c0a9a89ab6cd76345d25 Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 25 Jun 2025 11:33:10 +0200 Subject: [PATCH 7/9] mypy --- .../torch_export_patches/onnx_export_serialization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index 0c7c1f98..6f0f14a2 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -133,7 +133,7 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: f"registered first" ) unregister_class_serialization(cls, verbose=verbose) - registration_functions[cls](verbose=verbose) # type: ignore[arg-type] + registration_functions[cls](verbose=verbose) # type: ignore[arg-type, call-arg] if verbose: print(f"[_fix_registration] {cls.__name__} done.") # To avoid doing it multiple times. @@ -142,7 +142,7 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: # classes with no registration at all. done = {} for k, v in registration_functions.items(): - done[k] = v(verbose=verbose) # type: ignore[arg-type] + done[k] = v(verbose=verbose) # type: ignore[arg-type, call-arg] return done From b37ec29a63089f472f7a59e318fd026148c34bed Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 25 Jun 2025 12:21:09 +0200 Subject: [PATCH 8/9] refactor --- _doc/api/torch_export_patches/index.rst | 1 + .../onnx_export_serialization_impl.rst | 7 + .../test_patch_serialization.py | 7 + .../onnx_export_serialization.py | 348 ++---------------- .../onnx_export_serialization_impl.py | 314 ++++++++++++++++ 5 files changed, 363 insertions(+), 314 deletions(-) create mode 100644 _doc/api/torch_export_patches/onnx_export_serialization_impl.rst create mode 100644 onnx_diagnostic/torch_export_patches/onnx_export_serialization_impl.py diff --git a/_doc/api/torch_export_patches/index.rst b/_doc/api/torch_export_patches/index.rst index 4493ffec..23f136a5 100644 --- a/_doc/api/torch_export_patches/index.rst +++ b/_doc/api/torch_export_patches/index.rst @@ -8,6 +8,7 @@ onnx_diagnostic.torch_export_patches eval/index onnx_export_errors onnx_export_serialization + onnx_export_serialization_impl patches/index patch_expressions patch_inputs diff --git a/_doc/api/torch_export_patches/onnx_export_serialization_impl.rst b/_doc/api/torch_export_patches/onnx_export_serialization_impl.rst new file mode 100644 index 00000000..22a94d27 --- /dev/null +++ b/_doc/api/torch_export_patches/onnx_export_serialization_impl.rst @@ -0,0 +1,7 @@ + +onnx_diagnostic.torch_export_patches.onnx_export_serialization_impl +=================================================================== + +.. automodule:: onnx_diagnostic.torch_export_patches.onnx_export_serialization_impl + :members: + :no-undoc-members: diff --git a/_unittests/ut_torch_export_patches/test_patch_serialization.py b/_unittests/ut_torch_export_patches/test_patch_serialization.py index 4bdb1eca..fc7beaa4 100644 --- a/_unittests/ut_torch_export_patches/test_patch_serialization.py +++ b/_unittests/ut_torch_export_patches/test_patch_serialization.py @@ -224,6 +224,13 @@ def test_unet_2d_condition_output(self): bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4))) self.assertEqual(bo.__class__.__name__, "UNet2DConditionOutput") + bo2 = torch_deepcopy([bo]) + self.assertIsInstance(bo2, list) + self.assertEqual( + "UNet2DConditionOutput(sample:T1s4x4x4)", + self.string_type(bo, with_shape=True), + ) + with torch_export_patches(): # internal function bo2 = torch_deepcopy([bo]) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index 6f0f14a2..f71a0429 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -1,5 +1,5 @@ import pprint -from typing import Any, Callable, Dict, List, Optional, Set, Tuple +from typing import Any, Callable, Dict, Optional, Set import packaging.version as pv import optree import torch @@ -31,7 +31,6 @@ raise e from ..helpers import string_type -from ..helpers.cache_helper import make_static_cache PATCH_OF_PATCHES: Set[Any] = set() @@ -148,6 +147,27 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: def serialization_functions(verbose: int = 0) -> Dict[type, Callable[[int], bool]]: """Returns the list of serialization functions.""" + from .onnx_export_serialization_impl import ( + SUPPORTED_DATACLASSES, + _lower_name_with_, + __dict__ as all_functions, + flatten_dynamic_cache, + unflatten_dynamic_cache, + flatten_with_keys_dynamic_cache, + flatten_mamba_cache, + unflatten_mamba_cache, + flatten_with_keys_mamba_cache, + flatten_encoder_decoder_cache, + unflatten_encoder_decoder_cache, + flatten_with_keys_encoder_decoder_cache, + flatten_sliding_window_cache, + unflatten_sliding_window_cache, + flatten_with_keys_sliding_window_cache, + flatten_static_cache, + unflatten_static_cache, + flatten_with_keys_static_cache, + ) + transformers_classes = { DynamicCache: lambda verbose=verbose: register_class_serialization( DynamicCache, @@ -171,13 +191,6 @@ def serialization_functions(verbose: int = 0) -> Dict[type, Callable[[int], bool flatten_with_keys_encoder_decoder_cache, verbose=verbose, ), - BaseModelOutput: lambda verbose=verbose: register_class_serialization( - BaseModelOutput, - flatten_base_model_output, - unflatten_base_model_output, - flatten_with_keys_base_model_output, - verbose=verbose, - ), SlidingWindowCache: lambda verbose=verbose: register_class_serialization( SlidingWindowCache, flatten_sliding_window_cache, @@ -193,17 +206,20 @@ def serialization_functions(verbose: int = 0) -> Dict[type, Callable[[int], bool verbose=verbose, ), } - if UNet2DConditionOutput: - diffusers_classes = { - UNet2DConditionOutput: lambda verbose=verbose: register_class_serialization( - UNet2DConditionOutput, - flatten_unet_2d_condition_output, - unflatten_unet_2d_condition_output, - flatten_with_keys_unet_2d_condition_output, + for cls in SUPPORTED_DATACLASSES: + lname = _lower_name_with_(cls.__name__) + assert ( + f"flatten_{lname}" in all_functions + ), f"Unable to find function 'flatten_{lname}' in {sorted(all_functions)}" + transformers_classes[cls] = ( + lambda verbose=verbose, _ln=lname, cls=cls, _al=all_functions: register_class_serialization( + cls, + _al[f"flatten_{_ln}"], + _al[f"unflatten_{_ln}"], + _al[f"flatten_with_keys_{_ln}"], verbose=verbose, ) - } - transformers_classes.update(diffusers_classes) + ) return transformers_classes @@ -241,299 +257,3 @@ def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0): for cls in cls_ensemble: if undo.get(cls.__name__, False): unregister_class_serialization(cls, verbose) - - -############ -# MambaCache -############ - - -def flatten_mamba_cache( - mamba_cache: MambaCache, -) -> Tuple[List[Any], torch.utils._pytree.Context]: - """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" - flat = [ - ("conv_states", mamba_cache.conv_states), - ("ssm_states", mamba_cache.ssm_states), - ] - return [f[1] for f in flat], [f[0] for f in flat] - - -def unflatten_mamba_cache( - values: List[Any], context: torch.utils._pytree.Context, output_type=None -) -> MambaCache: - """Restores a :class:`transformers.cache_utils.MambaCache` from python objects.""" - conv_states, ssm_states = values - - class _config: - def __init__(self): - if isinstance(conv_states, list): - self.intermediate_size = conv_states[0].shape[1] - self.state_size = ssm_states[0].shape[2] - self.conv_kernel = conv_states[0].shape[2] - self.num_hidden_layers = len(conv_states) - else: - self.intermediate_size = conv_states.shape[2] - self.state_size = ssm_states.shape[3] - self.conv_kernel = conv_states.shape[3] - self.num_hidden_layers = conv_states.shape[0] - - cache = MambaCache( - _config(), - max_batch_size=1, - dtype=values[-1][0].dtype, - device="cpu" if values[-1][0].get_device() < 0 else "cuda", - ) - values = dict(zip(context, values)) - for k, v in values.items(): - setattr(cache, k, v) - return cache - - -def flatten_with_keys_mamba_cache(cache: MambaCache) -> Tuple[ - List[Tuple[torch.utils._pytree.KeyEntry, Any]], - torch.utils._pytree.Context, -]: - """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" - values, context = flatten_mamba_cache(cache) - return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context - - -############## -# DynamicCache -############## - - -def flatten_dynamic_cache( - dynamic_cache: DynamicCache, -) -> Tuple[List[Any], torch.utils._pytree.Context]: - """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" - if hasattr(transformers.cache_utils, "_flatten_dynamic_cache"): - return transformers.cache_utils._flatten_dynamic_cache(dynamic_cache) - flat = [("key_cache", dynamic_cache.key_cache), ("value_cache", dynamic_cache.value_cache)] - return [f[1] for f in flat], [f[0] for f in flat] - - -def flatten_with_keys_dynamic_cache( - dynamic_cache: DynamicCache, -) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: - """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" - if hasattr(transformers.cache_utils, "_flatten_with_keys_dynamic_cache"): - return transformers.cache_utils._flatten_with_keys_dynamic_cache(dynamic_cache) - values, context = flatten_dynamic_cache(dynamic_cache) - return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context - - -def unflatten_dynamic_cache( - values: List[Any], context: torch.utils._pytree.Context, output_type=None -) -> DynamicCache: - """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects.""" - if hasattr(transformers.cache_utils, "_unflatten_dynamic_cache"): - assert output_type is None, f"output_type={output_type} not supported" - return transformers.cache_utils._unflatten_dynamic_cache(values, context) - - cache = transformers.cache_utils.DynamicCache() - values = dict(zip(context, values)) - for k, v in values.items(): - setattr(cache, k, v) - return cache - - -############# -# StaticCache -############# - - -def flatten_static_cache( - cache: StaticCache, -) -> Tuple[List[Any], torch.utils._pytree.Context]: - """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects.""" - flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)] - return [f[1] for f in flat], [f[0] for f in flat] - - -def flatten_with_keys_static_cache( - cache: StaticCache, -) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: - """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects.""" - values, context = flatten_static_cache(cache) - return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context - - -def unflatten_static_cache( - values: List[Any], context: torch.utils._pytree.Context, output_type=None -) -> StaticCache: - """Restores a :class:`transformers.cache_utils.StaticCache` from python objects.""" - return make_static_cache(list(zip(values[0], values[1]))) - - -#################### -# SlidingWindowCache -#################### - - -def flatten_sliding_window_cache( - cache: SlidingWindowCache, -) -> Tuple[List[Any], torch.utils._pytree.Context]: - """ - Serializes a :class:`transformers.cache_utils.SlidingWindowCache` - with python objects. - """ - flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)] - return [f[1] for f in flat], [f[0] for f in flat] - - -def flatten_with_keys_sliding_window_cache( - cache: SlidingWindowCache, -) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: - """ - Serializes a :class:`transformers.cache_utils.SlidingWindowCache` - with python objects. - """ - values, context = flatten_sliding_window_cache(cache) - return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context - - -def unflatten_sliding_window_cache( - values: List[Any], context: torch.utils._pytree.Context, output_type=None -) -> SlidingWindowCache: - """Restores a :class:`transformers.cache_utils.SlidingWindowCache` from python objects.""" - key_cache, value_cache = values - - class _config: - def __init__(self): - self.head_dim = key_cache[0].shape[-1] - self.num_attention_heads = key_cache[0].shape[1] - self.num_hidden_layers = len(key_cache) - self.sliding_window = key_cache[0].shape[2] - - cache = SlidingWindowCache( - _config(), - max_batch_size=key_cache[0].shape[0], - max_cache_len=key_cache[0].shape[2], # sligding window - device=key_cache[0].device, - dtype=key_cache[0].dtype, - ) - - values = dict(zip(context, values)) - for k, v in values.items(): - setattr(cache, k, v) - return cache - - -##################### -# EncoderDecoderCache -##################### - - -def flatten_encoder_decoder_cache( - ec_cache: EncoderDecoderCache, -) -> Tuple[List[Any], torch.utils._pytree.Context]: - """ - Serializes a :class:`transformers.cache_utils.EncoderDecoderCache` - with python objects. - """ - dictionary = { - "self_attention_cache": ec_cache.self_attention_cache, - "cross_attention_cache": ec_cache.cross_attention_cache, - } - return torch.utils._pytree._dict_flatten(dictionary) - - -def flatten_with_keys_encoder_decoder_cache(ec_cache: EncoderDecoderCache) -> Tuple[ - List[Tuple[torch.utils._pytree.KeyEntry, Any]], - torch.utils._pytree.Context, -]: - """ - Serializes a :class:`transformers.cache_utils.EncoderDecoderCache` - with python objects. - """ - dictionary = { - "self_attention_cache": ec_cache.self_attention_cache, - "cross_attention_cache": ec_cache.cross_attention_cache, - } - return torch.utils._pytree._dict_flatten_with_keys(dictionary) - - -def unflatten_encoder_decoder_cache( - values: List[Any], context: torch.utils._pytree.Context, output_type=None -) -> EncoderDecoderCache: - """Restores a :class:`transformers.cache_utils.EncoderDecoderCache` from python objects.""" - dictionary = torch.utils._pytree._dict_unflatten(values, context) - return EncoderDecoderCache(**dictionary) - - -################# -# BaseModelOutput -################# - - -def flatten_base_model_output( - bo: BaseModelOutput, -) -> Tuple[List[Any], torch.utils._pytree.Context]: - """ - Serializes a :class:`transformers.modeling_outputs.BaseModelOutput` - with python objects. - """ - return list(bo.values()), list(bo.keys()) - - -def flatten_with_keys_base_model_output( - bo: BaseModelOutput, -) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: - """ - Serializes a :class:`transformers.modeling_outputs.BaseModelOutput` - with python objects. - """ - values, context = flatten_base_model_output(bo) - return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context - - -def unflatten_base_model_output( - values: List[Any], - context: torch.utils._pytree.Context, - output_type=None, -) -> BaseModelOutput: - """ - Restores a :class:`transformers.modeling_outputs.BaseModelOutput` - from python objects. - """ - return BaseModelOutput(**dict(zip(context, values))) - - -####################### -# UNet2DConditionOutput -####################### - - -def flatten_unet_2d_condition_output( - obj: UNet2DConditionOutput, -) -> Tuple[List[Any], torch.utils._pytree.Context]: - """ - Serializes a :class:`diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput` - with python objects. - """ - return list(obj.values()), list(obj.keys()) - - -def flatten_with_keys_unet_2d_condition_output( - obj: UNet2DConditionOutput, -) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: - """ - Serializes a :class:`diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput` - with python objects. - """ - values, context = flatten_unet_2d_condition_output(obj) - return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context - - -def unflatten_unet_2d_condition_output( - values: List[Any], - context: torch.utils._pytree.Context, - output_type=None, -) -> UNet2DConditionOutput: - """ - Restores a :class:`diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput` - from python objects. - """ - return UNet2DConditionOutput(**dict(zip(context, values))) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization_impl.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization_impl.py new file mode 100644 index 00000000..4f7de62a --- /dev/null +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization_impl.py @@ -0,0 +1,314 @@ +import re +from typing import Any, Callable, List, Tuple +import torch +import transformers +from transformers.cache_utils import ( + DynamicCache, + MambaCache, + EncoderDecoderCache, + SlidingWindowCache, + StaticCache, +) +from transformers.modeling_outputs import BaseModelOutput + +try: + from diffusers.models.autoencoders.vae import DecoderOutput, EncoderOutput + from diffusers.models.unets.unet_1d import UNet1DOutput + from diffusers.models.unets.unet_2d import UNet2DOutput + from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput + from diffusers.models.unets.unet_3d_condition import UNet3DConditionOutput +except ImportError as e: + try: + import diffusers + except ImportError: + diffusers = None + DecoderOutput, EncoderOutput = None, None + UNet1DOutput, UNet2DOutput = None, None + UNet2DConditionOutput, UNet3DConditionOutput = None, None + if diffusers: + raise e + +from ..helpers.cache_helper import make_static_cache + + +SUPPORTED_DATACLASSES = set() + +############ +# MambaCache +############ + + +def flatten_mamba_cache( + mamba_cache: MambaCache, +) -> Tuple[List[Any], torch.utils._pytree.Context]: + """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" + flat = [ + ("conv_states", mamba_cache.conv_states), + ("ssm_states", mamba_cache.ssm_states), + ] + return [f[1] for f in flat], [f[0] for f in flat] + + +def unflatten_mamba_cache( + values: List[Any], context: torch.utils._pytree.Context, output_type=None +) -> MambaCache: + """Restores a :class:`transformers.cache_utils.MambaCache` from python objects.""" + conv_states, ssm_states = values + + class _config: + def __init__(self): + if isinstance(conv_states, list): + self.intermediate_size = conv_states[0].shape[1] + self.state_size = ssm_states[0].shape[2] + self.conv_kernel = conv_states[0].shape[2] + self.num_hidden_layers = len(conv_states) + else: + self.intermediate_size = conv_states.shape[2] + self.state_size = ssm_states.shape[3] + self.conv_kernel = conv_states.shape[3] + self.num_hidden_layers = conv_states.shape[0] + + cache = MambaCache( + _config(), + max_batch_size=1, + dtype=values[-1][0].dtype, + device="cpu" if values[-1][0].get_device() < 0 else "cuda", + ) + values = dict(zip(context, values)) + for k, v in values.items(): + setattr(cache, k, v) + return cache + + +def flatten_with_keys_mamba_cache(cache: MambaCache) -> Tuple[ + List[Tuple[torch.utils._pytree.KeyEntry, Any]], + torch.utils._pytree.Context, +]: + """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" + values, context = flatten_mamba_cache(cache) + return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context + + +############## +# DynamicCache +############## + + +def flatten_dynamic_cache( + dynamic_cache: DynamicCache, +) -> Tuple[List[Any], torch.utils._pytree.Context]: + """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" + if hasattr(transformers.cache_utils, "_flatten_dynamic_cache"): + return transformers.cache_utils._flatten_dynamic_cache(dynamic_cache) + flat = [("key_cache", dynamic_cache.key_cache), ("value_cache", dynamic_cache.value_cache)] + return [f[1] for f in flat], [f[0] for f in flat] + + +def flatten_with_keys_dynamic_cache( + dynamic_cache: DynamicCache, +) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: + """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" + if hasattr(transformers.cache_utils, "_flatten_with_keys_dynamic_cache"): + return transformers.cache_utils._flatten_with_keys_dynamic_cache(dynamic_cache) + values, context = flatten_dynamic_cache(dynamic_cache) + return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context + + +def unflatten_dynamic_cache( + values: List[Any], context: torch.utils._pytree.Context, output_type=None +) -> DynamicCache: + """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects.""" + if hasattr(transformers.cache_utils, "_unflatten_dynamic_cache"): + assert output_type is None, f"output_type={output_type} not supported" + return transformers.cache_utils._unflatten_dynamic_cache(values, context) + + cache = transformers.cache_utils.DynamicCache() + values = dict(zip(context, values)) + for k, v in values.items(): + setattr(cache, k, v) + return cache + + +############# +# StaticCache +############# + + +def flatten_static_cache( + cache: StaticCache, +) -> Tuple[List[Any], torch.utils._pytree.Context]: + """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects.""" + flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)] + return [f[1] for f in flat], [f[0] for f in flat] + + +def flatten_with_keys_static_cache( + cache: StaticCache, +) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: + """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects.""" + values, context = flatten_static_cache(cache) + return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context + + +def unflatten_static_cache( + values: List[Any], context: torch.utils._pytree.Context, output_type=None +) -> StaticCache: + """Restores a :class:`transformers.cache_utils.StaticCache` from python objects.""" + return make_static_cache(list(zip(values[0], values[1]))) + + +#################### +# SlidingWindowCache +#################### + + +def flatten_sliding_window_cache( + cache: SlidingWindowCache, +) -> Tuple[List[Any], torch.utils._pytree.Context]: + """ + Serializes a :class:`transformers.cache_utils.SlidingWindowCache` + with python objects. + """ + flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)] + return [f[1] for f in flat], [f[0] for f in flat] + + +def flatten_with_keys_sliding_window_cache( + cache: SlidingWindowCache, +) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: + """ + Serializes a :class:`transformers.cache_utils.SlidingWindowCache` + with python objects. + """ + values, context = flatten_sliding_window_cache(cache) + return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context + + +def unflatten_sliding_window_cache( + values: List[Any], context: torch.utils._pytree.Context, output_type=None +) -> SlidingWindowCache: + """Restores a :class:`transformers.cache_utils.SlidingWindowCache` from python objects.""" + key_cache, value_cache = values + + class _config: + def __init__(self): + self.head_dim = key_cache[0].shape[-1] + self.num_attention_heads = key_cache[0].shape[1] + self.num_hidden_layers = len(key_cache) + self.sliding_window = key_cache[0].shape[2] + + cache = SlidingWindowCache( + _config(), + max_batch_size=key_cache[0].shape[0], + max_cache_len=key_cache[0].shape[2], # sligding window + device=key_cache[0].device, + dtype=key_cache[0].dtype, + ) + + values = dict(zip(context, values)) + for k, v in values.items(): + setattr(cache, k, v) + return cache + + +##################### +# EncoderDecoderCache +##################### + + +def flatten_encoder_decoder_cache( + ec_cache: EncoderDecoderCache, +) -> Tuple[List[Any], torch.utils._pytree.Context]: + """ + Serializes a :class:`transformers.cache_utils.EncoderDecoderCache` + with python objects. + """ + dictionary = { + "self_attention_cache": ec_cache.self_attention_cache, + "cross_attention_cache": ec_cache.cross_attention_cache, + } + return torch.utils._pytree._dict_flatten(dictionary) + + +def flatten_with_keys_encoder_decoder_cache(ec_cache: EncoderDecoderCache) -> Tuple[ + List[Tuple[torch.utils._pytree.KeyEntry, Any]], + torch.utils._pytree.Context, +]: + """ + Serializes a :class:`transformers.cache_utils.EncoderDecoderCache` + with python objects. + """ + dictionary = { + "self_attention_cache": ec_cache.self_attention_cache, + "cross_attention_cache": ec_cache.cross_attention_cache, + } + return torch.utils._pytree._dict_flatten_with_keys(dictionary) + + +def unflatten_encoder_decoder_cache( + values: List[Any], context: torch.utils._pytree.Context, output_type=None +) -> EncoderDecoderCache: + """Restores a :class:`transformers.cache_utils.EncoderDecoderCache` from python objects.""" + dictionary = torch.utils._pytree._dict_unflatten(values, context) + return EncoderDecoderCache(**dictionary) + + +############# +# dataclasses +############# + + +def _lower_name_with_(name): + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + + +def make_serialization_function_for_dataclass(cls) -> Tuple[Callable, Callable, Callable]: + """ + Automatically creates serialization function for a class decorated with + ``dataclasses.dataclass``. + """ + + def flatten_cls(obj: cls) -> Tuple[List[Any], torch.utils._pytree.Context]: + """Serializes a ``%s`` with python objects.""" + return list(obj.values()), list(obj.keys()) + + def flatten_with_keys_cls( + obj: cls, + ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: + """Serializes a ``%s`` with python objects with keys.""" + values, context = list(obj.values()), list(obj.keys()) + return [ + (torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values) + ], context + + def unflatten_cls( + values: List[Any], context: torch.utils._pytree.Context, output_type=None + ) -> cls: + """Restores an instance of ``%s`` from python objects.""" + return cls(**dict(zip(context, values))) + + name = _lower_name_with_(cls.__name__) + flatten_cls.__name__ = f"flatten_{name}" + flatten_with_keys_cls.__name__ = f"flatten_with_keys_{name}" + unflatten_cls.__name__ = f"unflatten_{name}" + flatten_cls.__doc__ = flatten_cls.__doc__ % cls.__name__ + flatten_with_keys_cls.__doc__ = flatten_with_keys_cls.__doc__ % cls.__name__ + unflatten_cls.__doc__ = unflatten_cls.__doc__ % cls.__name__ + SUPPORTED_DATACLASSES.add(cls) + return flatten_cls, flatten_with_keys_cls, unflatten_cls + + +( + flatten_base_model_output, + flatten_with_keys_base_model_output, + unflatten_base_model_output, +) = make_serialization_function_for_dataclass(BaseModelOutput) + + +if UNet2DConditionOutput is not None: + ( + flatten_u_net2_d_condition_output, + flatten_with_keys_u_net2_d_condition_output, + unflatten_u_net2_d_condition_output, + ) = make_serialization_function_for_dataclass(UNet2DConditionOutput) From 2821a35c858d4ca2e0be460904173e12db7b9163 Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 25 Jun 2025 12:27:25 +0200 Subject: [PATCH 9/9] refactor --- .../onnx_export_serialization.py | 29 +++---------------- .../onnx_export_serialization_impl.py | 15 +++++++++- 2 files changed, 18 insertions(+), 26 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index f71a0429..4f216367 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -11,34 +11,11 @@ SlidingWindowCache, StaticCache, ) -from transformers.modeling_outputs import BaseModelOutput - -try: - from diffusers.models.autoencoders.vae import DecoderOutput, EncoderOutput - from diffusers.models.unets.unet_1d import UNet1DOutput - from diffusers.models.unets.unet_2d import UNet2DOutput - from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput - from diffusers.models.unets.unet_3d_condition import UNet3DConditionOutput -except ImportError as e: - try: - import diffusers - except ImportError: - diffusers = None - DecoderOutput, EncoderOutput = None, None - UNet1DOutput, UNet2DOutput = None, None - UNet2DConditionOutput, UNet3DConditionOutput = None, None - if diffusers: - raise e from ..helpers import string_type PATCH_OF_PATCHES: Set[Any] = set() -WRONG_REGISTRATIONS: Dict[str, Optional[str]] = { - DynamicCache: "4.50", - BaseModelOutput: None, - UNet2DConditionOutput: None, -} def register_class_serialization( @@ -101,6 +78,8 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: Registers many classes with :func:`register_class_serialization`. Returns information needed to undo the registration. """ + from .onnx_export_serialization_impl import WRONG_REGISTRATIONS + registration_functions = serialization_functions(verbose=verbose) # DynamicCache serialization is different in transformers and does not @@ -212,7 +191,7 @@ def serialization_functions(verbose: int = 0) -> Dict[type, Callable[[int], bool f"flatten_{lname}" in all_functions ), f"Unable to find function 'flatten_{lname}' in {sorted(all_functions)}" transformers_classes[cls] = ( - lambda verbose=verbose, _ln=lname, cls=cls, _al=all_functions: register_class_serialization( + lambda verbose=verbose, _ln=lname, cls=cls, _al=all_functions: register_class_serialization( # noqa: E501 cls, _al[f"flatten_{_ln}"], _al[f"unflatten_{_ln}"], @@ -253,7 +232,7 @@ def unregister_class_serialization(cls: type, verbose: int = 0): def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0): """Undo all registrations.""" - cls_ensemble = {MambaCache, DynamicCache, EncoderDecoderCache, BaseModelOutput} | set(undo) + cls_ensemble = {MambaCache, DynamicCache, EncoderDecoderCache} | set(undo) for cls in cls_ensemble: if undo.get(cls.__name__, False): unregister_class_serialization(cls, verbose) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization_impl.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization_impl.py index 4f7de62a..32a7415c 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization_impl.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization_impl.py @@ -1,5 +1,5 @@ import re -from typing import Any, Callable, List, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import transformers from transformers.cache_utils import ( @@ -31,7 +31,20 @@ from ..helpers.cache_helper import make_static_cache +def _make_wrong_registrations() -> Dict[str, Optional[str]]: + res = { + DynamicCache: "4.50", + BaseModelOutput: None, + } + for c in [UNet2DConditionOutput]: + if c is not None: + res[c] = None + return res + + SUPPORTED_DATACLASSES = set() +WRONG_REGISTRATIONS = _make_wrong_registrations() + ############ # MambaCache