Skip to content

Support for text-to-image #165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.7.2
+++++

* :pr:`165`: support for task text-to-image
* :pr:`162`: improves graphs rendering for historical data

0.7.1
Expand Down
1 change: 1 addition & 0 deletions _doc/api/tasks/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Or:
summarization
text_classification
text_generation
text_to_image
text2text_generation
zero_shot_image_classification

Expand Down
7 changes: 7 additions & 0 deletions _doc/api/tasks/text_to_image.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

onnx_diagnostic.tasks.text_to_image
===================================

.. automodule:: onnx_diagnostic.tasks.text_to_image
:members:
:no-undoc-members:
1 change: 1 addition & 0 deletions _doc/api/torch_export_patches/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ onnx_diagnostic.torch_export_patches
eval/index
onnx_export_errors
onnx_export_serialization
onnx_export_serialization_impl
patches/index
patch_expressions
patch_inputs
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

onnx_diagnostic.torch_export_patches.onnx_export_serialization_impl
===================================================================

.. automodule:: onnx_diagnostic.torch_export_patches.onnx_export_serialization_impl
:members:
:no-undoc-members:
5 changes: 5 additions & 0 deletions _doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def linkcode_resolve(domain, info):
"https://sdpython.github.io/doc/experimental-experiment/dev/",
None,
),
# Not a sphinx documentation
# "diffusers": ("https://huggingface.co/docs/diffusers/index", None),
"matplotlib": ("https://matplotlib.org/stable/", None),
"numpy": ("https://numpy.org/doc/stable", None),
"onnx": ("https://onnx.ai/onnx/", None),
Expand All @@ -104,6 +106,8 @@ def linkcode_resolve(domain, info):
"sklearn": ("https://scikit-learn.org/stable/", None),
"skl2onnx": ("https://onnx.ai/sklearn-onnx/", None),
"torch": ("https://pytorch.org/docs/main/", None),
# Not a sphinx documentation
# "transformers": ("https://huggingface.co/docs/transformers/index", None),
}

# Check intersphinx reference targets exist
Expand All @@ -116,6 +120,7 @@ def linkcode_resolve(domain, info):
("py:class", "True"),
("py:class", "Argument"),
("py:class", "default=sklearn.utils.metadata_routing.UNCHANGED"),
("py:class", "diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput"),
("py:class", "ModelProto"),
("py:class", "Model"),
("py:class", "Module"),
Expand Down
2 changes: 1 addition & 1 deletion _doc/patches.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ Here is the list of supported caches:

import onnx_diagnostic.torch_export_patches.onnx_export_serialization as p

print("\n".join(sorted(p.serialization_functions())))
print("\n".join(sorted(t.__name__ for t in p.serialization_functions())))

.. _l-control-flow-rewriting:

Expand Down
2 changes: 1 addition & 1 deletion _doc/status/patches_coverage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ The following code shows the list of serialized classes in transformers.

import onnx_diagnostic.torch_export_patches.onnx_export_serialization as p

print('\n'.join(sorted(p.serialization_functions())))
print('\n'.join(sorted(t.__name__ for t in p.serialization_functions())))

Patched Classes
===============
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_tasks/test_tasks_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str


class TestTasks(ExtTestCase):
class TestTasksImageClassification(ExtTestCase):
@hide_stdout()
def test_image_classification(self):
mid = "hf-internal-testing/tiny-random-BeitForImageClassification"
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_tasks/test_tasks_image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str


class TestTasks(ExtTestCase):
class TestTasksImageTextToText(ExtTestCase):
@hide_stdout()
@requires_transformers("4.52")
@requires_torch("2.7.99")
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_tasks/test_tasks_object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str


class TestTasks(ExtTestCase):
class TestTasksObjectDetection(ExtTestCase):
@hide_stdout()
def test_object_detection(self):
mid = "hustvl/yolos-tiny"
Expand Down
35 changes: 35 additions & 0 deletions _unittests/ut_tasks/test_tasks_text_to_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import unittest
import torch
from onnx_diagnostic.ext_test_case import (
ExtTestCase,
hide_stdout,
requires_transformers,
requires_torch,
)
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str


class TestTasksTextToTimage(ExtTestCase):
@hide_stdout()
@requires_transformers("4.52")
@requires_torch("2.7.99")
def test_text_to_image(self):
mid = "diffusers/tiny-torch-full-checker"
data = get_untrained_model_with_inputs(
mid, verbose=1, add_second_input=True, subfolder="unet"
)
self.assertEqual(data["task"], "text-to-image")
self.assertIn((data["size"], data["n_weights"]), [(5708048, 1427012)])
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
model(**inputs)
model(**data["inputs2"])
with torch_export_patches(patch_transformers=True, verbose=10, stop_if_static=1):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)


if __name__ == "__main__":
unittest.main(verbosity=2)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str


class TestTasks(ExtTestCase):
class TestTasksZeroShotImageClassification(ExtTestCase):
@requires_torch("2.7.99")
@hide_stdout()
def test_zero_shot_image_classification(self):
Expand Down
21 changes: 21 additions & 0 deletions _unittests/ut_tasks/try_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,27 @@ def test_object_detection(self):
f"{round(score.item(), 3)} at location {box}"
)

@never_test()
def test_text_to_image(self):
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k test_text_to_image
import torch
from diffusers import StableDiffusionPipeline

model_id = "diffusers/tiny-torch-full-checker" # "stabilityai/stable-diffusion-2"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(
"cuda"
)

prompt = "a photo of an astronaut riding a horse on mars and on jupyter"
print()
with steal_forward(pipe.unet, with_min_max=True):
image = pipe(prompt).images[0]
print("-- output", self.string_type(image, with_shape=True, with_min_max=True))
# stolen forward for class UNet2DConditionModel -- iteration 44
# sample=T10s2x4x96x96[-3.7734375,4.359375:A-0.043463995395642184]
# time_step=T7s=101
# encoder_hidden_states:T10s2x77x1024[-6.58203125,13.0234375:A-0.16780663634440257]


if __name__ == "__main__":
unittest.main(verbosity=2)
64 changes: 63 additions & 1 deletion _unittests/ut_torch_export_patches/test_patch_serialization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import unittest
import torch
from transformers.modeling_outputs import BaseModelOutput
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_torch
from onnx_diagnostic.ext_test_case import (
ExtTestCase,
ignore_warnings,
requires_torch,
requires_diffusers,
)
from onnx_diagnostic.helpers.cache_helper import (
make_encoder_decoder_cache,
make_dynamic_cache,
Expand Down Expand Up @@ -212,6 +217,63 @@ def test_sliding_window_cache_flatten(self):
self.string_type(cache2, with_shape=True, with_min_max=True),
)

@ignore_warnings(UserWarning)
@requires_diffusers("0.30")
def test_unet_2d_condition_output(self):
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput

bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4)))
self.assertEqual(bo.__class__.__name__, "UNet2DConditionOutput")
bo2 = torch_deepcopy([bo])
self.assertIsInstance(bo2, list)
self.assertEqual(
"UNet2DConditionOutput(sample:T1s4x4x4)",
self.string_type(bo, with_shape=True),
)

with torch_export_patches():
# internal function
bo2 = torch_deepcopy([bo])
self.assertIsInstance(bo2, list)
self.assertEqual(bo2[0].__class__.__name__, "UNet2DConditionOutput")
self.assertEqualAny([bo], bo2)
self.assertEqual(
"UNet2DConditionOutput(sample:T1s4x4x4)",
self.string_type(bo, with_shape=True),
)

# serialization
flat, _spec = torch.utils._pytree.tree_flatten(bo)
self.assertEqual(
"#1[T1s4x4x4]",
self.string_type(flat, with_shape=True),
)
bo2 = torch.utils._pytree.tree_unflatten(flat, _spec)
self.assertEqual(
self.string_type(bo, with_shape=True, with_min_max=True),
self.string_type(bo2, with_shape=True, with_min_max=True),
)

# flatten_unflatten
flat, _spec = torch.utils._pytree.tree_flatten(bo)
unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True)
self.assertIsInstance(unflat, dict)
self.assertEqual(list(unflat), ["sample"])

# export
class Model(torch.nn.Module):
def forward(self, cache):
return cache.sample[0]

bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4)))
model = Model()
model(bo)
DYN = torch.export.Dim.DYNAMIC
ds = [{0: DYN}]

with torch_export_patches():
torch.export.export(model, (bo,), dynamic_shapes=(ds,))


if __name__ == "__main__":
unittest.main(verbosity=2)
17 changes: 16 additions & 1 deletion onnx_diagnostic/helpers/config_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ def update_config(config: Any, mkwargs: Dict[str, Any]):
else:
update_config(getattr(config, k), v)
continue
setattr(config, k, v)
if type(config) is dict:
config[k] = v
else:
setattr(config, k, v)


def _pick(config, *atts, exceptions: Optional[Dict[str, Callable]] = None):
Expand All @@ -66,6 +69,18 @@ def _pick(config, *atts, exceptions: Optional[Dict[str, Callable]] = None):
raise AssertionError(f"Unable to find any of these {atts!r} in {config}")


def pick(config, name: str, default_value: Any) -> Any:
"""
Returns the value of a attribute if config has it
otherwise the default value.
"""
if not config:
return default_value
if type(config) is dict:
return config.get(name, default_value)
return getattr(config, name, default_value)


@functools.cache
def config_class_from_architecture(arch: str, exc: bool = False) -> Optional[type]:
"""
Expand Down
2 changes: 2 additions & 0 deletions onnx_diagnostic/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
summarization,
text_classification,
text_generation,
text_to_image,
text2text_generation,
zero_shot_image_classification,
)
Expand All @@ -27,6 +28,7 @@
summarization,
text_classification,
text_generation,
text_to_image,
text2text_generation,
zero_shot_image_classification,
]
Expand Down
Loading
Loading