https://github.com/felixykliu/NAOMI
Tip revision: 7c61e07394029e502506d0dbc44fe1d622199cc9 authored by yuqirose on 25 October 2019, 22:25:24 UTC
Update README.md
Update README.md
Tip revision: 7c61e07
README.md
# NAOMI
Code for NeurIPS 2019 paper titled [NAOMI: Non-Autoregressive Multiresolution Sequence Imputation](https://arxiv.org/abs/1901.10946)
Code is written with PyTorch v0.4.1 (Python 3.6.5). Billiards data can be downloaded [here](https://drive.google.com/open?id=17Ov4nwshLbn13w8qLuH8LNvzXzMTcjJt), basketball data is available from [STATS](https://www.stats.com/data-science/).
## To train the model:
First open visdom, then adjust hyperparameters in `train_model.sh` and run the shell file.
## Detailed explanations of hyperparameters:
• `--model`: “NAOMI” or “SingleRes”
• `--task`: “basketball” or “billiard”
• `--y_dim`: 10 for basketball and 2 for billiard
• `--rnn_dim` and `--n_layers`: gru cell size for all models, including forward and backward rnns
• `--dec1_dim` to `--dec16_dim`: For NAOMI, these values correspond to dimensions of different decoders. For SingleRes, only dec1_dim is used for decoder.
• `--pre_start_lr`: initial learning rate for supervised pretrain
• `--pretrain`: supervised pretrain epochs
• `--highest`: largest stepsize for NAOMI decoders, should be 2^n
• `--discrim_rnn_dim` and `--discrim_layers`: discriminator rnn size
• `--policy_learning_rate`: learning rate for generator in adversarial training
• `--discrim_learning_rate`: learning rate for discriminator in adversarial training
• `--pretrain_disc_iter`: number of iterations to pretrain discriminator
• `--max_iter_num`: number of adversarial training iterations
## Citation
If you find this repository, e.g., the code and the datasets, useful in your research, please cite the following paper:
```
@inproceedings{liu2019naomi,
title={NAOMI: Non-Autoregressive Multiresolution Sequence Imputation},
author={Liu, Yukai and Yu, Rose and Zheng, Stephan and Zhan, Eric and Yue, Yisong},
booktitle={Advances in Neural Information Processing Systems(NeurIPS '19)},
year={2019}
}
```