Skip to content

More "OpenAI Blog Post" Training | Depth 32 | Heads 8 | LR 5e-4 #86

Closed
@afiaka87

Description

@afiaka87

Edit: Moved to discussions: #106

Hey, all. Some of you might know I'm practicing and learning about machine learning with dalle-pytorch and a dataset consisting of the images OpenAI presented in the DALLE blog post. I honestly dont have the money to train this whole dataset,

edit: this is no longer true. Using the 1024 VQGAN from the "Taming Transformers" research, it's now quite possible to train a full dataset of 1,000,000 image-text pairs and i'm doing just that. I hope to have it finished in about a week. I assume someone else will release a dalle-pytorch trained properly on COCO and other image sets before then, but if they dont, check here for updates.

Anway, it ran for ~36000 steps. As you can see it...still really likes mannequins. I'm considering removing them from the dataset. But also, you'll notice that the network has actually got a decent idea of the sort of general colors that belong in types of prompts.

Some Samples from Near the End of Training

results

Every Text-Image Reconstruction

https://wandb.ai/afiaka87/dalle_pytorch_live_training/reports/dalle-pytorch-Test-Run-2--Vmlldzo1MzM5MjQ

Deliverables (my train_dalle.py)

https://gist.github.com/afiaka87/850fb3cc48edde8a7ed4cb1ce53b6bd2

This has some code in it that actually manages to deal with truncated images via Try Catch. Apparently detecting a corrupted PNG is harder than P vs NP. PIL's imverify() function doesnt catch all of them. Python's built in imghdr library doesn't catch all of them either. So you just sort of catch OSError and return an item further along. Works well enough.

Parameters

SHUFFLE = True
EPOCHS = 28 # This wound up being less than a single epoch, of course. 
BATCH_SIZE = 16
LEARNING_RATE = 0.0005 # I found this learning rate to be more suitable than 0.0003 in my hyperparameter sweep post
GRAD_CLIP_NORM = 0.5
DEPTH = 32
HEADS = 8
MODEL_DIM = 512
TEXT_SEQ_LEN = 256
DIM_HEAD = 64
REVERSIBLE = True,
ATTN_TYPES = ('full')

Dataset Description

#61 (comment)

Just for more info on the dataset itself, it is roughly 1,100,000 256x256 image-text pairs that were generated by OpenAI's DALL-E. They presented roughly ~30k unique text prompts of which they posted the top 32 of 512 generations on https://openai.com/blog/dall-e/. Many images were corrupt, and not every prompt has a full 32 examples, but the total number of images winds up being about 1.1 million. If you look at many of the examples on that page, you'll see that DALL-E (in that form at least), can and will make mistakes. These mistakes are also in this dataset. Anyway I'm just messing around having fun training and what not. This is definitely not going to produce a good model or anything.

There are also a large number of images in the dataset which are intended to be used with the "mask" feature. I don't know if that's possible yet in DALLE-pytorch though. Anyway, that can't be helping much.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions