diff --git a/README.md b/README.md index 75bf62ff..9a678c52 100644 --- a/README.md +++ b/README.md @@ -40,24 +40,26 @@ Our model consists of three key components: Generator (G), Pooling Module (PM) a ## Setup -All code was developed and tested on Ubuntu 16.04 with Python 3.5 and PyTorch 0.4. +All code was developed and tested on Ubuntu 22.04 with Python 3.10 and torch -You can setup a virtual environment to run the code like this: +You can setup a virtual conda environment to run the code like this: ```bash -python3 -m venv env # Create a virtual environment -source env/bin/activate # Activate virtual environment -pip install -r requirements.txt # Install dependencies -echo $PWD > env/lib/python3.5/site-packages/sgan.pth # Add current directory to python path +conda create -n test python=3.10 -y # Create a virtual environment +conda activate test # Activate virtual environment # Work for a while ... -deactivate # Exit virtual environment +conda deactivate # Exit virtual environment ``` -## Pretrained Models -You can download pretrained models by running the script `bash scripts/download_models.sh`. This will download the following models: +## clone repo and download files -- `sgan-models/_.pt`: Contains 10 pretrained models for all five datasets. These models correspond to SGAN-20V-20 in Table 1. -- `sgan-p-models/_.pt`: Contains 10 pretrained models for all five datasets. These models correspond to SGAN-20VP-20 in Table 1. +```bash +git clone https://github.com/bharath5673/Social-GAN.git +cd Social-GAN +pip install -r requirements.txt # Install dependencies +sh scripts/download_data.sh +sh scripts/download_models.sh +``` Please refer to [Model Zoo](MODEL_ZOO.md) for results. @@ -65,9 +67,14 @@ Please refer to [Model Zoo](MODEL_ZOO.md) for results. You can use the script `scripts/evaluate_model.py` to easily run any of the pretrained models on any of the datsets. For example you can replicate the Table 1 results for all datasets for SGAN-20V-20 like this: ```bash -python scripts/evaluate_model.py \ - --model_path models/sgan-models +cd scripts +sh run_eval.sh ``` ## Training new models + +```bash +cd scripts +sh run_traj.sh +``` Instructions for training new models can be [found here](TRAINING.md). diff --git a/requirements.txt b/requirements.txt index fb971f9a..1b1c82ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ -attrdict==2.0.0 -numpy==1.14.5 -Pillow==6.2.0 -pkg-resources==0.0.0 -six==1.11.0 -torch==0.4.0 -torchvision==0.2.1 +scripts/attrdict-2.0.1-py2.py3-none-any.whl +numpy +Pillow +six +torch +torchvision diff --git a/scripts/attrdict-2.0.1-py2.py3-none-any.whl b/scripts/attrdict-2.0.1-py2.py3-none-any.whl new file mode 100644 index 00000000..4dd9de1d Binary files /dev/null and b/scripts/attrdict-2.0.1-py2.py3-none-any.whl differ diff --git a/scripts/evaluate_model.py b/scripts/evaluate_model.py index c791b906..a6a175a2 100644 --- a/scripts/evaluate_model.py +++ b/scripts/evaluate_model.py @@ -4,6 +4,9 @@ from attrdict import AttrDict +import sys +sys.path.append('../') + from sgan.data.loader import data_loader from sgan.models import TrajectoryGenerator from sgan.losses import displacement_error, final_displacement_error diff --git a/scripts/run_eval.sh b/scripts/run_eval.sh new file mode 100644 index 00000000..fa120628 --- /dev/null +++ b/scripts/run_eval.sh @@ -0,0 +1 @@ +python3 evaluate_model.py --model_path ../models/sgan-models \ No newline at end of file diff --git a/scripts/train.py b/scripts/train.py index 61b60e02..afa7da81 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -11,6 +11,9 @@ import torch.nn as nn import torch.optim as optim + +import sys +sys.path.append('../') from sgan.data.loader import data_loader from sgan.losses import gan_g_loss, gan_d_loss, l2_loss from sgan.losses import displacement_error, final_displacement_error @@ -109,7 +112,7 @@ def get_dtypes(args): def main(args): - os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num + # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num train_path = get_dset_path(args.dataset_name, 'train') val_path = get_dset_path(args.dataset_name, 'val') diff --git a/sgan/__pycache__/losses.cpython-310.pyc b/sgan/__pycache__/losses.cpython-310.pyc new file mode 100644 index 00000000..6223bbc8 Binary files /dev/null and b/sgan/__pycache__/losses.cpython-310.pyc differ diff --git a/sgan/__pycache__/models.cpython-310.pyc b/sgan/__pycache__/models.cpython-310.pyc new file mode 100644 index 00000000..277a7a83 Binary files /dev/null and b/sgan/__pycache__/models.cpython-310.pyc differ diff --git a/sgan/__pycache__/utils.cpython-310.pyc b/sgan/__pycache__/utils.cpython-310.pyc new file mode 100644 index 00000000..2fe7c837 Binary files /dev/null and b/sgan/__pycache__/utils.cpython-310.pyc differ diff --git a/sgan/data/__pycache__/__init__.cpython-310.pyc b/sgan/data/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 00000000..0c902644 Binary files /dev/null and b/sgan/data/__pycache__/__init__.cpython-310.pyc differ diff --git a/sgan/data/__pycache__/loader.cpython-310.pyc b/sgan/data/__pycache__/loader.cpython-310.pyc new file mode 100644 index 00000000..6a46c19b Binary files /dev/null and b/sgan/data/__pycache__/loader.cpython-310.pyc differ diff --git a/sgan/data/__pycache__/trajectories.cpython-310.pyc b/sgan/data/__pycache__/trajectories.cpython-310.pyc new file mode 100644 index 00000000..4f5e75d5 Binary files /dev/null and b/sgan/data/__pycache__/trajectories.cpython-310.pyc differ diff --git a/sgan/models.py b/sgan/models.py index ca0efd67..e7fd0e48 100644 --- a/sgan/models.py +++ b/sgan/models.py @@ -60,7 +60,7 @@ def forward(self, obs_traj): """ # Encode observed Trajectory batch = obs_traj.size(1) - obs_traj_embedding = self.spatial_embedding(obs_traj.view(-1, 2)) + obs_traj_embedding = self.spatial_embedding(obs_traj.reshape(-1, 2)) obs_traj_embedding = obs_traj_embedding.view( -1, batch, self.embedding_dim )