Giter Club home page Giter Club logo

tpb's Introduction

Cross-city Few-shot Traffic Forecasting via Traffic Pattern Bank

[CIKM 2023] In this repository, we presents the code of "Cross-city Few-shot Traffic Forecasting via Traffic Pattern Bank" (TPB).

TPB

Data

The data is in https://drive.google.com/drive/folders/1UrKTgR27YmP9PjJ-FWv4SCDH3zUxtc5R?usp=share_link. Please download it and save them in ./data

Environment

The code is implemented in pytorch 1.10.0, CUDA version 11.3, python 3.7.0.

pip install torch==1.10.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html

Reproducibility

The default configs of the four datasets are set in ./config. To reproduce the result, please run following command:

bash train.sh

or run the experiment on specific dataset (PEMS-BAY as an example):

CUDA_VISIBLE_DEVICES=0 nohup python -u train.py  --config_file ./configs/config_pems.yaml > train_pems.out 2>&1 &

Pre-trained stuff

The pre-trained patch encoder and traffic pattern bank is contained in this repository. The pre-trained patch encoder is in ./save/pretrain_model and the traffic pattern bank is in ./pattern.

You can also pre-train and generate traffic pattern bank on your own by:

# Pre-train
python -u ./pretrain.py --test_dataset ${test_dataset} --data_list ${data_list}
wait
python -u ./patch_devide.py --test_dataset ${test_dataset} --data_list ${data_list}
wait
python -u ./pattern_clustering.py --data_list $data_list --test_dataset ${test_dataset} --sim ${sim} --K ${K}

${data_list} is the source data. For example, if you want to pre-train the encoder in Chengdu, METR-LA and PEMS-BAY, then ${data_list} is chengdu_metr_pems.

${test_dataset} is the dataset you want to build target data on. If you want to build target data on Shenzhen then the ${test_dataset} is shenzhen.

${sim} and ${K} are the clustering hyper-parameter. You can set them by your own.

tpb's People

Contributors

zhyliu00 avatar

Stargazers

Training.L avatar yichizhang avatar han avatar Wang Jiayi avatar Luxuan(Tina) WANG avatar Azusa avatar 余望之 avatar  avatar Sam avatar  avatar Ruiqian HAN avatar Junfeng Woo avatar  avatar  avatar zhuangyuan avatar Chumeng Liang avatar  avatar  avatar

Watchers

 avatar

tpb's Issues

pattern文件夹

您好!在此向您请教,pattern文件夹中的文件是怎么得到的。
祝您一切顺利!

Maml is used instead of Reptile?

It is declared that Reptile is used as the meta learning method in your paper, but it seems that Maml is used in your code instead of Reptile.

Code Error to be corrected

There is something wrong in your code.
When I try to run the code, the following problem could not be eliminated:

Traceback (most recent call last):
  File "train.py", line 128, in <module>
    model_loss ,mse_loss, rmse_loss, mae_loss, mape_loss = rep_model.meta_train_revise(data_spt, matrix_spt, data_qry, matrix_qry)
  File "/root/autodl-nas/TPB/./model/Meta_Models/rep_model_final.py", line 251, in meta_train_revise
    out, y, meta_graph = self.model(data_spt[i], A_gnd)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/autodl-nas/TPB/./model/Meta_Models/rep_model_final.py", line 114, in forward
    raw_emb, Ax = STmodel(raw_x,A_list)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/autodl-nas/TPB/./model/Meta_Models/reconstruction.py", line 384, in forward
    gate = self.gate_convs[i](residual)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 302, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 298, in _conv_forward
    return F.conv1d(input, weight, bias, self.stride,
RuntimeError: Expected 2D (unbatched) or 3D (batched) input to conv1d, but got input of size: [16, 32, 207, 13]

Your work is extremely valuable to me. I would greatly appreciate it if you could make an effort to resolve the bug.

dataset_expand如何得到?

您好!我想向您请教一下,如果我想复现其他数据集,比如pems04,我应该如何得到dataset_expand呢,应该如何数据处理?

Question about clustering parameter K

Thanks for your nice work and code.

I would like to know about the K value. Is your final choose K = 10? For all datasets.
Sorry I didn't find it in your paper, but I saw Figure 5, I guess K =10.

Looking forward to your reply.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.