1
1
import argparse
2
+ from dataclasses import dataclass
2
3
import itertools
3
4
import math
4
5
import os
5
6
import random
6
7
from pathlib import Path
7
8
from typing import Optional , Dict , NamedTuple , List
8
9
from argparse import Namespace
10
+ from random import sample , random
9
11
10
12
import numpy as np
11
13
import torch
15
17
from torch .utils .data import Dataset
16
18
17
19
import PIL
20
+ from PIL .Image import Image as Img
18
21
from accelerate import Accelerator
19
22
from accelerate .logging import get_logger
20
23
from accelerate .utils import set_seed
29
32
from packaging import version
30
33
from PIL import Image
31
34
from torchvision import transforms
35
+ from torchvision .transforms .functional import hflip
32
36
from tqdm .auto import tqdm
33
37
from transformers import CLIPFeatureExtractor , CLIPTextModel , CLIPTokenizer , PreTrainedTokenizer
34
38
@@ -120,6 +124,9 @@ def parse_args():
120
124
parser .add_argument (
121
125
"--initialize_rest_random" , action = "store_true" , help = "Initialize rest of the placeholder tokens with random."
122
126
)
127
+ parser .add_argument (
128
+ "--cache_images" , action = "store_true" , help = "Cache tensors of every image we load. You should only do this if your training set is small."
129
+ )
123
130
parser .add_argument (
124
131
"--save_steps" ,
125
132
type = int ,
@@ -341,20 +348,26 @@ def parse_args():
341
348
"a large painting in the style of {}" ,
342
349
]
343
350
351
+ @dataclass
352
+ class Variations :
353
+ original : Tensor
354
+ flipped : Tensor
344
355
345
356
class TextualInversionDataset (Dataset ):
357
+ cache : Dict [str , Variations ]
346
358
def __init__ (
347
359
self ,
348
360
data_root ,
349
361
tokenizer ,
350
362
learnable_property = "object" , # [object, style]
351
363
size = 512 ,
352
364
repeats = 100 ,
353
- interpolation = "bicubic " ,
365
+ interpolation = "lanczos " ,
354
366
flip_p = 0.5 ,
355
367
set = "train" ,
356
368
placeholder_token = "*" ,
357
369
center_crop = False ,
370
+ cache_enabled = False ,
358
371
):
359
372
self .data_root = data_root
360
373
self .tokenizer = tokenizer
@@ -364,7 +377,10 @@ def __init__(
364
377
self .center_crop = center_crop
365
378
self .flip_p = flip_p
366
379
367
- self .image_paths = [os .path .join (self .data_root , file_path ) for file_path in os .listdir (self .data_root )]
380
+ self .image_paths = [
381
+ os .path .join (self .data_root , file_path )
382
+ for file_path in os .listdir (self .data_root ) if file_path .endswith ('.png' ) or file_path .endswith ('.jpg' )
383
+ ]
368
384
369
385
self .num_images = len (self .image_paths )
370
386
self ._length = self .num_images
@@ -380,20 +396,74 @@ def __init__(
380
396
}[interpolation ]
381
397
382
398
self .templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
383
- self .flip_transform = transforms .RandomHorizontalFlip (p = self .flip_p )
399
+ # we have so few images and so much VRAM that we should prefer to retain tensors rather than redo work
400
+ self .cache = {}
401
+ self .cache_enabled = cache_enabled
384
402
385
403
def __len__ (self ):
386
404
return self ._length
387
405
388
406
def __getitem__ (self , i ):
389
407
example = {}
390
- image = Image .open (self .image_paths [i % self .num_images ])
391
-
392
- if not image .mode == "RGB" :
393
- image = image .convert ("RGB" )
408
+ image_path : str = self .image_paths [i % self .num_images ]
409
+ stem : str = Path (image_path ).stem
394
410
395
411
placeholder_string = self .placeholder_token
396
- text = random .choice (self .templates ).format (placeholder_string )
412
+ # text = random.choice(self.templates).format(placeholder_string)
413
+ def describe_placeholder () -> str :
414
+ if random () < 0.3 :
415
+ return self .placeholder_token
416
+ return placeholder_string
417
+
418
+ def describe_subject (character : str ) -> str :
419
+ placeholder : str = describe_placeholder ()
420
+ if random () < 0.3 :
421
+ return f"photo of { placeholder } "
422
+ return f"photo of { character } { placeholder } "
423
+
424
+ def make_prompt (character : str , general_labels : List [str ], sitting = True , on_floor = True ) -> str :
425
+ even_more_labels = [* general_labels , '1girl' ]
426
+ if sitting :
427
+ even_more_labels .append ('sitting' )
428
+ if on_floor :
429
+ even_more_labels .append ('on floor' )
430
+ subject : str = describe_subject (character )
431
+ # we can use this for dropout but I think dropout is undesirable
432
+ # label_count = randrange(0, len(even_more_labels))
433
+ label_count = len (even_more_labels )
434
+ if label_count == 0 :
435
+ return subject
436
+ labels = sample (even_more_labels , label_count )
437
+ joined = ', ' .join (labels )
438
+ return f"{ subject } with { joined } "
439
+
440
+ match stem :
441
+ case 'koishi' :
442
+ text = make_prompt ('komeiji koishi' , ['green hair' , 'black footwear' , 'medium hair' , 'blue eyes' , 'yellow jacket' , 'green skirt' 'hat' , 'black headwear' , 'smile' , 'touhou project' ])
443
+ case 'flandre' :
444
+ text = make_prompt ('flandre scarlet' , ['fang' , 'red footwear' , 'slit pupils' , 'medium hair' , 'blonde hair' , 'red eyes' , 'red dress' , 'mob cap' , 'smile' , 'short sleeves' , 'yellow ascot' , 'touhou project' ])
445
+ case 'sanae' :
446
+ text = make_prompt ('kochiya sanae' , ['green hair' , 'blue footwear' , 'long hair' , 'green eyes' , 'white dress' , 'blue skirt' , 'frog hair ornament' , 'snake hair ornament' , 'smile' , 'standing' , 'touhou project' ])
447
+ case 'sanaestand' :
448
+ text = make_prompt ('kochiya sanae' , ['green hair' , 'blue footwear' , 'long hair' , 'green eyes' , 'white dress' , 'blue skirt' , 'frog hair ornament' , 'snake hair ornament' , 'smile' , 'touhou project' ], sitting = False )
449
+ case 'tenshi' :
450
+ text = make_prompt ('hinanawi tenshi' , ['blue hair' , 'brown footwear' , 'slit pupils' , 'very long hair' , 'red eyes' , 'white dress' , 'blue skirt' , 'hat' , 'black headwear' , 'smile' , 'touhou project' ])
451
+ case 'youmu' :
452
+ text = make_prompt ('konpaku youmu' , ['silver hair' , 'black footwear' , 'medium hair' , 'slit pupils' , 'green eyes' , 'green dress' , 'sleeveless dress' , 'white sleeves' , 'black ribbon' , 'hair ribbon' , 'unhappy' , 'touhou project' ])
453
+ case 'yuyuko' :
454
+ text = make_prompt ('saigyouji yuyuko' , ['pink hair' , 'black footwear' , 'medium hair' , 'pink eyes' , 'wide sleeves' , 'long sleeves' , 'blue dress' , 'mob cap' , 'touhou project' ])
455
+ case 'nagisa' :
456
+ text = make_prompt ('furukawa nagisa' , ['brown hair' , 'brown footwear' , 'medium hair' , 'brown eyes' , 'smile' , 'school briefcase' , 'blue skirt' , 'yellow jacket' , 'antenna hair' , 'dango' , 'clannad' ])
457
+ case 'teto' :
458
+ text = make_prompt ('kasane teto' , ['pink hair' , 'red footwear' , 'red eyes' , 'medium hair' , 'detached sleeves' , 'twin drills' , 'drill hair' , 'grey dress' , 'smile' , 'vocaloid' ])
459
+ case 'korone' :
460
+ text = make_prompt ('inugami korone' , ['yellow jacket' , 'blue footwear' , 'long hair' , 'white dress' , 'brown hair' , 'brown eyes' , 'on chair' , 'hairclip' , 'uwu' , 'hololive' ], on_floor = False )
461
+ case 'kudo' :
462
+ text = make_prompt ('kudryavka noumi' , ['fang' , 'black footwear' , 'very long hair' , 'white hat' , 'white cape' , 'silver hair' , 'grey skirt' , 'blue eyes' , 'smile' , 'little busters!' ])
463
+ case 'patchouli' :
464
+ text = make_prompt ('patchouli knowledge' , ['mob cap' , 'pink footwear' , 'long hair' , 'slit pupils' , 'striped dress' , 'pink dress' , 'purple hair' , 'ribbons in hair' , 'unhappy' , 'touhou project' ])
465
+ case _:
466
+ text = f"photo of { placeholder_string } "
397
467
398
468
example ["input_ids" ] = self .tokenizer (
399
469
text ,
@@ -403,25 +473,44 @@ def __getitem__(self, i):
403
473
return_tensors = "pt" ,
404
474
).input_ids [0 ]
405
475
406
- # default to score-sde preprocessing
407
- img = np .array (image ).astype (np .uint8 )
408
-
409
- if self .center_crop :
410
- crop = min (img .shape [0 ], img .shape [1 ])
411
- h , w , = (
412
- img .shape [0 ],
413
- img .shape [1 ],
476
+ if stem not in self .cache :
477
+ image = Image .open (image_path )
478
+ if not image .mode == "RGB" :
479
+ image = image .convert ("RGB" )
480
+
481
+ # default to score-sde preprocessing
482
+ img = np .array (image ).astype (np .uint8 )
483
+
484
+ if self .center_crop :
485
+ crop = min (img .shape [0 ], img .shape [1 ])
486
+ h , w , = (
487
+ img .shape [0 ],
488
+ img .shape [1 ],
489
+ )
490
+ img = img [(h - crop ) // 2 : (h + crop ) // 2 , (w - crop ) // 2 : (w + crop ) // 2 ]
491
+
492
+ image : Img = Image .fromarray (img )
493
+ image : Img = image .resize ((self .size , self .size ), resample = self .interpolation )
494
+
495
+ flipped : Img = hflip (image )
496
+
497
+ def pil_to_latents (image : Img ) -> Tensor :
498
+ image = np .array (image ).astype (np .uint8 )
499
+ image = (image / 127.5 - 1.0 ).astype (np .float32 )
500
+ latents : Tensor = torch .from_numpy (image ).permute (2 , 0 , 1 )
501
+ return latents
502
+
503
+ image , flipped = (pil_to_latents (variation ) for variation in (image , flipped ))
504
+
505
+ self .cache [stem ] = Variations (
506
+ original = image ,
507
+ flipped = flipped ,
414
508
)
415
- img = img [(h - crop ) // 2 : (h + crop ) // 2 , (w - crop ) // 2 : (w + crop ) // 2 ]
416
-
417
- image = Image .fromarray (img )
418
- image = image .resize ((self .size , self .size ), resample = self .interpolation )
419
-
420
- image = self .flip_transform (image )
421
- image = np .array (image ).astype (np .uint8 )
422
- image = (image / 127.5 - 1.0 ).astype (np .float32 )
509
+ variations = self .cache [stem ]
510
+ flip = torch .rand (1 ) < self .flip_p
511
+ image = variations .flipped if flip else variations .original
423
512
424
- example ["pixel_values" ] = torch . from_numpy ( image ). permute ( 2 , 0 , 1 )
513
+ example ["pixel_values" ] = image
425
514
return example
426
515
427
516
0 commit comments