Giter Club home page Giter Club logo

Comments (18)

joshcarty avatar joshcarty commented on June 5, 2024 5

I've expanded the ogbn-arxiv example above into a repository at joshcarty/tfgnn-ogb.

A small model trained on my laptop is getting around 56% Top 1 and 88% Top 5 accuracy on the validation set. While it's not going to challenge the Leaderboard, I'm confident it's learning.

The example shows: sampling a NetworkX graph into GraphTensors, optionally saving them to tf.Examples, loading them as as a tf.data.Dataset, batching, and training a model using tfgnn.keras.ConvGNNBuilder.

Hope it helps!

from gnn.

sibonli avatar sibonli commented on June 5, 2024 4

Hi Loreto,

This is something that we're actively working on. Stay posted :)

from gnn.

joshcarty avatar joshcarty commented on June 5, 2024 4

I've made some progress testing the forward pass with the Open Graph Benchmark ogbn-arxiv dataset and networkx. A similar approach should work for TUDatasets, @loretoparisi, but I couldn't find a good example with node features to test on.

from typing import Iterator

import networkx as nx
import numpy as np
import tensorflow_gnn as tfgnn
import tensorflow as tf
from ogb.nodeproppred import NodePropPredDataset


def ogb_as_networkx_graph(name: str) -> nx.Graph:
    dataset = NodePropPredDataset(name)
    ogb_graph, labels = dataset[0]
    num_nodes = ogb_graph["num_nodes"]
    ogb_features = ogb_graph["node_feat"]
    ogb_edgelist = ogb_graph["edge_index"]
    ogb_node_indices = np.arange(num_nodes)

    graph = nx.from_edgelist(ogb_edgelist.T)
    data = zip(ogb_node_indices, ogb_features, labels)
    features = {
        node: {"features": features, "label": label}
        for node, features, label in data
    }
    nx.set_node_attributes(graph, values=features)
    return graph


def generate_graph_samples(
    graph: nx.Graph, neighbours: int = 2
) -> Iterator[tfgnn.GraphTensor]:
    for node in graph.nodes:
        subgraph = nx.ego_graph(graph, node, radius=neighbours)

        num_edges = subgraph.number_of_edges()
        adjacency = np.asarray(subgraph.edges)
        edges = tfgnn.EdgeSet.from_fields(
            sizes=[num_edges],
            adjacency=tfgnn.Adjacency.from_indices(
                source=("node", adjacency[:, 0]),
                target=("node", adjacency[:, 1]),
            ),
        )

        data = (
            (data["features"], data["label"])
            for _, data in subgraph.nodes(data=True)
        )
        features, labels = zip(*data)
        num_nodes = subgraph.number_of_nodes()
        nodes = tfgnn.NodeSet.from_fields(
            features={
                "hidden_state": np.asarray(features),
                "labels": np.asarray(labels),
            },
            sizes=[num_nodes],
        )

        graph_tensor = tfgnn.GraphTensor.from_pieces(
            edge_sets={"edges": edges}, node_sets={"nodes": nodes}
        )

        yield graph_tensor


graph = ogb_as_networkx_graph("ogbn-arxiv")
graph_samples = generate_graph_samples(graph, neighbours=2)
graph_tensor = next(graph_samples)
type_spec = graph_tensor.spec

inputs = tf.keras.layers.Input(type_spec=type_spec)
gnn = tfgnn.keras.ConvGNNBuilder(
    lambda edge: tfgnn.keras.layers.SimpleConvolution(
        tf.keras.layers.Dense(32)
    ),
    lambda node: tfgnn.keras.layers.NextStateFromConcat(
        tf.keras.layers.Dense(32)
    ),
)
update = gnn.Convolve()(inputs)
hidden = tfgnn.keras.layers.Readout(node_set_name="nodes")(update)
output = tf.keras.layers.Dense(1, activation="sigmoid")(hidden)
model = tf.keras.Model(inputs, output)

y = model(graph_tensor)
print(y)

It does seem like you need to call your features hidden_state to have _get_features_ref[feature_name] succeed in NodeSetUpdate. I'd need to investigate further where this is (or could be) parameterised, but I hope this helps you @thilograffe.

Now to see if it fits...

from gnn.

jackd avatar jackd commented on June 5, 2024 2

Seconded. Perhaps an example tensorflow-datasets implementation? I'd be happy to contribute tfds implementations for many common research datasets if standardized FeatureConnector could be provided for the relevant graph types.

from gnn.

thilograffe avatar thilograffe commented on June 5, 2024 2

We are hyped to use this lib. Any updates regarding the examples?

from gnn.

joshcarty avatar joshcarty commented on June 5, 2024 2

I'm trying to put together a minimal example but have got stuck:

import tensorflow as tf
import tensorflow_gnn as tfgnn

tf.debugging.disable_traceback_filtering()

schema = tfgnn.read_schema("tensorflow_gnn/testdata/homogeneous/citrus.pbtxt")
spec = tfgnn.create_graph_spec_from_schema_pb(schema)
graph = tfgnn.random_graph_tensor(spec)

gnn = tfgnn.keras.ConvGNNBuilder(
    lambda edge: tfgnn.keras.layers.SimpleConvolution(tf.keras.layers.Dense(32)),
    lambda node: tfgnn.keras.layers.NextStateFromConcat(tf.keras.layers.Dense(32)),
)

inputs = tf.keras.layers.Input(type_spec=spec)
update = gnn.Convolve()(inputs)
output = tfgnn.keras.layers.Readout(node_set_name="fruits")(update)
model = tf.keras.Model(inputs, output)

model(graph)

but am getting the following:

Traceback (most recent call last):
  File "/Users/joshuac/Developer/tf-gnn-playground/main.py", line 16, in <module>
    update = gnn.Convolve()(inputs)
  File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 60, in error_handler
    return fn(*args, **kwargs)
  File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1019, in __call__
    return self._functional_construction_call(inputs, args, kwargs,
  File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1160, in _functional_construction_call
    outputs = self._keras_tensor_symbolic_call(
  File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/keras/engine/base_layer.py", line 885, in _keras_tensor_symbolic_call
    return self._infer_output_signature(inputs, args, kwargs, input_masks)
  File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/keras/engine/base_layer.py", line 930, in _infer_output_signature
    outputs = call_fn(inputs, *args, **kwargs)
  File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler
    raise e
  File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
    return fn(*args, **kwargs)
  File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/tensorflow_gnn/graph/keras/layers/graph_update.py", line 224, in call
    update_fn(graph, node_set_name=node_set_name)))
  File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 60, in error_handler
    return fn(*args, **kwargs)
  File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1083, in __call__
    outputs = call_fn(inputs, *args, **kwargs)
  File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 145, in error_handler
    raise new_e.with_traceback(e.__traceback__) from None
  File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
    return fn(*args, **kwargs)
  File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/tensorflow_gnn/graph/keras/layers/graph_update.py", line 389, in call
    _get_feature_or_features(graph.node_sets[node_set_name],
  File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/tensorflow_gnn/graph/keras/layers/graph_update.py", line 513, in _get_feature_or_features
    return features[names]
  File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/tensorflow_gnn/graph/graph_tensor.py", line 34, in __getitem__
    return self._get_features_ref[feature_name]
KeyError: 'Exception encountered when calling layer "node_set_update" (type NodeSetUpdate).\n\nhidden_state\n\nCall arguments received:\n  • graph=<tensorflow_gnn.graph.graph_tensor.GraphTensor object at 0x1486cbcc0>\n  • node_set_name=\'fruits\''

The debugger shows self._get_features_ref[feature_name] failing to index for hidden_state.

Are you able to help?

Thanks for releasing this - I'm also really excited to start using it!

from gnn.

joshcarty avatar joshcarty commented on June 5, 2024 2

Thanks for sharing @thilograffe. Here's an example using dummy data where I've got the forward pass to work. It's modified from some of the Keras tests in the package:

from typing import Tuple

import tensorflow as tf
import tensorflow_gnn as tfgnn
from tensorflow_gnn.graph.graph_constants import FieldsNest


def build_graph_tensor(
    num_edges: int = 10, num_nodes: int = 10, num_features: int = 5
) -> tfgnn.GraphTensor:
    adjacency = tf.random.stateless_binomial(
        shape=[num_edges, 2], counts=[0, 1], probs=[0.5, 0.5], seed=[0, 0]
    )
    edges = tfgnn.EdgeSet.from_fields(
        features={"edge_weight": tf.random.normal(shape=[num_features])},
        sizes=[num_edges],
        adjacency=tfgnn.HyperAdjacency.from_indices(
            indices={
                tfgnn.SOURCE: ("node", adjacency[:, 0]),
                tfgnn.TARGET: ("node", adjacency[:, 1]),
            }
        ),
    )

    nodes = tfgnn.NodeSet.from_fields(
        features={"hidden_state": tf.random.normal(shape=[num_nodes, num_features])},
        sizes=[num_nodes],
    )

    return tfgnn.GraphTensor.from_pieces(
        edge_sets={"edge": edges}, node_sets={"node": nodes}
    )


class MultiplyNodeEdge(tf.keras.layers.Layer):
    def __init__(self, edge_feature: str, node_feature: str = tfgnn.SOURCE) -> None:
        super().__init__()
        self.node_feature = node_feature
        self.edge_feature = edge_feature

    def call(self, inputs: Tuple[FieldsNest, FieldsNest, FieldsNest]) -> tf.Tensor:
        edge_inputs, node_inputs, _ = inputs
        return tf.multiply(
            edge_inputs[self.edge_feature], node_inputs[self.node_feature]
        )


graph_tensor = build_graph_tensor()
spec = graph_tensor.spec

input = tf.keras.layers.Input(type_spec=spec)
update = tfgnn.keras.layers.GraphUpdate(
    edge_sets={
        "edge": tfgnn.keras.layers.EdgeSetUpdate(
            next_state=MultiplyNodeEdge(edge_feature="edge_weight"),
            edge_input_feature=["edge_weight"],
        )
    },
    node_sets={
        "node": tfgnn.keras.layers.NodeSetUpdate(
            edge_set_inputs={"edge": tfgnn.keras.layers.Pool(tfgnn.TARGET, "sum")},
            next_state=tfgnn.keras.layers.NextStateFromConcat(
                tf.keras.layers.Dense(16)
            ),
        )
    },
)
graph = update(input)
hidden = tfgnn.keras.layers.Readout(node_set_name="node")(graph)
output = tf.keras.layers.Dense(1, activation="sigmoid")(hidden)
model = tf.keras.Model(input, output)

y = model(graph_tensor)
print(y)

I still haven't had any luck with real-looking data but I hope the above helps you and others.

If you have the time, we'd really appreciate some support @sibonli @arnoegw @janpfeifer? In particular on the KeyError described above and using ConvGNNBuilder in general. Thanks!

from gnn.

thilograffe avatar thilograffe commented on June 5, 2024 2

We had to adapt the type spec (seen in the screenshot). Otherwise the fitting method throws an error with an unmatched type_spec from the model and the given training sample.
image

from gnn.

arnoegw avatar arnoegw commented on June 5, 2024 2

tensorflow-gnn 0.2.0 comes with several Colab notebook examples for end-to-end models as well as extended documentation, which should answer many of the questions raised above.

The MUTAG colab has one example dataset from TUDATASET.

I realize this does not fully address the feature request to integrate all of them, but in the interest of keeping issues manageable, please allow me to close this one, because it has come to cover so many things. Better support for datasets is on the roadmap...

from gnn.

thilograffe avatar thilograffe commented on June 5, 2024 1

Thanks @joshcarty for your example! The next step is working on some real data.

Perhaps there is a quick fix for the KeyError by just using the ConvGNNBuilder in another way, we will see...

from gnn.

thilograffe avatar thilograffe commented on June 5, 2024 1

@joshcarty
Thanks for your quick answer. The issue was indeed the wrong version. So there is no need to discuss further on this topic.
Such great work from you with the other repo!

from gnn.

thilograffe avatar thilograffe commented on June 5, 2024

Hi joshcarty ,

we are facing the exact same problem with the KeyError. We tried different examples, but all are failing.

from gnn.

thilograffe avatar thilograffe commented on June 5, 2024

@joshcarty
Wow, that's exactly what we are looking for. We tried to train different models, but failed with fitting the model with the right data format. I think your example will help to fix our problems.

from gnn.

joshcarty avatar joshcarty commented on June 5, 2024

Is that using joshcarty/tfgnn-ogb and the latest version of tensorflow_gnn? I've just tested with a fresh virtualenv and it works for me.

You'll want to install tfgnn-ogb/requirements.txt rather than having a clone of tensorflow_gnn. Along with other dependencies, that does a pip install of the current version of the package e.g.

$ pip install git+https://github.com/tensorflow/gnn@f63478419152cc2f22886173b45302323f4bfc56`

which works since the builds were fixed here.

Upgrading to the latest version of tensorflow_gnn did require a change. I had to add an empty set of context features to prevent a KeyError in GraphTensor.merge_batch_to_components. That should be both in the data and the type spec so they should match.

Maybe best to discuss in a new issue in that repository?

from gnn.

SidneyLann avatar SidneyLann commented on June 5, 2024

File "/home/sidney/py_proj/tf-gnn/tensorflow_gnn/graph/schema_utils.py", line 9, in
import tensorflow_gnn.proto.graph_schema_pb2 as schema_pb2
ModuleNotFoundError: No module named 'tensorflow_gnn.proto.graph_schema_pb2'

Error accours when run most test cases, no error when you running?

from gnn.

SidneyLann avatar SidneyLann commented on June 5, 2024

import tensorflow_gnn.proto.graph_schema_pb2 as schema_pb2 in schema_utils.py

There is no graph_schema_pb2 file in this repo, how do you solve this problem then?

from gnn.

joshcarty avatar joshcarty commented on June 5, 2024

@SidneyLann

Have you installed the package with the command in the Installation Instructions?

cd tensorflow_gnn && python3 -m pip install .

Even though it looks like you've cloned the package locally, you still need to install it with pip. Part of the installation process builds a bunch of protocol buffers into modules that you can access from Python. Once you install it, those module should be available and you won't get that error anymore.

from gnn.

joshcarty avatar joshcarty commented on June 5, 2024

For anyone interested in the reason features need to be called hidden_state when passed into convolution layers, I suspect it's because you're expected to preprocess them with a MapFeatures layer. That layer lets you specify a function to preprocess features on nodes and edges and can return them as a single DEFAULT_STATE_NAME (hidden_state) feature.

Hope that's useful to you @thilograffe and others.

from gnn.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.