Giter Club home page Giter Club logo

Comments (20)

imelnyk avatar imelnyk commented on August 19, 2024 1

Hi, here is a rough sketch on how to implement this:

  1. Create a trainable query_embed = nn.Embedding(max_nodes, hidden_dim)

  2. When calling T5 transformer

    output = self.transformer(input_ids=text,
    change it to remove decoder-related inputs and pass decoder_inputs_embeds=query_embed

  3. The output node features coming from transformer should be processed by GRUDecoder (similar as for EdgesGen)

    class EdgesGen(nn.Module):

  4. Finally, similarly as currently done, the node features should also be passed for edge generation.

Note: Optionally, you can also implement Matcher class which estimates permutation matrix to match logits_nodes with target_nodes (since in general their order may not match). The main idea here is to create all_pairwise_distance matrix where [i,j] element is cross_entropy(logits_nodes[i], target_nodes[j]). Then use linear_sum_assignment function to find the best match and construct a permutation matrix. Finally, apply this permutation matrix to logits_nodes.

from grapher.

Jasaxion avatar Jasaxion commented on August 19, 2024

excuse me, @imelnyk , sorry to disturb you, I am also interestred in the code of query node section, could you please add it to code or sent to me, I really wander that how to implemented it and interested in it. thanks you a lot.

from grapher.

Jasaxion avatar Jasaxion commented on August 19, 2024

Thank you very much, I will give it a try.

from grapher.

Jasaxion avatar Jasaxion commented on August 19, 2024

@imelnyk I am so sorry to disturb you. I add query_embed in litgrapher def __init__

self.query_embed = nn.Embedding(self.max_nodes, self.model.hidden_dim)  # query_node

and then pass parameters to grapher

logits_nodes, logits_edges = self.model(text_input_ids,
                                                text_input_attn_mask,
                                                target_nodes,
                                                target_nodes_mask,
                                                self.query_embed.weight,
                                                target_edges)

but it seems get failed inference (eg. failded_edge --> failded_node) after training, because the dimension of query_node is (max_node x hidden_dim), but T5 decoder_inputs_embeds is (batch_size x sequence_length x hidden_dim)
I tried to unqueeze query_node to (batch_size x max_node x hidden_dim), but still failed after fewer epoch training.
So I wonder that where I should add query_node in code.
Thank you for your answering!

from grapher.

imelnyk avatar imelnyk commented on August 19, 2024

Hi,
It is hard to say what went wrong. But here are some ideas: Make sure that the transformer only gets input_ids, attention_mask, and decoder_inputs_embeds (this is your query_embed ). Once you get the features out of the transformer, pass them through MLP (similar as for the edges), and then apply CE loss.

from grapher.

mengdawn025 avatar mengdawn025 commented on August 19, 2024

@imelnyk I am sorry to bother you again.
First, I wonder if the N (number of nodes) d-dimensional node features mentioned in the paper are the state of the last layer of the decoder.

joint_features = output.decoder_hidden_states[-1]

Second, are the node features decoded into node logits same as the node features used to generate edges?

Third, should the sample function in the grapher.py

 def sample(self, text, text_mask):

be modified after the query_embed is added?

from grapher.

imelnyk avatar imelnyk commented on August 19, 2024
  1. Yes, you should get the last hidden state
    joint_features = transformer(input_ids=text, attention_mask=text_mask, decoder_inputs_embeds=query_embed).last_hidden_state

  2. No, node features and edge features are different. Otherwise, how can you convert the same features into different objects (nodes and edges)?

  3. Yes, it needs to be modified since you need to pass learned query_embed to your transformer during inference.

from grapher.

mengdawn025 avatar mengdawn025 commented on August 19, 2024

@imelnyk
OK. Thank you for your answer. Two other questions I have are:

  1. Are the node features decoded into node logits the state of the last layer of the decoder?
joint_features = transformer(input_ids=text, attention_mask=text_mask, decoder_inputs_embeds=query_embed).last_hidden_state
  1. How can I obtain the node features used to generate edges?

from grapher.

imelnyk avatar imelnyk commented on August 19, 2024
  1. Yes, the last hidden state from decoder is used to get node features which are then used to get node logits.
  2. The edge pipeline can remain the same as is currently done in the code - each pair of node features is combined and passed through MLP to get edge logits.

from grapher.

mengdawn025 avatar mengdawn025 commented on August 19, 2024

@imelnyk
OK. Thank you for your answer. One other question I have is:
I need to modify the sample function in the grapher.py after the query_embed is added. Should the query_embed be added to the self.transformer.generate in the sample function?

 def sample (self, text, text_mask):
        output = self.transformer.generate(input_ids=text,
                                           max_length=150,
                                           attention_mask=text_mask,
                                           output_hidden_states=True,
                                           output_scores=True,
                                           return_dict_in_generate=True)

If yes, which parameter should the query_embed be passed to? decoder_inputs_embeds? But the parameters of the generate function do not seem to have the decoder_inputs_embeds parameter.

from grapher.

imelnyk avatar imelnyk commented on August 19, 2024

Yes, this part is a bit tricky. generate is not applicable here. One option would be to use the same setup as in training with learnedquery_embed.

from grapher.

mengdawn025 avatar mengdawn025 commented on August 19, 2024

OK. Thank you very much, I will give it a try.

from grapher.

Jasaxion avatar Jasaxion commented on August 19, 2024

Hi @imelnyk I apologize for reaching out again. In section 2.2 "Node Generation: Query Nodes" of the paper, it is mentioned that the node features are encoded as Fn ∈ Rd×N. May I kindly ask you how to pass this to the GRUDecoder to generate logits nodes (Seq_len X voc_size X num_nodes) in detail? I have attempted numerous ways to modify it, but the issue still persists. If you still have the code archive, I would be grateful if you could share it with me. I am very interested in your implementation of this part. Thank you for tirelessly teaching me how to make modifications.

from grapher.

mengdawn025 avatar mengdawn025 commented on August 19, 2024

@imelnyk I am sorry to bother you again. One thing I wonder to figure out:
It is mentioned in the paper that the query node part is not performing well, but I want to know which kind of poor performance it refers to:

  1. Normal triples can be generated, such as :
Aarhus | leader | Jacob Bundsgaard

but the accuracy is very low;
2. Unable to generate normal triples at all, such as :

<extra_id_0> | failed edge | failed node
  1. Others;
    If it is the third one, I hope you can explain it.

from grapher.

imelnyk avatar imelnyk commented on August 19, 2024

Yes, the query node training is not easy, you have to train longer, and play with learning rates, gradient clipping, etc. For us, the performance was not great, however it was still able to generate legible nodes and edges. It looks like in your case it might be the training problems or even some issues with implementation.

from grapher.

mengdawn025 avatar mengdawn025 commented on August 19, 2024

Okay. Thank you very much.

from grapher.

mengdawn025 avatar mengdawn025 commented on August 19, 2024

Hi, @imelnyk I am sorry to bother you again.
How to view the evaluation metrics, such as Precision, Recall, and F1 scores, after the model is trained. Is it using the command tensorboard --logdir output? But I didn't obtain the evaluation metrics after the command is executed.

from grapher.

imelnyk avatar imelnyk commented on August 19, 2024

Yes, as the model trains, it evaluates the model, and saves the results. You can see it here:

self.logger.experiment.add_scalar(f'{split}_score/{k}', v, global_step=iteration)

from grapher.

mengdawn025 avatar mengdawn025 commented on August 19, 2024

Okay. Thank you very much.

from grapher.

mengdawn025 avatar mengdawn025 commented on August 19, 2024

Hello, @imelnyk , I am sorry to bother you again. So far, we are still a little puzzled about the function of query node. Could you please explain it?

from grapher.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.