TensorFLow or PyTorch? Both!
GraphGallery is a gallery for benchmarking Graph Neural Networks (GNNs) and Graph Adversarial Learning with TensorFlow 2.x and PyTorch backend. Besides, Pytorch Geometric (PyG) backend and Deep Graph Library (DGL) backend now are available in GraphGallery.
We have integrated the Adversarial Attacks in this project, examples please refer to Graph Adversarial Learning examples.
pip install -U graphgallery
or
https://github.com/EdisonLeeeee/GraphGallery.git
cd GraphGallery
pip install -e .
GraphGallery has been tested on:
- CPU, CUDA 10.1, CUDA 11.0
- TensorFlow 2.1~2.4, 2.1.2 is recommended.
- PyTorch 1.4~1.7
- Pytorch Geometric (PyG) 1.6.1
- DGL 0.5.2, 0.5.3
Please refer to the examples directory. (The examples are updating...)
more details please refer to GraphData.
from graphgallery.gallery import GCN
# initialize a GNN trainer
trainer = GCN(graph)
# process your inputs, such as converting to tensors
trainer.process()
# build your GCN trainer with default hyper-parameters
trainer.build()
# train your trainer. here splits.train_nodes and splits.val_nodes are numpy arrays
# verbose takes 0, 1, 2, 3, 4
history = trainer.train(splits.train_nodes, splits.val_nodes, verbose=1, epochs=100)
# test your trainer
# verbose takes 0, 1, 2
results = trainer.test(splits.nodes, verbose=1)
print(f'Test loss {results.loss:.5}, Test accuracy {results.accuracy:.2%}')
Other models in the gallery are the same.
>>> import graphgallery
>>> graphgallery.backend()
TensorFlow 2.1.2 Backend
>>> graphgallery.set_backend("pytorch")
PyTorch 1.6.0+cu101 Backend
# DGL PyTorch backend
>>> graphgallery.set_backend("dgl")
# DGL TensorFlow backend
>>> graphgallery.set_backend("dgl-tf")
But your codes don't even need to change.
This is motivated by gnn-benchmark
from graphgallery.data import Graph
# Load the adjacency matrix A, attribute matrix X and labels vector y
# A - scipy.sparse.csr_matrix of shape [num_nodes, num_nodes]
# X - scipy.sparse.csr_matrix or np.ndarray of shape [num_nodes, num_attrs]
# y - np.ndarray of shape [num_nodes]
mydataset = Graph(adj_matrix=A, node_attr=X, node_label=y)
# save dataset
mydataset.to_npz('path/to/mydataset.npz')
# load dataset
mydataset = Graph.from_npz('path/to/mydataset.npz')
- Add PyTorch trainers support
- Add other frameworks (PyG and DGL) support
- Add more GNN trainers (TF and Torch backend)
- Support for more tasks, e.g.,
graph Classification
andlink prediction
- Support for more types of graphs, e.g., Heterogeneous graph
- Add Docstrings and Documentation (Building)
- Comprehensive tutorials
This project is motivated by Pytorch Geometric, Tensorflow Geometric, Stellargraph and DGL, etc., and the original implementations of the authors, thanks for their excellent works!