Skip to content

Commit cae3ca6

Browse files
committed
add image caching, fumo prompts
1 parent e738172 commit cae3ca6

File tree

2 files changed

+116
-26
lines changed

2 files changed

+116
-26
lines changed

.vscode/launch.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@
4646
"--initializer_token",
4747
"plush",
4848
"--train_data_dir",
49-
"--only_save_embeds"
5049
"/Users/birch/plush/512_ti",
50+
"--only_save_embeds",
51+
"--cache_images"
5152
],
5253
"console": "integratedTerminal",
5354
"justMyCode": true,

scripts/ti_train.py

Lines changed: 114 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import argparse
2+
from dataclasses import dataclass
23
import itertools
34
import math
45
import os
56
import random
67
from pathlib import Path
78
from typing import Optional, Dict, NamedTuple, List
89
from argparse import Namespace
10+
from random import sample, random
911

1012
import numpy as np
1113
import torch
@@ -15,6 +17,7 @@
1517
from torch.utils.data import Dataset
1618

1719
import PIL
20+
from PIL.Image import Image as Img
1821
from accelerate import Accelerator
1922
from accelerate.logging import get_logger
2023
from accelerate.utils import set_seed
@@ -29,6 +32,7 @@
2932
from packaging import version
3033
from PIL import Image
3134
from torchvision import transforms
35+
from torchvision.transforms.functional import hflip
3236
from tqdm.auto import tqdm
3337
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, PreTrainedTokenizer
3438

@@ -120,6 +124,9 @@ def parse_args():
120124
parser.add_argument(
121125
"--initialize_rest_random", action="store_true", help="Initialize rest of the placeholder tokens with random."
122126
)
127+
parser.add_argument(
128+
"--cache_images", action="store_true", help="Cache tensors of every image we load. You should only do this if your training set is small."
129+
)
123130
parser.add_argument(
124131
"--save_steps",
125132
type=int,
@@ -341,20 +348,26 @@ def parse_args():
341348
"a large painting in the style of {}",
342349
]
343350

351+
@dataclass
352+
class Variations:
353+
original: Tensor
354+
flipped: Tensor
344355

345356
class TextualInversionDataset(Dataset):
357+
cache: Dict[str, Variations]
346358
def __init__(
347359
self,
348360
data_root,
349361
tokenizer,
350362
learnable_property="object", # [object, style]
351363
size=512,
352364
repeats=100,
353-
interpolation="bicubic",
365+
interpolation="lanczos",
354366
flip_p=0.5,
355367
set="train",
356368
placeholder_token="*",
357369
center_crop=False,
370+
cache_enabled=False,
358371
):
359372
self.data_root = data_root
360373
self.tokenizer = tokenizer
@@ -364,7 +377,10 @@ def __init__(
364377
self.center_crop = center_crop
365378
self.flip_p = flip_p
366379

367-
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
380+
self.image_paths = [
381+
os.path.join(self.data_root, file_path)
382+
for file_path in os.listdir(self.data_root) if file_path.endswith('.png') or file_path.endswith('.jpg')
383+
]
368384

369385
self.num_images = len(self.image_paths)
370386
self._length = self.num_images
@@ -380,20 +396,74 @@ def __init__(
380396
}[interpolation]
381397

382398
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
383-
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
399+
# we have so few images and so much VRAM that we should prefer to retain tensors rather than redo work
400+
self.cache = {}
401+
self.cache_enabled = cache_enabled
384402

385403
def __len__(self):
386404
return self._length
387405

388406
def __getitem__(self, i):
389407
example = {}
390-
image = Image.open(self.image_paths[i % self.num_images])
391-
392-
if not image.mode == "RGB":
393-
image = image.convert("RGB")
408+
image_path: str = self.image_paths[i % self.num_images]
409+
stem: str = Path(image_path).stem
394410

395411
placeholder_string = self.placeholder_token
396-
text = random.choice(self.templates).format(placeholder_string)
412+
# text = random.choice(self.templates).format(placeholder_string)
413+
def describe_placeholder() -> str:
414+
if random() < 0.3:
415+
return self.placeholder_token
416+
return placeholder_string
417+
418+
def describe_subject(character: str) -> str:
419+
placeholder: str = describe_placeholder()
420+
if random() < 0.3:
421+
return f"photo of {placeholder}"
422+
return f"photo of {character} {placeholder}"
423+
424+
def make_prompt(character: str, general_labels: List[str], sitting=True, on_floor=True) -> str:
425+
even_more_labels = [*general_labels, '1girl']
426+
if sitting:
427+
even_more_labels.append('sitting')
428+
if on_floor:
429+
even_more_labels.append('on floor')
430+
subject: str = describe_subject(character)
431+
# we can use this for dropout but I think dropout is undesirable
432+
# label_count = randrange(0, len(even_more_labels))
433+
label_count = len(even_more_labels)
434+
if label_count == 0:
435+
return subject
436+
labels = sample(even_more_labels, label_count)
437+
joined = ', '.join(labels)
438+
return f"{subject} with {joined}"
439+
440+
match stem:
441+
case 'koishi':
442+
text = make_prompt('komeiji koishi', ['green hair', 'black footwear', 'medium hair', 'blue eyes', 'yellow jacket', 'green skirt' 'hat', 'black headwear', 'smile', 'touhou project'])
443+
case 'flandre':
444+
text = make_prompt('flandre scarlet', ['fang', 'red footwear', 'slit pupils', 'medium hair', 'blonde hair', 'red eyes', 'red dress', 'mob cap', 'smile', 'short sleeves', 'yellow ascot', 'touhou project'])
445+
case 'sanae':
446+
text = make_prompt('kochiya sanae', ['green hair', 'blue footwear', 'long hair', 'green eyes', 'white dress', 'blue skirt', 'frog hair ornament', 'snake hair ornament', 'smile', 'standing', 'touhou project'])
447+
case 'sanaestand':
448+
text = make_prompt('kochiya sanae', ['green hair', 'blue footwear', 'long hair', 'green eyes', 'white dress', 'blue skirt', 'frog hair ornament', 'snake hair ornament', 'smile', 'touhou project'], sitting=False)
449+
case 'tenshi':
450+
text = make_prompt('hinanawi tenshi', ['blue hair', 'brown footwear', 'slit pupils', 'very long hair', 'red eyes', 'white dress', 'blue skirt', 'hat', 'black headwear', 'smile', 'touhou project'])
451+
case 'youmu':
452+
text = make_prompt('konpaku youmu', ['silver hair', 'black footwear', 'medium hair', 'slit pupils', 'green eyes', 'green dress', 'sleeveless dress', 'white sleeves', 'black ribbon', 'hair ribbon', 'unhappy', 'touhou project'])
453+
case 'yuyuko':
454+
text = make_prompt('saigyouji yuyuko', ['pink hair', 'black footwear', 'medium hair', 'pink eyes', 'wide sleeves', 'long sleeves', 'blue dress', 'mob cap', 'touhou project'])
455+
case 'nagisa':
456+
text = make_prompt('furukawa nagisa', ['brown hair', 'brown footwear', 'medium hair', 'brown eyes', 'smile', 'school briefcase', 'blue skirt', 'yellow jacket', 'antenna hair', 'dango', 'clannad'])
457+
case 'teto':
458+
text = make_prompt('kasane teto', ['pink hair', 'red footwear', 'red eyes', 'medium hair', 'detached sleeves', 'twin drills', 'drill hair', 'grey dress', 'smile', 'vocaloid'])
459+
case 'korone':
460+
text = make_prompt('inugami korone', ['yellow jacket', 'blue footwear', 'long hair', 'white dress', 'brown hair', 'brown eyes', 'on chair', 'hairclip', 'uwu', 'hololive'], on_floor=False)
461+
case 'kudo':
462+
text = make_prompt('kudryavka noumi', ['fang', 'black footwear', 'very long hair', 'white hat', 'white cape', 'silver hair', 'grey skirt', 'blue eyes', 'smile', 'little busters!'])
463+
case 'patchouli':
464+
text = make_prompt('patchouli knowledge', ['mob cap', 'pink footwear', 'long hair', 'slit pupils', 'striped dress', 'pink dress', 'purple hair', 'ribbons in hair', 'unhappy', 'touhou project'])
465+
case _:
466+
text = f"photo of {placeholder_string}"
397467

398468
example["input_ids"] = self.tokenizer(
399469
text,
@@ -403,25 +473,44 @@ def __getitem__(self, i):
403473
return_tensors="pt",
404474
).input_ids[0]
405475

406-
# default to score-sde preprocessing
407-
img = np.array(image).astype(np.uint8)
408-
409-
if self.center_crop:
410-
crop = min(img.shape[0], img.shape[1])
411-
h, w, = (
412-
img.shape[0],
413-
img.shape[1],
476+
if stem not in self.cache:
477+
image = Image.open(image_path)
478+
if not image.mode == "RGB":
479+
image = image.convert("RGB")
480+
481+
# default to score-sde preprocessing
482+
img = np.array(image).astype(np.uint8)
483+
484+
if self.center_crop:
485+
crop = min(img.shape[0], img.shape[1])
486+
h, w, = (
487+
img.shape[0],
488+
img.shape[1],
489+
)
490+
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
491+
492+
image: Img = Image.fromarray(img)
493+
image: Img = image.resize((self.size, self.size), resample=self.interpolation)
494+
495+
flipped: Img = hflip(image)
496+
497+
def pil_to_latents(image: Img) -> Tensor:
498+
image = np.array(image).astype(np.uint8)
499+
image = (image / 127.5 - 1.0).astype(np.float32)
500+
latents: Tensor = torch.from_numpy(image).permute(2, 0, 1)
501+
return latents
502+
503+
image, flipped = (pil_to_latents(variation) for variation in (image, flipped))
504+
505+
self.cache[stem] = Variations(
506+
original=image,
507+
flipped=flipped,
414508
)
415-
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
416-
417-
image = Image.fromarray(img)
418-
image = image.resize((self.size, self.size), resample=self.interpolation)
419-
420-
image = self.flip_transform(image)
421-
image = np.array(image).astype(np.uint8)
422-
image = (image / 127.5 - 1.0).astype(np.float32)
509+
variations = self.cache[stem]
510+
flip = torch.rand(1) < self.flip_p
511+
image = variations.flipped if flip else variations.original
423512

424-
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
513+
example["pixel_values"] = image
425514
return example
426515

427516

0 commit comments

Comments
 (0)