From 128b0afbb66ecb67849b0632257ce0f8e9d44d90 Mon Sep 17 00:00:00 2001 From: KernelA <17554646+KernelA@users.noreply.github.com> Date: Thu, 13 Oct 2022 12:58:18 +0500 Subject: [PATCH 1/3] Add init --- taming/__init__.py | 0 taming/data/__init__.py | 0 taming/data/conditional_builder/__init__.py | 0 taming/models/__init__.py | 0 taming/modules/__init__.py | 0 taming/modules/diffusionmodules/__init__.py | 0 taming/modules/discriminator/__init__.py | 0 taming/modules/misc/__init__.py | 0 taming/modules/transformer/__init__.py | 0 taming/modules/vqvae/__init__.py | 0 10 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 taming/__init__.py create mode 100644 taming/data/__init__.py create mode 100644 taming/data/conditional_builder/__init__.py create mode 100644 taming/models/__init__.py create mode 100644 taming/modules/__init__.py create mode 100644 taming/modules/diffusionmodules/__init__.py create mode 100644 taming/modules/discriminator/__init__.py create mode 100644 taming/modules/misc/__init__.py create mode 100644 taming/modules/transformer/__init__.py create mode 100644 taming/modules/vqvae/__init__.py diff --git a/taming/__init__.py b/taming/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/data/__init__.py b/taming/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/data/conditional_builder/__init__.py b/taming/data/conditional_builder/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/models/__init__.py b/taming/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/modules/__init__.py b/taming/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/modules/diffusionmodules/__init__.py b/taming/modules/diffusionmodules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/modules/discriminator/__init__.py b/taming/modules/discriminator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/modules/misc/__init__.py b/taming/modules/misc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/modules/transformer/__init__.py b/taming/modules/transformer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/modules/vqvae/__init__.py b/taming/modules/vqvae/__init__.py new file mode 100644 index 00000000..e69de29b From d8464046ffba4a03273df8fe93788ccf6a915993 Mon Sep 17 00:00:00 2001 From: KernelA <17554646+KernelA@users.noreply.github.com> Date: Thu, 13 Oct 2022 13:26:39 +0500 Subject: [PATCH 2/3] Add __init__ --- taming/data/ade20k.py | 2 +- taming/data/annotated_objects_coco.py | 4 ++-- taming/data/annotated_objects_dataset.py | 10 +++++----- taming/data/annotated_objects_open_images.py | 4 ++-- taming/data/coco.py | 2 +- taming/data/conditional_builder/objects_bbox.py | 8 ++++---- .../data/conditional_builder/objects_center_points.py | 6 +++--- taming/data/conditional_builder/utils.py | 2 +- taming/data/custom.py | 2 +- taming/data/faceshq.py | 2 +- taming/data/image_transforms.py | 2 +- taming/data/imagenet.py | 6 +++--- taming/data/utils.py | 3 ++- taming/models/cond_transformer.py | 2 +- taming/models/vqgan.py | 8 ++++---- taming/modules/discriminator/model.py | 2 +- taming/modules/losses/__init__.py | 2 +- taming/modules/losses/lpips.py | 2 +- taming/modules/losses/vqperceptual.py | 4 ++-- 19 files changed, 37 insertions(+), 36 deletions(-) diff --git a/taming/data/ade20k.py b/taming/data/ade20k.py index 366dae97..badf3a32 100644 --- a/taming/data/ade20k.py +++ b/taming/data/ade20k.py @@ -5,7 +5,7 @@ from PIL import Image from torch.utils.data import Dataset -from taming.data.sflckr import SegmentationBase # for examples included in repo +from .sflckr import SegmentationBase # for examples included in repo class Examples(SegmentationBase): diff --git a/taming/data/annotated_objects_coco.py b/taming/data/annotated_objects_coco.py index af000ecd..63de8727 100644 --- a/taming/data/annotated_objects_coco.py +++ b/taming/data/annotated_objects_coco.py @@ -6,8 +6,8 @@ from tqdm import tqdm -from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset -from taming.data.helper_types import Annotation, ImageDescription, Category +from .annotated_objects_dataset import AnnotatedObjectsDataset +from .helper_types import Annotation, ImageDescription, Category COCO_PATH_STRUCTURE = { 'train': { diff --git a/taming/data/annotated_objects_dataset.py b/taming/data/annotated_objects_dataset.py index 53cc346a..57729a4b 100644 --- a/taming/data/annotated_objects_dataset.py +++ b/taming/data/annotated_objects_dataset.py @@ -7,11 +7,11 @@ from torch.utils.data import Dataset from torchvision import transforms -from taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder -from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder -from taming.data.conditional_builder.utils import load_object_from_string -from taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType -from taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \ +from .conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder +from .conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder +from .conditional_builder.utils import load_object_from_string +from .helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType +from .image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \ Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor diff --git a/taming/data/annotated_objects_open_images.py b/taming/data/annotated_objects_open_images.py index aede6803..63e718fb 100644 --- a/taming/data/annotated_objects_open_images.py +++ b/taming/data/annotated_objects_open_images.py @@ -4,8 +4,8 @@ from typing import Dict, List, Any import warnings -from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset -from taming.data.helper_types import Annotation, Category +from .annotated_objects_dataset import AnnotatedObjectsDataset +from .helper_types import Annotation, Category from tqdm import tqdm OPEN_IMAGES_STRUCTURE = { diff --git a/taming/data/coco.py b/taming/data/coco.py index 2b2f7838..76f0a501 100644 --- a/taming/data/coco.py +++ b/taming/data/coco.py @@ -6,7 +6,7 @@ from tqdm import tqdm from torch.utils.data import Dataset -from taming.data.sflckr import SegmentationBase # for examples included in repo +from .sflckr import SegmentationBase # for examples included in repo class Examples(SegmentationBase): diff --git a/taming/data/conditional_builder/objects_bbox.py b/taming/data/conditional_builder/objects_bbox.py index 15881e76..db37289a 100644 --- a/taming/data/conditional_builder/objects_bbox.py +++ b/taming/data/conditional_builder/objects_bbox.py @@ -3,12 +3,12 @@ from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont from more_itertools.recipes import grouper -from taming.data.image_transforms import convert_pil_to_tensor from torch import LongTensor, Tensor -from taming.data.helper_types import BoundingBox, Annotation -from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder -from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \ +from ..image_transforms import convert_pil_to_tensor +from ..helper_types import BoundingBox, Annotation +from ..conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder +from ..conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \ pad_list, get_plot_font_size, absolute_bbox diff --git a/taming/data/conditional_builder/objects_center_points.py b/taming/data/conditional_builder/objects_center_points.py index 9a480329..5cf8ec34 100644 --- a/taming/data/conditional_builder/objects_center_points.py +++ b/taming/data/conditional_builder/objects_center_points.py @@ -6,11 +6,11 @@ from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont from more_itertools.recipes import grouper -from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \ +from ..conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \ additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \ absolute_bbox, rescale_annotations -from taming.data.helper_types import BoundingBox, Annotation -from taming.data.image_transforms import convert_pil_to_tensor +from ..helper_types import BoundingBox, Annotation +from ..image_transforms import convert_pil_to_tensor from torch import LongTensor, Tensor diff --git a/taming/data/conditional_builder/utils.py b/taming/data/conditional_builder/utils.py index d0ee175f..3a5c195b 100644 --- a/taming/data/conditional_builder/utils.py +++ b/taming/data/conditional_builder/utils.py @@ -1,7 +1,7 @@ import importlib from typing import List, Any, Tuple, Optional -from taming.data.helper_types import BoundingBox, Annotation +from ...data.helper_types import BoundingBox, Annotation # source: seaborn, color palette tab10 COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188), diff --git a/taming/data/custom.py b/taming/data/custom.py index 33f302a4..de212eeb 100644 --- a/taming/data/custom.py +++ b/taming/data/custom.py @@ -3,7 +3,7 @@ import albumentations from torch.utils.data import Dataset -from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex +from .base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex class CustomBase(Dataset): diff --git a/taming/data/faceshq.py b/taming/data/faceshq.py index 6912d04b..40d55b37 100644 --- a/taming/data/faceshq.py +++ b/taming/data/faceshq.py @@ -3,7 +3,7 @@ import albumentations from torch.utils.data import Dataset -from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex +from .base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex class FacesBase(Dataset): diff --git a/taming/data/image_transforms.py b/taming/data/image_transforms.py index 657ac332..285f260b 100644 --- a/taming/data/image_transforms.py +++ b/taming/data/image_transforms.py @@ -7,7 +7,7 @@ from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor from torchvision.transforms.functional import _get_image_size as get_image_size -from taming.data.helper_types import BoundingBox, Image +from .helper_types import BoundingBox, Image pil_to_tensor = PILToTensor() diff --git a/taming/data/imagenet.py b/taming/data/imagenet.py index 9a02ec44..acf1e000 100644 --- a/taming/data/imagenet.py +++ b/taming/data/imagenet.py @@ -7,9 +7,9 @@ from omegaconf import OmegaConf from torch.utils.data import Dataset -from taming.data.base import ImagePaths -from taming.util import download, retrieve -import taming.data.utils as bdu +from .base import ImagePaths +from ..util import download, retrieve +import .utils as bdu def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"): diff --git a/taming/data/utils.py b/taming/data/utils.py index 2b3c3d53..119a8892 100644 --- a/taming/data/utils.py +++ b/taming/data/utils.py @@ -7,11 +7,12 @@ import numpy as np import torch -from taming.data.helper_types import Annotation from torch._six import string_classes from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format from tqdm import tqdm +from .helper_types import Annotation + def unpack(path): if path.endswith("tar.gz"): diff --git a/taming/models/cond_transformer.py b/taming/models/cond_transformer.py index e4c63730..a00171ca 100644 --- a/taming/models/cond_transformer.py +++ b/taming/models/cond_transformer.py @@ -4,7 +4,7 @@ import pytorch_lightning as pl from main import instantiate_from_config -from taming.modules.util import SOSProvider +from ..modules.util import SOSProvider def disabled_train(self, mode=True): diff --git a/taming/models/vqgan.py b/taming/models/vqgan.py index a6950baa..ac8a0039 100644 --- a/taming/models/vqgan.py +++ b/taming/models/vqgan.py @@ -4,10 +4,10 @@ from main import instantiate_from_config -from taming.modules.diffusionmodules.model import Encoder, Decoder -from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer -from taming.modules.vqvae.quantize import GumbelQuantize -from taming.modules.vqvae.quantize import EMAVectorQuantizer +from ..modules.diffusionmodules.model import Encoder, Decoder +from ..modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer +from ..modules.vqvae.quantize import GumbelQuantize +from ..modules.vqvae.quantize import EMAVectorQuantizer class VQModel(pl.LightningModule): def __init__(self, diff --git a/taming/modules/discriminator/model.py b/taming/modules/discriminator/model.py index 2aaa3110..1e8afd48 100644 --- a/taming/modules/discriminator/model.py +++ b/taming/modules/discriminator/model.py @@ -2,7 +2,7 @@ import torch.nn as nn -from taming.modules.util import ActNorm +from ...modules.util import ActNorm def weights_init(m): diff --git a/taming/modules/losses/__init__.py b/taming/modules/losses/__init__.py index d09caf9e..128588ad 100644 --- a/taming/modules/losses/__init__.py +++ b/taming/modules/losses/__init__.py @@ -1,2 +1,2 @@ -from taming.modules.losses.vqperceptual import DummyLoss +from .vqperceptual import DummyLoss diff --git a/taming/modules/losses/lpips.py b/taming/modules/losses/lpips.py index a7280447..5c2325be 100644 --- a/taming/modules/losses/lpips.py +++ b/taming/modules/losses/lpips.py @@ -5,7 +5,7 @@ from torchvision import models from collections import namedtuple -from taming.util import get_ckpt_path +from ...util import get_ckpt_path class LPIPS(nn.Module): diff --git a/taming/modules/losses/vqperceptual.py b/taming/modules/losses/vqperceptual.py index c2febd44..e488ca07 100644 --- a/taming/modules/losses/vqperceptual.py +++ b/taming/modules/losses/vqperceptual.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch.nn.functional as F -from taming.modules.losses.lpips import LPIPS -from taming.modules.discriminator.model import NLayerDiscriminator, weights_init +from ..losses.lpips import LPIPS +from ..discriminator.model import NLayerDiscriminator, weights_init class DummyLoss(nn.Module): From b50d80b2a0a38f803b1e2c8c2e7d984597cbaa9b Mon Sep 17 00:00:00 2001 From: KernelA <17554646+KernelA@users.noreply.github.com> Date: Mon, 17 Oct 2022 12:35:50 +0500 Subject: [PATCH 3/3] Move some methods from main --- main.py | 15 ++------------- scripts/make_samples.py | 4 +++- scripts/sample_conditional.py | 4 +++- scripts/sample_fast.py | 3 +-- taming/data/annotated_objects_open_images.py | 3 ++- taming/data/conditional_builder/utils.py | 2 +- taming/models/cond_transformer.py | 2 +- taming/models/vqgan.py | 3 +-- taming/modules/util.py | 3 ++- taming/util.py | 12 ++++++++++++ 10 files changed, 28 insertions(+), 23 deletions(-) diff --git a/main.py b/main.py index 3d83cb21..43fa9755 100644 --- a/main.py +++ b/main.py @@ -12,15 +12,7 @@ from pytorch_lightning.utilities.distributed import rank_zero_only from taming.data.utils import custom_collate - - -def get_obj_from_str(string, reload=False): - module, cls = string.rsplit(".", 1) - if reload: - module_imp = importlib.import_module(module) - importlib.reload(module_imp) - return getattr(importlib.import_module(module, package=None), cls) - +from taming.util import get_obj_from_str, instantiate_from_config def get_parser(**parser_kwargs): def str2bool(v): @@ -113,10 +105,7 @@ def nondefault_trainer_args(opt): return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) -def instantiate_from_config(config): - if not "target" in config: - raise KeyError("Expected key `target` to instantiate.") - return get_obj_from_str(config["target"])(**config.get("params", dict())) + class WrappedDataset(Dataset): diff --git a/scripts/make_samples.py b/scripts/make_samples.py index 5e4d6995..5e445761 100644 --- a/scripts/make_samples.py +++ b/scripts/make_samples.py @@ -3,11 +3,13 @@ import numpy as np from omegaconf import OmegaConf from PIL import Image -from main import instantiate_from_config, DataModuleFromConfig +from main import DataModuleFromConfig from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate from tqdm import trange +from taming.util import instantiate_from_config + def save_image(x, path): c,h,w = x.shape diff --git a/scripts/sample_conditional.py b/scripts/sample_conditional.py index 174cf2af..edb367cd 100644 --- a/scripts/sample_conditional.py +++ b/scripts/sample_conditional.py @@ -5,10 +5,12 @@ import streamlit as st from streamlit import caching from PIL import Image -from main import instantiate_from_config, DataModuleFromConfig +from main import DataModuleFromConfig from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate +from taming.util import instantiate_from_config + rescale = lambda x: (x + 1.) / 2. diff --git a/scripts/sample_fast.py b/scripts/sample_fast.py index ff546c7d..6eabfb42 100644 --- a/scripts/sample_fast.py +++ b/scripts/sample_fast.py @@ -7,9 +7,8 @@ from tqdm import tqdm, trange from einops import repeat -from main import instantiate_from_config from taming.modules.transformer.mingpt import sample_with_past - +from taming.util import instantiate_from_config rescale = lambda x: (x + 1.) / 2. diff --git a/taming/data/annotated_objects_open_images.py b/taming/data/annotated_objects_open_images.py index 63e718fb..ffc272df 100644 --- a/taming/data/annotated_objects_open_images.py +++ b/taming/data/annotated_objects_open_images.py @@ -4,9 +4,10 @@ from typing import Dict, List, Any import warnings +from tqdm import tqdm + from .annotated_objects_dataset import AnnotatedObjectsDataset from .helper_types import Annotation, Category -from tqdm import tqdm OPEN_IMAGES_STRUCTURE = { 'train': { diff --git a/taming/data/conditional_builder/utils.py b/taming/data/conditional_builder/utils.py index 3a5c195b..bb2085d2 100644 --- a/taming/data/conditional_builder/utils.py +++ b/taming/data/conditional_builder/utils.py @@ -1,7 +1,7 @@ import importlib from typing import List, Any, Tuple, Optional -from ...data.helper_types import BoundingBox, Annotation +from ..helper_types import BoundingBox, Annotation # source: seaborn, color palette tab10 COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188), diff --git a/taming/models/cond_transformer.py b/taming/models/cond_transformer.py index a00171ca..9e12cde5 100644 --- a/taming/models/cond_transformer.py +++ b/taming/models/cond_transformer.py @@ -3,7 +3,7 @@ import torch.nn.functional as F import pytorch_lightning as pl -from main import instantiate_from_config +from ..util import instantiate_from_config from ..modules.util import SOSProvider diff --git a/taming/models/vqgan.py b/taming/models/vqgan.py index ac8a0039..96c80049 100644 --- a/taming/models/vqgan.py +++ b/taming/models/vqgan.py @@ -2,8 +2,7 @@ import torch.nn.functional as F import pytorch_lightning as pl -from main import instantiate_from_config - +from ..util import instantiate_from_config from ..modules.diffusionmodules.model import Encoder, Decoder from ..modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer from ..modules.vqvae.quantize import GumbelQuantize diff --git a/taming/modules/util.py b/taming/modules/util.py index 9ee16385..550cdd5c 100644 --- a/taming/modules/util.py +++ b/taming/modules/util.py @@ -1,7 +1,8 @@ +import importlib + import torch import torch.nn as nn - def count_params(model): total_params = sum(p.numel() for p in model.parameters()) return total_params diff --git a/taming/util.py b/taming/util.py index 06053e5d..4bdecee0 100644 --- a/taming/util.py +++ b/taming/util.py @@ -1,4 +1,5 @@ import os, hashlib +import importlib import requests from tqdm import tqdm @@ -14,6 +15,17 @@ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" } +def instantiate_from_config(config): + if not "target" in config: + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) def download(url, local_path, chunk_size=1024): os.makedirs(os.path.split(local_path)[0], exist_ok=True)