Skip to content

Commit

Permalink
Human Pose additions
Browse files Browse the repository at this point in the history
  • Loading branch information
KardelenCeren committed May 18, 2023
1 parent b320d91 commit d10cadd
Show file tree
Hide file tree
Showing 10 changed files with 687 additions and 59 deletions.
17 changes: 17 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

src/data/MNIST/raw/train-labels-idx1-ubyte
*.gz
src/data/MNIST/raw/train-images-idx3-ubyte
*.iml
*.xml
*.pyc
*.meta
src/data/cifar-10-batches-py/data_batch_1
src/data/cifar-10-batches-py/data_batch_3
src/data/cifar-10-batches-py/data_batch_2
src/data/cifar-10-batches-py/data_batch_5
*.html
src/data/cifar-10-batches-py/data_batch_4
src/data/cifar-10-batches-py/test_batch
src/data/MNIST/raw/t10k-images-idx3-ubyte
src/data/MNIST/raw/t10k-labels-idx1-ubyte
3 changes: 3 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 4 additions & 12 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,9 @@ Base dataset = CIFAR10
Incremental dataset = MNIST

### "Base":
Test with base classes:70.28%
Test with incr. classes:3.16%

### "Freeze":
Test with base classes:70.28%
Test with incr. classes:3.16%

### "AddRegularization":
Test with base classes:15.44%
Test with incr. classes:93.2%
Test with base classes:67.27%
Test with incr. classes:14.6%

### "LearningWithoutForgetting":
Test with base classes:23.31%
Test with incr. classes:94.85%
Test with base classes:33.12%
Test with incr. classes:94.14%
19 changes: 12 additions & 7 deletions src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
from torch.utils.data import DataLoader, SubsetRandomSampler
from torch.utils.data.dataset import Subset

import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from torch.utils.data import DataLoader, SubsetRandomSampler
from torch.utils.data.dataset import Subset


class IncrementalData:
def __init__(self, dataset_name="CIFAR10", incremental_dataset_name="MNIST", batch_size=128, incremental_class=0,
Expand Down Expand Up @@ -38,7 +45,7 @@ def __init__(self, dataset_name="CIFAR10", incremental_dataset_name="MNIST", bat
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

incr_train_transform = transforms.Compose( # transforms MNIST into CIFAR-sized images
incr_train_transform = transforms.Compose(
[transforms.Resize(32),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
Expand All @@ -63,6 +70,9 @@ def __init__(self, dataset_name="CIFAR10", incremental_dataset_name="MNIST", bat
download=True, transform=val_test_transform
)

targets_train = torch.tensor(self.train_dataset.targets)
targets_test = torch.tensor(self.test_dataset.targets)

if not choose_incr_from_dataset:
self.incr_train_dataset = self.incr_dataset(
root=data_root, train=True,
Expand Down Expand Up @@ -98,9 +108,6 @@ def __init__(self, dataset_name="CIFAR10", incremental_dataset_name="MNIST", bat
base_classes = [x for x in range(self.class_size) if x != self.incremental_class]
incremental_classes = [self.incremental_class] # more classes, more steps?

targets_train = torch.tensor(self.train_dataset.targets)
targets_test = torch.tensor(self.test_dataset.targets)

base_train_idx = 0
base_test_idx = 0
for base_class in base_classes:
Expand All @@ -114,7 +121,6 @@ def __init__(self, dataset_name="CIFAR10", incremental_dataset_name="MNIST", bat
incr_test_idx += targets_test == incr_class

# define dataloaders
# with only the base classes
self.base_train_loader = DataLoader(
Subset(self.train_dataset, np.where(base_train_idx == 1)[0]),
batch_size=batch_size, num_workers=2
Expand All @@ -135,9 +141,8 @@ def __init__(self, dataset_name="CIFAR10", incremental_dataset_name="MNIST", bat
Subset(self.test_dataset, np.where(incr_test_idx == 1)[0]),
batch_size=batch_size, num_workers=2
)

# with all classes available
self.all_test_loader = DataLoader(
self.test_dataset,
batch_size=batch_size, num_workers=2
)
)
Loading

0 comments on commit d10cadd

Please sign in to comment.