This is a TensorFlow implementation of Diffusion Convolutional Recurrent Neural Network in the following paper:
Yaguang Li, Rose Yu, Cyrus Shahabi, Yan Liu, Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting, ICLR 2018.
- scipy>=0.19.0
- numpy>=1.12.1
- pandas>=0.19.2
- tensorflow>=1.3.0
- pyaml
Dependency can be installed using the following command:
pip install -r requirements.txt
The traffic data file for Los Angeles, i.e., df_highway_2012_4mon_sample.h5
, is available at Google Drive or Baidu Yun, and should be
put into the data/METR-LA
folder.
Besides, the locations of sensors are available at data/sensor_graph/graph_sensor_locations.csv.
python -m scripts.generate_training_data --output_dir=data/METR-LA
The generated train/val/test dataset will be saved at data/METR-LA/{train,val,test}.npz
.
python run_demo.py
The generated prediction of DCRNN is in data/results/dcrnn_predictions_[1-12].h5
.
python dcrnn_train.py --config_filename=data/model/dcrnn_config.yaml
Each epoch takes about 5min with a single GTX 1080 Ti.
As the currently implementation is based on pre-calculated road network distances between sensors, it currently only
supports sensor ids in Los Angeles (see data/sensor_graph/sensor_info_201206.csv
).
python -m scripts.gen_adj_mx --sensor_ids_filename=data/sensor_graph/graph_sensor_ids.txt --normalized_k=0.1\
--output_pkl_filename=data/sensor_graph/adj_mx.pkl
More details are being added ...
If you find this repository useful in your research, please cite the following paper:
@inproceedings{li2018dcrnn_traffic,
title={Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting},
author={Li, Yaguang and Yu, Rose and Shahabi, Cyrus and Liu, Yan},
booktitle={International Conference on Learning Representations (ICLR '18)},
year={2018}
}