Comments (15)
Thanks for your kind words!
What does the CE loss from here stand for, i.e. which of the 4 losses from the paper it refers to?
Does the CE loss from here refer to the l_p loss in the paper? If not, which loss form the paper it refers to?
This and the loss defined on line 508 make up
From these lines (line1, line2, line3), it looks like all tokens that are not part of the caption text or [IMG0] have been set to -100 to be ignored from calculating loss. Is my understanding correct? If it is, how are we learning embeddings for other [IMG{r}] tokens (r={2,3,...,8})?
That's right, and the reason for this is that we force the generation of the r={2,3,...,8}
tokens as the next 7 tokens whenever the model produces [IMG0]
(since we always need all 8 tokens for generation/retrieval, so it doesn't make sense to have a partial set of the [IMG]
tokens). The embeddings of [IMG2]...[IMG8]
tokens are therefore only learnt through the other losses (in particular the generation loss [IMG2]...[IMG8]
tokens. So the model will never produce [IMG2]...[IMG8]
organically, but their representations are still helpful for feeding into the GILLMapper module for image generation.
Hope that makes sense!
from gill.
As I was trying to train my own model and run inference using the saved checkpoint, I noticed a few things, please verify (might be helpful for other users).
- Because
GILLArgs()
only has local variables and no attributes, all of them don't get saved in themodel_args.json
file unless specifically set after instantiating. One such attribute istext_emb_layers
. Turning all local variables into class attributes can solve this. - Currently the
main.py
script also saves the scheduler state, which pretty much saves the entire model (probable reason) and therefore results in a large checkpoint. - The pretrained checkpoint provided in this codebase has a shape of (8,4096) for
input_embeddings.weight
whereas running themain.py
will produce a checkpoint withinput_embeddings.weight
of shape (50274, 4096). Looks like the provided checkpoint contains only the trainable [IMG] token embeddings. This requires either changing this line or this line to run inference with the produced checkpoint. For example,
img_token_embeddings = state_dict['model.input_embeddings.weight'].cpu().detach()[-model_kwargs['num_tokens']:, :]
from gill.
You're absolutely right, thanks for pointing this out! We'll fix this in the paper soon. The correct scheme should be that the loss is only considered for the first [IMG0]
token. The part about forcing generation of the remaining tokens during inference is still true.
from gill.
Thanks for your quick response. Reopening this issue for another query regarding the pipeline (didn't want to unnecessarily create new issue).
If I am not wrong, this line makes the entire OPT embedding layer trainable. It is also evident from the param_count.txt
file your scripts generate. However, according to the paper only the [IMG]
embedding matrix Eimg was supposed to be trainable. Did I miss anything here?
First few lines from param_count.txt :
Module | Trainable | Shape | Param Count |
| model.logit_scale | True | () | 1 |
| model.lm.model.decoder.embed_tokens.weight | True | (50274, 4096) | 205,922,304 |
from gill.
You're right, they become trainable, which is why we zero out the gradients of the non-[IMG] embeddings here:
Lines 578 to 587 in 53fdcf2
This is not super ideal, but I think it is overall cleaner than concatenating a trainable embedding matrix with a frozen one.
from gill.
Thanks again. Unfortunately, I missed this section of the script.
Is it also correct to say that for the lp loss, you are considering the loss for generating each token of the input text (caption) i.e. the negative log likelihood of generating token st conditioned on s1,...,st-1 where t={1,...,T}?
from gill.
Yes that's right!
from gill.
In that case, I believe equation 2 is slightly misleading, as the summation goes over i
from 1
to r
there. This practically says that we are considering loss for generating all 8 [IMG]
tokens.
from gill.
Thanks for sharing this!
- The pretrained checkpoint provided in this codebase has a shape of (8,4096) for
input_embeddings.weight
whereas running themain.py
will produce a checkpoint withinput_embeddings.weight
of shape (50274, 4096). Looks like the provided checkpoint contains only the trainable [IMG] token embeddings. This requires either changing this line or this line to run inference with the produced checkpoint. For example,
You're right, and I also realized that I hadn't uploaded the script used to prune the checkpoints (keeping just the trained weights, and discarding the pretrained model weights). I just did that here: https://github.com/kohjingyu/gill/blob/main/scripts/prune_model_ckpt.py
I think this is essentially the same as the changes you probably made locally, though I haven't tested this script in a while.
from gill.
Thanks for sharing the script! Just noticed a few things -
- These arguments no longer exist in the current version. Could it be that you probably coalesced them into
num_tokens
? Please verify. - This line gives error, as it is trying to mutate an ordered dict during the for loop. This can be avoided by making an empty dict first and then copying everything there (just like it's done here in
models.py
) - I believe the example usage should be
python scripts/prune_model_ckpt.py runs/gill_exp
, given the location of the script. - What's the use of
share_ret_gen
? I could not find any use of this in themodels.py
,validate.py
ormain.py
script.
from gill.
Thanks for the notes! Sorry about this, it's what happens when you don't test before you upload...
These arguments no longer exist in the current version. Could it be that you probably coalesced them into num_tokens? Please verify.
That's right.
What's the use of share_ret_gen? I could not find any use of this in the models.py, validate.py or main.py script.
share_ret_gen
doesn't exist anymore, I think it was something used during debugging previously. I've updated the script as such, hopefully it works as expected now. Thanks for your help in debugging this!
from gill.
Another small update. I could not find warmup-scheduler==0.3.2
(as mentioned in the requirements.txt file), the current available version is probably 0.3. Will it be compatible with your scripts? (I can verify that the training continues with this version)
from gill.
Ah, looks like it should be pip install git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git
instead. The link you provided should still work though.
from gill.
I have another question. As mentioned above, the lp loss includes the negative log likelihood (NLL) of generating each token of the input text (caption). Did you find this helpful for the overall model performance? I am asking this because from the name and purpose of this loss, I would assume that it was intended to only consider the NLL of generating [IMG]
tokens.
from gill.
I have not run this particular ablation, sorry. I would guess that it does not have a significant effect on performance on the tasks we evaluated on.
from gill.
Related Issues (20)
- Clarification on precomputing the visual embeddings HOT 1
- How to get cc3m_embeddings HOT 1
- About the running log HOT 4
- Normalization of cc3m features HOT 1
- How could this affect the performance? HOT 10
- About error when running Precomputing Text Embeddings and Train HOT 2
- shape mismatch in the example notebook HOT 2
- [solved]
- why don't you use universal representation in one task?
- GILL Image Retrieval Code on VIST HOT 1
- Inference shape is not 8 HOT 1
- Visdial相关问题
- Error size mismatch when load decision model HOT 2
- RuntimeError: CUDA error: no kernel image is available for execution on the device CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
- param.grad is None !
- shape mismatch in the example "Multimodal Dialogue" HOT 1
- FID Evaluation on CC3M and VIST
- i try to dowmload cc3m using tools recommand by readme.md, but the number of picture can be download only 10% . is it normal?
- about [img] token and train data
- environment conflict
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from gill.