Giter Club home page Giter Club logo

Comments (11)

mg2015started avatar mg2015started commented on August 21, 2024

hello, i am trying to use the sonnet gru or lstm function as the edge or node function. In this case, i wanna to input two graph_tuples into the gru graph model. But I have no idea of implementing this, because if i use one graph_tuple as cur_state_input, common functions like linear function can direct operate on its nodes and edges but if i wanna another graph_tuple as pre_state_input...oh, i am a little confusing.

Actually, i want a graph model which can handle two graph inputs and use the gru as its function, where the node function is like f((e_t-1, v_t-1, u_t-1), (e_t, v_t, u_t)). i have tried but failed. can you help me to solve this problem? thank you.

from graph_nets.

alvarosg avatar alvarosg commented on August 21, 2024

Hi @mg2015started

We have currently not added official support for recurrent graph networks to the library, since we are still finishing up the design details for that, and are experimenting with different approaches internally.

However, an easy workaround would be to have GRUs or LSTMs that work directly independently (similar to modules.GraphIndependent) on the nodes, edges, or globals of the graph. For example:

input_graphs_tuple = ...

# Build three sonnet modules for globals, nodes and edges
lstm_globals = snt.LSTM(hidden_size=7)
lstm_nodes = snt.LSTM(hidden_size=7)
lstm_edges = snt.LSTM(hidden_size=7)

# Build initial states according to the total number of globals, nodes or edges in the batch.
state_globals = lstm_globals.zero_state(batch_size=tf.shape(input_graphs_tuple.globals)[0])
state_nodes = lstm_nodes.zero_state(batch_size=tf.shape(input_graphs_tuple.nodes)[0])
state_edges = lstm_edges.zero_state(batch_size=tf.shape(input_graphs_tuple.edges)[0])

# Process globals, nodes and edges independently, and get outputs and updates states
output_globals, state_globals = lstm_globals(input_graphs_tuple.globals, state_globals)
output_nodes, state_nodes = lstm_nodes(input_graphs_tuple.nodes, state_nodes)
output_edges, state_edges = lstm_edges(input_graphs_tuple.edges, state_edges)

# Put the outputs back into a GraphsTuple
output_graph = input_graphs_tuple.replace(
    globals=output_globals,
    nodes=output_nodes,
    edges=output_edges,
)

Hope this helps!

from graph_nets.

mg2015started avatar mg2015started commented on August 21, 2024

Oh, @alvarosg , thanks for your detailed reply and example !
And I have another question: does this library supports a graph where nodes have different feature dimensions and a graph wanna apply different functions to different nodes ? Maybe do i need to modify some codes in original blocks or somewhere else?

from graph_nets.

alvarosg avatar alvarosg commented on August 21, 2024

The current Graph Nets blocks assume the same number of features for all nodes. The reason for this is that it would not be obvious how to aggregate nodes of different sizes into globals, or how to do edge computations where the sender and receiver node may have any combination of types.

Your best options are (I think):

  • Pad your nodes to a fixed length, and add a one hot encoding of the node types so your MLP (if that is what you are using) can learn to distinguish them.

  • Have a per node type encoder that embeds each type of nodes into a common feature space.

  • Have nodes of different types in different graphs (although this won't work if they have edges connecting them).

Regarding applying different functions to different node types, it is up to you to include the node type as part of the features and write a custom node_model_fn that applies different functions to nodes of different types, for example using tf.gather and tf.scatter operations, or tf.where.

from graph_nets.

mg2015started avatar mg2015started commented on August 21, 2024

thank you so much, it helps me a lot.

from graph_nets.

ICDI0906 avatar ICDI0906 commented on August 21, 2024

Hi @mg2015started

We have currently not added official support for recurrent graph networks to the library, since we are still finishing up the design details for that, and are experimenting with different approaches internally.

However, an easy workaround would be to have GRUs or LSTMs that work directly independently (similar to modules.GraphIndependent) on the nodes, edges, or globals of the graph. For example:

input_graphs_tuple = ...

# Build three sonnet modules for globals, nodes and edges
lstm_globals = snt.LSTM(hidden_size=7)
lstm_nodes = snt.LSTM(hidden_size=7)
lstm_edges = snt.LSTM(hidden_size=7)

# Build initial states according to the total number of globals, nodes or edges in the batch.
state_globals = lstm_globals.zero_state(batch_size=tf.shape(input_graphs_tuple.globals)[0])
state_nodes = lstm_nodes.zero_state(batch_size=tf.shape(input_graphs_tuple.nodes)[0])
state_edges = lstm_edges.zero_state(batch_size=tf.shape(input_graphs_tuple.edges)[0])

# Process globals, nodes and edges independently, and get outputs and updates states
output_globals, state_globals = lstm_globals(input_graphs_tuple.globals, state_globals)
output_nodes, state_nodes = lstm_nodes(input_graphs_tuple.nodes, state_nodes)
output_edges, state_edges = lstm_edges(input_graphs_tuple.edges, state_edges)

# Put the outputs back into a GraphsTuple
output_graph = input_graphs_tuple.replace(
    globals=output_globals,
    nodes=output_nodes,
    edges=output_edges,
)

Hope this helps!

Hi~,I am confused about the shape of the input_graphs_tuple.nodes , Is it [batch_size,timestep,feature]? Looking forward to your reply~

from graph_nets.

alvarosg avatar alvarosg commented on August 21, 2024

This code just shows how to do a single step of the recurrent model, hence the inputs don't have a time axis. If you wanted to process a sequence using the recurrent model, then then shape you are suggesting [batch_size, timestep, feature] is probably good. However, you may have to select select the different elements in the sequence manually at different steps of the rnn.

from graph_nets.

ICDI0906 avatar ICDI0906 commented on August 21, 2024

This code just shows how to do a single step of the recurrent model, hence the inputs don't have a time axis. If you wanted to process a sequence using the recurrent model, then then shape you are suggesting [batch_size, timestep, feature] is probably good. However, you may have to select select the different elements in the sequence manually at different steps of the rnn.

Thank you for your answer. I solve it by using tf.nn.dynamic_rnn(lstm_nodes ,input_graphs_tuple.nodes),before that I reshape the shape of input_graphs_tuple.nodes to [batch_size, timestep, feature].Luckily it works!

from graph_nets.

felbecker avatar felbecker commented on August 21, 2024

Hi,

is there any update on the matter of LSTMs as node or edge functions?

The "independent" workaround is not enough, if you want to condition inputs as usual based on the graph topology. For instance the edge LSTM could have edge, 2 respective adjacent nodes, global and hidden edge state as inputs (as opposed to just edges and hidden edge state).

So I guess a step in a complete GraphNetwork with Identity edge, node and global functions (that does all the concatenations) and then an "independent LSTM" step as described above would be an easy way to achieve such a thing?
Are there still plans for official support for recurrent functions that have input + hidden state?

Best, ~Ungod

from graph_nets.

alvarosg avatar alvarosg commented on August 21, 2024

Thank you for your message Ungod,

if you want to condition inputs as usual based on the graph topology

Note that if you put the independent RNN update right after the GraphNetwork, then the output representations after the GraphNetwork will already be based on the graph topology, or said differently, it should not matter that much if you use an identity (as you are describing in your strategy) or a function approximator. Specially if you interleave many layers of this.

The code of modules.GraphNetwork is pretty simple (about 10 lines of actual code). Since you have something specific in mind, I think it should be reasonably easy to fork modules.GraphNetwork use the pass identity functions to the Edge, Node, and Global block, and run the corresponding RNNs right after the EdgeBlock, NodeBlock, and GlobalBlock (essentially use the identity trick you are describing, but inside of modules.GraphNetwork.

Beyond that, we do not have plans in the short term to add support for recurrent states to modules.GraphNetwork, because there can be a few different approaches to integrate RNNs within a GraphNet, and we do not want to restrict users to that.

Hope this helps!

from graph_nets.

felbecker avatar felbecker commented on August 21, 2024

This was exactly what I had in mind. I just forgot about taking the aggregations into account when writing my message. Of course you have to operate block wise, apply LSTM and let the next block do its respective aggregations. It does not mean a large amount of code either, I'm now into rewriting modules.GraphNetwork a bit.

Thank you for your fast reply!

PS: In fact I don't think is so "specific". Its just the general graph-network approach as described in your paper but with additional hidden LSTM-states for edges, nodes and globals.

from graph_nets.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.