Giter Club home page Giter Club logo

sat's Introduction

Structure-Aware Transformer

Updates: We have added the script for model visualization (Figure 4 in our paper)!

The repository implements the Structure-Aware Transformer (SAT) in Pytorch Geometric described in the following paper

Dexiong Chen*, Leslie O'Bray*, and Karsten Borgwardt. Structure-Aware Transformer for Graph Representation Learning. ICML 2022.
*Equal contribution

TL;DR: A class of simple and flexible graph transformers built upon a new self-attention mechanism, which incorporates structural information into the original self-attention by extracting a subgraph representation rooted at each node before computing the attention. Our structure-aware framework can leverage any existing GNN to extract the subgraph representation and systematically improve the peroformance relative to the base GNN.

Citation

Please use the following to cite our work:

@InProceedings{Chen22a,
	author = {Dexiong Chen and Leslie O'Bray and Karsten Borgwardt},
	title = {Structure-Aware Transformer for Graph Representation Learning},
	year = {2022},
	booktitle = {Proceedings of the 39th International Conference on Machine Learning~(ICML)},
	series = {Proceedings of Machine Learning Research}
}

A short description about the SAT attention mechanism

Overview figure

We first extract the $k$-hop subgraphs centered at each node (here, $k=1$) and use a structure extractor to compute structure-aware node representations. The structure extractor can, for example, be any GNN. Then, the updated node embeddings are used to compute the query ($\mathbf{Q}$) and key ($\mathbf{K}$) matrices.

A quick-start example

Below you can find a quick-start example on the ZINC dataset, see ./experiments/train_zinc.py for more details.

click to see the example:
import torch
from torch_geometric import datasets
from torch_geometric.loader import DataLoader
from sat.data import GraphDataset
from sat import GraphTransformer

# Load the ZINC dataset using our wrapper GraphDataset,
# which automatically creates the fully connected graph.
# For datasets with large graph, we recommend setting return_complete_index=False
# leading to faster computation
dset = datasets.ZINC('./datasets/ZINC', subset=True, split='train')
dset = GraphDataset(dset)

# Create a PyG data loader
train_loader = DataLoader(dset, batch_size=16, shuffle=True)

# Create a SAT model
dim_hidden = 16
gnn_type = 'gcn' # use GCN as the structure extractor
k_hop = 2 # use a 2-layer GCN

model = GraphTransformer(
    in_size=28, # number of node labels for ZINC
    num_class=1, # regression task
    d_model=dim_hidden,
    dim_feedforward=2 * dim_hidden,
    num_layers=2,
    batch_norm=True,
    gnn_type='gcn', # use GCN as the structure extractor
    use_edge_attr=True,
    num_edge_features=4, # number of edge labels
    edge_dim=dim_hidden,
    k_hop=k_hop,
    se='gnn', # we use the k-subtree structure extractor
    global_pool='add'
)

for data in train_loader:
    output = model(data) # batch_size x 1
    break

Installation

The dependencies are managed by miniconda

python=3.9
numpy
scipy
pytorch=1.9.1
pytorch-geometric=2.0.2
einops
ogb

Once you have activated the environment and installed all dependencies, run:

source s

Datasets will be downloaded via Pytorch geometric and OGB package.

Train SAT on graph and node prediction datasets

All our experimental scripts are in the folder experiments. So to start with, after having run source s, run cd experiments. The hyperparameters used below are selected as optimal

Graph regression on ZINC dataset

Train a k-subtree SAT with PNA:

python train_zinc.py --abs-pe rw --se gnn --gnn-type pna2 --dropout 0.3 --k-hop 3 --use-edge-attr

Train a k-subgraph SAT with PNA

python train_zinc.py --abs-pe rw --se khopgnn --gnn-type pna2 --dropout 0.2 --k-hop 3 --use-edge-attr

Node classification on PATTERN and CLUSTER datasets

Train a k-subtree SAT on PATTERN:

python train_SBMs.py --dataset PATTERN --weight-class --abs-pe rw --abs-pe-dim 7 --se gnn --gnn-type pna3 --dropout 0.2 --k-hop 3 --num-layers 6 --lr 0.0003

and on CLUSTER:

python train_SBMs.py --dataset CLUSTER --weight-class --abs-pe rw --abs-pe-dim 3 --se gnn --gnn-type pna2 --dropout 0.4 --k-hop 3 --num-layers 16 --dim-hidden 48 --lr 0.0005

Graph classification on OGB datasets

--gnn-type can be gcn, gine or pna, where pna obtains the best performance.

# Train SAT on OGBG-PPA
python train_ppa.py --gnn-type gcn --use-edge-attr
# Train SAT on OGBG-CODE2
python train_code2.py --gnn-type gcn --use-edge-attr

Model visualization

We showcase here how to visualize the attention weights of the [CLS] node learned by SAT and vanilla Transformer with the random walk positional encoding. We have provided the pre-trained models on the Mutagenecity dataset. To visualize the pre-trained models, you need to install the networkx and matplotlib packages, then run:

python model_visu.py --graph-idx 2003

This will generate the following image, the same as the Figure 4 in our paper:

Model_interpretation

sat's People

Contributors

lobray avatar claying avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.