Giter Club home page Giter Club logo

gill's Introduction

Generating Images with Multimodal Language Models

GILL chat animation

This repository hosts the code and model weights for the GILL model.

Paper | Project Webpage

HF paper page Open in Spaces

Model and Usage

GILL model architecture

GILL (Generating Images with Large Language Models) is capable of processing arbitrarily interleaved image-and-text inputs to generate text, retrieve images, and generate novel images.

Setup instructions

Environment

Set up a new virtualenv, and install required libraries:

python -m venv venv
source venv/bin/activate
pip install -r requirements.txt

Add the gill library to PYTHONPATH:

export PYTHONPATH=$PYTHONPATH:/home/path/to/gill/

Pretrained Checkpoints

The GILL model weights (linear layers and [IMG] embeddings) are small (around 96MB), and are included in this git repo. They will be in the checkpoints/gill_opt/ folder after cloning. The checkpoint and model config in checkpoints/gill_opt/ reproduce the main results reported in our paper.

Precomputed Embeddings For Image Retrieval

For image retrieval, we provide the precomputed visual embeddings for Conceptual Captions images with valid URLs. They are stored at this URL. These are used to enable the model to retrieve images. The embeddings take up around 3GB, and are compatible with both model configs we provide. Download the files and place cc3m_embeddings_urls.npy into the checkpoints/gill_opt/ directory.

Note that you can still run the model without these, but it will not produce retrieved images. It will always generate novel images!

If you wish to precompute these embeddings for a different set of image URLs or for a different model, edit scripts/extract_img_embs.py with the list of image URLs and run it:

python scripts/extract_img_embs.py

Inference

Check out GILL_example_notebook.ipynb for examples on calling the model for inference. Several of the figures presented in the paper are reproduced in this notebook using greedy decoding of the model. Note that there may be minor differences in image outputs due to CC3M images being lost over time.

The notebook also shows how to use the model for generating images and generating text.

Training

Preparing CC3M

Our model is trained on the Conceptual Captions dataset. After following the instructions on the website to download the captions and images, format it into a .tsv file as follows:

caption image
A picture of a cat  cat.png
Mountains  mountain.png

where each line contains the caption followed by the filename of the image files. Save these .tsv files into the dataset/ folder (the default names expected are cc3m_train.tsv and cc3m_val.tsv). The repo contains two placeholder files with a few examples, and you will have to replace them with the appropriate data.

The corresponding image files should be saved in the data/ directory. The directory can be changed with the --image-dir runtime flag.

If you need help downloading CC3M for GILL, this repo contains helpful step-by-step tips.

Precomputing Text Embeddings

In addition to downloading the images, GILL also requires the embeddings from the text encoder of Stable Diffusion to train. We precompute this ahead of time in order to improve training time throughput. To do so, run the following script:

python scripts/preprocess_sd_embeddings.py  datasets/cc3m_val.tsv data/cc3m/validation/clip_embs

This will precompute embeddings from the captions in cc3m_val.tsv, and save the results to data/cc3m/validation/clip_embs.

Starting a Training Job

After preprocessing the data, we can finally start a training job with the following command line flag:

randport=$(shuf -i8000-9999 -n1)  # Generate a random port number
python -u main.py \
    --dist-url "tcp://127.0.0.1:${randport}" --dist-backend 'nccl' \
    --multiprocessing-distributed --world-size 1 --rank 0 \
    --dataset=cc3m  --val-dataset=cc3m \
    --exp-name='gill_exp' --image-dir='data/'  --log-base-dir='runs/' \
    --precision='bf16'  --print-freq=100

The default hyperparameters in main.py should reproduce our main results in the paper. We train on 2 A6000 GPUs for 48 hours. For GPUs with smaller memory available, you might need to reduce the batch size, enable gradient accumulation, or adjust hyperparameters to get good performance. You may also have to disable NCCL P2P with export NCCL_P2P_DISABLE=1 if you run into issues.

You can also run a small job on CPU, for testing purposes:

python -u main.py \
    --dataset=cc3m  --val-dataset=cc3m \
    --opt-version='facebook/opt-125m' --visual-model='openai/clip-vit-base-patch16' \
    --exp-name='gill_exp'   --log-base-dir='runs/' \
    --batch-size=2  --val-batch-size=2  --precision='fp32'  --print-freq=1 \
    --epochs=2  --val_steps_per_epoch=2   --steps_per_epoch=2

Pruning the Checkpoint

As GILL only consists of a few pretrained linear layers and the [IMG] embeddings, we can discard most of the pretrained weights to save on disk space. If you have trained a new model, and wish to do so, you can use gill/prune_model_ckpt.py file to prune the model weights, and format the ckpt as required by gill/models.py:

python scripts/prune_model_ckpt.py  runs/gill_exp

We used the same script to create the weights in the checkpoints/ directory.

Training a Decision Classifier

As described in the paper (Appendix F), we annotate PartiPrompts with per-example labels to retrieve or generate. The annotations are provided in data/PartiPromptsAllDecisions.tsv. The format follows PartiPrompts, with an additional Decisions column that we introduce:

Prompt	Category	Challenge	Note	Decisions
bond	Abstract	Basic	Biology-inspired concepts with multiple meanings	ret,gen,gen,same,gen
element	Abstract	Basic	Biology-inspired concepts with multiple meanings	ret,ret,ret,ret,same

this column indicates the annotations of 5 independent human evaluators. The decisions indicate whether the annotators prefer the retrieved image (ret), Stable Diffusion generated image (gen), or if both are around the same (same). The annotations released are for the query assessing which image is more relevant to the provided prompt. The annotations for the query on realism is also available at data/PartiPromptsAllDecisions_Realism.tsv, although we recommend using the text alignment annotations for training a decision classifier (as retrieved images are likely to be significantly more realistic than generated ones in general).

To train a decision classifier, first, preprocess the PartiPrompts annotations to keep only those with high interannotator agreement:

python scripts/process_p2_annotations.py

To train a decision model on these annotations, please follow the steps in TrainDecisionClassifier.ipynb. F1 scores of the model and human baselines are reported in the notebook. If you trained a GILL model from scratch, you would need to train this classifier as well, as the one provided at checkpoints/gill_opt/decision_model.pth.tar is only compatible with our original model weights.

Evaluation

We provide code to reproduce the VIST (Table 1) and VisDial (Table 2) results presented in our paper.

VIST Evaluation

To run the VIST evaluation, first download the annotations from the val set of the official VIST dataset. We will need to download and process the image files for running the evaluations presented in the paper. This can be done by running python evals/download_vist_images.py. By default, images are saved to the sis/val_images/ directory. Downloading the images should take about 1 hour on a decent connection (as images are downloaded directly from the Flickr URLs).

After the image files are downloaded, we can run the VIST generation experiment described in Section 4.1 our paper. First, we will run GILL to generate the last image in the sequence, conditioned on image + text inputs:

python evals/generate_vist_images.py  gill_vist_outputs

The generated images for each VIST example will be saved in gill_vist_outputs/. Then, to benchmark the models, we can compute the CLIP similarity scores:

python evals/compute_clip_similarity_vist.py

For the LPIPS metric, please refer to their official GitHub repo for installation instructions. Then, we can compute the results as follows:

python evals/lpips_2dirs.py -d0  sis/val_images/  -d1  gill_vist_outputs  -o results.txt --use_gpu

For LPIPS, you may have to resize the images to 256x256 to match the AlexNet model used. We have also uploaded our LPIPS eval script (gill/evals/lpips_2dirs.py) for reference.

VisDial Evaluation

Similarly, for VisDial, download the VisDial validation annotations, the dense answer annotations, and the images. Extract everything to the VisualDialog folder.

We can run the VisDial generation experiment described in Section 4.1 our paper. We run GILL to generate an image conditioned on the full text dialogue input:

python evals/generate_visdial_images.py  gill_visdial_outputs

The generated images for each VisDial example will be saved in gill_visdial_outputs/. Then, to benchmark the models, we can compute the CLIP similarity scores:

python evals/compute_clip_similarity_visdial.py

For LPIPS, please follow the VIST instructions above to compute scores using the official LPIPS GitHub repo.

Gradio Demo

You can launch your own version of the Gradio demo locally by running python demo/app_gradio.py, or duplicating the HuggingFace space.

TODOs

  • Add web demo.
  • Add evaluation scripts for reproducing the results in the paper.
  • Add training code and instructions for training a new GILL model on CC3M.

Citation

If you find this work or our code useful, please consider citing:

@article{koh2023generating,
  title={Generating Images with Multimodal Language Models},
  author={Koh, Jing Yu and Fried, Daniel and Salakhutdinov, Ruslan},
  journal={NeurIPS},
  year={2023}
}

gill's People

Contributors

kohjingyu avatar ray-ruisun avatar vishaal27 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

gill's Issues

How could this affect the performance?

Hi,
I'm training the model with "Llama2" as the frozen LLM, I was wondering how exactly this part of the code could affect the model training and performance? or actually what is the purpose of this?
meaning going thorough if condition or the else part

image

Would really appreciate the help
Best

shape mismatch in the example "Multimodal Dialogue"

I copy the whole jupyter code to the "test.py", and when I run the "Multimodal Dialogue" part, I encounter mistakes:

Traceback (most recent call last):
  File "test.py", line 130, in <module>
    full_outputs = generate_dialogue(prompts, num_words=num_words, sf=sf, temperature=temperature, top_p=top_p)
  File "test.py", line 67, in generate_dialogue
    return_outputs = model.generate_for_images_and_texts(
  File "/home/zhcheng/CoMT/Gill/gill/models.py", line 719, in generate_for_images_and_texts
    gen_emb = self.model.gen_text_hidden_fcs[0](raw_emb, gen_prefix_embs)  # (1, 77, 768)
  File "/home/zhcheng/anaconda3/envs/gill/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/zhcheng/CoMT/Gill/gill/layers.py", line 32, in forward
    x = x + input_embs
RuntimeError: The size of tensor a (0) must match the size of tensor b (8) at non-singleton dimension 1

Incorrect Loss Calculation in 'generation' Mode in validate.py

image

Hello there! ๐Ÿ‘‹ Thank you very much for this fascinating work
I think I noticed a small issue in the code. When model_mode is set to 'generation', the loss variable is being calculated using "args.ret_loss_scale", which seems to be incorrect. Shouldn't it use "args.gen_loss_scale" instead? I see no other reason for it to be set to retrieval loss. :)

About error when running Precomputing Text Embeddings and Train

Hi authors,

I have two problems, one is when I try to run the following script.

python scripts/preprocess_sd_embeddings.py  datasets/cc3m_val.tsv data/cc3m/validation/clip_embs

, then I got.
image

And run the training script

randport=$(shuf -i8000-9999 -n1)  # Generate a random port number
python -u main.py \
    --dist-url "tcp://127.0.0.1:${randport}" --dist-backend 'nccl' \
    --multiprocessing-distributed --world-size 1 --rank 0 \
    --dataset=cc3m  --val-dataset=cc3m \
    --exp-name='gill_exp' --image-dir='data/'  --log-base-dir='runs/' \
    --precision='bf16'  --print-freq=100

, then got

image

release the evaluation code

Hey, thank you so much for the great work! I was wondering when will you release the evaluation code regarding to VIST dataset and the calculation of CLIP/ LPIPS scores. Thanks!

Estimated time-line for code and weights?

Hey, thanks so much for your great work -- I enjoyed reading the paper! Do you have an estimated timeline for when you are planning to release the code and pre-trained weights for the model?

Normalization of cc3m features

I enjoy this job, but I've come across some issues. Why isn't there any normalization performed when extracting visual features? I noticed that CC3M also doesn't apply this operation during testing, but query features are normalized. Will this have an impact on the retrieval results?

img_tensor = utils.get_pixel_values_for_model(feature_extractor, img)
img_tensor = img_tensor[None, ...].to(device).bfloat16()
img_emb = model.model.get_visual_embs(img_tensor, mode="retrieval")
img_emb = img_emb[0, 0, :].cpu()

Multimodal generation in one pass

Hi, thanks for sharing this awesome work.

As I was trying your system more and more, a few questions popped up in my mind:

  1. In my experience, I am seeing instances where the LLM starts generating text but instead of finishing the sentence, it generates [IMG0] token and then because of forcibly adding the remaining IMG tokens, that text output remains incomplete. Whatever text tokens the model generates after the first batch of [IMG] tokens are mostly garbage. How often have you encountered this in your experiments? What could be the possible solution?
  2. I noticed that you had sf=1.4 in your notebook for multimodal dialogue generation. How did you choose this hyper-parameter to balance the generation between text and image? Was it the same for the 4 examples in figure 5 from your paper? Same questions go for num_words and min_word_tokens.
  3. Did you evaluate GILL's text generation ability on any task e.g. VQA or plan to do so in future?

Generated image quality

Hi @kohjingyu , thanks for your awesome work and sharing the code, but I found that the generated image quality is not very good. Do I have any problem in running the code? such as "an astronaut riding a horse on mars", "an cat riding a horse on mars" and "an dog riding a horse on mars". The results are as follows:

astronaut2
cat2
dog2

The number of astronaut and the generated cat or dog is not correct, do you have any suggestions? Thanks very much!

[solved]

TypeError: Descriptors cannot not be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:

  1. Downgrade the protobuf package to 3.20.x or lower.
  2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).

More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates


solved.

About the running log

Thanks for sharing this excellent work! Since many CC3M images have expired, I used the LLava-595k CC3M subset for training. Can you provide the training log based on complete CC3M, thx.

param.grad is None !

Hi! Thank you for your great work!

After preparing datasets and pretrained model, I trained the model using this command:

randport=$(shuf -i8000-9999 -n1) # Generate a random port number
python -u main.py
--dist-url "tcp://127.0.0.1:${randport}" --dist-backend 'nccl'
--multiprocessing-distributed --world-size 1 --rank 0 --batch-size=256
--dataset=cc3m --val-dataset=cc3m
--exp-name='gill_exp' --image-dir='/data/vol1/public-datasets/03-CC/cc3m/' --log-base-dir='runs/'
--precision='bf16' --print-freq=100
--opt-version='/data/models/facebook/opt-6.7b' --visual-model='/data/pretrained_weights/openai/clip-vit-large-patch14' --workers=0

No code is modified execpt the data path. However, I got this error:

Traceback (most recent call last):
File "/root/miniconda3/envs/vlm/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 74, in _wrap
fn(i, *args)
File "/data/vol1/zky/methods/gill/main.py", line 402, in main_worker
train(train_loader, model, tokenizer, criterion, optimizer, epoch, scheduler, args)
File "/data/vol1/zky/methods/gill/main.py", line 586, in train
assert param.grad.shape[0] == len(tokenizer)
AttributeError: 'NoneType' object has no attribute 'shape'

It seems like the parm in model.module.model.input_embeddings.parameters() has no grad. Could you teach me how to solve this problem? Thank you!

Error size mismatch when load decision model

After training both gill and decision model, load_model failed:

โ•ญโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Traceback (most recent call last) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ
โ”‚ in <cell line: 2>:2                                                                              โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /content/gill/gill/models.py:873 in load_gill                                                    โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   870 โ”‚   decision_model_path = None                                                             โ”‚
โ”‚   871                                                                                            โ”‚
โ”‚   872   # Initialize model for inference.                                                        โ”‚
โ”‚ โฑ 873   model = GILL(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix,              โ”‚
โ”‚   874 โ”‚   โ”‚   โ”‚      load_sd=True, num_gen_images=1, decision_model_path=decision_model_path)    โ”‚
โ”‚   875   model = model.eval()                                                                     โ”‚
โ”‚   876   model = model.bfloat16()                                                                 โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /content/gill/gill/models.py:560 in __init__                                                     โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   557 โ”‚   โ”‚     nn.Linear(768, 2),                                                               โ”‚
โ”‚   558 โ”‚     ])                                                                                   โ”‚
โ”‚   559 โ”‚     mlp_checkpoint = torch.load(decision_model_path)                                     โ”‚
โ”‚ โฑ 560 โ”‚     self.decision_model.load_state_dict(mlp_checkpoint['state_dict'], strict=False)      โ”‚
โ”‚   561 โ”‚     self.decision_model.eval()                                                           โ”‚
โ”‚   562                                                                                            โ”‚
โ”‚   563   def __call__(self, images: Tensor, tgt_tokens: Optional[Tensor] = None, caption_len: O   โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1671 in load_state_dict       โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   1668 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   ', '.join('"{}"'.format(k) for k in missing_keys)))               โ”‚
โ”‚   1669 โ”‚   โ”‚                                                                                     โ”‚
โ”‚   1670 โ”‚   โ”‚   if len(error_msgs) > 0:                                                           โ”‚
โ”‚ โฑ 1671 โ”‚   โ”‚   โ”‚   raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(     โ”‚
โ”‚   1672 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚      self.__class__.__name__, "\n\t".join(error_msgs)))         โ”‚
โ”‚   1673 โ”‚   โ”‚   return _IncompatibleKeys(missing_keys, unexpected_keys)                           โ”‚
โ”‚   1674                                                                                           โ”‚
โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ
RuntimeError: Error(s) in loading state_dict for Sequential:
        size mismatch for 1.weight: copying a param with shape torch.Size([2, 768]) from checkpoint, the shape in 
current model is torch.Size([2, 4096]).

Inference shape is not 8

Thank you for the good code. However, the inference code appears as follows. The value of the first dimension of the actual raw_emb tensor is 0, not 8.
image

How to get cc3m_embeddings

Hi Authors,

How did you get cc3m_embeddings.npy ? Do you use the trained vision encoder and linear project layer? Or you only use the method of Vector database(I don't know about the Vector database at all)

Thank you!

Problem in running the evaluation script

I have problem in running the evaluation

Are saving some images I am getting the Errors in both VIST Evaluation and VisDial Evaluation.

Saving to /home/mbzuaiser/gill/gill_vist_outputs/2591919314.png โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–‰ | 49/50 [00:02<00:00, 17.59it/s]
4%|โ–ˆโ–ˆโ–ˆ| 185/4990 [40:27<17:30:59, 13.12s/it]

Traceback (most recent call last):
File "evals/generate_vist_images.py", line 74, in
return_outputs = model.generate_for_images_and_texts(
File "/home/mbzuaiser/gill/gill/models.py", line 688, in generate_for_images_and_texts
img = utils.get_image_from_url(self.path_array[img_idx])
File "/home/mbzuaiser/gill/gill/utils.py", line 29, in get_image_from_url
img = img.resize((224, 224))
File "/home/mbzuaiser/gill/venv/lib/python3.8/site-packages/PIL/Image.py", line 2138, in resize
self.load()
File "/home/mbzuaiser/gill/venv/lib/python3.8/site-packages/PIL/ImageFile.py", line 266, in load
raise OSError(msg)
OSError: image file is truncated (0 bytes not processed)

Can you Please help me fix this ?

RuntimeError: CUDA error: no kernel image is available for execution on the device CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

RuntimeError: CUDA error: no kernel image is available for execution on the device
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

Queries regarding the Precomputed Text Embeddings

Dear Authors,

I found your work very interesting. Can you please clarify how are the text embedding generated and what text encoder of stable diffusion is used like "openai/clip-vit-base-patch32" or "openai/clip-vit-base-patch14" or something less.
Or can you please provide the code of it.

A few questions about the training pipeline

Hi,

I read your paper and it was a great work! Thanks for sharing your codebase with the community. As I was going through your codes, I came across a few places, where I would greatly appreciate your explanations/suggestions. Here are my questions -

  • What does the CE loss from here stand for, i.e. which of the 4 losses from the paper it refers to?
  • Does the CE loss from here refer to the l_p loss in the paper? If not, which loss form the paper it refers to?
  • From these lines (line1, line2, line3), it looks like all tokens that are not part of the caption text or [IMG0] have been set to -100 to be ignored from calculating loss. Is my understanding correct? If it is, how are we learning embeddings for other [IMG{r}] tokens (r={2,3,...,8})?

Visdial็›ธๅ…ณ้—ฎ้ข˜

image 1. ็”จ็š„ๆ˜ฏ่ฎญ็ปƒ้›†๏ผŒsplit=val๏ผŒๅนถไธ”ๆฒกๆœ‰ๆŠŠimageไฝœไธบ่พ“ๅ…ฅใ€‚ๅบ”่ฏฅๆ˜ฏdialogs+image -> image? 2. ๅ’Œvist็š„ไฝฟ็”จๆ–นๅผๆœ‰ๆŒบๅคงๅŒบๅˆซ๏ผŒvist็”จ็š„ๆ˜ฏdialogs+image -> image

GILL Image Retrieval Code on VIST

Hi! Great work.

I want to reproduce the GILL results in Table 5.
image

I'm wondering what code changes need to be done in eval_vist_retrieval.py from FROMAGe.

I've tried changing these lines(L135-L142 and L154-L159) but the results seem not quite right...

"""L135-142"""
all_input_ids = []
for i, c in enumerate(captions):
    if i == len(captions) - 1:
        c += '[IMG]'     # c += '[RET]'  
    input_ids = model.model.tokenizer(c, add_special_tokens=True, return_tensors="pt").input_ids.to(emb_matrix.device)
    all_input_ids.append(input_ids)

input_embs = [model.model.input_embeddings(s)[0, ...] for s in all_input_ids]  # (N, T, D)

"""L154-159"""
output = model.model.lm(inputs_embeds=final_input_embs, labels=None, use_cache=False, output_hidden_states=True)
# last_hidden_state = model.model.text_hidden_fcs[0](output.hidden_states[-1])
last_hidden_state = model.model.ret_text_hidden_fcs[0](output.hidden_states[-1], None)
ret_emb = last_hidden_state[:, -1, :]

ret_emb = ret_emb / ret_emb.norm(dim=1, keepdim=True)
scores = ret_emb.squeeze() @ emb_matrix.squeeze().T

Looking forward to your reply.

Thanks.

Custom SD pipeline with hard-coded left truncation of text prompts

Hi,

Thanks a lot for sharing your codebase. As I was trying to run your code on my own data, I figured out that here, you have set truncate_side to be left. What is the rationale behind it? I can see that the official SD pipeline truncates text prompts from the right side. Truncating from the left leaves me with a series of <|endoftext|> (as expected). It would be great if you could take a look. Thanks.

Shape mismatch in example notebook

Thanks for your great work.
Trying your GILL_Inference_Examples.ipynb, I got shape mismatch on Linear layer when the model is piped.
Any ideas for why? Since # of mismatch is 1, maybe related to bias term?

โ•ญโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Traceback (most recent call last) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ
โ”‚ in <cell line: 6>:13                                                                             โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /content/gill/gill/models.py:700 in generate_for_images_and_texts                                โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   697 โ”‚   โ”‚   โ”‚     decision_emb = raw_emb[:, 0, :]  # (1, 4096)                                 โ”‚
โ”‚   698 โ”‚   โ”‚   โ”‚     assert decision_emb.shape[1] == 4096, decision_emb.shape                     โ”‚
โ”‚   699 โ”‚   โ”‚   โ”‚     max_ret_score = scores.max().reshape((1, 1)).clone().detach().to(device=de   โ”‚
โ”‚ โฑ 700 โ”‚   โ”‚   โ”‚     decision_logits = self.decision_model(torch.cat([decision_emb, max_ret_sco   โ”‚
โ”‚   701 โ”‚   โ”‚   โ”‚     probs = decision_logits.softmax(dim=-1).cpu().float().numpy().tolist()       โ”‚
โ”‚   702 โ”‚   โ”‚   โ”‚     image_outputs['decision'] = [self.idx2dec[decision_logits.argmax().item()]   โ”‚
โ”‚   703 โ”‚   โ”‚     else:                                                                            โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1194 in _call_impl            โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   1191 โ”‚   โ”‚   # this function, and just call forward.                                           โ”‚
โ”‚   1192 โ”‚   โ”‚   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  โ”‚
โ”‚   1193 โ”‚   โ”‚   โ”‚   โ”‚   or _global_forward_hooks or _global_forward_pre_hooks):                   โ”‚
โ”‚ โฑ 1194 โ”‚   โ”‚   โ”‚   return forward_call(*input, **kwargs)                                         โ”‚
โ”‚   1195 โ”‚   โ”‚   # Do not call functions when jit is used                                          โ”‚
โ”‚   1196 โ”‚   โ”‚   full_backward_hooks, non_full_backward_hooks = [], []                             โ”‚
โ”‚   1197 โ”‚   โ”‚   if self._backward_hooks or _global_backward_hooks:                                โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py:204 in forward             โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   201 โ”‚   # with Any as TorchScript expects a more precise type                                  โ”‚
โ”‚   202 โ”‚   def forward(self, input):                                                              โ”‚
โ”‚   203 โ”‚   โ”‚   for module in self:                                                                โ”‚
โ”‚ โฑ 204 โ”‚   โ”‚   โ”‚   input = module(input)                                                          โ”‚
โ”‚   205 โ”‚   โ”‚   return input                                                                       โ”‚
โ”‚   206 โ”‚                                                                                          โ”‚
โ”‚   207 โ”‚   def append(self, module: Module) -> 'Sequential':                                      โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1194 in _call_impl            โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   1191 โ”‚   โ”‚   # this function, and just call forward.                                           โ”‚
โ”‚   1192 โ”‚   โ”‚   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  โ”‚
โ”‚   1193 โ”‚   โ”‚   โ”‚   โ”‚   or _global_forward_hooks or _global_forward_pre_hooks):                   โ”‚
โ”‚ โฑ 1194 โ”‚   โ”‚   โ”‚   return forward_call(*input, **kwargs)                                         โ”‚
โ”‚   1195 โ”‚   โ”‚   # Do not call functions when jit is used                                          โ”‚
โ”‚   1196 โ”‚   โ”‚   full_backward_hooks, non_full_backward_hooks = [], []                             โ”‚
โ”‚   1197 โ”‚   โ”‚   if self._backward_hooks or _global_backward_hooks:                                โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:114 in forward                โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   111 โ”‚   โ”‚   โ”‚   init.uniform_(self.bias, -bound, bound)                                        โ”‚
โ”‚   112 โ”‚                                                                                          โ”‚
โ”‚   113 โ”‚   def forward(self, input: Tensor) -> Tensor:                                            โ”‚
โ”‚ โฑ 114 โ”‚   โ”‚   return F.linear(input, self.weight, self.bias)                                     โ”‚
โ”‚   115 โ”‚                                                                                          โ”‚
โ”‚   116 โ”‚   def extra_repr(self) -> str:                                                           โ”‚
โ”‚   117 โ”‚   โ”‚   return 'in_features={}, out_features={}, bias={}'.format(                          โ”‚
โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x4097 and 4096x2)

Issue with multiple image inputs. (Only last image input taken into consideration)

Hi,
Thank you so much for this amazing work! I have been evaluating the model on different tasks, so I tried to run GILL with provided checkpoint on the Visual Storytelling task on the VIST dataset (task describe by figure 1 of linked paper). I observed that the model seems to only take into account the very last image rather than considering all the images.
Please find an attached screenshot of the Hugging Face spaces chat below demonstrating this issue. This is a datapoint from the VIST dataset. Could you please help me understand why this seems to be an issue?
GILL-1695125715748

Error while loading the gill model

Hi @kohjingyu, thank you so much for this amazing work!
I am trying to replicate the results and I have been using the "GILL_Inference_Examples.ipynb" you have provided. While running the code to load the gill model, I encounter the following error message:

Code:
model = models.load_gill(model_dir)

Error:
KeyError Traceback (most recent call last)
Cell In[2], line 3
1 # Download the model checkpoint and embeddings to checkpoints/gill_opt/
2 model_dir = 'checkpoints/gill_opt/'
----> 3 model = models.load_gill(model_dir)

File /home/wiseyak/aafiya/gill/gill/models.py:883, in load_gill(model_dir, load_ret_embs)
881 for k, v in checkpoint['state_dict'].items():
882 state_dict[k.replace('module.', '')] = v
--> 883 img_token_embeddings = state_dict['model.input_embeddings.weight'].cpu().detach()
884 del state_dict['model.input_embeddings.weight']
885 print("Extracted image token embeddings")

KeyError: 'model.input_embeddings.weight'

Could you please guide me on how to resolve this?

Question about PartiPrompts

Hey,
I was just going over the annotated PartiPrompts data you've uploaded, and had a hard time figuring out the exact structure.

Could you please describe what the different attributes here mean? I assume the prompt basically means the caption corresponding to the images shown.

Prompt	Category	Challenge	Note	Decisions
bond	Abstract	Basic	Biology-inspired concepts with multiple meanings	ret,gen,gen,same,gen
element	Abstract	Basic	Biology-inspired concepts with multiple meanings	ret,ret,ret,ret,same
molecule	Abstract	Basic	Biology-inspired concepts with multiple meanings	gen,gen,gen,gen,gen

Another related question is: Can we access the exact images that the human annotators saw while making the decisions? i.e. from the paper Fig. 9, it seems like you've collected text-image alignment and fidelity ratings from annotators for each pair of images, is the raw pairwise preference data public (with both annotations)?

instruction tuning on other datasets

Appreciate the provided code.
I have a question regarding the training process. While the LLM is in the process of learning to generate images, it generates the same text as the input. Nevertheless, during inference, the model is able to produce sensible responses for the input. Consequently, I'm curious if there are any other instruction datasets utilized for fine-tuning the model's ability to follow instructions. If such datasets are indeed employed, could the instruction fine-tuning resources be made publicly accessible?

training setting for reproduce the paper

Hi. Thank you for sharing your great research!
Your work is very inspiring.
I have a question about the training settings in the code to reproduce your paper.

Here's what your paper says:
We trained the model with batch size 180 for 1 epoch (18000 iterations) on 1 A6000 GPU (24 clock hours).

and parameters in your code are defined by follows (main.py):
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--steps_per_epoch', default=2000, type=int, metavar='N',
help='number of training steps per epoch')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--val_steps_per_epoch', default=-1, type=int, metavar='N',
help='number of validation steps per epoch')
parser.add_argument('-b', '--batch-size', default=200, type=int,
metavar='N',
help='mini-batch size (default: 200), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel')

I think there is a difference in the training parameters (especially batch size and iteration) between paper and code, can you clarify this?

Problem in running the evaluation script

I still have the problem in running the evaluation script.

After certain point of the iteration the code is getting stuck.

Saving to /home/mbzuaiser/gill/gill_vist_outputs/514809043.pngโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–‰ | 49/50 [00:02<00:00, 17.78it/s]100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 50/50 [00:02<00:00, 17.43it/s]
Saving to /home/mbzuaiser/gill/gill_vist_outputs/514808431.png โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–‰ | 49/50 [00:02<00:00, 17.76it/s]
6%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ– | 279/4990 [51:08<10:47:19, 8.24s/it]

In both the case of VIST and VisDial.
As you have given the solution earlier I have [add an except for OSError](). Still I am facing the same issue.

It would be great if you can please help me in this

Query regarding the preprocessing of the data

Hi,

Great work!!!!

After downloading the cc3m data and creating .tsv file I am getting the error

Error reading for /home/mbzuaiser/gill/data/training/01963207---stock-photo-an-brazil-road-pare-stop-sign-with-a-sky-background-and-copy-space-for-your-message-110146604.jpg with caption a road pare stop sign with a sky background and copy space for your message: Failed to interpret file <_io.BufferedReader name='/home/mbzuaiser/gill/data/training/01963207---stock-photo-an-brazil-road-pare-stop-sign-with-a-sky-background-and-copy-space-for-your-message-110146604.jpg.npy'> as a pickle

Can you please help me in fixing this error

FID Evaluation on CC3M and VIST

Hi!

Congratulations on great work!

Could you please point me to the code to reproduce results in Table 3 and Table 4, particularly FID scores on CC3M and VIST dataset? What splits do you use and is there any implementation in the code that I missed?

Any help is greatly appreciated!

Thanks!

shape mismatch in the example notebook

Thanks for the great work! I came across the following issue when I tried to run the example notebook:

File [/afs/cs.wisc.edu/u/y/z/yzeng58/micl/models/gill/gill/models.py:712](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a224f4c56492d32227d.vscode-resource.vscode-cdn.net/afs/cs.wisc.edu/u/y/z/yzeng58/micl/models/gill/gill/models.py:712), in GILL.generate_for_images_and_texts(self, prompts, num_words, min_word_tokens, ret_scale_factor, gen_scale_factor, top_p, temperature, max_num_rets, generator, always_add_bos, guidance_scale, num_inference_steps)
    [710](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a224f4c56492d32227d.vscode-resource.vscode-cdn.net/afs/cs.wisc.edu/u/y/z/yzeng58/micl/models/gill/gill/models.py:710) print(raw_emb.shape)
    [711](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a224f4c56492d32227d.vscode-resource.vscode-cdn.net/afs/cs.wisc.edu/u/y/z/yzeng58/micl/models/gill/gill/models.py:711) print(gen_prefix_embs.shape)
--> [712](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a224f4c56492d32227d.vscode-resource.vscode-cdn.net/afs/cs.wisc.edu/u/y/z/yzeng58/micl/models/gill/gill/models.py:712) gen_emb = self.model.gen_text_hidden_fcs[0](raw_emb, gen_prefix_embs)  # (1, 77, 768)
    [714](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a224f4c56492d32227d.vscode-resource.vscode-cdn.net/afs/cs.wisc.edu/u/y/z/yzeng58/micl/models/gill/gill/models.py:714) if gen_emb.shape[1] != 77:
    [715](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a224f4c56492d32227d.vscode-resource.vscode-cdn.net/afs/cs.wisc.edu/u/y/z/yzeng58/micl/models/gill/gill/models.py:715)   print(f"Padding {gen_emb.shape} with zeros")

File [~/anaconda3/envs/micl/lib/python3.8/site-packages/torch/nn/modules/module.py:1518](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a224f4c56492d32227d.vscode-resource.vscode-cdn.net/u/y/z/yzeng58/micl/trials/~/anaconda3/envs/micl/lib/python3.8/site-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1516](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a224f4c56492d32227d.vscode-resource.vscode-cdn.net/u/y/z/yzeng58/micl/trials/~/anaconda3/envs/micl/lib/python3.8/site-packages/torch/nn/modules/module.py:1516)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1517](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a224f4c56492d32227d.vscode-resource.vscode-cdn.net/u/y/z/yzeng58/micl/trials/~/anaconda3/envs/micl/lib/python3.8/site-packages/torch/nn/modules/module.py:1517) else:
-> [1518](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a224f4c56492d32227d.vscode-resource.vscode-cdn.net/u/y/z/yzeng58/micl/trials/~/anaconda3/envs/micl/lib/python3.8/site-packages/torch/nn/modules/module.py:1518)     return self._call_impl(*args, **kwargs)
...
---> [32](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a224f4c56492d32227d.vscode-resource.vscode-cdn.net/afs/cs.wisc.edu/u/y/z/yzeng58/micl/models/gill/gill/layers.py:32)   x = x + input_embs
     [34](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a224f4c56492d32227d.vscode-resource.vscode-cdn.net/afs/cs.wisc.edu/u/y/z/yzeng58/micl/models/gill/gill/layers.py:34) if isinstance(self.model, nn.ModuleList):
     [35](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a224f4c56492d32227d.vscode-resource.vscode-cdn.net/afs/cs.wisc.edu/u/y/z/yzeng58/micl/models/gill/gill/layers.py:35)   assert len(self.model) == x.shape[1] == self.num_input_tokens, (len(self.model), x.shape, self.num_input_tokens)

RuntimeError: The size of tensor a (8) must match the size of tensor b (15) at non-singleton dimension 1

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.