Giter Club home page Giter Club logo

gpd's Introduction

Spatio-Temporal Few-Shot Learning via Diffusive Neural Network Generation

model framework

The official implementation of the ICLR 2024 paper entitled "Spatio-Temporal Few-Shot Learning via Diffusive Neural Network Generation".

In this project, we propose a novel framework, GPD, which performs generative pre-training on a collection of model parameters optimized with data from source cities. Our proposed approach recasts spatio-temporal graph transfer learning as pre-training a generative hypernetwork, which generates tailored model parameters guided by prompts. Our framework has the potential to revolutionize smart city applications in data-scarce environments and contribute to more sustainable and efficient urban development.

Installation

Environment

  • Tested OS: Linux
  • Python >= 3.8
  • torch == 1.12.0
  • torch_geometric == 2.2.0
  • Tensorboard

Dependencies:

  1. Install Pytorch with the correct CUDA version.
  2. Use the pip install -r requirements.txt command to install all of the Python modules and packages used in this project.

Data

The data used for training and evaluation can be found in Time-Series data. After downloading the data, move them to ./Data.

For each city, we provide the following data:

  • Graph data: It records the adjacency matrix of the spatiotemporal graph.
  • Time series data: It records the temporal sequential data for each node.

We provide two time-series datasets: crowd flow (including DC, BM, man) and traffic speed (including metr-la, pems-bay, shenzhen, chengdu_m).

The details of these two data sets are as follows:

datasets information

Model Training

To train node-level models with the traffic dataset, run:

cd Pretrain

CUDA_VISIBLE_DEVICES=0 python main.py --taskmode task4 --model v_GWN --test_data metr-la --ifnewname 1 --aftername TrafficData

After full-trained, run Pretrain\PrepareParams\model2tensor.py to extract parameters from the trained model. And put the params-dataset in ./Data.

To train diffusion model and generate the parameters of the target city:

cd GPD

CUDA_VISIBLE_DEVICES=0 python 1Dmain.py --expIndex 140 --targetDataset metr-la --modeldim 512 --epochs 80000 --diffusionstep 500 --basemodel v_GWN --denoise Trans1

  • expIndex assigns a special number to the experiment.
  • targetDataset specifies the target dataset, which can be selected from ['DC', 'BM', 'man', 'metr-la', 'pemes-bay', 'shenzhen', 'chengdu_m'].
  • modeldim specifies the hidden dim of the Transformer.
  • epochs specifies the number of iterations.
  • diffusionstep specifies the total steps of the diffusion process.
  • basemodel specifies the spatio-temporal graph model, which can be selected from ['v_STGCN5', 'v_GWN'].
  • denoise model specifies the conditioning strategies, which can be selected from ['Trans1', 'Trans2', 'Trans3', 'Trans4', 'Trans5'].
    • Trans1: Pre-conditioning with inductive bias.
    • Trans2: Pre-conditioning.
    • Trans3: Pre-adaptive conditioning.
    • Trans4: Post-adaptive Conditioning.
    • Trans5: Adaptive norm conditioning.

conditioning

The sample result is in GPD/Output/expXX/.

Finetune and Evaluate

To finetune the generated parameters of the target city and evaluate, run:

cd Pretrain

CUDA_VISIBLE_DEVICES=0 python main.py --taskmode task7 --model v_GWN --test_data metr-la --ifnewname 1 --aftername finetune_7days --epochs 600 --target_days 7

  • taskmode 'task7' means finetune after diffusion sampling.
  • model specifies the spatio-temporal graph model, which can be selected from ['v_STGCN5', 'v_GWN'].
  • test_data specifies the dataset, which can be selected from ['DC', 'BM', 'man', 'metr-la', 'pemes-bay', 'shenzhen', 'chengdu_m'].
  • ifnewname assign 1 to better distinguish the results of the current experiment.
  • aftername Use with --ifnewname 1 to give an identification name to the log file and results folder of the current experiment.
  • epochs specifies the number of iterations.
  • target_days specifies the amount of data used in finetune stage.

overall instructions

Let me give an example of the overall instructions. If you want to set 'metr-la' as target city:

  • In pretrain: set the test_data as 'PMS-Bay', 'Didi-Chengdu', and 'Didi-Shenzhen' respectively to pretrain the models of other three source cities.
  • In Diffusion: set the targetDataset as 'metr-la'.
  • In finetune: set the test_dataset as 'metr-la'.

Since finetune and pretraining share the same code framework and use the same set of parameter names, this can be a little confusing and I will try to make the distinction between them in later versions of the code.

gpd's People

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.