Comments (5)
真实值当中也有很多 0。Loss 确实很小,我训练时的结果也只是从 50 多下降到 10多。
这篇论文没有给出开源代码。也没能找到数据集。在应用到新数据集上时,应该要进行调参
模型的实现就在 /libcity/model/traffic_od_prediction/GEML.py
当中,使用喜闻乐见的 pytorch 框架。
(注意模型当中有一个 pSkip-LSTM 层,这个层据作者所说,是跳过 p 步的 LSTM,这个实现论文里没有详细阐述,是我自己理解的。)
我过后会尝试调参,或者检查模型实现。
from bigscity-libcity.
学长您好,我想问下您有单独的GEML模型代码吗?这个lib东西太多我理不清楚,能单独分享一下吗?谢谢
from bigscity-libcity.
@l782993610 感谢您的回复,确实数据集中有很多0值,但也有稀疏的非零值。根据您的回复,我重新clone了代码,并且重新按照给的配置参数训练了模型,同样模型loss从一开始就稳定在7附近,然后我将预测结果npz进行了可视化
,绝大部分预测结果是图中的右侧即全是0的,而不是如真值(左图)一样有零星的非零值。我模型训练的log如下:
GEML-NYCTAXI202004-202006_OD-Dec-02-2021_10-02-30.log
from bigscity-libcity.
GEML 实现时参考代码,这里是另一个人做的实现,他不是作者,我也不是,只是为了作为其论文的 BaseLine 使用。注意他这个实现里面是没有多任务预测部分的,然后两个语义嵌入的融合也是直接相加而不是论文提到的拼接。
只看我那个代码实现文件应该就可以,不需要了解框架的细节。从 GEML 类开始看。
以下是我当时实现的论文笔记。
GEML
输入:T 个历史时间段中 N 个区域间 OD 矩阵构成的张量
模型结构:
- 图卷积层(GCN)
-
构建邻接图:定义了两种邻居,并在这两种邻居定义下构成的两个邻接图。
-
地理邻居
:$\Phi_i$ 网格中相邻节点是邻居。 -
语义邻居
:$\Omega_i^{t'}$ 时间片内存在交通流连接的节点是邻居。注解
:每一个时间步,都会构造一个语义邻接图。
-
-
加权聚合:在两个邻接图上分别对节点特征
$\mathrm{f}^i_{t'}$ 进行聚合操作,作者还引入了一组预加权权重,加速训练。-
地理邻居:以距离作为预加权权重,得到节点地理邻居聚合后的特征 $\mathrm{r}^i_{t'} = \sigma(\mathrm{W}g\cdot(\mathrm{f}{t'}^i+\sum\limits_{g_j\in\Phi_i}^{} \cfrac{\mathrm{dis}(g_i,g_j)}{\sum\limits_{g_k\in \Phi_i} \mathrm{dis}(g_i,g_k)}\mathrm{f}_{t'}^j))$。
-
语义邻居:以度数作为预加权权重,得到节点语义邻居聚合后特征 。
$\mathrm{s}^i_{t'} = \sigma(\mathrm{W}s\cdot(\mathrm{f}{t'}^i+\sum\limits_{g_j\in\Omega^i_{t'}}^{} \cfrac{\mathrm{degree}(g_j)}{\sum\limits_{g_k\in \Phi_i} \mathrm{degree}(g_k)+\epsilon}\mathrm{f}_{t'}^j))$,其中
$\epsilon$ 用于避免除零。 -
通过两层图卷积,得到节点嵌入表示,进行拼接融合
$v^i_t = [r^i_{t'},s^i_{t'}]$ 。 -
输出形状:$\rm \mathbb{R}^{T\times N\times 2E}$。
-
-
- 时序预测层 (P-Skip-LSTM)
-
将各个时间步嵌入表示构成的序列输入 LSTM 网络进行预测。
-
LSTM 网络是经过修改的,跳过 p 个时间步以有效捕获长程模式。
若原始的 LSTM 表示为 $\mathrm{h}t^i = \mathrm{LSTM}(v_t^i,h{t-1})$,
则新 P-Skip-LSTM 表示为 $\mathrm{h}t^i = \mathrm{LSTM}^p(v_t^i,h{t-p})$。
-
输出形状:$\rm \mathbb{R}^{N\times 2E}$。
-
- 多任务预测层
- 多任务学习:同时对始发地需求,目的地需求,OD 流进行预测。
- 始发地需求预测值:对嵌入特征加权得 $\hat{p}i = \mathrm{w}{in}^\mathrm{T}\mathrm{h}_t^i$。
- 目的地需求预测值:对嵌入特征加权得 $\hat{p}i = \mathrm{w}{in}^\mathrm{T}\mathrm{h}_t^i$。
- OD 流预测值:$\hat{m}_{ij}= (\mathrm{W}_m \mathrm{h}_t^j)^\mathrm{T}\mathrm{h}_t^i,\mathrm{W}_m\in \mathbb{R}^{\rm 2E\times 2E}$。
- 目标函数:对三个任务的预测分别使用 MSE 后进行加权,权重为 0.25, 0.25, 0.5
输出:OD 矩阵预测值
对我的 shit mount 代码不理解可联系 vx 17786408192。
from bigscity-libcity.
俺们自己训练的结果
GEML-NYCTAXI202004-202006_OD-Nov-23-2021_10-12-07.log
@YichenWHU
这个模型模型的想法非常新颖,但是最大的问题就是没找着原论文的数据集。这个 NYC 数据集是我们自己处理的,没调过参。
from bigscity-libcity.
Related Issues (20)
- CSTN缺少weather embedding HOT 1
- 刚刚问题的补充:内存问题 HOT 7
- 关于STAN模型运行一个epoch的时间问题 HOT 21
- 你好 运行DCRNN时为什么填GRU 是文档错误吗 HOT 23
- 作者您好,MTGNN运行报错 HOT 7
- windows系统是否可以安装webtool呢?必须linux系统吗 HOT 1
- Question about the evaluation metrics in traj_loc_pred task HOT 2
- 关于传统模型 HOT 5
- How to select traffic flow / speed / occupancy in PEMSD4 & 8 HOT 1
- How to set the evaluation mode for traffic prediction? single or average HOT 2
- 请问test文件夹下提供的四个非深度学习模型ARIMA和SVR等能否修改为只预测PEMS数据集中的流量特征 HOT 3
- How to train DCRNN, STGCN, MTGNN model use different datasets HOT 7
- 请问如何把评价指标MAPE调出来,现在显示的是inf HOT 1
- 针对PDFormer模型的相关问题 HOT 2
- 关于ASTGCN模型运行PEMSD模型计算拉普拉斯矩阵时出现错误; HOT 2
- 关于如何用其它模型跑PDFormer中的grid-based dataset NYCTaxi的问题 HOT 2
- PDFormer如何同时得到inflow和outflow的evaluation metrics HOT 1
- 关于libcity news中提到的几篇论文的实验结果差距 HOT 2
- How to adapt the models to grid data and the correspondence between data and models? HOT 3
- 请问如何画出像Pdformer一样或类似的热力图 HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from bigscity-libcity.