Giter Club home page Giter Club logo

dgl's Introduction

Deep Graph Library (DGL)

Build Status GitHub license

Documentation | DGL at a glance | Model Tutorials | Discussion Forum

DGL is a Python package that interfaces between existing tensor libraries and data being expressed as graphs.

It makes implementing graph neural networks (including Graph Convolution Networks, TreeLSTM, and many others) easy while maintaining high computation efficiency.

A summary of the model accuracy and training speed with the Pytorch backend (on Amazon EC2 p3.2x instance (w/ V100 GPU)), as compared with the best open-source implementations:

Model Reported
Accuracy
DGL
Accuracy
Author's training speed (epoch time) DGL speed (epoch time) Improvement
GCN 81.5% 81.0% 0.0051s (TF) 0.0038s 1.34x
SGC 81.0% 81.9% n/a 0.0008s n/a
TreeLSTM 51.0% 51.72% 14.02s (DyNet) 3.18s 4.3x
R-GCN
(classification)
73.23% 73.53% 0.2853s (Theano) 0.0097s 29.4x
R-GCN
(link prediction)
0.158 0.151 2.204s (TF) 0.453s 4.86x
JTNN 96.44% 96.44% 1826s (Pytorch) 743s 2.5x
LGNN 94% 94% n/a 1.45s n/a
DGMG 84% 90% n/a 238s n/a

With the MXNet/Gluon backend , we scaled a graph of 50M nodes and 150M edges on a P3.8xlarge instance, with 160s per epoch, on SSE (Stochastic Steady-state Embedding), a model similar to GCN.

We are currently in Beta stage. More features and improvements are coming.

News

We presented DGL at GTC 2019 as an instructor-led training session. Check out our slides and tutorial materials here!!!

v0.2 has just been released! Many features, bugfix and performance improvement. See release note here.

System requirements

DGL should work on

  • all Linux distributions no earlier than Ubuntu 16.04
  • macOS X
  • Windows 10

DGL also requires Python 3.5 or later. Python 2 support is coming.

Right now, DGL works on PyTorch 0.4.1+ and MXNet nightly build.

Installation

Using anaconda

conda install -c dglteam dgl

Using pip

pip install dgl

From source

Refer to the guide here.

How DGL looks like

A graph can be constructed with feature tensors like this:

import dgl
import torch as th

g = dgl.DGLGraph()
g.add_nodes(5)                          # add 5 nodes
g.add_edges([0, 0, 0, 0], [1, 2, 3, 4]) # add 4 edges 0->1, 0->2, 0->3, 0->4
g.ndata['h'] = th.randn(5, 3)           # assign one 3D vector to each node
g.edata['h'] = th.randn(4, 4)           # assign one 4D vector to each edge

This is everything to implement a single layer for Graph Convolutional Network on PyTorch:

import dgl.function as fn
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph

msg_func = fn.copy_src(src='h', out='m')
reduce_func = fn.sum(msg='m', out='h')

class GCNLayer(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)

    def apply(self, nodes):
        return {'h': F.relu(self.linear(nodes.data['h']))}

    def forward(self, g, feature):
        g.ndata['h'] = feature
        g.update_all(msg_func, reduce_func)
        g.apply_nodes(func=self.apply)
        return g.ndata.pop('h')

One can also customize how message and reduce function works. The following code demonstrates a (simplified version of) Graph Attention Network (GAT) layer:

def msg_func(edges):
    return {'k': edges.src['k'], 'v': edges.src['v']}

def reduce_func(nodes):
    # nodes.data['q'] has the shape
    #     (number_of_nodes, feature_dims)
    # nodes.data['k'] and nodes.data['v'] have the shape
    #     (number_of_nodes, number_of_incoming_messages, feature_dims)
    # You only need to deal with the case where all nodes have the same number
    # of incoming messages.
    q = nodes.data['q'][:, None]
    k = nodes.mailbox['k']
    v = nodes.mailbox['v']
    s = F.softmax((q * k).sum(-1), 1)[:, :, None]
    return {'v': th.sum(s * v, 1)}

class GATLayer(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(GATLayer, self).__init__()
        self.Q = nn.Linear(in_feats, out_feats)
        self.K = nn.Linear(in_feats, out_feats)
        self.V = nn.Linear(in_feats, out_feats)

    def apply(self, nodes):
        return {'v': F.relu(self.linear(nodes.data['v']))}

    def forward(self, g, feature):
        g.ndata['v'] = self.V(feature)
        g.ndata['q'] = self.Q(feature)
        g.ndata['k'] = self.K(feature)
        g.update_all(msg_func, reduce_func)
        g.apply_nodes(func=self.apply)
        return g.ndata['v']

For the basics of coding with DGL, please see DGL basics.

For more realistic, end-to-end examples, please see model tutorials.

New to Deep Learning?

Check out the open source book Dive into Deep Learning.

Contributing

Please let us know if you encounter a bug or have any suggestions by filing an issue.

We welcome all contributions from bug fixes to new features and extensions. We expect all contributions discussed in the issue tracker and going through PRs. Please refer to our contribution guide.

The Team

DGL is developed and maintained by NYU, NYU Shanghai, AWS Shanghai AI Lab, and AWS MXNet Science Team.

License

DGL uses Apache License 2.0.

dgl's People

Contributors

jermainewang avatar zheng-da avatar barclayii avatar vovallen avatar lingfanyu avatar zzhang-cn avatar gaiyu0 avatar aksnzhy avatar yzh119 avatar mufeili avatar eric-haibin-lin avatar sufeidechabei avatar ziyuehuang avatar aymenwah avatar hq01 avatar astonzhang avatar giuseppefutia avatar szha avatar hbsun2113 avatar yifeim avatar ivanbrugere avatar tiiiger avatar mori97 avatar yzhliu avatar xavierzw avatar askliar avatar brettkoonce avatar kitaev-chen avatar mc-robinson avatar zengxy avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.