Giter Club home page Giter Club logo

pytorch_han's Introduction

pytorch_HAN

一位热心git友指出之前问题是过拟合了,修改以下utils.py相关地方得到以下结果:

1.image 2.image


Paper address:
https://github.com/Jhy1993/Representation-Learning-on-Heterogeneous-Graph
Heterogeneous Graph Attention Network (HAN) with pytorch. If you want to pursue the performance in the original paper, this may not be suitable for you, because there is still a problem: training loss decreases, but verification loss increases.

If you just want to figure out the basic principles of HAN and how to change tensorflow code to pytorch code, then this is for you. I implemented it according to the original tensorflow code structure.

If you want to pursue higher performance, please refer to:
Official tensorflow implementation:https://github.com/Jhy1993/HAN
DGL implementation:https://github.com/dmlc/dgl/tree/master/examples/pytorch/han

The result

Address of ACM data set: Preprocessed ACM can be found in: https://pan.baidu.com/s/1V2iOikRqHPtVvaANdkzROw 提取码:50k2

You can use the following command to run:

python main.py

Training result:

600 300 2125
y_train:(3025, 3), y_val:(3025, 3), y_test:(3025, 3), train_idx:(1, 600), val_idx:(1, 300), test_idx:(1, 2125)
2
model: pre_trained/acm/acm_allMP_multi_fea_.ckpt
fea_list[0].shape torch.Size([1, 1870, 3025])
biases_list[0].shape: torch.Size([1, 3025, 3025])
3
2
torch.Size([1, 1870, 3025]) torch.Size([1, 3025, 3025])
torch.Size([1, 1870, 3025]) torch.Size([1, 3025, 3025])
训练节点个数600
验证节点个数300
测试节点个数2125
epoch:001, loss:1.1004, TrainAcc:0.3517, ValLoss:1.1022, ValAcc:0.4000
epoch:002, loss:1.0762, TrainAcc:0.4250, ValLoss:1.1980, ValAcc:0.0533
epoch:003, loss:1.0007, TrainAcc:0.6300, ValLoss:1.4572, ValAcc:0.0533
epoch:004, loss:0.8876, TrainAcc:0.6583, ValLoss:2.0040, ValAcc:0.0500
epoch:005, loss:0.8145, TrainAcc:0.6350, ValLoss:2.7091, ValAcc:0.0500
epoch:006, loss:0.7897, TrainAcc:0.6267, ValLoss:3.2186, ValAcc:0.0500
epoch:007, loss:0.7804, TrainAcc:0.6150, ValLoss:3.4550, ValAcc:0.0500
epoch:008, loss:0.7527, TrainAcc:0.6150, ValLoss:3.5096, ValAcc:0.0500
epoch:009, loss:0.7404, TrainAcc:0.6117, ValLoss:3.5125, ValAcc:0.0600
epoch:010, loss:0.7329, TrainAcc:0.6633, ValLoss:3.5349, ValAcc:0.0400
epoch:011, loss:0.7169, TrainAcc:0.6983, ValLoss:3.5743, ValAcc:0.0133
epoch:012, loss:0.6934, TrainAcc:0.6917, ValLoss:3.6612, ValAcc:0.0033
epoch:013, loss:0.6711, TrainAcc:0.6750, ValLoss:3.7738, ValAcc:0.0033
epoch:014, loss:0.6645, TrainAcc:0.6733, ValLoss:3.9418, ValAcc:0.0200
epoch:015, loss:0.6652, TrainAcc:0.6833, ValLoss:4.0934, ValAcc:0.0300
epoch:016, loss:0.6515, TrainAcc:0.6883, ValLoss:4.2498, ValAcc:0.0300
epoch:017, loss:0.6238, TrainAcc:0.7050, ValLoss:4.4304, ValAcc:0.0300
epoch:018, loss:0.6082, TrainAcc:0.7317, ValLoss:4.5820, ValAcc:0.0333
epoch:019, loss:0.6030, TrainAcc:0.7517, ValLoss:4.7110, ValAcc:0.0367
epoch:020, loss:0.5933, TrainAcc:0.7850, ValLoss:4.8053, ValAcc:0.0400
epoch:021, loss:0.5824, TrainAcc:0.8267, ValLoss:4.8781, ValAcc:0.0333
epoch:022, loss:0.5655, TrainAcc:0.8017, ValLoss:4.9006, ValAcc:0.0267
epoch:023, loss:0.5333, TrainAcc:0.8083, ValLoss:4.9148, ValAcc:0.0167
epoch:024, loss:0.5175, TrainAcc:0.8050, ValLoss:4.8788, ValAcc:0.0100
epoch:025, loss:0.4994, TrainAcc:0.8117, ValLoss:4.7670, ValAcc:0.0033
epoch:026, loss:0.4888, TrainAcc:0.8333, ValLoss:4.5965, ValAcc:0.0033

This is where the problem lies.
If you know how to solve this problem, please don't hesitate to tell me.

pytorch_han's People

Contributors

taishan1994 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

pytorch_han's Issues

提供数据集

您好,可以顺便提供一下测试用的数据集ACM3025.mat吗?

关于模型实现上的一些疑惑

您好,最近想把HAN网络构建到自己的一个idea中使用,碰巧看到了您的工作,希望能参考下,但是在模型代码里有一些地方比较困惑想向您请教下:

第一个问题是您复现的HAN网络里元路径这里是怎么定义或计算的?我再代码里似乎没有找到香关的处理过程。

第二个是想确认下HAN网络里关于层级注意力机制的实现您是借用多头注意力机制完成的么?

对模型的输入biases_list和adj_to_bias方法的疑惑

对adj_to_bias方法里有两个疑惑:

  1. mt[g] = np.matmul(mt[g], (adj[g] + np.eye(adj.shape[1])))这个相乘是什么意思,为什么邻接矩阵的对角线要为1然后去和 一个np.empty初始化后的矩阵相乘?
  2. 对mt[g][i][j] > 0.0的元素赋值为1,然后最后返回-1e9 * (1.0 - mt),这个操作是什么意义?

得到了biases_list目的是什么呢?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.