Skip to content

Missing __init__.py and relative import #182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 2 additions & 13 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,7 @@
from pytorch_lightning.utilities.distributed import rank_zero_only

from taming.data.utils import custom_collate


def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)

from taming.util import get_obj_from_str, instantiate_from_config

def get_parser(**parser_kwargs):
def str2bool(v):
Expand Down Expand Up @@ -113,10 +105,7 @@ def nondefault_trainer_args(opt):
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))


def instantiate_from_config(config):
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))



class WrappedDataset(Dataset):
Expand Down
4 changes: 3 additions & 1 deletion scripts/make_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from main import instantiate_from_config, DataModuleFromConfig
from main import DataModuleFromConfig
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from tqdm import trange

from taming.util import instantiate_from_config


def save_image(x, path):
c,h,w = x.shape
Expand Down
4 changes: 3 additions & 1 deletion scripts/sample_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import streamlit as st
from streamlit import caching
from PIL import Image
from main import instantiate_from_config, DataModuleFromConfig
from main import DataModuleFromConfig
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate

from taming.util import instantiate_from_config


rescale = lambda x: (x + 1.) / 2.

Expand Down
3 changes: 1 addition & 2 deletions scripts/sample_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from tqdm import tqdm, trange
from einops import repeat

from main import instantiate_from_config
from taming.modules.transformer.mingpt import sample_with_past

from taming.util import instantiate_from_config

rescale = lambda x: (x + 1.) / 2.

Expand Down
Empty file added taming/__init__.py
Empty file.
Empty file added taming/data/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion taming/data/ade20k.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from PIL import Image
from torch.utils.data import Dataset

from taming.data.sflckr import SegmentationBase # for examples included in repo
from .sflckr import SegmentationBase # for examples included in repo


class Examples(SegmentationBase):
Expand Down
4 changes: 2 additions & 2 deletions taming/data/annotated_objects_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from tqdm import tqdm

from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
from taming.data.helper_types import Annotation, ImageDescription, Category
from .annotated_objects_dataset import AnnotatedObjectsDataset
from .helper_types import Annotation, ImageDescription, Category

COCO_PATH_STRUCTURE = {
'train': {
Expand Down
10 changes: 5 additions & 5 deletions taming/data/annotated_objects_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from torch.utils.data import Dataset
from torchvision import transforms

from taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder
from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
from taming.data.conditional_builder.utils import load_object_from_string
from taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType
from taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \
from .conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder
from .conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
from .conditional_builder.utils import load_object_from_string
from .helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType
from .image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \
Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor


Expand Down
5 changes: 3 additions & 2 deletions taming/data/annotated_objects_open_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from typing import Dict, List, Any
import warnings

from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
from taming.data.helper_types import Annotation, Category
from tqdm import tqdm

from .annotated_objects_dataset import AnnotatedObjectsDataset
from .helper_types import Annotation, Category

OPEN_IMAGES_STRUCTURE = {
'train': {
'top_level': '',
Expand Down
2 changes: 1 addition & 1 deletion taming/data/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tqdm import tqdm
from torch.utils.data import Dataset

from taming.data.sflckr import SegmentationBase # for examples included in repo
from .sflckr import SegmentationBase # for examples included in repo


class Examples(SegmentationBase):
Expand Down
Empty file.
8 changes: 4 additions & 4 deletions taming/data/conditional_builder/objects_bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
from more_itertools.recipes import grouper
from taming.data.image_transforms import convert_pil_to_tensor
from torch import LongTensor, Tensor

from taming.data.helper_types import BoundingBox, Annotation
from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \
from ..image_transforms import convert_pil_to_tensor
from ..helper_types import BoundingBox, Annotation
from ..conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
from ..conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \
pad_list, get_plot_font_size, absolute_bbox


Expand Down
6 changes: 3 additions & 3 deletions taming/data/conditional_builder/objects_center_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
from more_itertools.recipes import grouper
from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \
from ..conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \
additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \
absolute_bbox, rescale_annotations
from taming.data.helper_types import BoundingBox, Annotation
from taming.data.image_transforms import convert_pil_to_tensor
from ..helper_types import BoundingBox, Annotation
from ..image_transforms import convert_pil_to_tensor
from torch import LongTensor, Tensor


Expand Down
2 changes: 1 addition & 1 deletion taming/data/conditional_builder/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import importlib
from typing import List, Any, Tuple, Optional

from taming.data.helper_types import BoundingBox, Annotation
from ..helper_types import BoundingBox, Annotation

# source: seaborn, color palette tab10
COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188),
Expand Down
2 changes: 1 addition & 1 deletion taming/data/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import albumentations
from torch.utils.data import Dataset

from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
from .base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex


class CustomBase(Dataset):
Expand Down
2 changes: 1 addition & 1 deletion taming/data/faceshq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import albumentations
from torch.utils.data import Dataset

from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
from .base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex


class FacesBase(Dataset):
Expand Down
2 changes: 1 addition & 1 deletion taming/data/image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor
from torchvision.transforms.functional import _get_image_size as get_image_size

from taming.data.helper_types import BoundingBox, Image
from .helper_types import BoundingBox, Image

pil_to_tensor = PILToTensor()

Expand Down
6 changes: 3 additions & 3 deletions taming/data/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from omegaconf import OmegaConf
from torch.utils.data import Dataset

from taming.data.base import ImagePaths
from taming.util import download, retrieve
import taming.data.utils as bdu
from .base import ImagePaths
from ..util import download, retrieve
import .utils as bdu


def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"):
Expand Down
3 changes: 2 additions & 1 deletion taming/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

import numpy as np
import torch
from taming.data.helper_types import Annotation
from torch._six import string_classes
from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
from tqdm import tqdm

from .helper_types import Annotation


def unpack(path):
if path.endswith("tar.gz"):
Expand Down
Empty file added taming/models/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions taming/models/cond_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch.nn.functional as F
import pytorch_lightning as pl

from main import instantiate_from_config
from taming.modules.util import SOSProvider
from ..util import instantiate_from_config
from ..modules.util import SOSProvider


def disabled_train(self, mode=True):
Expand Down
11 changes: 5 additions & 6 deletions taming/models/vqgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import torch.nn.functional as F
import pytorch_lightning as pl

from main import instantiate_from_config

from taming.modules.diffusionmodules.model import Encoder, Decoder
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from taming.modules.vqvae.quantize import GumbelQuantize
from taming.modules.vqvae.quantize import EMAVectorQuantizer
from ..util import instantiate_from_config
from ..modules.diffusionmodules.model import Encoder, Decoder
from ..modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ..modules.vqvae.quantize import GumbelQuantize
from ..modules.vqvae.quantize import EMAVectorQuantizer

class VQModel(pl.LightningModule):
def __init__(self,
Expand Down
Empty file added taming/modules/__init__.py
Empty file.
Empty file.
Empty file.
2 changes: 1 addition & 1 deletion taming/modules/discriminator/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn


from taming.modules.util import ActNorm
from ...modules.util import ActNorm


def weights_init(m):
Expand Down
2 changes: 1 addition & 1 deletion taming/modules/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from taming.modules.losses.vqperceptual import DummyLoss
from .vqperceptual import DummyLoss

2 changes: 1 addition & 1 deletion taming/modules/losses/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torchvision import models
from collections import namedtuple

from taming.util import get_ckpt_path
from ...util import get_ckpt_path


class LPIPS(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions taming/modules/losses/vqperceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import torch.nn as nn
import torch.nn.functional as F

from taming.modules.losses.lpips import LPIPS
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
from ..losses.lpips import LPIPS
from ..discriminator.model import NLayerDiscriminator, weights_init


class DummyLoss(nn.Module):
Expand Down
Empty file added taming/modules/misc/__init__.py
Empty file.
Empty file.
3 changes: 2 additions & 1 deletion taming/modules/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import importlib

import torch
import torch.nn as nn


def count_params(model):
total_params = sum(p.numel() for p in model.parameters())
return total_params
Expand Down
Empty file.
12 changes: 12 additions & 0 deletions taming/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os, hashlib
import importlib
import requests
from tqdm import tqdm

Expand All @@ -14,6 +15,17 @@
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
}

def instantiate_from_config(config):
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))

def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)

def download(url, local_path, chunk_size=1024):
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
Expand Down