27
27
28
28
import numpy as np
29
29
import torch
30
- import torch .utils .checkpoint
31
30
import transformers
32
31
from accelerate import Accelerator
33
32
from accelerate .logging import get_logger
53
52
)
54
53
from diffusers .optimization import get_scheduler
55
54
from diffusers .training_utils import (
55
+ _collate_lora_metadata ,
56
56
_set_state_dict_into_text_encoder ,
57
57
cast_training_params ,
58
58
compute_density_for_timestep_sampling ,
@@ -358,7 +358,12 @@ def parse_args(input_args=None):
358
358
default = 4 ,
359
359
help = ("The dimension of the LoRA update matrices." ),
360
360
)
361
-
361
+ parser .add_argument (
362
+ "--lora_alpha" ,
363
+ type = int ,
364
+ default = 4 ,
365
+ help = "LoRA alpha to be used for additional scaling." ,
366
+ )
362
367
parser .add_argument ("--lora_dropout" , type = float , default = 0.0 , help = "Dropout probability for LoRA layers" )
363
368
364
369
parser .add_argument (
@@ -1238,7 +1243,7 @@ def main(args):
1238
1243
# now we will add new LoRA weights the transformer layers
1239
1244
transformer_lora_config = LoraConfig (
1240
1245
r = args .rank ,
1241
- lora_alpha = args .rank ,
1246
+ lora_alpha = args .lora_alpha ,
1242
1247
lora_dropout = args .lora_dropout ,
1243
1248
init_lora_weights = "gaussian" ,
1244
1249
target_modules = target_modules ,
@@ -1247,7 +1252,7 @@ def main(args):
1247
1252
if args .train_text_encoder :
1248
1253
text_lora_config = LoraConfig (
1249
1254
r = args .rank ,
1250
- lora_alpha = args .rank ,
1255
+ lora_alpha = args .lora_alpha ,
1251
1256
lora_dropout = args .lora_dropout ,
1252
1257
init_lora_weights = "gaussian" ,
1253
1258
target_modules = ["q_proj" , "k_proj" , "v_proj" , "out_proj" ],
@@ -1264,12 +1269,14 @@ def save_model_hook(models, weights, output_dir):
1264
1269
if accelerator .is_main_process :
1265
1270
transformer_lora_layers_to_save = None
1266
1271
text_encoder_one_lora_layers_to_save = None
1267
-
1272
+ modules_to_save = {}
1268
1273
for model in models :
1269
1274
if isinstance (model , type (unwrap_model (transformer ))):
1270
1275
transformer_lora_layers_to_save = get_peft_model_state_dict (model )
1276
+ modules_to_save ["transformer" ] = model
1271
1277
elif isinstance (model , type (unwrap_model (text_encoder_one ))):
1272
1278
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict (model )
1279
+ modules_to_save ["text_encoder" ] = model
1273
1280
else :
1274
1281
raise ValueError (f"unexpected save model: { model .__class__ } " )
1275
1282
@@ -1280,6 +1287,7 @@ def save_model_hook(models, weights, output_dir):
1280
1287
output_dir ,
1281
1288
transformer_lora_layers = transformer_lora_layers_to_save ,
1282
1289
text_encoder_lora_layers = text_encoder_one_lora_layers_to_save ,
1290
+ ** _collate_lora_metadata (modules_to_save ),
1283
1291
)
1284
1292
1285
1293
def load_model_hook (models , input_dir ):
@@ -1889,23 +1897,27 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1889
1897
# Save the lora layers
1890
1898
accelerator .wait_for_everyone ()
1891
1899
if accelerator .is_main_process :
1900
+ modules_to_save = {}
1892
1901
transformer = unwrap_model (transformer )
1893
1902
if args .upcast_before_saving :
1894
1903
transformer .to (torch .float32 )
1895
1904
else :
1896
1905
transformer = transformer .to (weight_dtype )
1897
1906
transformer_lora_layers = get_peft_model_state_dict (transformer )
1907
+ modules_to_save ["transformer" ] = transformer
1898
1908
1899
1909
if args .train_text_encoder :
1900
1910
text_encoder_one = unwrap_model (text_encoder_one )
1901
1911
text_encoder_lora_layers = get_peft_model_state_dict (text_encoder_one .to (torch .float32 ))
1912
+ modules_to_save ["text_encoder" ] = text_encoder_one
1902
1913
else :
1903
1914
text_encoder_lora_layers = None
1904
1915
1905
1916
FluxPipeline .save_lora_weights (
1906
1917
save_directory = args .output_dir ,
1907
1918
transformer_lora_layers = transformer_lora_layers ,
1908
1919
text_encoder_lora_layers = text_encoder_lora_layers ,
1920
+ ** _collate_lora_metadata (modules_to_save ),
1909
1921
)
1910
1922
1911
1923
# Final inference
0 commit comments