Skip to content

Commit f0dba33

Browse files
[training] show how metadata stuff should be incorporated in training scripts. (#11707)
* show how metadata stuff should be incorporated in training scripts. * typing * fix --------- Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
1 parent d1db4f8 commit f0dba33

File tree

3 files changed

+70
-5
lines changed

3 files changed

+70
-5
lines changed

examples/dreambooth/test_dreambooth_lora_flux.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import json
1617
import logging
1718
import os
1819
import sys
1920
import tempfile
2021

2122
import safetensors
2223

24+
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
25+
2326

2427
sys.path.append("..")
2528
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
@@ -234,3 +237,45 @@ def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_mult
234237
run_command(self._launch_args + resume_run_args)
235238

236239
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
240+
241+
def test_dreambooth_lora_with_metadata(self):
242+
# Use a `lora_alpha` that is different from `rank`.
243+
lora_alpha = 8
244+
rank = 4
245+
with tempfile.TemporaryDirectory() as tmpdir:
246+
test_args = f"""
247+
{self.script_path}
248+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
249+
--instance_data_dir {self.instance_data_dir}
250+
--instance_prompt {self.instance_prompt}
251+
--resolution 64
252+
--train_batch_size 1
253+
--gradient_accumulation_steps 1
254+
--max_train_steps 2
255+
--lora_alpha={lora_alpha}
256+
--rank={rank}
257+
--learning_rate 5.0e-04
258+
--scale_lr
259+
--lr_scheduler constant
260+
--lr_warmup_steps 0
261+
--output_dir {tmpdir}
262+
""".split()
263+
264+
run_command(self._launch_args + test_args)
265+
# save_pretrained smoke test
266+
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
267+
self.assertTrue(os.path.isfile(state_dict_file))
268+
269+
# Check if the metadata was properly serialized.
270+
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
271+
metadata = f.metadata() or {}
272+
273+
metadata.pop("format", None)
274+
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
275+
if raw:
276+
raw = json.loads(raw)
277+
278+
loaded_lora_alpha = raw["transformer.lora_alpha"]
279+
self.assertTrue(loaded_lora_alpha == lora_alpha)
280+
loaded_lora_rank = raw["transformer.r"]
281+
self.assertTrue(loaded_lora_rank == rank)

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727

2828
import numpy as np
2929
import torch
30-
import torch.utils.checkpoint
3130
import transformers
3231
from accelerate import Accelerator
3332
from accelerate.logging import get_logger
@@ -53,6 +52,7 @@
5352
)
5453
from diffusers.optimization import get_scheduler
5554
from diffusers.training_utils import (
55+
_collate_lora_metadata,
5656
_set_state_dict_into_text_encoder,
5757
cast_training_params,
5858
compute_density_for_timestep_sampling,
@@ -358,7 +358,12 @@ def parse_args(input_args=None):
358358
default=4,
359359
help=("The dimension of the LoRA update matrices."),
360360
)
361-
361+
parser.add_argument(
362+
"--lora_alpha",
363+
type=int,
364+
default=4,
365+
help="LoRA alpha to be used for additional scaling.",
366+
)
362367
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
363368

364369
parser.add_argument(
@@ -1238,7 +1243,7 @@ def main(args):
12381243
# now we will add new LoRA weights the transformer layers
12391244
transformer_lora_config = LoraConfig(
12401245
r=args.rank,
1241-
lora_alpha=args.rank,
1246+
lora_alpha=args.lora_alpha,
12421247
lora_dropout=args.lora_dropout,
12431248
init_lora_weights="gaussian",
12441249
target_modules=target_modules,
@@ -1247,7 +1252,7 @@ def main(args):
12471252
if args.train_text_encoder:
12481253
text_lora_config = LoraConfig(
12491254
r=args.rank,
1250-
lora_alpha=args.rank,
1255+
lora_alpha=args.lora_alpha,
12511256
lora_dropout=args.lora_dropout,
12521257
init_lora_weights="gaussian",
12531258
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
@@ -1264,12 +1269,14 @@ def save_model_hook(models, weights, output_dir):
12641269
if accelerator.is_main_process:
12651270
transformer_lora_layers_to_save = None
12661271
text_encoder_one_lora_layers_to_save = None
1267-
1272+
modules_to_save = {}
12681273
for model in models:
12691274
if isinstance(model, type(unwrap_model(transformer))):
12701275
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
1276+
modules_to_save["transformer"] = model
12711277
elif isinstance(model, type(unwrap_model(text_encoder_one))):
12721278
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
1279+
modules_to_save["text_encoder"] = model
12731280
else:
12741281
raise ValueError(f"unexpected save model: {model.__class__}")
12751282

@@ -1280,6 +1287,7 @@ def save_model_hook(models, weights, output_dir):
12801287
output_dir,
12811288
transformer_lora_layers=transformer_lora_layers_to_save,
12821289
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
1290+
**_collate_lora_metadata(modules_to_save),
12831291
)
12841292

12851293
def load_model_hook(models, input_dir):
@@ -1889,23 +1897,27 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18891897
# Save the lora layers
18901898
accelerator.wait_for_everyone()
18911899
if accelerator.is_main_process:
1900+
modules_to_save = {}
18921901
transformer = unwrap_model(transformer)
18931902
if args.upcast_before_saving:
18941903
transformer.to(torch.float32)
18951904
else:
18961905
transformer = transformer.to(weight_dtype)
18971906
transformer_lora_layers = get_peft_model_state_dict(transformer)
1907+
modules_to_save["transformer"] = transformer
18981908

18991909
if args.train_text_encoder:
19001910
text_encoder_one = unwrap_model(text_encoder_one)
19011911
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
1912+
modules_to_save["text_encoder"] = text_encoder_one
19021913
else:
19031914
text_encoder_lora_layers = None
19041915

19051916
FluxPipeline.save_lora_weights(
19061917
save_directory=args.output_dir,
19071918
transformer_lora_layers=transformer_lora_layers,
19081919
text_encoder_lora_layers=text_encoder_lora_layers,
1920+
**_collate_lora_metadata(modules_to_save),
19091921
)
19101922

19111923
# Final inference

src/diffusers/training_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,14 @@ def _set_state_dict_into_text_encoder(
247247
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
248248

249249

250+
def _collate_lora_metadata(modules_to_save: Dict[str, torch.nn.Module]) -> Dict[str, Any]:
251+
metadatas = {}
252+
for module_name, module in modules_to_save.items():
253+
if module is not None:
254+
metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict()
255+
return metadatas
256+
257+
250258
def compute_density_for_timestep_sampling(
251259
weighting_scheme: str,
252260
batch_size: int,

0 commit comments

Comments
 (0)