Giter Club home page Giter Club logo

juyongjiang / trendgcn Goto Github PK

View Code? Open in Web Editor NEW
44.0 3.0 1.0 12.9 MB

[CIKM 2023] This is the official source code of "TrendGCN: Enhancing the Robustness via Adversarial Learning and Joint Spatial-Temporal Embeddings in Traffic Forecasting" based on Pytorch.

Home Page: https://arxiv.org/abs/2208.03063

License: MIT License

Python 100.00%
time-series-forecasting adaptive-graph deep-learning dynamic-graph generative-adversarial-network graph-convolutional-networks traffic-forecasting

trendgcn's Introduction

Enhancing the Robustness via Adversarial Learning and Joint Spatial-Temporal Embeddings in Traffic Forecasting

License Python 3.9+ Code style: black arXiv

This is the official Pytorch implementation for our CIKM 2023 paper: "TrendGCN: Enhancing the Robustness via Adversarial Learning and Joint Spatial-temporal Embeddings in Traffic Forecasting".

TrendGCN model architecture
Figure 1. TrendGCN Model Architecture.

Overview

TrendGCN
├── config                   # the configuration of six datasets
    ├── METR-LA.conf
    ├── PEMS-Bay.conf
    ├── PEMS03.conf
    ├── PEMS04.conf
    ├── PEMS07.conf
    └── PEMS08.conf
├── dataset                  # place six dataset folders
    ├── METR-LA
    ├── PEMS-Bay
    ├── PEMS03
    ├── PEMS04
    ├── PEMS07
    └── PEMS08
├── model
    ├── discriminator.py     
    └── generator.py         
├── utils
    ├── adj_dis_matrix.py    # construct adjacent matrix 
    ├── metrics.py           # evaluation metrics
    ├── norm.py              # data normalization
    └── util.py              # useful tools
├── dataloader.py            # load dataset
├── LICENSE                  
├── main.py                  # run
├── README.md                # detailed illustration of model training and testing
├── requirements.yml         # environment dependencies
└── trainer.py               # training and testing procedure

Environment

Make sure you have Python>=3.8 and Pytorch>=1.8 installed on your machine.

  • Pytorch 1.8.1
  • Python 3.8.*

Install python dependencies by running:

conda env create -f requirements.yml
# After creating environment, activate it
conda activate trendgcn

Datasets Preparation

In our work, we evaluate proposed models on six real-world traffic benchmark dataset, including: PEMS03, PEMS04, PEMS07, PEMS08, PEMS-Bay, and METR-LA. Then, place them into dataset folder.

Train and Test

Step 1:

Modifying the following variables in main.py script.

#********************************************************#
Mode = 'Train'     # or Test (loading best_model.pth to evaluate on test dataset)
DATASET = 'PEMS04' # PEMS03 or PEMS04 or PEMS07 or PEMS08
#********************************************************#

Step 2:

Modifying corresponding configuration for used dataset at config/dataset_name.conf, e.g., config/PEMS04.conf.

[data]
num_nodes = 307
lag = 12
horizon = 12
val_ratio = 0.2
test_ratio = 0.2
tod = False
normalizer = std
column_wise = False
default_graph = True
...

Step 3:

python -u main.py --gpu_id=1 2>&1 | tee exps/PEMS04.log

Note that for descriptions of more arguments, please run python main.py -h. After training, the model will be evalutated on test dataset automatically. The results for 1 ~ 12 horizon prediction will be shown in terminal or can be found in the end of exps/PEMS04.log.

Horizon 01, MAE: 17.16, RMSE: 27.69, MAPE: 11.2595%
Horizon 02, MAE: 17.57, RMSE: 28.50, MAPE: 11.4979%
Horizon 03, MAE: 17.98, RMSE: 29.21, MAPE: 11.7343%
Horizon 04, MAE: 18.29, RMSE: 29.76, MAPE: 11.9162%
Horizon 05, MAE: 18.54, RMSE: 30.23, MAPE: 12.0755%
Horizon 06, MAE: 18.80, RMSE: 30.68, MAPE: 12.2436%
Horizon 07, MAE: 19.04, RMSE: 31.09, MAPE: 12.4009%
Horizon 08, MAE: 19.24, RMSE: 31.43, MAPE: 12.5158%
Horizon 09, MAE: 19.43, RMSE: 31.76, MAPE: 12.6333%
Horizon 10, MAE: 19.62, RMSE: 32.05, MAPE: 12.7421%
Horizon 11, MAE: 19.82, RMSE: 32.37, MAPE: 12.8842%
Horizon 12, MAE: 20.20, RMSE: 32.88, MAPE: 13.1226%
Average Horizon, MAE: 18.81, RMSE: 30.68, MAPE: 12.2522%

More prediction results are stored in exps/META-LA.log, exps/PeMS-BAY.log, exps/PEMS03.log, exps/PEMS07.log, and exps/PEMS08.log.

Experimental Results

The prediction average horizon results of TrendGCN on six datasets are as follows:

Visualization



Figure 2. Comparison of short (12 steps)-(a)(c)(e)(g) and long (288 steps)-(b)(d)(f)(h) term prediction curves between STSGCN, AGCRN, and our TrendGCN on a snapshot of the test data of four datasets. Note that, the predicted time series for the whole day period (288 steps) is simply obtained by concatenating all the short-term predictions (12 steps) along the time axis (and remove overlaps), which is a common practice widely used in existing literatures, so that a better visualization of the prediction quality during different time of the day can be presented.



Figure 3. Visualization of 2D projection of UMAP on spatial embeddings (Upper) and the heatmap of learned graphs (Lower) at t = {2, 4, 6, 8, 10, 12} time steps.

Citation

If you use the data or code in this repo, please cite the repo.

@article{jiang2022dynamic,
  title={Enhancing the Robustness via Adversarial Learning and Joint Spatial-Temporal Embeddings in Traffic Forecasting},
  author={Jiang, Juyong and Wu, Binqing and Chen, Ling and Zhang, Kai and Kim, Sunghun},
  journal={arXiv preprint arXiv:2208.03063},
  year={2022}
}

trendgcn's People

Contributors

juyongjiang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

lizhuhua310

trendgcn's Issues

Experiment results

Hi, I'm glad you're exposing the code. May I ask that the original paper said that the average result of 5 experiments was used as the final result, but the run log in the readme file of your open code seems to be the best experimental result, which is completely consistent with the result of PEMS04 in the paper.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.