Giter Club home page Giter Club logo

graphwaveletneuralnetwork's Introduction

Graph Wavelet Neural Network

Arxiv codebeat badge repo size benedekrozemberczki

A PyTorch implementation of Graph Wavelet Neural Network (ICLR 2019).

Abstract

We present graph wavelet neural network (GWNN), a novel graph convolutional neural network (CNN), leveraging graph wavelet transform to address the shortcomings of previous spectral graph CNN methods that depend on graph Fourier transform. Different from graph Fourier transform, graph wavelet transform can be obtained via a fast algorithm without requiring matrix eigendecomposition with high computational cost. Moreover, graph wavelets are sparse and localized in vertex domain, offering high efficiency and good interpretability for graph convolution. The proposed GWNN significantly outperforms previous spectral graph CNNs in the task of graph-based semi-supervised classification on three benchmark datasets: Cora, Citeseer and Pubmed.

A reference Tensorflow implementation is accessible [here].

This repository provides an implementation of Graph Wavelet Neural Network as described in the paper:

Graph Wavelet Neural Network. Bingbing Xu, Huawei Shen, Qi Cao, Yunqi Qiu, Xueqi Cheng. ICLR, 2019. [Paper]


Requirements

The codebase is implemented in Python 3.5.2. package versions used for development are just below.

networkx          2.4
tqdm              4.28.1
numpy             1.15.4
pandas            0.23.4
texttable         1.5.0
scipy             1.1.0
argparse          1.1.0
torch             1.1.0
torch-scatter     1.4.0
torch-sparse      0.4.3
torch-cluster     1.4.5
torch-geometric   1.3.2
torchvision       0.3.0
scikit-learn      0.20.0
PyGSP             0.5.1

Datasets

The code takes the **edge list** of the graph in a csv file. Every row indicates an edge between two nodes separated by a comma. The first row is a header. Nodes should be indexed starting with 0. A sample graph for `Cora` is included in the `input/` directory. In addition to the edgelist there is a JSON file with the sparse features and a csv with the target variable.

The **feature matrix** is a sparse binary one it is stored as a json. Nodes are keys of the json and feature indices are the values. For each node feature column ids are stored as elements of a list. The feature matrix is structured as:

{ 0: [0, 1, 38, 1968, 2000, 52727],
  1: [10000, 20, 3],
  2: [],
  ...
  n: [2018, 10000]}

The **target vector** is a csv with two columns and headers, the first contains the node identifiers the second the targets. This csv is sorted by node identifiers and the target column contains the class meberships indexed from zero.

NODE ID Target
0 3
1 1
2 0
3 1
... ...
n 3

Options

Training the model is handled by the src/main.py script which provides the following command line arguments.

Input and output options

  --edge-path        STR   Input graph path.   Default is `input/cora_edges.csv`.
  --features-path    STR   Features path.      Default is `input/cora_features.json`.
  --target-path      STR   Target path.        Default is `input/cora_target.csv`.
  --log-path         STR   Log path.           Default is `logs/cora_logs.json`.

Model options

  --epochs                INT       Number of Adam epochs.         Default is 200.
  --learning-rate         FLOAT     Number of training epochs.     Default is 0.01.
  --weight-decay          FLOAT     Weight decay.                  Default is 5*10**-4.
  --filters               INT       Number of filters.             Default is 16.
  --dropout               FLOAT     Dropout probability.           Default is 0.5.
  --test-size             FLOAT     Test set ratio.                Default is 0.2.
  --seed                  INT       Random seeds.                  Default is 42.
  --approximation-order   INT       Chebyshev polynomial order.    Default is 3.
  --tolerance             FLOAT     Wavelet coefficient limit.     Default is 10**-4.
  --scale                 FLOAT     Heat kernel scale.             Default is 1.0.

Examples

The following commands learn the weights of a graph wavelet neural network and saves the logs. The first example trains a graph wavelet neural network on the default dataset with standard hyperparameter settings. Saving the logs at the default path.

python src/main.py

Training a model with more filters in the first layer.

python src/main.py --filters 32

Approximationg the wavelets with polynomials that have an order of 5.

python src/main.py --approximation-order 5

License


graphwaveletneuralnetwork's People

Contributors

benedekrozemberczki avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

graphwaveletneuralnetwork's Issues

Laplacian matrix

Hi! Why do we need to use all the node data (2707 nodes) when constructing the Laplacian matrix instead of just the training set data?

Question about phi^(-1)

Hi, I have noticed that you implemented the inverse phi by adding -1 in the list of scale in your code, but i found that the resulf of the g.heat(tau=[-1]) was not equal to the g.inverse() on account of the definition of the g.Heat(), cause the kernel was defined as np.minimum(np.exp(-scale * x / G.lmax), 1), and it would return 1 in this case, so maybe i get a normal line instead of the actual inverse of filter g.

something about wavelet basis

Hello~, Thank you for your paper. when I read the paper, I think about what is the connection between wavelet basis and Fourier basis, can you give me some tips?

RuntimeError: the derivative for 'index' is not implemented

Hello,
I was running the example and got this error.

python src/main.py
+---------------------+----------------------------+
|      Parameter      |           Value            |
+=====================+============================+
| Approximation order | 20                         |
+---------------------+----------------------------+
| Dropout             | 0.500                      |
+---------------------+----------------------------+
| Edge path           | ./input/cora_edges.csv     |
+---------------------+----------------------------+
| Epochs              | 300                        |
+---------------------+----------------------------+
| Features path       | ./input/cora_features.json |
+---------------------+----------------------------+
| Filters             | 16                         |
+---------------------+----------------------------+
| Learning rate       | 0.001                      |
+---------------------+----------------------------+
| Log path            | ./logs/cora_logs.json      |
+---------------------+----------------------------+
| Scale               | 1                          |
+---------------------+----------------------------+
| Seed                | 42                         |
+---------------------+----------------------------+
| Target path         | ./input/cora_target.csv    |
+---------------------+----------------------------+
| Test size           | 0.200                      |
+---------------------+----------------------------+
| Tolerance           | 0.000                      |
+---------------------+----------------------------+
| Weight decay        | 0.001                      |
+---------------------+----------------------------+

Wavelet calculation and sparsification started.

100%|███████████████████████████████████████████████████████████████████████████████████| 2708/2708 [00:11<00:00, 237.23it/s]
100%|███████████████████████████████████████████████████████████████████████████████████| 2708/2708 [00:11<00:00, 228.91it/s]

Normalizing the sparsified wavelets.

Density of wavelets: 0.2%.
Density of inverse wavelets: 0.04%.

Training.

Loss:   0%|                                                                                          | 0/300 [00:00<?, ?it/s]Traceback (most recent call last):
  File "src/main.py", line 24, in <module>
    main()
  File "src/main.py", line 18, in main
    trainer.fit()
  File "/home/paperspace/Thesis/GraphWaveletNeuralNetwork/src/gwnn.py", line 131, in fit
    prediction = self.model(self.phi_indices, self.phi_values , self.phi_inverse_indices, self.phi_inverse_values, self.feature_indices, self.feature_values)
  File "/home/paperspace/miniconda2/envs/thesis/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/paperspace/Thesis/GraphWaveletNeuralNetwork/src/gwnn.py", line 44, in forward
    deep_features_1 = self.convolution_1(phi_indices, phi_values, phi_inverse_indices, phi_inverse_values, feature_indices, feature_values, self.args.dropout)
  File "/home/paperspace/miniconda2/envs/thesis/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/paperspace/Thesis/GraphWaveletNeuralNetwork/src/gwnn_layer.py", line 55, in forward
    localized_features = spmm(phi_product_indices, phi_product_values, self.ncount, filtered_features)
  File "/home/paperspace/miniconda2/envs/thesis/lib/python3.6/site-packages/torch_sparse/spmm.py", line 21, in spmm
    out = scatter_add(out, row, dim=0, dim_size=m)
  File "/home/paperspace/miniconda2/envs/thesis/lib/python3.6/site-packages/torch_scatter/add.py", line 73, in scatter_add
    return out.scatter_add_(dim, index, src)
RuntimeError: the derivative for 'index' is not implemented

Fatal Python error: Segmentation fault

hi, author.
These days i've been watching the program. But when I run on this code, I find an error happened during the time. Can you give me some suggestions?

image

image

train test split question

I found that in the train_test_split function in gwnn.py, you divide 80% of the nodes in the cora data set into training sets for training. Is this in line with common sense? Or I misunderstood or your code is wrong?

what's the meanning of the "feature matrix"?

Hello author, sorry about a stupid question. But the Cora dataset has Cora.cites corresponding your cora_edges.csv, and Cora.content's paper index and paper category for your cora_target.csv, so I don't understand the meanning of your cora_features.json . In the beginning, I just think it's an adjacency matrix of all nodes(paper index), however, the content are inconsistent. Such as ,in cora_edges.csv it's as the picture as follw:
image
and in cora_features.json it's :
image
So I am confused , and hope for your answer. Thank you very much.

the kernel

Hi, author,
There was a variable in the code called diagnoal_weight_filter
屏幕截图 2021-01-16 204442
I think the variable should change in the trainning time,but it never changed when I debugging.
It's so confusing.
And I wonder if the variable conduct the same role as the diagnoal_weight_filer in the tensorflow implementation will change.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.