Comments (18)
I've expanded the ogbn-arxiv
example above into a repository at joshcarty/tfgnn-ogb.
A small model trained on my laptop is getting around 56% Top 1 and 88% Top 5 accuracy on the validation set. While it's not going to challenge the Leaderboard, I'm confident it's learning.
The example shows: sampling a NetworkX graph into GraphTensors, optionally saving them to tf.Examples, loading them as as a tf.data.Dataset
, batching, and training a model using tfgnn.keras.ConvGNNBuilder
.
Hope it helps!
from gnn.
Hi Loreto,
This is something that we're actively working on. Stay posted :)
from gnn.
I've made some progress testing the forward pass with the Open Graph Benchmark ogbn-arxiv
dataset and networkx. A similar approach should work for TUDatasets, @loretoparisi, but I couldn't find a good example with node features to test on.
from typing import Iterator
import networkx as nx
import numpy as np
import tensorflow_gnn as tfgnn
import tensorflow as tf
from ogb.nodeproppred import NodePropPredDataset
def ogb_as_networkx_graph(name: str) -> nx.Graph:
dataset = NodePropPredDataset(name)
ogb_graph, labels = dataset[0]
num_nodes = ogb_graph["num_nodes"]
ogb_features = ogb_graph["node_feat"]
ogb_edgelist = ogb_graph["edge_index"]
ogb_node_indices = np.arange(num_nodes)
graph = nx.from_edgelist(ogb_edgelist.T)
data = zip(ogb_node_indices, ogb_features, labels)
features = {
node: {"features": features, "label": label}
for node, features, label in data
}
nx.set_node_attributes(graph, values=features)
return graph
def generate_graph_samples(
graph: nx.Graph, neighbours: int = 2
) -> Iterator[tfgnn.GraphTensor]:
for node in graph.nodes:
subgraph = nx.ego_graph(graph, node, radius=neighbours)
num_edges = subgraph.number_of_edges()
adjacency = np.asarray(subgraph.edges)
edges = tfgnn.EdgeSet.from_fields(
sizes=[num_edges],
adjacency=tfgnn.Adjacency.from_indices(
source=("node", adjacency[:, 0]),
target=("node", adjacency[:, 1]),
),
)
data = (
(data["features"], data["label"])
for _, data in subgraph.nodes(data=True)
)
features, labels = zip(*data)
num_nodes = subgraph.number_of_nodes()
nodes = tfgnn.NodeSet.from_fields(
features={
"hidden_state": np.asarray(features),
"labels": np.asarray(labels),
},
sizes=[num_nodes],
)
graph_tensor = tfgnn.GraphTensor.from_pieces(
edge_sets={"edges": edges}, node_sets={"nodes": nodes}
)
yield graph_tensor
graph = ogb_as_networkx_graph("ogbn-arxiv")
graph_samples = generate_graph_samples(graph, neighbours=2)
graph_tensor = next(graph_samples)
type_spec = graph_tensor.spec
inputs = tf.keras.layers.Input(type_spec=type_spec)
gnn = tfgnn.keras.ConvGNNBuilder(
lambda edge: tfgnn.keras.layers.SimpleConvolution(
tf.keras.layers.Dense(32)
),
lambda node: tfgnn.keras.layers.NextStateFromConcat(
tf.keras.layers.Dense(32)
),
)
update = gnn.Convolve()(inputs)
hidden = tfgnn.keras.layers.Readout(node_set_name="nodes")(update)
output = tf.keras.layers.Dense(1, activation="sigmoid")(hidden)
model = tf.keras.Model(inputs, output)
y = model(graph_tensor)
print(y)
It does seem like you need to call your features hidden_state
to have _get_features_ref[feature_name]
succeed in NodeSetUpdate
. I'd need to investigate further where this is (or could be) parameterised, but I hope this helps you @thilograffe.
Now to see if it fits...
from gnn.
Seconded. Perhaps an example tensorflow-datasets implementation? I'd be happy to contribute tfds implementations for many common research datasets if standardized FeatureConnector could be provided for the relevant graph types.
from gnn.
We are hyped to use this lib. Any updates regarding the examples?
from gnn.
I'm trying to put together a minimal example but have got stuck:
import tensorflow as tf
import tensorflow_gnn as tfgnn
tf.debugging.disable_traceback_filtering()
schema = tfgnn.read_schema("tensorflow_gnn/testdata/homogeneous/citrus.pbtxt")
spec = tfgnn.create_graph_spec_from_schema_pb(schema)
graph = tfgnn.random_graph_tensor(spec)
gnn = tfgnn.keras.ConvGNNBuilder(
lambda edge: tfgnn.keras.layers.SimpleConvolution(tf.keras.layers.Dense(32)),
lambda node: tfgnn.keras.layers.NextStateFromConcat(tf.keras.layers.Dense(32)),
)
inputs = tf.keras.layers.Input(type_spec=spec)
update = gnn.Convolve()(inputs)
output = tfgnn.keras.layers.Readout(node_set_name="fruits")(update)
model = tf.keras.Model(inputs, output)
model(graph)
but am getting the following:
Traceback (most recent call last):
File "/Users/joshuac/Developer/tf-gnn-playground/main.py", line 16, in <module>
update = gnn.Convolve()(inputs)
File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 60, in error_handler
return fn(*args, **kwargs)
File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1019, in __call__
return self._functional_construction_call(inputs, args, kwargs,
File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1160, in _functional_construction_call
outputs = self._keras_tensor_symbolic_call(
File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/keras/engine/base_layer.py", line 885, in _keras_tensor_symbolic_call
return self._infer_output_signature(inputs, args, kwargs, input_masks)
File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/keras/engine/base_layer.py", line 930, in _infer_output_signature
outputs = call_fn(inputs, *args, **kwargs)
File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler
raise e
File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
return fn(*args, **kwargs)
File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/tensorflow_gnn/graph/keras/layers/graph_update.py", line 224, in call
update_fn(graph, node_set_name=node_set_name)))
File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 60, in error_handler
return fn(*args, **kwargs)
File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1083, in __call__
outputs = call_fn(inputs, *args, **kwargs)
File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 145, in error_handler
raise new_e.with_traceback(e.__traceback__) from None
File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
return fn(*args, **kwargs)
File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/tensorflow_gnn/graph/keras/layers/graph_update.py", line 389, in call
_get_feature_or_features(graph.node_sets[node_set_name],
File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/tensorflow_gnn/graph/keras/layers/graph_update.py", line 513, in _get_feature_or_features
return features[names]
File "/Users/joshuac/Developer/tf-gnn-playground/.env/lib/python3.9/site-packages/tensorflow_gnn/graph/graph_tensor.py", line 34, in __getitem__
return self._get_features_ref[feature_name]
KeyError: 'Exception encountered when calling layer "node_set_update" (type NodeSetUpdate).\n\nhidden_state\n\nCall arguments received:\n • graph=<tensorflow_gnn.graph.graph_tensor.GraphTensor object at 0x1486cbcc0>\n • node_set_name=\'fruits\''
The debugger shows self._get_features_ref[feature_name]
failing to index for hidden_state
.
Are you able to help?
Thanks for releasing this - I'm also really excited to start using it!
from gnn.
Thanks for sharing @thilograffe. Here's an example using dummy data where I've got the forward pass to work. It's modified from some of the Keras tests in the package:
from typing import Tuple
import tensorflow as tf
import tensorflow_gnn as tfgnn
from tensorflow_gnn.graph.graph_constants import FieldsNest
def build_graph_tensor(
num_edges: int = 10, num_nodes: int = 10, num_features: int = 5
) -> tfgnn.GraphTensor:
adjacency = tf.random.stateless_binomial(
shape=[num_edges, 2], counts=[0, 1], probs=[0.5, 0.5], seed=[0, 0]
)
edges = tfgnn.EdgeSet.from_fields(
features={"edge_weight": tf.random.normal(shape=[num_features])},
sizes=[num_edges],
adjacency=tfgnn.HyperAdjacency.from_indices(
indices={
tfgnn.SOURCE: ("node", adjacency[:, 0]),
tfgnn.TARGET: ("node", adjacency[:, 1]),
}
),
)
nodes = tfgnn.NodeSet.from_fields(
features={"hidden_state": tf.random.normal(shape=[num_nodes, num_features])},
sizes=[num_nodes],
)
return tfgnn.GraphTensor.from_pieces(
edge_sets={"edge": edges}, node_sets={"node": nodes}
)
class MultiplyNodeEdge(tf.keras.layers.Layer):
def __init__(self, edge_feature: str, node_feature: str = tfgnn.SOURCE) -> None:
super().__init__()
self.node_feature = node_feature
self.edge_feature = edge_feature
def call(self, inputs: Tuple[FieldsNest, FieldsNest, FieldsNest]) -> tf.Tensor:
edge_inputs, node_inputs, _ = inputs
return tf.multiply(
edge_inputs[self.edge_feature], node_inputs[self.node_feature]
)
graph_tensor = build_graph_tensor()
spec = graph_tensor.spec
input = tf.keras.layers.Input(type_spec=spec)
update = tfgnn.keras.layers.GraphUpdate(
edge_sets={
"edge": tfgnn.keras.layers.EdgeSetUpdate(
next_state=MultiplyNodeEdge(edge_feature="edge_weight"),
edge_input_feature=["edge_weight"],
)
},
node_sets={
"node": tfgnn.keras.layers.NodeSetUpdate(
edge_set_inputs={"edge": tfgnn.keras.layers.Pool(tfgnn.TARGET, "sum")},
next_state=tfgnn.keras.layers.NextStateFromConcat(
tf.keras.layers.Dense(16)
),
)
},
)
graph = update(input)
hidden = tfgnn.keras.layers.Readout(node_set_name="node")(graph)
output = tf.keras.layers.Dense(1, activation="sigmoid")(hidden)
model = tf.keras.Model(input, output)
y = model(graph_tensor)
print(y)
I still haven't had any luck with real-looking data but I hope the above helps you and others.
If you have the time, we'd really appreciate some support @sibonli @arnoegw @janpfeifer? In particular on the KeyError
described above and using ConvGNNBuilder
in general. Thanks!
from gnn.
We had to adapt the type spec (seen in the screenshot). Otherwise the fitting method throws an error with an unmatched type_spec from the model and the given training sample.
from gnn.
tensorflow-gnn 0.2.0 comes with several Colab notebook examples for end-to-end models as well as extended documentation, which should answer many of the questions raised above.
The MUTAG colab has one example dataset from TUDATASET.
I realize this does not fully address the feature request to integrate all of them, but in the interest of keeping issues manageable, please allow me to close this one, because it has come to cover so many things. Better support for datasets is on the roadmap...
from gnn.
Thanks @joshcarty for your example! The next step is working on some real data.
Perhaps there is a quick fix for the KeyError by just using the ConvGNNBuilder in another way, we will see...
from gnn.
@joshcarty
Thanks for your quick answer. The issue was indeed the wrong version. So there is no need to discuss further on this topic.
Such great work from you with the other repo!
from gnn.
Hi joshcarty ,
we are facing the exact same problem with the KeyError. We tried different examples, but all are failing.
from gnn.
@joshcarty
Wow, that's exactly what we are looking for. We tried to train different models, but failed with fitting the model with the right data format. I think your example will help to fix our problems.
from gnn.
Is that using joshcarty/tfgnn-ogb and the latest version of tensorflow_gnn
? I've just tested with a fresh virtualenv and it works for me.
You'll want to install tfgnn-ogb/requirements.txt
rather than having a clone of tensorflow_gnn
. Along with other dependencies, that does a pip install of the current version of the package e.g.
$ pip install git+https://github.com/tensorflow/gnn@f63478419152cc2f22886173b45302323f4bfc56`
which works since the builds were fixed here.
Upgrading to the latest version of tensorflow_gnn
did require a change. I had to add an empty set of context features to prevent a KeyError in GraphTensor.merge_batch_to_components
. That should be both in the data and the type spec so they should match.
Maybe best to discuss in a new issue in that repository?
from gnn.
File "/home/sidney/py_proj/tf-gnn/tensorflow_gnn/graph/schema_utils.py", line 9, in
import tensorflow_gnn.proto.graph_schema_pb2 as schema_pb2
ModuleNotFoundError: No module named 'tensorflow_gnn.proto.graph_schema_pb2'
Error accours when run most test cases, no error when you running?
from gnn.
import tensorflow_gnn.proto.graph_schema_pb2 as schema_pb2 in schema_utils.py
There is no graph_schema_pb2 file in this repo, how do you solve this problem then?
from gnn.
Have you installed the package with the command in the Installation Instructions?
cd tensorflow_gnn && python3 -m pip install .
Even though it looks like you've cloned the package locally, you still need to install it with pip. Part of the installation process builds a bunch of protocol buffers into modules that you can access from Python. Once you install it, those module should be available and you won't get that error anymore.
from gnn.
For anyone interested in the reason features need to be called hidden_state
when passed into convolution layers, I suspect it's because you're expected to preprocess them with a MapFeatures
layer. That layer lets you specify a function to preprocess features on nodes and edges and can return them as a single DEFAULT_STATE_NAME
(hidden_state
) feature.
Hope that's useful to you @thilograffe and others.
from gnn.
Related Issues (20)
- Mixed Precision Training with Runner HOT 3
- If plan to launch a pretrained model? is it possible to fine-tune or transfer-learning for tf-gnn? HOT 6
- setup failure on macOS 13.5.2, commit d7c81a2 HOT 5
- How to express the edge weight when aggregating neighbor nodes? HOT 3
- How to print the value/embedding of a node_set after training is complete? HOT 1
- Graph execution error: INVALID_ARGUMENT: ConcatOp : Dimension 0 in both shapes must be equal
- Graph explainability : feature suggestion
- How does GraphTensor work with tf.data.Dataset? HOT 3
- How to construct an input string using languages like Java? HOT 1
- Does the deployment of the model based on tfgnn in tf-serving support batch?
- The parameters of the preprocessing model cannot be updated.
- Dockerfile is not consistent with readme guide. HOT 2
- Problem of graph_sampler when using latent node
- The model SubMix can't be found!!! HOT 2
- Dynamically complete the features for nodes in the graph during training
- graph with multiple same direction edges (but different feature) between two nodes HOT 3
- .
- Broadcast padding mask to align with labels HOT 4
- Using tfgnn with tflite without graphtensor input HOT 1
- Support tf.SparseTensor HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from gnn.