Giter Club home page Giter Club logo

Comments (6)

CaffreyR avatar CaffreyR commented on September 2, 2024

Btw what is the exact meaning of bs in your code? @muqeeth

        if not self.config.split_option_at_inference:
            bs, num_choices = choices_ids.size()[:2]
            flat_choices_ids = choices_ids.flatten(0, 1)
            attention_mask = (input_ids != self.tokenizer.pad_token_id).float()  # [bs, max_seq_len]
            encoder_hidden_states = self.model.encoder(input_ids=input_ids, attention_mask=attention_mask)[0]
            encoder_hidden_states = encoder_hidden_states.unsqueeze(dim=1).repeat(1, num_choices, 1, 1).flatten(0, 1)
            attention_mask = attention_mask.unsqueeze(dim=1).repeat(1, num_choices, 1).flatten(0, 1)
            decoder_input_ids = torch.cat([torch.zeros_like(flat_choices_ids[:, :1]), flat_choices_ids[:, :-1]], dim=1)
            decoder_attention_mask = (decoder_input_ids == decoder_input_ids).float()
            lm_target = flat_choices_ids - 100 * (flat_choices_ids == self.tokenizer.pad_token_id).long()

            model_output = self.model(
                attention_mask=attention_mask,
                encoder_outputs=[encoder_hidden_states],
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
            )
            choices_scores = (
                F.cross_entropy(model_output.logits.flatten(0, 1), lm_target.flatten(0, 1), reduction="none")
                .view(bs, num_choices, -1)
                .sum(dim=-1)
            )
            if self.config.length_norm > 0:
                choices_scores = choices_scores / torch.pow(
                    (choices_ids != self.tokenizer.pad_token_id).sum(dim=-1), self.config.length_norm
                )
            pred_score, prediction = choices_scores.min(dim=1)

     score_gt = choices_scores[range(bs), labels]
     choices_scores[range(bs), labels] = choices_scores.max(dim=-1)[0]
     score_cand = choices_scores.min(dim=-1)[0]

from t-few.

muqeeth avatar muqeeth commented on September 2, 2024

bs here in the code is batch size I think.

from t-few.

CaffreyR avatar CaffreyR commented on September 2, 2024

But what does it mean for score_sand & gt @muqeeth

from t-few.

HaokunLiu avatar HaokunLiu commented on September 2, 2024

score_cand and score_gt means the average score for wrong answers and correct answers.

from t-few.

PastelBelem8 avatar PastelBelem8 commented on September 2, 2024

@HaokunLiu imagine I'd like to persist the scores as probabilities, is it safe to assume that torch.exp(score_gt) + torch.exp(score_cand) < 1?

from t-few.

HaokunLiu avatar HaokunLiu commented on September 2, 2024

@HaokunLiu imagine I'd like to persist the scores as probabilities, is it safe to assume that torch.exp(score_gt) + torch.exp(score_cand) < 1?

Ha, you found this issue. In fact, if we are going to compute a probability distribution over all the choices (including correct and incorrect), they should be considered as -logits rather than probabilities. They correspond to $ - \beta (x, y)$ from eq. 2 in the paper.

from t-few.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.