diff --git a/gscan_metaseq2seq/util/dataset.py b/gscan_metaseq2seq/util/dataset.py index ee054ad..4e92721 100644 --- a/gscan_metaseq2seq/util/dataset.py +++ b/gscan_metaseq2seq/util/dataset.py @@ -1,7 +1,7 @@ import numpy as np import torch from torch.utils.data import Dataset, IterableDataset - +import random from .padding import pad_to, recursive_pad_array @@ -69,3 +69,34 @@ def __getitem__(self, i): self.indices = torch.randperm(len(self.dataset)) return self.dataset[self.indices[i]] + + +class AddRandomNoiseDataset(Dataset): + def __init__(self, dataset, ACTION2IDX, prob=0.01): + super().__init__() + self.dataset = dataset + self.ACTION2IDX = ACTION2IDX + self.prob = prob + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, i): + item = self.dataset[i] + + rand = random.uniform(0, 1) + + def add_random_noise(instruction): + if len(np.where(instruction == self.ACTION2IDX["walk"])[0]) > 0: + walk_index = random.choice(np.where(instruction == self.ACTION2IDX["walk"])[0]) + instruction[walk_index] = self.ACTION2IDX["turn left"] + + return instruction + + return instruction + + if rand <= self.prob: + return item[0], add_random_noise(item[1]), item[2] + + return item + diff --git a/gscan_metaseq2seq/util/logging.py b/gscan_metaseq2seq/util/logging.py index a9219d6..369a93c 100644 --- a/gscan_metaseq2seq/util/logging.py +++ b/gscan_metaseq2seq/util/logging.py @@ -2,7 +2,7 @@ import os from pytorch_lightning.loggers import CSVLogger -from pytorch_lightning.loggers.logger import rank_zero_experiment +from pytorch_lightning.loggers.base import rank_zero_experiment def get_most_recent_version(experiment_dir): diff --git a/requirements.txt b/requirements.txt index 428f491..f570d9d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,69 @@ -matplotlib>=3.3.3 -seaborn>=0.9.0 -h5py>=3.7.0 -pandas>=1.1.5 -positional-encodings>=4.0 -pytorch-lightning>=1.1.7 -rliable>=1.0.0 -scikit-learn>=0.24.1 -scipy>=1.7.0 -torch>=1.8.0 -torchvision>=0.11.2 -tqdm>=4.57.0 -torchmetrics>=0.4.0 +absl-py==1.4.0 +aiohttp==3.8.3 +aiosignal==1.3.1 +arch==5.3.0 +async-timeout==4.0.2 +asynctest==0.13.0 +attrs==22.2.0 +cachetools==5.3.0 +certifi==2022.12.7 +charset-normalizer==2.1.1 +cloudpickle==2.2.1 +cycler==0.11.0 +fonttools==4.38.0 +frozenlist==1.3.3 +fsspec==2023.1.0 +future==0.18.3 +google-auth==2.16.0 +google-auth-oauthlib==0.4.6 +grpcio==1.51.1 +gym==0.26.2 +gym-notices==0.0.8 +h5py==3.8.0 +idna==3.4 +importlib-metadata==6.0.0 +joblib==1.2.0 +kiwisolver==1.4.4 +Markdown==3.4.1 +MarkupSafe==2.1.2 +matplotlib==3.5.3 +multidict==6.0.4 +numpy==1.21.6 +oauthlib==3.2.2 +packaging==23.0 +pandas==1.3.5 +patsy==0.5.3 +Pillow==9.4.0 +positional-encodings==6.0.1 +property-cached==1.6.4 +protobuf==3.20.3 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pyDeprecate==0.3.1 +pyparsing==3.0.9 +python-dateutil==2.8.2 +pytorch-lightning==1.5.10 +pytz==2022.7.1 +PyYAML==6.0 +requests==2.28.2 +requests-oauthlib==1.3.1 +rliable==1.0.8 +rsa==4.9 +scikit-learn==1.0.2 +scipy==1.7.3 +seaborn==0.12.2 +six==1.16.0 +statsmodels==0.13.5 +tensorboard==2.11.2 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.1 +threadpoolctl==3.1.0 +torch==1.7.1 +torchmetrics==0.10.3 +torchvision==0.8.2 +tqdm==4.64.1 +typing_extensions==4.4.0 +urllib3==1.26.14 +Werkzeug==2.2.2 +yarl==1.8.2 +zipp==3.12.0 diff --git a/scripts/train_transformer.py b/scripts/train_transformer.py index 8cc0161..d01fd1a 100644 --- a/scripts/train_transformer.py +++ b/scripts/train_transformer.py @@ -12,7 +12,7 @@ from pytorch_lightning.loggers import TensorBoardLogger from gscan_metaseq2seq.models.embedding import BOWEmbedding -from gscan_metaseq2seq.util.dataset import PaddingDataset, ReshuffleOnIndexZeroDataset +from gscan_metaseq2seq.util.dataset import PaddingDataset, ReshuffleOnIndexZeroDataset, AddRandomNoiseDataset from gscan_metaseq2seq.util.load_data import load_data from gscan_metaseq2seq.util.logging import LoadableCSVLogger from gscan_metaseq2seq.util.scheduler import transformer_optimizer_config @@ -51,7 +51,6 @@ def __init__( nhead=nhead, dim_feedforward=embedding_dim * 4, dropout=dropout_p, - norm_first=norm_first, ), num_layers=nlayers, ) @@ -117,7 +116,6 @@ def __init__( dim_feedforward=hidden_size * 4, dropout=dropout_p, nhead=nhead, - norm_first=norm_first, ), num_layers=nlayers, ) @@ -419,6 +417,8 @@ def main(): args.train_demonstrations, args.valid_demonstrations_directory, args.dictionary ) + train_demonstrations = AddRandomNoiseDataset(train_demonstrations, ACTION2IDX) + IDX2WORD = {i: w for w, i in WORD2IDX.items()} IDX2ACTION = {i: w for w, i in ACTION2IDX.items()}