Skip to content

Commit b79803f

Browse files
Allow remote code repo names to contain "." (#11652)
* allow loading from repo with dot in name * put new arg at the end to avoid breaking compatibility * add test for loading repo with dot in name --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent b0f7036 commit b79803f

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

src/diffusers/utils/dynamic_modules_utils.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,30 @@ def check_imports(filename):
154154
return get_relative_imports(filename)
155155

156156

157-
def get_class_in_module(class_name, module_path):
157+
def get_class_in_module(class_name, module_path, pretrained_model_name_or_path=None):
158158
"""
159159
Import a module on the cache directory for modules and extract a class from it.
160160
"""
161161
module_path = module_path.replace(os.path.sep, ".")
162-
module = importlib.import_module(module_path)
162+
try:
163+
module = importlib.import_module(module_path)
164+
except ModuleNotFoundError as e:
165+
# This can happen when the repo id contains ".", which Python's import machinery interprets as a directory
166+
# separator. We do a bit of monkey patching to detect and fix this case.
167+
if not (
168+
pretrained_model_name_or_path is not None
169+
and "." in pretrained_model_name_or_path
170+
and module_path.startswith("diffusers_modules")
171+
and pretrained_model_name_or_path.replace("/", "--") in module_path
172+
):
173+
raise e # We can't figure this one out, just reraise the original error
174+
175+
corrected_path = os.path.join(HF_MODULES_CACHE, module_path.replace(".", "/")) + ".py"
176+
corrected_path = corrected_path.replace(
177+
pretrained_model_name_or_path.replace("/", "--").replace(".", "/"),
178+
pretrained_model_name_or_path.replace("/", "--"),
179+
)
180+
module = importlib.machinery.SourceFileLoader(module_path, corrected_path).load_module()
163181

164182
if class_name is None:
165183
return find_pipeline_class(module)
@@ -454,4 +472,4 @@ def get_class_from_dynamic_module(
454472
revision=revision,
455473
local_files_only=local_files_only,
456474
)
457-
return get_class_in_module(class_name, final_module.replace(".py", ""))
475+
return get_class_in_module(class_name, final_module.replace(".py", ""), pretrained_model_name_or_path)

tests/pipelines/test_pipelines.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,21 @@ def test_remote_auto_custom_pipe(self):
11051105

11061106
assert images.shape == (1, 64, 64, 3)
11071107

1108+
def test_remote_custom_pipe_with_dot_in_name(self):
1109+
# make sure that trust remote code has to be passed
1110+
with self.assertRaises(ValueError):
1111+
pipeline = DiffusionPipeline.from_pretrained("akasharidas/ddpm-cifar10-32-dot.in.name")
1112+
1113+
pipeline = DiffusionPipeline.from_pretrained("akasharidas/ddpm-cifar10-32-dot.in.name", trust_remote_code=True)
1114+
1115+
assert pipeline.__class__.__name__ == "CustomPipeline"
1116+
1117+
pipeline = pipeline.to(torch_device)
1118+
images, output_str = pipeline(num_inference_steps=2, output_type="np")
1119+
1120+
assert images[0].shape == (1, 32, 32, 3)
1121+
assert output_str == "This is a test"
1122+
11081123
def test_local_custom_pipeline_repo(self):
11091124
local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline")
11101125
pipeline = DiffusionPipeline.from_pretrained(

0 commit comments

Comments
 (0)