Comments (8)
I guess maybe you don't train this model for 100 epochs. MTGNN propose a curriculum learning strategy to train multi-step model, i.e. the task_level
gradually +1 for every step_size
. Since we predict 12 steps, the model must iterate enough rounds to increase the task_level to 12, so that all steps are fully trained. Related code is here MTGNN.
If you transpose the origin shape (batch_size, input_window, num_nodes, feature_dim) into (batch_size, feature_dim, num_nodes, input_window), in our code, the max task_level
will equal to feature_dim=1
. Therefore, it is not necessary to train too many rounds to make the model fully trained, so the results appear to be normal.
However, this is inconsistent with the paper authors' approach. In the original author's code, the reason why transpose(1,3)
is executed is because it performs transpose(1,3)
on the predicted target trainy
, and in order to make the predicted value match the shape of the true value, so he needs to perform a shape change. And the shape change actually does not affect the output of inverse_transform
and transform
of scaler
. The original code is here Origin code.
In our code, no changes are made to the shape of trainy
, and therefore no changes are needed to the shape of the model's output y_pred
. The reason for the situation you posted is caused by task_level
, just increase the training epoch
or reduce the batch_size
to reproduce the original performance.
from bigscity-libcity.
Bugs founded:
In MTGNN, The tensor before feeding into inverse_transform should be transposed as (batch_size, feature_dim, num_nodes, input_window).
See the code provided by the original author here:
https://github.com/nnzhan/MTGNN/blob/f811746fa7022ebf336f9ecd2434af5f365ecbf6/trainer.py#L48
That is, the original code
Bigscity-LibCity/libcity/model/traffic_speed_prediction/MTGNN.py
Lines 522 to 523 in c1c95f9
Should be changed as:
y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim].transpose(1, 3))
y_predicted = self._scaler.inverse_transform(y_predicted[..., :self.output_dim].transpose(1, 3))
The performance now looks fine after applying those changes:
MAE | MAPE | MSE | RMSE | masked_MAE | masked_MAPE | masked_MSE | masked_RMSE | R2 | EVAR |
---|---|---|---|---|---|---|---|---|---|
8.940899849 | inf | 425.4656067 | 20.6268177 | 2.323023558 | 0.057006676 | 17.02191544 | 4.125762463 | 0.184079458 | 0.288268507 |
9.279856682 | inf | 440.0405579 | 20.97714424 | 2.607111931 | 0.066930175 | 24.59722519 | 4.959558964 | 0.156134462 | 0.263421774 |
9.541911125 | inf | 453.8035278 | 21.3026638 | 2.801694632 | 0.074403904 | 30.08820343 | 5.485271454 | 0.12967786 | 0.241285801 |
9.75474453 | inf | 465.2220459 | 21.56900597 | 2.958072186 | 0.080727793 | 34.65478516 | 5.88683176 | 0.10787711 | 0.223937809 |
9.918874741 | inf | 473.8552551 | 21.76821709 | 3.082611322 | 0.08576522 | 38.60670471 | 6.213429451 | 0.09118306 | 0.209788263 |
10.09046078 | inf | 484.2091675 | 22.00475311 | 3.204313755 | 0.090506367 | 42.44840622 | 6.515244007 | 0.071050926 | 0.193283856 |
10.19396973 | inf | 489.6725464 | 22.12854576 | 3.287316561 | 0.094344415 | 45.54912567 | 6.749009132 | 0.060823213 | 0.184867501 |
10.30573463 | inf | 496.2115479 | 22.27580643 | 3.372331858 | 0.097648136 | 48.44382095 | 6.960159302 | 0.048238196 | 0.175677776 |
10.43615055 | inf | 506.2624817 | 22.50027657 | 3.444404364 | 0.100696735 | 51.01102066 | 7.142199993 | 0.029032076 | 0.160097897 |
10.46782017 | inf | 505.6869507 | 22.48748398 | 3.502105713 | 0.102586389 | 53.26374817 | 7.298201561 | 0.030235945 | 0.160259724 |
10.54193497 | inf | 509.4328003 | 22.57061768 | 3.561977863 | 0.105110899 | 55.16071701 | 7.427026272 | 0.023078795 | 0.154398263 |
10.56889248 | inf | 508.2257385 | 22.5438633 | 3.615408182 | 0.107188798 | 57.01773834 | 7.551009178 | 0.025389521 | 0.155566633 |
from bigscity-libcity.
Also to add, the author's code also supports task_level
, the code is here. But because of the execution of transpose(1,3)
, his task_level is in the fourth dimension predict[:, :, :, :self.task_level]
, while we are in the second dimension y_predicted[:, :self.task_level, :, :]
, which means that even if the original author's code is used, it is not possible to train the later time steps without training enough rounds.
Later I will add a log of my training process to you. If there are still questions, please continue to leave your comments. Feel free to give us a STAR if our code is helpful, thanks!
from bigscity-libcity.
MAE | MAPE | MSE | RMSE | masked_MAE | masked_MAPE | masked_MSE | masked_RMSE | R2 | EVAR |
---|---|---|---|---|---|---|---|---|---|
8.909078 | inf | 427.3758 | 20.67307 | 2.253886 | 0.054309 | 15.49193 | 3.935979 | 0.172218 | 0.268965 |
9.164458 | inf | 437.0187 | 20.90499 | 2.50127 | 0.062337 | 21.68376 | 4.656582 | 0.153564 | 0.252859 |
9.338023 | inf | 443.2457 | 21.0534 | 2.675945 | 0.068277 | 26.53755 | 5.151461 | 0.141458 | 0.241343 |
9.482356 | inf | 448.7668 | 21.18412 | 2.815347 | 0.073929 | 30.62461 | 5.53395 | 0.130882 | 0.23367 |
9.605302 | inf | 454.0782 | 21.30911 | 2.930262 | 0.078487 | 34.21933 | 5.849729 | 0.120484 | 0.224556 |
9.713979 | inf | 457.7052 | 21.39404 | 3.044403 | 0.082614 | 37.49809 | 6.123569 | 0.113217 | 0.218411 |
9.792989 | inf | 460.9681 | 21.47017 | 3.121814 | 0.086119 | 40.06687 | 6.32984 | 0.107152 | 0.212963 |
9.876019 | inf | 464.0353 | 21.54148 | 3.205873 | 0.089727 | 42.66326 | 6.531712 | 0.101188 | 0.208306 |
9.954038 | inf | 468.3024 | 21.6403 | 3.273369 | 0.092512 | 45.13892 | 6.71855 | 0.093004 | 0.201936 |
10.02133 | inf | 471.5396 | 21.71496 | 3.33448 | 0.095099 | 47.27936 | 6.875999 | 0.086842 | 0.196892 |
10.09106 | inf | 475.0183 | 21.79491 | 3.394903 | 0.097416 | 49.14698 | 7.010491 | 0.080146 | 0.191286 |
10.16447 | inf | 478.4005 | 21.87237 | 3.459933 | 0.09999 | 51.05228 | 7.145088 | 0.073609 | 0.186053 |
from bigscity-libcity.
86194-MTGNN-METR_LA-May-22-2022_11-58-10.log
from bigscity-libcity.
You can mainly focus on the increasing task_level
records in the log.
2022-05-22 11:58:47,795 - INFO - Training: task_level increase from 0 to 1
...
2022-05-22 12:44:58,146 - INFO - Training: task_level increase from 11 to 12
from bigscity-libcity.
Thanks for your patient response!
Yes, you are right, I only train 30 epochs and that is the reason why the first two steps are captured.
I totally miss the curriculum learning part in the MTGNN, and thanks for reminding me that! You can close this issue now. Apologize for my misleading message.
Again, a very nice library and thank you for your time to build it! I have starred it, and hope it would become more popular!
from bigscity-libcity.
Thank you!
from bigscity-libcity.
Related Issues (20)
- ValueError caused by a higher version of numpy `setting an array element with a sequence. The requested array has an inhomogeneous shape after 2 dimensions. The detected shape was (1, 2) + inhomogeneous part.` HOT 2
- libcity for pdformer HOT 3
- why not use gpu HOT 2
- 您好,交通数据原子文件的统一存储格式404了。 HOT 3
- _add_external_information如何使用
- _add_external_information如何使用 HOT 1
- how to visualize the result?
- NYCBIKE20160809数据中是否无经纬度,不能进行可视化?
- ARIMA模型好像有问题
- 如何保存模型预测结果并对模型预测结果做误差下降曲线? HOT 3
- 图卷积实现-torch.geometric模块使用 HOT 3
- map_matching任务大概需要跑多长时间? HOT 1
- 将citibike网站的数据进行原子转换后输入模型,R2为负数 HOT 1
- 关于保存好的模型如何使用 HOT 1
- STSGCN实验结果 HOT 3
- Question about dataset
- weighted sum in masked loss HOT 1
- TPG创建数据时报错
- TPG pickle时出现EOFError: Ran out of input的问题
- AttributeError:
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from bigscity-libcity.