Comments (10)
Hi, @JoanFM
vilt_module.py doesn't tell much about its internal working.
Yes, if you want to let ViLT infer
take care of the embedding part, default values are for that.
The details on text-related data can be found here and here.
text_ids
are the tokenized ids those converted from the raw text by self.tokenizer
in get_text()
.
text_labels
are labels for MLM tasks and are generated by hugging face mlm_collator
, and they aren't used at all if the model doesn't perform the MLM task. (the default value is [-100] * seq_len
).
text_masks
are for masking out text paddings which are adopted to batch sentences with various lengths. (the default value is [1] * seq_len
).
batch["image"]
is a doubly nested list that also considers multiple views.
The details are here.
At the release of ViLT, we adopt a single view generation (augmentation) policy, so index the batch["image"]
with 0
to get the first view of images.
You can do a batch inference by properly generating the batch dictionary.
Follow the collate
function to learn how to generate the proper batch.
None of them are for the similarity metric.
It can only be acquired by passing cls_feats
to the itm_score
head.
from vilt.
Hello @dandelin .
I have a question.
I have written this code, and I do not understand 2 things:
- Why do I get different score all the time? Is this the good way to load VisonTransformer from a checkpoint?
- Also, why I get 2 scores in the output? Shouldn't there only be one as a similarity metric? How do I get the similarity metric from this?
import torch
from vilt import config
from vilt.modules import ViLTransformerSS
conf = config.config()
conf.load_path = 'vilt_irtr_f30k.ckpt'
conf.test_only = True
image_vilt = torch.ones(3, 224, 224)
batch = {}
batch['image'] = [image_vilt.unsqueeze(0)]
batch['text_ids'] = torch.IntTensor([[1, 0, 16, 32, 55]]) # random sentence tokens
batch['text_masks'] = torch.IntTensor([[1, 1, 1, 1, 1]]) # no masking
batch['text_labels'] = None
with torch.no_grad():
vilt = ViLTransformerSS(conf)
vilt.train(mode=False)
out = vilt(batch)
score = vilt.itm_score(out['cls_feats'])
print(f' score {score}')
from vilt.
Another question I have about the batch
processing you suggest using collate
:
As I see in the code, the inference
only consideres the first image
(view) in the batch[image_key]
so I do not understand how batching
can work in this case? Do u mean the batching of the same image against different texts
?
I would be more interested in having a match inference of the same text
with different images
from vilt.
Hi @JoanFM
I wrote a pedagogical code with comments.
Hope it answers your questions.
import torch
import copy
from vilt import config
from vilt.modules import ViLTransformerSS
# Scared config is immutable object, so you need to deepcopy it.
conf = copy.deepcopy(config.config())
conf["load_path"] = "vilt_irtr_coco.ckpt"
conf["test_only"] = True
# You need to properly configure loss_names to initialize heads (0.5 means it initializes head, but ignores the loss during training)
loss_names = {
"itm": 0.5,
"mlm": 0,
"mpp": 0,
"vqa": 0,
"imgcls": 0,
"nlvr2": 0,
"irtr": 1,
"arc": 0,
}
conf["loss_names"] = loss_names
# two different random images
image_vilt = torch.randn(2, 3, 224, 224)
batch = {}
batch["image"] = [image_vilt]
# repeated random sentence tokens
batch["text_ids"] = torch.IntTensor([[1, 0, 16, 32, 55], [1, 0, 16, 32, 55]])
# no masking
batch["text_masks"] = torch.IntTensor([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]])
batch["text_labels"] = None
with torch.no_grad():
vilt = ViLTransformerSS(conf)
vilt.train(mode=False)
out = vilt(batch)
itm_logit = vilt.itm_score(out["cls_feats"]).squeeze()
print(
f"itm logit, two logit (logit_neg, logit_pos) for a image-text pair.\n{itm_logit}"
)
itm_score = itm_logit.softmax(dim=-1)[:, 1]
print(f"itm score, one score for a image-text pair.\n{itm_score}")
# You should see "rank_output" head if loss_names["irtr"] > 0
score = vilt.rank_output(out["cls_feats"]).squeeze()
print(f"unnormalized irtr score, one score for a image-text pair.\n{score}")
normalized_score = score.softmax(dim=0)
print(
f"normalized (relative) irtr score, one score for a image-text pair.\n{score}"
)
from vilt.
Hey @dandelin, thank you for the quick and clear response.
It answers but still poses some doubts.
- What is the logit about, what is logit_pos and logit_neg?
- In the example, where the text is the same, shouldn't the logit be simmetric?
- What is the difference between itm_score and score? if I do text 2 image retrieval and I pass the model with one text against 20 images, what score should I look at?
- Also how can I batch this call of one query text against 20 images?
Thank you very much
from vilt.
Another ~~ thing, when trying to run your script, I get:
I guess the problem is that I am using a pytorch
version that is newer to the one in the model? Can you give more details as per the torch
version used during training?
Thank you very much,
Best regards
Joan
from vilt.
Forget about this last message. This was my bad! Sorry for the inconvenience!
from vilt.
Hi @JoanFM
- ITM objective is done in 2-way classification: the pair is either matched (pos) or mismatched (neg). Logits from
vilt.itm_score
are the logits for this 2-way classification. - I used the same text
[1, 0, 16, 32, 55]
, however since the images are two different randomly initialized ones,image_vilt = torch.randn(2, 3, 224, 224)
, we have two different logits for each pair. itm_score
is fromvilt.itm_score
as a result of ITM objective (dimension = (#pair, 2)) andscore
is fromvilt.rank_output
as a result of IRTR objective (dimension = (#pair, 1)). So if you do T2I one text against 20 images, score will have dimension of (20, 1).- Put the images in the shape of (#images, #channel, H, W) as in
image_vilt = torch.randn(2, 3, 224, 224)
(2 images).
from vilt.
Hello @dandelin ,
So for me to clarify.
- On top of the cross-attention pooled features u add an ITM that retrieves 2 logits, one for each class (they match together, and they do not match together?) Then with the softmax u get a probability distribution of (pos, negative). Is this correct? Intuitively would a single score not be enough to achieve the same?
-The rank_output is just a layer where u do this:
self.rank_output.weight.data = self.itm_score.fc.weight.data[1:, :]
self.rank_output.bias.data = self.itm_score.fc.bias.data[1:]
So I guess you are using only one of the logits
information to compute the ranking
?
As per the image 2 text, I am trying to run this:
conf = copy.deepcopy(config.config())
conf["load_path"] = 'vilt_irtr_f30k.ckpt'
conf["test_only"] = True
# You need to properly configure loss_names to initialize heads (0.5 means it initializes head, but ignores the loss during training)
loss_names = {
"itm": 0.5,
"mlm": 0,
"mpp": 0,
"vqa": 0,
"imgcls": 0,
"nlvr2": 0,
"irtr": 1,
"arc": 0,
}
conf["loss_names"] = loss_names
# two different random images
image_vilt = torch.ones(2, 3, 224, 224) # 2 images in database
batch = {}
batch["image"] = [image_vilt]
# repeated random sentence tokens
batch["text_ids"] = torch.IntTensor([[1, 0, 16, 32, 55]]) # 1 single text query
# no masking
batch["text_masks"] = torch.IntTensor([[1, 1, 1, 1, 1]]) # 1 single text query
batch["text_labels"] = None
with torch.no_grad():
vilt = ViLTransformerSS(conf)
vilt.train(mode=False)
out = vilt(batch)
# You should see "rank_output" head if loss_names["irtr"] > 0
score = vilt.rank_output(out["cls_feats"]).squeeze()
print(f"unnormalized irtr score, one score for a image-text pair.\n{score}")
normalized_score = score.softmax(dim=0)
print(
f"normalized (relative) irtr score, one score for a image-text pair.\n{score}"
This is failing with:
Traceback (most recent call last):
File "src/model/vilt_demo.py", line 52, in <module>
out = vilt(batch)
File "/home/joan/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/joan/.local/lib/python3.7/site-packages/vilt/modules/vilt_module.py", line 191, in forward
ret.update(self.infer(batch))
File "/home/joan/.local/lib/python3.7/site-packages/vilt/modules/vilt_module.py", line 158, in infer
co_embeds = torch.cat([text_embeds, image_embeds], dim=1)
RuntimeError: Sizes of tensors must match except in dimension 1. Got 1 and 2 in dimension 0 (The offending index is 1)
Again, highly appreciate your incredible support!
Regards,
Joan
from vilt.
ViLT model Transformer Encoder total number of layers?
from vilt.
Related Issues (20)
- RuntimeError: CUDA error: invalid device function HOT 3
- Question about train on coco dataset HOT 1
- pretrain datasets
- The problem of fine-flickr30k
- What is the image resolution during VQA finetuning and pretraining?
- Mistakes in vqa_dict.json ?
- pyarrow.lib.ArrowInvalid: Not an Arrow file HOT 2
- fine-tuning ViLT for MLM task with a new dataset
- Can't the weight folder be opened before the pre-training is over?
- RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu) HOT 2
- What could be the reason that the model weights are not updating while finetuning? HOT 2
- cannot import name 'Final' from 'typing' HOT 2
- AttributeError: 'TracebackException' object has no attribute 'exc_traceback' HOT 1
- KeyError: 'false_image_0'
- error: subprocess-exited-with-error HOT 1
- 更改输入 HOT 1
- Which python could I use
- requests.exceptions.MissingSchema: Invalid URL 'None': No scheme supplied. Perhaps you meant https://None?
- ValueError: Connection error, and we cannot find the requested files in the cached path. Please try again or make sure your Internet connection is on.
- When distributed training was performed, the program remained unresponsive
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from vilt.