Giter Club home page Giter Club logo

Comments (1)

Nana2929 avatar Nana2929 commented on July 19, 2024

Code

  • 3 masks (3 attentions)
def build_attention(self, sequence_outputs, speaker_masks=None, reply_masks=None, thread_masks=None):
        """
        sequence_outputs: batch_size, seq_len, hidden_size
        speaker_matrix: batch_size, num, num 
        head_matrix: batch_size, num, num 
        """
        speaker_masks = speaker_masks.bool().unsqueeze(1)
        reply_masks = reply_masks.bool().unsqueeze(1)
        thread_masks = thread_masks.bool().unsqueeze(1)

        rep = self.reply_attention(sequence_outputs, sequence_outputs, sequence_outputs, reply_masks)[0]
        thr = self.thread_attention(sequence_outputs, sequence_outputs, sequence_outputs, thread_masks)[0]
        sp = self.speaker_attention(sequence_outputs, sequence_outputs, sequence_outputs, speaker_masks)[0]

        r = torch.stack((rep, thr, sp), 0)
        r = torch.max(r, 0)[0]
        return r
  • custom forward()
    def forward(self, **kwargs):
        input_ids, input_masks, input_segments = [kwargs[w] for w in ['input_ids', 'input_masks', 'input_segments']]
        ent_matrix, rel_matrix, pol_matrix = [kwargs[w] for w in ['ent_matrix', 'rel_matrix', 'pol_matrix']]
        reply_masks, speaker_masks, thread_masks = [kwargs[w] for w in ['reply_masks', 'speaker_masks', 'thread_masks']]
        sentence_masks, full_masks, dialogue_length = [kwargs[w] for w in ['sentence_masks', 'full_masks', 'dialogue_length']]

        sequence_outputs = self.bert(input_ids, token_type_ids=input_segments, attention_mask=input_masks)[0]

        sequence_outputs =  self.merge_sentence(sequence_outputs, input_masks, dialogue_length)
        sequence_outputs = self.dropout(sequence_outputs)

        sequence_outputs = self.build_attention(sequence_outputs, reply_masks=reply_masks, speaker_masks=speaker_masks, thread_masks=thread_masks)

        loss0, tags0 = self.classify_matrix(kwargs, sequence_outputs, ent_matrix, sentence_masks, 'ent')
        loss1, tags1 = self.classify_matrix(kwargs, sequence_outputs, rel_matrix, full_masks, 'rel')
        loss2, tags2 = self.classify_matrix(kwargs, sequence_outputs, pol_matrix, full_masks, 'pol')
      
        return (loss0, loss1, loss2), (tags0, tags1, tags2)

from dialogue-absa.

Related Issues (15)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.