Comments (6)
Btw what is the exact meaning of bs in your code? @muqeeth
if not self.config.split_option_at_inference:
bs, num_choices = choices_ids.size()[:2]
flat_choices_ids = choices_ids.flatten(0, 1)
attention_mask = (input_ids != self.tokenizer.pad_token_id).float() # [bs, max_seq_len]
encoder_hidden_states = self.model.encoder(input_ids=input_ids, attention_mask=attention_mask)[0]
encoder_hidden_states = encoder_hidden_states.unsqueeze(dim=1).repeat(1, num_choices, 1, 1).flatten(0, 1)
attention_mask = attention_mask.unsqueeze(dim=1).repeat(1, num_choices, 1).flatten(0, 1)
decoder_input_ids = torch.cat([torch.zeros_like(flat_choices_ids[:, :1]), flat_choices_ids[:, :-1]], dim=1)
decoder_attention_mask = (decoder_input_ids == decoder_input_ids).float()
lm_target = flat_choices_ids - 100 * (flat_choices_ids == self.tokenizer.pad_token_id).long()
model_output = self.model(
attention_mask=attention_mask,
encoder_outputs=[encoder_hidden_states],
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
choices_scores = (
F.cross_entropy(model_output.logits.flatten(0, 1), lm_target.flatten(0, 1), reduction="none")
.view(bs, num_choices, -1)
.sum(dim=-1)
)
if self.config.length_norm > 0:
choices_scores = choices_scores / torch.pow(
(choices_ids != self.tokenizer.pad_token_id).sum(dim=-1), self.config.length_norm
)
pred_score, prediction = choices_scores.min(dim=1)
score_gt = choices_scores[range(bs), labels]
choices_scores[range(bs), labels] = choices_scores.max(dim=-1)[0]
score_cand = choices_scores.min(dim=-1)[0]
from t-few.
bs here in the code is batch size I think.
from t-few.
But what does it mean for score_sand & gt @muqeeth
from t-few.
score_cand and score_gt means the average score for wrong answers and correct answers.
from t-few.
@HaokunLiu imagine I'd like to persist the scores as probabilities, is it safe to assume that torch.exp(score_gt) + torch.exp(score_cand) < 1
?
from t-few.
@HaokunLiu imagine I'd like to persist the scores as probabilities, is it safe to assume that
torch.exp(score_gt) + torch.exp(score_cand) < 1
?
Ha, you found this issue. In fact, if we are going to compute a probability distribution over all the choices (including correct and incorrect), they should be considered as -logits rather than probabilities. They correspond to $ - \beta (x, y)$ from eq. 2 in the paper.
from t-few.
Related Issues (20)
- Accuracy could not match with the log when load_model HOT 10
- Validation score on WSC decreases with training HOT 3
- Sum of logprobs in the probability space adds up to values above 1 HOT 2
- What does the multi_lora_a and multi_lora_b mean in the code? HOT 3
- AttributeError: 'DistributedDataParallel' object has no attribute 'save_checkpoint' HOT 1
- save dev_pred.txt and test_pred.txt for RTE and ANLI HOT 2
- How is l_ff created? HOT 1
- How long it will take for pretraining the model using A100(80G)? HOT 2
- Where are the loss function changes in the codebase? HOT 1
- IA3 implementation doesn't add parameters for feedforward layers HOT 5
- questions from your paper HOT 1
- results for LoRA HOT 1
- t-few for decoder only models
- Multi-task batching HOT 4
- question about intrinsic.py HOT 1
- Creation of the `decoder_attention_mask` while evaluating HOT 1
- Could your please give a detailed explanation for the "rank classification"? HOT 1
- Make use of the model, datasets and strategy to classify sentences as urgent not urgent HOT 1
- Issue on the install of first experiment and deepspeed in Windows HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from t-few.