Giter Club home page Giter Club logo

Comments (15)

kohjingyu avatar kohjingyu commented on September 13, 2024 2

Thanks for your kind words!

What does the CE loss from here stand for, i.e. which of the 4 losses from the paper it refers to?

Does the CE loss from here refer to the l_p loss in the paper? If not, which loss form the paper it refers to?

This and the loss defined on line 508 make up $l_p$ in the paper (equation 2). This is the loss for training the model to produce the [IMG] tokens at the end of "caption-like" text. Both L506 and L508 are actually the same loss, since the same caption (e.g., "A picture of a dog [IMG0]...[IMG7]") are used for retrieval and generation. This is why they have a 0.5 multiplier, so that they sum to be $l_p$.

From these lines (line1, line2, line3), it looks like all tokens that are not part of the caption text or [IMG0] have been set to -100 to be ignored from calculating loss. Is my understanding correct? If it is, how are we learning embeddings for other [IMG{r}] tokens (r={2,3,...,8})?

That's right, and the reason for this is that we force the generation of the r={2,3,...,8} tokens as the next 7 tokens whenever the model produces [IMG0] (since we always need all 8 tokens for generation/retrieval, so it doesn't make sense to have a partial set of the [IMG] tokens). The embeddings of [IMG2]...[IMG8] tokens are therefore only learnt through the other losses (in particular the generation loss $l_g$), when their embeddings/hidden states are used for computing the generation/retrieval objectives. $l_p$ doesn't affect [IMG2]...[IMG8] tokens. So the model will never produce [IMG2]...[IMG8] organically, but their representations are still helpful for feeding into the GILLMapper module for image generation.

Hope that makes sense!

from gill.

avipartho avatar avipartho commented on September 13, 2024 2

As I was trying to train my own model and run inference using the saved checkpoint, I noticed a few things, please verify (might be helpful for other users).

  • Because GILLArgs() only has local variables and no attributes, all of them don't get saved in the model_args.json file unless specifically set after instantiating. One such attribute is text_emb_layers. Turning all local variables into class attributes can solve this.
  • Currently the main.py script also saves the scheduler state, which pretty much saves the entire model (probable reason) and therefore results in a large checkpoint.
  • The pretrained checkpoint provided in this codebase has a shape of (8,4096) for input_embeddings.weight whereas running the main.py will produce a checkpoint with input_embeddings.weight of shape (50274, 4096). Looks like the provided checkpoint contains only the trainable [IMG] token embeddings. This requires either changing this line or this line to run inference with the produced checkpoint. For example,
img_token_embeddings = state_dict['model.input_embeddings.weight'].cpu().detach()[-model_kwargs['num_tokens']:, :]

from gill.

kohjingyu avatar kohjingyu commented on September 13, 2024 1

You're absolutely right, thanks for pointing this out! We'll fix this in the paper soon. The correct scheme should be that the loss is only considered for the first [IMG0] token. The part about forcing generation of the remaining tokens during inference is still true.

from gill.

avipartho avatar avipartho commented on September 13, 2024

Thanks for your quick response. Reopening this issue for another query regarding the pipeline (didn't want to unnecessarily create new issue).

If I am not wrong, this line makes the entire OPT embedding layer trainable. It is also evident from the param_count.txt file your scripts generate. However, according to the paper only the [IMG] embedding matrix Eimg was supposed to be trainable. Did I miss anything here?

First few lines from param_count.txt :

Module | Trainable | Shape | Param Count |

| model.logit_scale | True | () | 1 |
| model.lm.model.decoder.embed_tokens.weight | True | (50274, 4096) | 205,922,304 |

from gill.

kohjingyu avatar kohjingyu commented on September 13, 2024

You're right, they become trainable, which is why we zero out the gradients of the non-[IMG] embeddings here:

gill/main.py

Lines 578 to 587 in 53fdcf2

# Zero out gradients of the embedding matrix outside of [IMG].
for param in model.module.model.input_embeddings.parameters():
assert param.grad.shape[0] == len(tokenizer)
# Keep other embeddings frozen.
mask = torch.zeros((param.grad.shape[0], 1)).to(param.grad)
for ret_idx in args.retrieval_token_idx:
mask[ret_idx] = 1
for gen_idx in args.gen_token_idx:
mask[gen_idx] = 1
param.grad = param.grad * mask

This is not super ideal, but I think it is overall cleaner than concatenating a trainable embedding matrix with a frozen one.

from gill.

avipartho avatar avipartho commented on September 13, 2024

Thanks again. Unfortunately, I missed this section of the script.

Is it also correct to say that for the lp loss, you are considering the loss for generating each token of the input text (caption) i.e. the negative log likelihood of generating token st conditioned on s1,...,st-1 where t={1,...,T}?

from gill.

kohjingyu avatar kohjingyu commented on September 13, 2024

Yes that's right!

from gill.

avipartho avatar avipartho commented on September 13, 2024

In that case, I believe equation 2 is slightly misleading, as the summation goes over i from 1 to r there. This practically says that we are considering loss for generating all 8 [IMG] tokens.

from gill.

kohjingyu avatar kohjingyu commented on September 13, 2024

Thanks for sharing this!

  • The pretrained checkpoint provided in this codebase has a shape of (8,4096) for input_embeddings.weight whereas running the main.py will produce a checkpoint with input_embeddings.weight of shape (50274, 4096). Looks like the provided checkpoint contains only the trainable [IMG] token embeddings. This requires either changing this line or this line to run inference with the produced checkpoint. For example,

You're right, and I also realized that I hadn't uploaded the script used to prune the checkpoints (keeping just the trained weights, and discarding the pretrained model weights). I just did that here: https://github.com/kohjingyu/gill/blob/main/scripts/prune_model_ckpt.py

I think this is essentially the same as the changes you probably made locally, though I haven't tested this script in a while.

from gill.

avipartho avatar avipartho commented on September 13, 2024

Thanks for sharing the script! Just noticed a few things -

  • These arguments no longer exist in the current version. Could it be that you probably coalesced them into num_tokens? Please verify.
  • This line gives error, as it is trying to mutate an ordered dict during the for loop. This can be avoided by making an empty dict first and then copying everything there (just like it's done here in models.py)
  • I believe the example usage should be python scripts/prune_model_ckpt.py runs/gill_exp, given the location of the script.
  • What's the use of share_ret_gen? I could not find any use of this in the models.py, validate.py or main.py script.

from gill.

kohjingyu avatar kohjingyu commented on September 13, 2024

Thanks for the notes! Sorry about this, it's what happens when you don't test before you upload...

These arguments no longer exist in the current version. Could it be that you probably coalesced them into num_tokens? Please verify.

That's right.

What's the use of share_ret_gen? I could not find any use of this in the models.py, validate.py or main.py script.

share_ret_gen doesn't exist anymore, I think it was something used during debugging previously. I've updated the script as such, hopefully it works as expected now. Thanks for your help in debugging this!

from gill.

avipartho avatar avipartho commented on September 13, 2024

Another small update. I could not find warmup-scheduler==0.3.2 (as mentioned in the requirements.txt file), the current available version is probably 0.3. Will it be compatible with your scripts? (I can verify that the training continues with this version)

from gill.

kohjingyu avatar kohjingyu commented on September 13, 2024

Ah, looks like it should be pip install git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git instead. The link you provided should still work though.

from gill.

avipartho avatar avipartho commented on September 13, 2024

I have another question. As mentioned above, the lp loss includes the negative log likelihood (NLL) of generating each token of the input text (caption). Did you find this helpful for the overall model performance? I am asking this because from the name and purpose of this loss, I would assume that it was intended to only consider the NLL of generating [IMG] tokens.

from gill.

kohjingyu avatar kohjingyu commented on September 13, 2024

I have not run this particular ablation, sorry. I would guess that it does not have a significant effect on performance on the tasks we evaluated on.

from gill.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.