Comments (1)
Code
- 3 masks (3 attentions)
def build_attention(self, sequence_outputs, speaker_masks=None, reply_masks=None, thread_masks=None):
"""
sequence_outputs: batch_size, seq_len, hidden_size
speaker_matrix: batch_size, num, num
head_matrix: batch_size, num, num
"""
speaker_masks = speaker_masks.bool().unsqueeze(1)
reply_masks = reply_masks.bool().unsqueeze(1)
thread_masks = thread_masks.bool().unsqueeze(1)
rep = self.reply_attention(sequence_outputs, sequence_outputs, sequence_outputs, reply_masks)[0]
thr = self.thread_attention(sequence_outputs, sequence_outputs, sequence_outputs, thread_masks)[0]
sp = self.speaker_attention(sequence_outputs, sequence_outputs, sequence_outputs, speaker_masks)[0]
r = torch.stack((rep, thr, sp), 0)
r = torch.max(r, 0)[0]
return r
- custom forward()
def forward(self, **kwargs):
input_ids, input_masks, input_segments = [kwargs[w] for w in ['input_ids', 'input_masks', 'input_segments']]
ent_matrix, rel_matrix, pol_matrix = [kwargs[w] for w in ['ent_matrix', 'rel_matrix', 'pol_matrix']]
reply_masks, speaker_masks, thread_masks = [kwargs[w] for w in ['reply_masks', 'speaker_masks', 'thread_masks']]
sentence_masks, full_masks, dialogue_length = [kwargs[w] for w in ['sentence_masks', 'full_masks', 'dialogue_length']]
sequence_outputs = self.bert(input_ids, token_type_ids=input_segments, attention_mask=input_masks)[0]
sequence_outputs = self.merge_sentence(sequence_outputs, input_masks, dialogue_length)
sequence_outputs = self.dropout(sequence_outputs)
sequence_outputs = self.build_attention(sequence_outputs, reply_masks=reply_masks, speaker_masks=speaker_masks, thread_masks=thread_masks)
loss0, tags0 = self.classify_matrix(kwargs, sequence_outputs, ent_matrix, sentence_masks, 'ent')
loss1, tags1 = self.classify_matrix(kwargs, sequence_outputs, rel_matrix, full_masks, 'rel')
loss2, tags2 = self.classify_matrix(kwargs, sequence_outputs, pol_matrix, full_masks, 'pol')
return (loss0, loss1, loss2), (tags0, tags1, tags2)
from dialogue-absa.
Related Issues (15)
- 10. chatgpt raw replies generation
- 11. chatgpt replies post-processing
- 12. tableau gpt replies visualization
- 13. LLaMA family spike
- 14. compile challenging absa datasets and add readme for source credit HOT 2
- [survey] prompt
- [data prep] DiaASQ HOT 1
- [test] DiaASQ dataset
- [experiment] llm + diaASQ [en] HOT 1
- [exp] t5 + diaASQ [en] HOT 1
- [exp] DiaASQ T5 compute_metrics in Trainer
- [paper] Personalized Showcases: Generating Multi-Modal Explanations for Recommendations (An Yan, et al. SIGIR-23)
- ⏰ [exp] DiaASQ HOT 1
- Full-Dialogue DiaASQ zh
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from dialogue-absa.