Giter Club home page Giter Club logo

Comments (17)

awilson9 avatar awilson9 commented on May 18, 2024 3

This is still happening for me as well on the pretrained VAE on 0.2.2

from dalle-pytorch.

edend10 avatar edend10 commented on May 18, 2024 1

Thanks for the response @lucidrains !
Ohh interesting, I'll check out the changes and try it out. Will look out for more updates!

from dalle-pytorch.

afiaka87 avatar afiaka87 commented on May 18, 2024 1

Hm, does this work? @tommy19970714 ?
https://wandb.ai/afiaka87/dalle-pytorch-openai-samples/reports/Training-on-OpenAI-DALL-E-Generated-Images--Vmlldzo1MTk2MjQ?accessToken=89u5e10c2oag5mlv46xm2sz6orkyqdlwjrsj8vd95oz8ke3ez6v8v2fh07klk6j1

from dalle-pytorch.

lucidrains avatar lucidrains commented on May 18, 2024

@edend10 Hi Eden! Thanks for trying out the repository! I may have found a bug with the pretrained VAE wrapper, fixed in the latest commit https://github.com/lucidrains/DALLE-pytorch/blob/0.2.2/dalle_pytorch/vae.py#L82 πŸ™ I'll be training this myself this week, and ironing out any remaining issues (other than data and scale of course)

from dalle-pytorch.

CDitzel avatar CDitzel commented on May 18, 2024

what are those two mapping functions for anyway?

Are they just for transforming the pixel value range for the input data they just over at OpenAI?

from dalle-pytorch.

AlexanderRayCarlson avatar AlexanderRayCarlson commented on May 18, 2024

Hello! Thank you for this excellent work. I seem to be getting something similar - abstract sorts of blue squares when training in the colab notebook. It looks like the package (0.2.2) is updated with the latest fix - is there anything else needed to do at the moment?

from dalle-pytorch.

afiaka87 avatar afiaka87 commented on May 18, 2024

This is an early output (2 epochs) from the new code that removes the normalization from train_dalle.py. Was that the necessary fix @lucidrains ?

DEPTH = 6
BATCH_SIZE = 8

media_images_image_1600_82d6d0f7

"a female mannequin"
mannequin

Much more cohesive and a much stronger start now. No strange blueness, at the very least.

from dalle-pytorch.

liuqk3 avatar liuqk3 commented on May 18, 2024

Hi @afiaka87, Amazing results! Can you share more details about your configurations? such as the dataset, learning rate, lr scheduler, number of text and image (8192, right?) tokens? Thanks.

from dalle-pytorch.

afiaka87 avatar afiaka87 commented on May 18, 2024

Hi @afiaka87, Amazing results! Can you share more details about your configurations? such as the dataset, learning rate, lr scheduler, number of text and image (8192, right?) tokens? Thanks.

I should mention the dataset I'm using includes images released by OpenAI with their DALL-E. The mannequin image is not being generated from text alone, it's from an image text pair. Anyway, my point is that my dataset is bad and I'm mostly just messing around. It's probably the case that using images generated from DALL-E itself is bound to converge much quicker than usual.

I'm using the defaults in train_dalle.py except for the BATCH SIZE and DEPTH. Pretrained OpenAI VAE, top_k=0.9, and reversible=True. I tried mixing attention layers but it adds memory. (edit: I dont think it does actually. training with all four attn_types currently)

I'm working on creating a hyperparameter sweep with wandb currently. I think a learning rate of 2e-4 might be better for depth greater than 12 or so.

I still can't get a stable learning rate with 64 depth.

from dalle-pytorch.

afiaka87 avatar afiaka87 commented on May 18, 2024

Edit: You can find the whole training session here:

edit: edit: err here: https://wandb.ai/afiaka87/dalle-pytorch-openai-samples/reports/Training-on-OpenAI-DALL-E-Generated-Images--Vmlldzo1MTk2MjQ?accessToken=89u5e10c2oag5mlv46xm2sz6orkyqdlwjrsj8vd95oz8ke3ez6v8v2fh07klk6j1
I'm starting over because there have been updates to the main branch.

Original post:

"a professional high quality emoji of a spider starfish chimera . a spider imitating a starfish . a spider made of starfish . a professional emoji ."

starfish_spider_chimera

Left it running at 16 depth, 8 heads, batch size of 12 learning_rate=2e-4. The loss is going down at a steady consistent rate. (edit: just kidding! it seems to get stuck at around ~6.0 on this run. which seems high?)

DEPTH: 16
HEADS: 8
TOP_K: 0.85
EPOCHS: 27
SHUFFLE: True
DIM_HEAD: 64
MODEL_DIM: 512
BATCH_SIZE: 12
REVERSIBLE: true
TEXT_SEQ_LEN: 256
LEARNING_RATE: 0.0002
GRAD_CLIP_NORM: 0.5

from dalle-pytorch.

afiaka87 avatar afiaka87 commented on May 18, 2024

Edit:

Here, I used Weights & Biases to create a report. This link has all the images generated (every 100th iteration) for 27,831 total iterations

Edit: this one should work i think
https://wandb.ai/afiaka87/dalle-pytorch-openai-samples/reports/Training-on-OpenAI-DALL-E-Generated-Images--Vmlldzo1MTk2MjQ?accessToken=89u5e10c2oag5mlv46xm2sz6orkyqdlwjrsj8vd95oz8ke3ez6v8v2fh07klk6j1

from dalle-pytorch.

tommy19970714 avatar tommy19970714 commented on May 18, 2024

@afiaka87 Thank you for sharing your report of Weights & Biases!
But I can't see the report because its project is private.
Can you allow us to see it?
γ‚Ήγ‚―γƒͺγƒΌγƒ³γ‚·γƒ§γƒƒγƒˆ 2021-03-11 17 55 43

from dalle-pytorch.

afiaka87 avatar afiaka87 commented on May 18, 2024

Hi @afiaka87, Amazing results! Can you share more details about your configurations? such as the dataset, learning rate, lr scheduler, number of text and image (8192, right?) tokens? Thanks.

Just for more info on the dataset itself, it is roughly 1,100,000 256x256 image-text pairs that were generated by OpenAI's DALL-E. They presented roughly ~30k unique text prompts of which they posted the top 32 of 512 generations on https://openai.com/blog/dall-e/. Many images were corrupt, and not every prompt has a full 32 examples, but the total number of images winds up being about 1.1 million. If you look at many of the examples on that page, you'll see that DALL-E (in that form at least), can and will make mistakes. These mistakes are also in this dataset. Anyway I'm just messing around having fun training and what not. This is definitely not going to produce a good model or anything.

There are also a large number of images in the dataset which are intended to be used with the "mask" feature. I don't know if that's possible yet in DALLE-pytorch though. Anyway, that can't be helping much.

One valuable thing I've taken from this is that it seems to take at least ~2000 iterations with a batch size of 4 to approach any sort of coherent reproductions. This number specifically probably varies, but in terms of "knowing when to start over", I would say rougly 3000 steps might be a good soft target.

from dalle-pytorch.

tommy19970714 avatar tommy19970714 commented on May 18, 2024

Thank you for shareing your result!
I will refer your parameters.

from dalle-pytorch.

afiaka87 avatar afiaka87 commented on May 18, 2024

@tommy19970714

I did a hyperparameter sweep with weights and biases. Forty Eight 1200 iteration runs of dalle-pytorch while varying Learning Rate, Depth and Heads, (minimizing the total loss at the end of each run).

#84 (comment)

from dalle-pytorch.

afiaka87 avatar afiaka87 commented on May 18, 2024

Most important thing to note here is that the learning rate actually needs to go up to about 0.0005 when dealing with ~26-32 depth

from dalle-pytorch.

afiaka87 avatar afiaka87 commented on May 18, 2024

I've done a much longer training session on that same dataset here:

#86

from dalle-pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.