From 228c8f37b90cdc5d18dee8b757bcefa47d9ab37d Mon Sep 17 00:00:00 2001 From: nehaprakriya <81734507+nehaprakriya@users.noreply.github.com> Date: Sun, 27 Nov 2022 15:16:59 -0800 Subject: [PATCH] GPU Code --- resnet.py | 4 +- resnet_quant.py | 4 +- resnet_quant_gpu.py | 470 +++++++++++++++++++++++++++++++++ run_resnet20_partition.sh | 33 +++ train_resnet.py | 534 ++++++++++++++++++++++++-------------- util.py | 445 ++++++++++++------------------- 6 files changed, 1018 insertions(+), 472 deletions(-) create mode 100644 resnet_quant_gpu.py create mode 100644 run_resnet20_partition.sh diff --git a/resnet.py b/resnet.py index c8b656d..0292226 100644 --- a/resnet.py +++ b/resnet.py @@ -117,8 +117,8 @@ def forward(self, x): return out -def resnet20(): - return ResNet(BasicBlock, [3, 3, 3]) +def resnet20(num_classes=10): + return ResNet(BasicBlock, [3, 3, 3], num_classes=num_classes) def resnet32(): diff --git a/resnet_quant.py b/resnet_quant.py index baa2914..ad9b2cf 100644 --- a/resnet_quant.py +++ b/resnet_quant.py @@ -124,6 +124,6 @@ def forward(self, x): out = self.dequant(out) return out -def resnet20(): - return ResNet(BasicBlock, [3, 3, 3]) +def resnet20(num_classes=10): + return ResNet(BasicBlock, [3, 3, 3], num_classes=num_classes) \ No newline at end of file diff --git a/resnet_quant_gpu.py b/resnet_quant_gpu.py new file mode 100644 index 0000000..973d246 --- /dev/null +++ b/resnet_quant_gpu.py @@ -0,0 +1,470 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import torch +from torch import Tensor +import torch.nn as nn +from torch.hub import load_state_dict_from_url +from typing import Type, Any, Callable, Union, List, Optional +from pytorch_quantization import quant_modules +from pytorch_quantization import nn as quant_nn + +__all__ = [ + 'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2' +] + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +def conv3x3(in_planes: int, + out_planes: int, + stride: int = 1, + groups: int = 1, + dilation: int = 1, + quantize: bool = False) -> nn.Conv2d: + """3x3 convolution with padding""" + if quantize: + return quant_nn.QuantConv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation) + else: + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1, quantize: bool = False) -> nn.Conv2d: + """1x1 convolution""" + if quantize: + return quant_nn.QuantConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + else: + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__(self, + inplanes: int, + planes: int, + stride: int = 1, + shortcut: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + quantize: bool = False) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.shortcut layers shortcut the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride, quantize=quantize) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes, quantize=quantize) + self.bn2 = norm_layer(planes) + self.shortcut = shortcut + self.stride = stride + self._quantize = quantize + if self._quantize: + self.residual_quantizer = quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input) + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.shortcut is not None: + identity = self.shortcut(x) + + if self._quantize: + out += self.residual_quantizer(identity) + else: + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__(self, + inplanes: int, + planes: int, + stride: int = 1, + shortcut: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + quantize: bool = False) -> None: + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.shortcut layers shortcut the input when stride != 1 + self.conv1 = conv1x1(inplanes, width, quantize=quantize) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation, quantize=quantize) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion, quantize=quantize) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.shortcut = shortcut + self.stride = stride + self._quantize = quantize + if self._quantize: + self.residual_quantizer = quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input) + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.shortcut is not None: + identity = self.shortcut(x) + + if self._quantize: + out += self.residual_quantizer(identity) + else: + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + quantize: bool = False, + num_classes: int = 200, + zero_init_residual: bool = False, + cifar: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None) -> None: + super(ResNet, self).__init__() + self._quantize = quantize + + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + + if cifar: + kernel = 3 + stride = 1 + padding = 1 + else: + kernel = 7 + stride = 2 + padding = 3 + + if quantize: + self.conv1 = quant_nn.QuantConv2d(3, + self.inplanes, + kernel_size=kernel, + stride=stride, + padding=padding, + bias=False) + else: + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=kernel, stride=stride, padding=padding, bias=False) + + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0], quantize=quantize) + self.layer2 = self._make_layer(block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0], + quantize=quantize) + self.layer3 = self._make_layer(block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1], + quantize=quantize) + self.layer4 = self._make_layer(block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2], + quantize=quantize) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + + if quantize: + self.linear = quant_nn.QuantLinear(512 * block.expansion, num_classes) + else: + self.linear = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer(self, + block: Type[Union[BasicBlock, Bottleneck]], + planes: int, + blocks: int, + stride: int = 1, + dilate: bool = False, + quantize: bool = False) -> nn.Sequential: + norm_layer = self._norm_layer + shortcut = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + shortcut = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride, quantize=quantize) + # norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block(self.inplanes, planes, stride, shortcut, self.groups, self.base_width, previous_dilation, + norm_layer, self._quantize)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block(self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + quantize=quantize)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x: Tensor) -> Tensor: + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.linear(x) + + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _resnet(arch: str, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], pretrained: bool, progress: bool, + quantize: bool, **kwargs: Any) -> ResNet: + model = ResNet(block, layers, quantize, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + + +def resnet18(pretrained: bool = False, progress: bool = True, quantize: bool = False, **kwargs: Any) -> ResNet: + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, quantize, **kwargs) + + +def resnet34(pretrained: bool = False, progress: bool = True, quantize: bool = False, **kwargs: Any) -> ResNet: + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, quantize, **kwargs) + + +def resnet50(pretrained: bool = False, progress: bool = True, quantize: bool = False, **kwargs: Any) -> ResNet: + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, quantize, **kwargs) + + +def resnet101(pretrained: bool = False, progress: bool = True, quantize: bool = False, **kwargs: Any) -> ResNet: + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, quantize, **kwargs) + + +def resnet152(pretrained: bool = False, progress: bool = True, quantize: bool = False, **kwargs: Any) -> ResNet: + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, quantize, **kwargs) + + +def resnext50_32x4d(pretrained: bool = False, progress: bool = True, quantize: bool = False, **kwargs: Any) -> ResNet: + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained, progress, quantize, **kwargs) + + +def resnext101_32x8d(pretrained: bool = False, progress: bool = True, quantize: bool = False, **kwargs: Any) -> ResNet: + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained, progress, quantize, **kwargs) + + +def wide_resnet50_2(pretrained: bool = False, progress: bool = True, quantize: bool = False, **kwargs: Any) -> ResNet: + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained, progress, quantize, **kwargs) + + +def wide_resnet101_2(pretrained: bool = False, progress: bool = True, quantize: bool = False, **kwargs: Any) -> ResNet: + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained, progress, quantize, **kwargs) \ No newline at end of file diff --git a/run_resnet20_partition.sh b/run_resnet20_partition.sh new file mode 100644 index 0000000..2ac1600 --- /dev/null +++ b/run_resnet20_partition.sh @@ -0,0 +1,33 @@ +# python train_resnet.py -s 0.1 -w -b 512 -g --smtk 0 --gpu 7 --start-subset 0 --subset_schedule cnt --partition & +# python train_resnet.py -s 0.2 -w -b 512 -g --smtk 0 --gpu 6 --start-subset 0 --subset_schedule cnt --partition & +# python train_resnet.py -s 0.3 -w -b 512 -g --smtk 0 --gpu 5 --start-subset 0 --subset_schedule cnt --partition & +# python train_resnet.py -s 0.4 -w -b 512 -g --smtk 0 --gpu 4 --start-subset 0 --subset_schedule cnt --partition & +# python train_resnet.py -s 0.5 -w -b 512 -g --smtk 0 --gpu 3 --start-subset 0 --subset_schedule cnt --partition & +# python train_resnet.py -s 0.5 -w -b 512 -g --smtk 0 --gpu 2 --start-subset 30 --subset_schedule step --partition & +# python train_resnet.py -s 0.1 -w -b 512 -g --smtk 0 --gpu 1 --start-subset 50 --subset_schedule cnt --partition & +# python train_resnet.py -s 0.02 -w -b 512 -g --smtk 0 --gpu 3 --start-subset 60 --subset_schedule cnt --partition & + +# python train_resnet.py -s 0.5 -w -b 512 -g --smtk 0 --gpu 6 --start-subset 30 --subset_schedule step & +# python train_resnet.py -s 0.1 -w -b 512 -g --smtk 0 --gpu 7 --start-subset 50 --subset_schedule cnt & +# python train_resnet.py -s 0.02 -w -b 512 -g --smtk 0 --gpu 3 --start-subset 60 --subset_schedule cnt & + +# python train_resnet.py -s 0.5 -w -b 512 -g --smtk 0 --gpu 0 --start-subset 30 --subset_schedule step --dataset cifar100 & +# python train_resnet.py -s 0.1 -w -b 512 -g --smtk 0 --gpu 1 --start-subset 50 --subset_schedule cnt --dataset cifar100 & +# python train_resnet.py -s 0.02 -w -b 512 -g --smtk 0 --gpu 2 --start-subset 60 --subset_schedule cnt --dataset cifar100 & + +# python train_resnet.py -s 0.05 -w -b 512 -g --smtk 0 --gpu 0 --start-subset 200 --subset_schedule reduce --dataset cifar100 & +# python train_resnet.py -s 0.02 -w -b 512 -g --smtk 0 --gpu 1 --start-subset 200 --subset_schedule reduce --dataset cifar100 & +# python train_resnet.py -s 0.01 -w -b 512 -g --smtk 0 --gpu 2 --start-subset 200 --subset_schedule reduce --dataset cifar100 & + +# python train_resnet.py -s 0.05 -w -b 512 -g --smtk 0 --gpu 0 --start-subset 200 --subset_schedule reduce --dataset cifar100 --lr_schedule reduce & +# python train_resnet.py -s 0.02 -w -b 512 -g --smtk 0 --gpu 1 --start-subset 200 --subset_schedule reduce --dataset cifar100 --lr_schedule reduce & +# python train_resnet.py -s 0.01 -w -b 512 -g --smtk 0 --gpu 2 --start-subset 200 --subset_schedule reduce --dataset cifar100 --lr_schedule reduce & + +python train_resnet.py -s 0.1 -w -b 512 --smtk 0 --gpu 7 --start-subset 0 --subset_schedule cnt --partition & +python train_resnet.py -s 0.2 -w -b 512 --smtk 0 --gpu 6 --start-subset 0 --subset_schedule cnt --partition & +python train_resnet.py -s 0.3 -w -b 512 --smtk 0 --gpu 5 --start-subset 0 --subset_schedule cnt --partition & +python train_resnet.py -s 0.4 -w -b 512 --smtk 0 --gpu 4 --start-subset 0 --subset_schedule cnt --partition & +python train_resnet.py -s 0.5 -w -b 512 --smtk 0 --gpu 3 --start-subset 0 --subset_schedule cnt --partition & +python train_resnet.py -s 0.5 -w -b 512 --smtk 0 --gpu 2 --start-subset 30 --subset_schedule step --partition & +python train_resnet.py -s 0.1 -w -b 512 --smtk 0 --gpu 1 --start-subset 50 --subset_schedule cnt --partition & +python train_resnet.py -s 0.02 -w -b 512 --smtk 0 --gpu 0 --start-subset 60 --subset_schedule cnt --partition & \ No newline at end of file diff --git a/train_resnet.py b/train_resnet.py index 8119fc8..e4c516c 100644 --- a/train_resnet.py +++ b/train_resnet.py @@ -2,6 +2,7 @@ import os import time import numpy as np +import math import torch import torch.nn as nn @@ -9,9 +10,7 @@ import torch.backends.cudnn as cudnn import torch.optim import torch.utils.data -import torchvision.transforms as transforms -import torchvision.datasets as datasets -# import resnet_icml as resnet +import matplotlib.pyplot as plt from torch.utils.data import Dataset, DataLoader import util @@ -20,15 +19,22 @@ from resnet import resnet20 as target_resnet20 from resnet_quant import resnet20 as quant_resnet20 +# from resnet_quant_gpu import resnet18, resnet50 + +import datetime +from torch.utils.tensorboard import SummaryWriter # ignore all future warnings simplefilter(action='ignore', category=FutureWarning) np.seterr(all='ignore') parser = argparse.ArgumentParser(description='Propert ResNets for CIFAR10 in pytorch') -parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet20', #'resnet56', # +parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', #'resnet56', # help='model architecture: ' + - ' (default: resnet32)') + ' (default: resnet18)') +parser.add_argument('--data_dir', default='~/data') +parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'cifar100', 'cinic10', 'svhn', 'tinyimagenet', 'imagenet'], + help='dataset: ' + ' (default: cifar10)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=200, type=int, metavar='N', @@ -41,7 +47,7 @@ metavar='LR', help='initial learning rate') parser.add_argument('--momentum', '-m', type=float, metavar='M', default=0.9, help='momentum') -parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, +parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('--print-freq', '-p', default=100, type=int, metavar='N', help='print frequency (default: 20)') @@ -55,19 +61,20 @@ help='use half-precision(16-bit) ') parser.add_argument('--save-dir', dest='save_dir', help='The directory used to save the trained models', - default='save_temp', type=str) + default='outputs', type=str) parser.add_argument('--save-every', dest='save_every', help='Saves checkpoints at every specified number of epochs', type=int, default=300) # default=10) -parser.add_argument('--gpu', default='7', type=str, help='The GPU to be used') +parser.add_argument('--gpu', default='0', type=str, help='The GPU to be used') parser.add_argument('--greedy', '-g', dest='greedy', action='store_true', default=False, help='greedy ordering') +parser.add_argument('--uniform_weight', action='store_true', default=False, help='no sample reweighting') parser.add_argument('--subset_size', '-s', dest='subset_size', type=float, help='size of the subset', default=1.0) parser.add_argument('--random_subset_size', '-rs', type=float, help='size of the subset', default=1.0) parser.add_argument('--st_grd', '-stg', type=float, help='stochastic greedy', default=0) parser.add_argument('--smtk', type=int, help='smtk', default=1) parser.add_argument('--ig', type=str, help='ig method', default='sgd', choices=['sgd, adam, adagrad']) parser.add_argument('--lr_schedule', '-lrs', type=str, help='learning rate schedule', default='mile', - choices=['mile', 'exp', 'cnt', 'step', 'cosine']) + choices=['mile', 'exp', 'cnt', 'step', 'cosine', 'reduce']) parser.add_argument('--gamma', type=float, default=-1, help='learning rate decay parameter') parser.add_argument('--lag', type=int, help='update lags', default=1) parser.add_argument('--runs', type=int, help='num runs', default=1) @@ -75,27 +82,58 @@ parser.add_argument('--cluster_features', '-cf', dest='cluster_features', action='store_true', help='cluster_features') parser.add_argument('--cluster_all', '-ca', dest='cluster_all', action='store_true', help='cluster_all') parser.add_argument('--start-subset', '-st', default=0, type=int, metavar='N', help='start subset selection') +parser.add_argument('--drop_learned', action='store_true', help='drop learned examples') +parser.add_argument('--watch_interval', default=5, type=int, help='decide whether an example is learned based on how many epochs') +parser.add_argument('--drop_interval', default=20, type=int, help='decide whether an example is learned based on how many epochs') +parser.add_argument('--drop_thresh', default=2, type=float, help='loss threshold') parser.add_argument('--save_subset', dest='save_subset', action='store_true', help='save_subset') +parser.add_argument('--save_stats', action='store_true', help='save forgetting scores and losses') +parser.add_argument('--partition', dest='partition', action='store_true', help='paritition the dataset by the number of mini-batches') +parser.add_argument('--subset_schedule', type=str, help='subset size schedule', default='cnt', + choices=['cnt', 'step', 'reduce']) -TRAIN_NUM = 50000 -CLASS_NUM = 10 - -print("hello") -def main(subset_size=.1, greedy=0): - print("hello") - global args, best_prec1 +def main(args, subset_size=.1, greedy=0): + global best_prec1 args = parser.parse_args() - # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu print(f'--------- subset_size: {subset_size}, method: {args.ig}, moment: {args.momentum}, ' f'lr_schedule: {args.lr_schedule}, greedy: {greedy}, stoch: {args.st_grd}, rs: {args.random_subset_size} ---------------') - print(args.lr_schedule) + + grd = 'grd_w' if args.greedy else f'rand_rsize_{args.random_subset_size}' + grd += f'_st_{args.st_grd}' if args.st_grd > 0 else '' + grd += f'_warm' if args.warm_start > 0 else '' + grd += f'_feature' if args.cluster_features else '' + grd += f'_ca' if args.cluster_all else '' + grd += f'_uniform' if args.uniform_weight else '' + grd += f'_partition' if args.partition else '' + grd += f'_dropbelow{args.drop_thresh}_every{args.drop_interval}epochs_watch{args.watch_interval}epochs' if args.drop_learned else '' + folder = f'./{args.save_dir}/{args.dataset}' + save_path = f'{folder}/{args.ig}_moment_{args.momentum}_{args.arch}_{args.subset_size}_{grd}_{args.lr_schedule}_start_{args.start_subset}_lag_{args.lag}_{args.subset_schedule}size' + today = datetime.datetime.now() + timestamp = today.strftime("%m-%d-%Y-%H:%M:%S") + args.save_dir = f'{save_path}_{timestamp}' + # Check the save_dir exists or not - if not os.path.exists(args.save_dir): - os.makedirs(args.save_dir) + os.makedirs(args.save_dir) + os.makedirs(os.path.join(args.save_dir, 'images')) + args.writer = SummaryWriter(args.save_dir) - model = target_resnet20() + if args.dataset == 'cifar100': + args.class_num = 100 + elif args.dataset == 'imagenet': + args.class_num = 1000 + elif args.dataset == 'tinyimagenet': + args.class_num = 200 + else: + args.class_num = 10 + + if args.arch == 'resnet20': + model = target_resnet20(num_classes=args.class_num) + elif args.arch == 'resnet50': + model = torch.nn.DataParallel(resnet50(num_classes=args.class_num, cifar=True)) + else: + model = resnet18(num_classes=args.class_num, cifar=True) device='cuda' model.to(device) @@ -113,62 +151,38 @@ def main(subset_size=.1, greedy=0): print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True - - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) - - train_loader__ = torch.utils.data.DataLoader( - datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([ - transforms.RandomHorizontalFlip(), - transforms.RandomCrop(32, 4), - transforms.ToTensor(), - normalize, - ]), download=True), - batch_size=args.batch_size, shuffle=True, - num_workers=args.workers, pin_memory=True) - class IndexedDataset(Dataset): - def __init__(self): - self.cifar10 = datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([ - transforms.RandomHorizontalFlip(), - transforms.RandomCrop(32, 4), - transforms.ToTensor(), - normalize, - ]), download=True) + def __init__(self, args): + self.dataset = util.get_dataset(args) def __getitem__(self, index): - data, target = self.cifar10[index] - # Your transformations here (or set it in CIFAR10) + data, target = self.dataset[index] return data, target, index def __len__(self): - return len(self.cifar10) + return len(self.dataset) - indexed_dataset = IndexedDataset() + indexed_dataset = IndexedDataset(args) indexed_loader = DataLoader( indexed_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader( - datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([ - transforms.ToTensor(), - normalize, - ])), + util.get_dataset(args, train=False), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) train_val_loader = torch.utils.data.DataLoader( - datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([ - transforms.ToTensor(), - normalize, - ])), + util.get_dataset(args, train=True, train_transform=False), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) train_criterion = nn.CrossEntropyLoss(reduction='none').cuda() # (Note) val_criterion = nn.CrossEntropyLoss().cuda() + args.train_num = len(indexed_dataset) + if args.half: model.half() train_criterion.half() @@ -185,7 +199,7 @@ def __len__(self): times_selected = np.zeros((runs, len(indexed_loader.dataset))) if args.save_subset: - B = int(args.subset_size * TRAIN_NUM) + B = int(args.subset_size * args.train_num) selected_ndx = np.zeros((runs, epochs, B)) selected_wgt = np.zeros((runs, epochs, B)) @@ -198,22 +212,34 @@ def __len__(self): print(f'lr schedule: {args.lr_schedule}, epochs: {args.epochs}') print(f'lr: {lr}, b: {b}') + order = np.arange(0, args.train_num) + targets = np.array(indexed_dataset.dataset.targets) for run in range(runs): best_prec1_all, best_loss_all, prec1 = 0, 1e10, 0 + forgets = np.zeros(args.train_num) + learned = np.zeros(args.train_num) + watch = np.zeros((args.watch_interval, args.train_num)) if subset_size < 1: # initialize a random subset - B = int(args.random_subset_size * TRAIN_NUM) - order = np.arange(0, TRAIN_NUM) + B = int(args.random_subset_size * args.train_num) + order = np.arange(0, args.train_num) np.random.shuffle(order) order = order[:B] - print(f'Random init subset size: {args.random_subset_size}% = {B}') + print(f'Random init subset size: {args.random_subset_size*100}% = {B}') - model=target_resnet20() + if args.arch == 'resnet20': + model = target_resnet20(num_classes=args.class_num) + elif args.arch == 'resnet50': + model = torch.nn.DataParallel(resnet50(num_classes=args.class_num, cifar=False)) + else: + if args.dataset == 'tinyimagenet': + model = resnet18(num_classes=args.class_num, cifar=False) + else: + model = resnet18(num_classes=args.class_num, cifar=True) model.cuda() - q_model = quant_resnet20() - q_model.to('cpu') + best_prec1, best_loss = 0, 1e10 if args.ig == 'adam': print('using adam') @@ -238,6 +264,8 @@ def __len__(self): elif args.lr_schedule == 'cosine': # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2) + elif args.lr_schedule == 'reduce': + lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True) else: # constant lr lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.epochs, gamma=1.0) @@ -260,101 +288,211 @@ def __len__(self): for epoch in range(args.start_epoch, args.epochs): + curr_lr = optimizer.param_groups[0]['lr'] + # train for one epoch - print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) + print('current lr {:.5e}'.format(curr_lr)) + + corrects = np.zeros(args.train_num) + losses = np.zeros(args.train_num) + + if args.drop_learned and (epoch > 0): + if (epoch % args.drop_interval == 0) and (len(order) > 1000): + order_ = np.where(np.sum(watch>args.drop_thresh, axis=0)>0)[0] + if len(order_) > 1000: + order = order_ + subset_size = 1 / args.watch_interval + elif epoch < args.start_subset: + subset_size = 1 + elif args.subset_schedule == 'step': + if epoch < 75: + subset_size = args.subset_size + elif epoch == 75: + subset_size = 0.1 + elif epoch == 100: + subset_size = 0.01 + else: + subset_size = args.subset_size + + B = int(subset_size * len(order)) + print(f'Training size at epoch {epoch}: {subset_size*100}% = {B}') + + if args.partition and (subset_size < 1) and (epoch >= args.start_subset): + # random partition the dataset + partition = int(math.ceil(B / args.batch_size)) + B = min(args.batch_size, int(subset_size * len(order))) + else: + partition = 1 ############################# weight = None - if subset_size >= 1 or epoch < args.start_subset: - print('Training on all the data') - train_loader = indexed_loader - - elif subset_size < 1 and \ - (epoch % (args.lag + args.start_subset) == 0 or epoch == args.start_subset): - B = int(subset_size * TRAIN_NUM) - if greedy == 0: - # order = np.arange(0, TRAIN_NUM) - np.random.shuffle(order) - subset = order[:B] - weights = np.zeros(len(indexed_loader.dataset)) - weights[subset] = np.ones(B) - print(f'Selecting {B} element from the pre-selected random subset of size: {len(subset)}') - else: # Note: warm start - if args.cluster_features: - print(f'Selecting {B} elements greedily from features') - data = datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([ - transforms.RandomHorizontalFlip(), - transforms.RandomCrop(32, 4), - transforms.ToTensor(), - normalize, - ]), download=True) - preds, labels = np.reshape(data.data, (len(data.targets), -1)), data.targets - else: - print(f'Selecting {B} elements greedily from predictions') - torch.save(model.state_dict(), 'cifar10_target.pt') - print('Size (MB):', os.path.getsize("cifar10_target.pt")/1e6) - loaded_dict_enc = torch.load('cifar10_target.pt', map_location='cpu') - q_model = quant_resnet20() - q_model.to('cpu') - q_model.load_state_dict(loaded_dict_enc) + for i in range(partition): + print(f'Training on partition {i+1}/{partition}') + + if subset_size >= 1 or epoch < args.start_subset: + print('Training on all the data') + train_loader = indexed_loader + times_selected[run][order] += 1 + + if args.save_stats or args.drop_learned: + preds, labels = predictions(args, indexed_loader, model) + corrects = np.equal(np.argmax(preds, axis=1), labels) + losses = train_criterion(torch.from_numpy(preds), torch.from_numpy(labels).long()).numpy() + else: + if (epoch % args.lag == 0): + q_model_path = os.path.join(args.save_dir, f'{args.dataset}_target.pt') + if args.arch == 'resnet50': + torch.save(model.module.state_dict(), q_model_path) + else: + torch.save(model.state_dict(), q_model_path) + print('Size (MB):', os.path.getsize(q_model_path)/1e6) + loaded_dict_enc = torch.load(q_model_path, map_location='cpu') + if args.arch == 'resnet20': + q_model = quant_resnet20(num_classes=args.class_num) + q_model.load_state_dict(loaded_dict_enc) + q_model.to('cpu') + q_model.qconfig = torch.quantization.get_default_qconfig('fbgemm') + torch.quantization.prepare(q_model, inplace=True) + q_model.eval() + torch.quantization.convert(q_model, inplace=True) + else: + if args.arch == 'resnet50': + q_model = resnet50(num_classes=args.class_num, cifar=False, quantize=True) + else: + if args.dataset == 'tinyimagenet': + q_model = resnet18(num_classes=args.class_num, cifar=False, quantize=True) + else: + q_model = resnet18(num_classes=args.class_num, cifar=True, quantize=True) + q_model.load_state_dict(loaded_dict_enc) + q_model.cuda() + q_model.eval() print("loaded state dict") - q_model.qconfig = torch.quantization.get_default_qconfig('fbgemm') - torch.quantization.prepare(q_model, inplace=True) - q_model.eval() - torch.quantization.convert(q_model, inplace=True) - torch.save(q_model.state_dict(), 'cifar10_target.pt') - print('Size (MB):', os.path.getsize("cifar10_target.pt")/1e6) - preds, labels = quantization_predictions(indexed_loader, q_model) - preds -= np.eye(CLASS_NUM)[labels] - if epoch<=60: - B = 50000 - # elif 302*int(1./subset_size))] = 2*int(1./subset_size) + fig = plt.figure() + plt.hist(plt_weights, bins=np.arange(np.amax(plt_weights)), edgecolor='black') + args.writer.add_figure('cluster_weights', fig, epoch) + plt.savefig(os.path.join(args.save_dir, f'images/weights_epoch{epoch}.png')) + plt.close() + + weights = np.zeros(args.train_num) + subset_weight = subset_weight / np.sum(subset_weight) * len(subset_weight) + if args.save_subset: + selected_ndx[run, epoch], selected_wgt[run, epoch] = subset, subset_weight + + weights[subset] = subset_weight + weight = torch.from_numpy(weights).float().cuda() + + print(f'FL time: {ordering_time:.3f}, Sim time: {similarity_time:.3f}') + grd_time[run, epoch], sim_time[run, epoch] = ordering_time, similarity_time + + times_selected[run][subset] += 1 + print(f'{np.sum(times_selected[run] == 0) / len(times_selected[run]) * 100:.3f} % not selected yet') + not_selected[run, epoch] = np.sum(times_selected[run] == 0) / len(times_selected[run]) * 100 + indexed_subset = torch.utils.data.Subset(indexed_dataset, indices=subset) + if args.partition: + train_loader = DataLoader( + indexed_subset, + batch_size=len(subset), shuffle=True, + num_workers=args.workers, pin_memory=True) + else: + train_loader = DataLoader( + indexed_subset, + batch_size=args.batch_size, shuffle=True, + num_workers=args.workers, pin_memory=True) else: - B = 1000 - print(B) - fl_labels = np.zeros(np.shape(labels), dtype=int) if args.cluster_all else labels - subset, subset_weight, _, _, ordering_time, similarity_time = util.get_orders_and_weights( - B, preds, 'euclidean', smtk=args.smtk, no=0, y=fl_labels, stoch_greedy=args.st_grd, - equal_num=True) - - weights = np.zeros(len(indexed_loader.dataset)) - weights[subset] = np.ones(len(subset)) - subset_weight = subset_weight / np.sum(subset_weight) * len(subset_weight) - if args.save_subset: - selected_ndx[run, epoch], selected_wgt[run, epoch] = subset, subset_weight - - weights[subset] = subset_weight - weight = torch.from_numpy(weights).float().cuda() - weight = torch.tensor(weights).cuda() - np.random.shuffle(subset) - print(f'FL time: {ordering_time:.3f}, Sim time: {similarity_time:.3f}') - grd_time[run, epoch], sim_time[run, epoch] = ordering_time, similarity_time - - times_selected[run][subset] += 1 - print(f'{np.sum(times_selected[run] == 0) / len(times_selected[run]) * 100:.3f} % not selected yet') - not_selected[run, epoch] = np.sum(times_selected[run] == 0) / len(times_selected[run]) * 100 - indexed_subset = torch.utils.data.Subset(indexed_dataset, indices=subset) - train_loader = DataLoader( - indexed_subset, - batch_size=args.batch_size, shuffle=True, - num_workers=args.workers, pin_memory=True) - else: - print('Using the previous subset') - not_selected[run, epoch] = not_selected[run, epoch - 1] - print(f'{not_selected[run, epoch]:.3f} % not selected yet') - ############################# + print('Using the previous subset') + not_selected[run, epoch] = not_selected[run, epoch - 1] + times_selected[run][subset] += 1 + print(f'{not_selected[run, epoch]:.3f} % not selected yet') + ############################# + + prec1, loss, data_time_batch, train_time_batch = train( + train_loader, model, epoch, train_criterion, optimizer, weight) - data_time[run, epoch], train_time[run, epoch] = train( - train_loader, model, train_criterion, optimizer, epoch, weight) + data_time[run, epoch] += data_time_batch + train_time[run, epoch] += train_time_batch - lr_scheduler_f.step() + args.writer.add_scalar('train/3.train_size', int(len(order)*subset_size), epoch) + args.writer.add_scalar('train/4.train_frac', np.sum(times_selected[run])/args.train_num/(epoch+1), epoch) + + # evaluate on validation set + prec1, loss = validate(train_val_loader, model, val_criterion) + args.writer.add_scalar('train/1.train_loss', loss, epoch) + args.writer.add_scalar('train/2.train_acc', prec1, epoch) # evaluate on validation set prec1, loss = validate(val_loader, model, val_criterion) + if args.lr_schedule == 'reduce': + lr_scheduler_f.step(loss) + else: + lr_scheduler_f.step() + + args.writer.add_scalar('val/1.val_loss', loss, epoch) + args.writer.add_scalar('val/2.val_acc', prec1, epoch) + # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 # best_run = run if is_best else best_run @@ -364,12 +502,38 @@ def __len__(self): best_prec1_all = best_prec1 test_acc[run, epoch], test_loss[run, epoch] = prec1, loss + args.writer.add_scalar('test/1.test_loss', loss, epoch) + args.writer.add_scalar('test/2.test_acc', prec1, epoch) + ta, tl = validate(train_val_loader, model, val_criterion) # best_run_loss = run if tl < best_loss else best_run_loss best_loss = min(tl, best_loss) best_loss_all = min(best_loss_all, best_loss) train_acc[run, epoch], train_loss[run, epoch] = ta, tl + if args.save_stats or args.drop_learned: + watch[epoch%args.watch_interval] = losses + if epoch > 0: + forgets[learned > corrects] += 1 + learned = corrects + + if (((epoch + 1) % 5) == 0) and args.save_stats: + np.save(file=os.path.join(args.save_dir, f'forget_epoch{epoch}.npy'), arr=forgets) + fig = plt.figure() + plt.hist(forgets, bins=np.arange(np.amax(forgets)+1), edgecolor='black') + args.writer.add_figure('forgetting_scores', fig, epoch) + plt.hist(forgets, bins=np.arange(np.amax(forgets)), edgecolor='black') + plt.savefig(os.path.join(args.save_dir, f'images/forgetting_scores_epoch{epoch}.png')) + plt.close() + + np.save(file=os.path.join(args.save_dir, f'loss_epoch{epoch}.npy'), arr=losses) + fig = plt.figure() + plt.hist(losses, edgecolor='black') + args.writer.add_figure('example_losses', fig, epoch) + plt.hist(losses, edgecolor='black') + plt.savefig(os.path.join(args.save_dir, f'images/example_losses_epoch{epoch}.png')) + plt.close() + if epoch > 0 and epoch % args.save_every == 0: save_checkpoint({ 'epoch': epoch + 1, @@ -389,31 +553,22 @@ def __len__(self): f'best_g: {best_gs[run]:.3f}, best_b: {best_bs[run]:.3f}, ' f'not selected %:{not_selected[run][epoch]}') - grd = 'grd_w' if args.greedy else f'rand_rsize_{args.random_subset_size}' - grd += f'_st_{args.st_grd}' if args.st_grd > 0 else '' - grd += f'_warm' if args.warm_start > 0 else '' - grd += f'_feature' if args.cluster_features else '' - grd += f'_ca' if args.cluster_all else '' - folder = f'/home/nehaprakriya/quant/resnet20/' + save_path = f'{args.save_dir}/results' if args.save_subset: print( - f'Saving the results to {folder}_{args.ig}_moment_{args.momentum}_{args.arch}_{subset_size}' - f'_{grd}_{args.lr_schedule}_start_{args.start_subset}_lag_{args.lag}_subset') + f'Saving the results to {save_path}_subset') - np.savez(f'{folder}_{args.ig}_moment_{args.momentum}_{args.arch}_{subset_size}' - f'_{grd}_{args.lr_schedule}_start_{args.start_subset}_lag_{args.lag}_subset', + np.savez(f'{save_path}_subset', train_loss=train_loss, test_acc=test_acc, train_acc=train_acc, test_loss=test_loss, data_time=data_time, train_time=train_time, grd_time=grd_time, sim_time=sim_time, best_g=best_gs, best_b=best_bs, not_selected=not_selected, times_selected=times_selected, subset=selected_ndx, weights=selected_wgt) else: print( - f'Saving the results to {folder}_{args.ig}_moment_{args.momentum}_{args.arch}_{subset_size}' - f'_{grd}_{args.lr_schedule}_start_{args.start_subset}_lag_{args.lag}') + f'Saving the results to {save_path}') - np.savez(f'{folder}_{args.ig}_moment_{args.momentum}_{args.arch}_{subset_size}' - f'_{grd}_{args.lr_schedule}_start_{args.start_subset}_lag_{args.lag}', + np.savez(save_path, train_loss=train_loss, test_acc=test_acc, train_acc=train_acc, test_loss=test_loss, data_time=data_time, train_time=train_time, grd_time=grd_time, sim_time=sim_time, best_g=best_gs, best_b=best_bs, not_selected=not_selected, @@ -425,12 +580,12 @@ def __len__(self): -def train(train_loader, model, criterion, optimizer, epoch, weight=None): +def train(train_loader, model, epoch, criterion, optimizer, weight=None): """ Run one train epoch """ if weight is None: - weight = torch.ones(TRAIN_NUM).cuda() + weight = torch.ones(len(train_loader.dataset)).cuda() batch_time = AverageMeter() data_time = AverageMeter() @@ -455,7 +610,9 @@ def train(train_loader, model, criterion, optimizer, epoch, weight=None): # compute output output = model(input_var) loss = criterion(output, target_var) - loss = (loss).mean() # (Note) + # print(weight[idx.long()]) + # loss = loss * weight[idx.long()] + loss = loss.mean() # (Note) # compute gradient and do SGD step optimizer.zero_grad() @@ -474,15 +631,15 @@ def train(train_loader, model, criterion, optimizer, epoch, weight=None): batch_time.update(time.time() - end) end = time.time() - # if i % args.print_freq == 0: - # print('Epoch: [{0}][{1}/{2}]\t' - # 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - # 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' - # 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' - # 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( - # epoch, i, len(train_loader), batch_time=batch_time, - # data_time=data_time, loss=losses, top1=top1)) - return data_time.sum, batch_time.sum + if i % args.print_freq == 0: + print('Epoch: [{0}][{1}/{2}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( + epoch, i, len(train_loader), batch_time=batch_time, + data_time=data_time, loss=losses, top1=top1)) + return top1.avg, losses.avg, data_time.sum, batch_time.sum def validate(val_loader, model, criterion): @@ -559,22 +716,8 @@ def update(self, val, n=1): self.count += n self.avg = self.sum / self.count -# add a function for quant predictions -def quant_predictions(loader, model): - model.eval() - preds=numpy.zeros(TRAIN_NUM, CLASS_NUM) - labels=numpy.zeros(TRAIN_NUM, dtype=torch.int) - with torch.no_grad(): - for i, (input, target, idx) in enumerate(loader): - output = model(input) - preds[idx, :] = nn.Softmax(dim=1)(output) - return preds - - - - -def predictions(loader, model): +def predictions(args, loader, model): """ Get predictions """ @@ -583,8 +726,8 @@ def predictions(loader, model): # switch to evaluate mode model.eval() - preds = torch.zeros(TRAIN_NUM, CLASS_NUM).cuda() - labels = torch.zeros(TRAIN_NUM, dtype=torch.int) + preds = torch.zeros(args.train_num, args.class_num).cuda() + labels = torch.zeros(args.train_num, dtype=torch.int) end = time.time() with torch.no_grad(): for i, (input, target, idx) in enumerate(loader): @@ -600,18 +743,18 @@ def predictions(loader, model): batch_time.update(time.time() - end) end = time.time() - # if i % args.print_freq == 0: - # print('Predict: [{0}/{1}]\t' - # 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})' - # .format(i, len(loader), batch_time=batch_time)) + if i % args.print_freq == 0: + print('Predict: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})' + .format(i, len(loader), batch_time=batch_time)) return preds.cpu().data.numpy(), labels.cpu().data.numpy() -def quantization_predictions(loader, model): +def quantization_predictions(args, loader, model): model.to('cpu') model.eval() - preds = np.zeros((TRAIN_NUM, CLASS_NUM)) - labels = np.zeros(TRAIN_NUM) + preds = np.zeros((args.train_num, args.class_num)) + labels = np.zeros(args.train_num) labels=labels.astype('int32') for i, (input, target, idx) in enumerate(loader): preds[idx, :] = nn.Softmax(dim=1)(model(input)) @@ -637,5 +780,6 @@ def accuracy(output, target, topk=(1,)): if __name__ == '__main__': args = parser.parse_args() - main(subset_size=args.subset_size, greedy=args.greedy) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu + main(args, subset_size=args.subset_size, greedy=args.greedy) diff --git a/util.py b/util.py index c97df20..83d7a05 100644 --- a/util.py +++ b/util.py @@ -1,30 +1,13 @@ -import itertools import os import subprocess import time import gc - -from nearpy import Engine -from nearpy.distances import EuclideanDistance -from nearpy.filters import NearestFilter -from nearpy.hashes import RandomBinaryProjections -import matplotlib.pyplot as plt import numpy as np from lazy_greedy import FacilityLocation, lazy_greedy_heap -import scipy.spatial -# from eucl_dist.cpu_dist import dist -# from eucl_dist.gpu_dist import dist as gdist - - -from multiprocessing.dummy import Pool as ThreadPool -from itertools import repeat -import sklearn - -# from lazy_greedy import FacilityLocation, lazy_greedy, lazy_greedy_heap -# from set_cover import SetCover - -from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets -from tensorflow.examples.tutorials.mnist import input_data +from sklearn.metrics import pairwise_distances +from submodlib.functions.facilityLocation import FacilityLocationFunction +import torchvision.transforms as transforms +import torchvision SEED = 100 EPS = 1E-8 @@ -63,15 +46,6 @@ def load_dataset(dataset, dataset_dir): # feautres path = os.path.join('grad_features.npy') X = np.load(path) # shape [50000, 1000], type float16 - elif dataset == 'mnist': - mnist = input_data.read_data_sets('/tmp') - X_train = np.vstack([mnist.train.images, mnist.validation.images]) - y_train = np.hstack([mnist.train.labels, mnist.validation.labels]) - X_test = mnist.test.images - y_test = mnist.test.labels - X_train = X_train.astype(np.float32) / 255 - X_test = X_test.astype(np.float32) / 255 - return X_train, y_train, X_test, y_test else: num, dim, name = 0, 0, '' @@ -117,7 +91,7 @@ def similarity(X, metric): ''' # print(f'Computing similarity for {metric}...', flush=True) start = time.time() - dists = sklearn.metrics.pairwise_distances(X, metric=metric, n_jobs=1) + dists = pairwise_distances(X, metric=metric, n_jobs=1) # dists = gdist(X, X, optimize_level=0, output='cpu') elapsed = time.time() - start @@ -132,103 +106,6 @@ def similarity(X, metric): return S, elapsed -def greedy_merge(X, y, B, part_num, metric, smtk=0, stoch_greedy=False): - N = len(X) - indices = list(range(N)) - # np.random.shuffle(indices) - part_size = int(np.ceil(N / part_num)) - part_indices = [indices[slice(i * part_size, min((i + 1) * part_size, N))] for i in range(part_num)] - print(f'GreeDi with {part_num} parts, finding {B} elements...', flush=True) - - # pool = ThreadPool(part_num) - # order_mg_all, cluster_sizes_all, _, _, ordering_time, similarity_time = zip(*pool.map( - # lambda p: get_orders_and_weights( - # int(B / 2), X[part_indices[p], :], metric, p + 1, stoch_greedy, y[part_indices[p]]), np.arange(part_num))) - # pool.terminate() - - order_mg_all, cluster_sizes_all, _, _, ordering_time, similarity_time, F_val = zip(*map( - lambda p: get_orders_and_weights( - int(B / 2), X[part_indices[p], :], metric, p + 1, stoch_greedy, y[part_indices[p]]), np.arange(part_num))) - - # Returns the number of objects it has collected and deallocated - collected = gc.collect() - print(f'Garbage collector: collected {collected}') - - # order_mg_all = np.zeros((part_num, B)) - # cluster_sizes_all = np.zeros((part_num, B)) - # ordering_time = np.zeros(part_num) - # similarity_time = np.zeros(part_num) - # for p in range(part_num): - # order_mg_all[p,:], cluster_sizes_all[p,:], _, _, ordering_time[p], similarity_time[p] = get_orders_and_weights( - # B, X[part_indices[p], :], metric, p, stoch_greedy, y[part_indices[p]]) - order_mg_all = np.array(order_mg_all, dtype=np.int32) - cluster_sizes_all = np.array(cluster_sizes_all, dtype=np.float32) # / part_num (not needed) - order_mg = order_mg_all.flatten(order='F') - weights_mg = cluster_sizes_all.flatten(order='F') - print(f'GreeDi stage 1: found {len(order_mg)} elements in: {np.max(ordering_time)} sec', flush=True) - - # order_mg, weights_mg, order_sz, weights_sz, ordering_time, similarity_time - order, weights, order_sz, weights_sz, ordering_time_merge, similarity_time_merge = get_orders_and_weights( - B, X[order_mg, :], metric, smtk, 0, stoch_greedy, y[order_mg], weights_mg) - print(weights) - total_ordering_time = np.max(ordering_time) + ordering_time_merge - total_similarity_time = np.max(similarity_time) + similarity_time_merge - print(f'GreeDi stage 2: found {len(order)} elements in: {total_ordering_time + total_similarity_time} sec', - flush=True) - vals = order, weights, order_sz, weights_sz, total_ordering_time, total_similarity_time - return vals - - -def greedi(X, y, B, part_num, metric, smtk=0, stoch_greedy=False, seed=-1): - N = len(X) - indices = list(range(N)) - if seed != -1: - np.random.seed(seed) - np.random.shuffle(indices) # Note: random shuffling - part_size = int(np.ceil(N / part_num)) - part_indices = [indices[slice(i * part_size, min((i + 1) * part_size, N))] for i in range(part_num)] - print(f'GreeDi with {part_num} parts, finding {B} elements...', flush=True) - - # pool = ThreadPool(part_num) - # order_mg_all, cluster_sizes_all, _, _, ordering_time, similarity_time = zip(*pool.map( - # lambda p: get_orders_and_weights( - # B, X[part_indices[p], :], metric, p + 1, stoch_greedy, y[part_indices[p]]), np.arange(part_num))) - # pool.terminate() - # Returns the number of objects it has collected and deallocated - # collected = gc.collect() - # print(f'Garbage collector: collected {collected}') - order_mg_all, cluster_sizes_all, _, _, ordering_time, similarity_time = zip(*map( - lambda p: get_orders_and_weights( - B, X[part_indices[p], :], metric, p + 1, stoch_greedy, y[part_indices[p]]), np.arange(part_num))) - gc.collect() - - order_mg_all = np.array(order_mg_all, dtype=np.int32) - for c in np.arange(part_num): - order_mg_all[c] = np.array(part_indices[c])[order_mg_all[c]] - # order_mg_all = np.zeros((part_num, B)) - # cluster_sizes_all = np.zeros((part_num, B)) - # ordering_time = np.zeros(part_num) - # similarity_time = np.zeros(part_num) - # for p in range(part_num): - # order_mg_all[p,:], cluster_sizes_all[p,:], _, _, ordering_time[p], similarity_time[p] = get_orders_and_weights( - # B, X[part_indices[p], :], metric, p, stoch_greedy, y[part_indices[p]]) - cluster_sizes_all = np.array(cluster_sizes_all, dtype=np.float32) # / part_num (not needed) - order_mg = order_mg_all.flatten(order='F') - weights_mg = cluster_sizes_all.flatten(order='F') - print(f'GreeDi stage 1: found {len(order_mg)} elements in: {np.max(ordering_time)} sec', flush=True) - - # order_mg, weights_mg, order_sz, weights_sz, ordering_time, similarity_time - order, weights, order_sz, weights_sz, ordering_time_merge, similarity_time_merge = get_orders_and_weights( - B, X[order_mg,:], metric, smtk, 0, stoch_greedy, y[order_mg], weights_mg) - print(weights) - order = order_mg[order] - total_ordering_time = np.max(ordering_time) + ordering_time_merge - total_similarity_time = np.max(similarity_time) + similarity_time_merge - print(f'GreeDi stage 2: found {len(order)} elements in: {total_ordering_time + total_similarity_time} sec', flush=True) - vals = order, weights, order_sz, weights_sz, total_ordering_time, total_similarity_time - return vals - - def get_facility_location_submodular_order(S, B, c, smtk=0, no=0, stoch_greedy=0, weights=None): ''' Args @@ -269,163 +146,66 @@ def get_facility_location_submodular_order(S, B, c, smtk=0, no=0, stoch_greedy=0 greedy_time = time.time() - start F_val = 0 - order = np.asarray(order, dtype=np.int64) - sz = np.zeros(B, dtype=np.float64) - for i in range(N): - if weights is None: - sz[np.argmax(S[i, order])] += 1 - else: - sz[np.argmax(S[i, order])] += weights[i] + order = np.asarray(order, dtype=np.int64) + sz = np.zeros(B, dtype=np.float64) + for i in range(N): + if weights is None: + sz[np.argmax(S[i, order])] += 1 + else: + sz[np.argmax(S[i, order])] += weights[i] # print('time (sec) for computing facility location:', greedy_time, flush=True) collected = gc.collect() return order, sz, greedy_time, F_val -def faciliy_location_order(c, X, y, metric, num_per_class, smtk, no, stoch_greedy, weights=None): +# def faciliy_location_order(c, X, y, metric, num_per_class, smtk, no, stoch_greedy, weights=None): +# class_indices = np.where(y == c)[0] +# # print(class_indices) +# # print(X) +# print(f'Selecting from {len(class_indices)} examples in class {c}') +# S, S_time = similarity(X[class_indices], metric=metric) +# order, cluster_sz, greedy_time, F_val = get_facility_location_submodular_order( +# S, num_per_class, c, smtk, no, stoch_greedy, weights) +# return class_indices[order], cluster_sz, greedy_time, S_time + +def faciliy_location_order(c, X, y, metric, num_per_class, smtk, no, stoch_greedy, weights=None, mode='dense', num_n=128): class_indices = np.where(y == c)[0] - print(c) - print(class_indices) - print(len(class_indices)) - S, S_time = similarity(X[class_indices], metric=metric) - order, cluster_sz, greedy_time, F_val = get_facility_location_submodular_order( - S, num_per_class, c, smtk, no, stoch_greedy, weights) - return class_indices[order], cluster_sz, greedy_time, S_time + X = X[class_indices] + N = X.shape[0] + if mode == 'dense': + num_n = None -def save_all_orders_and_weights(folder, X, metric='l2', stoch_greedy=False, y=None, equal_num=False, outdir='.'): - N = X.shape[0] - if y is None: - y = np.zeros(N, dtype=np.int32) # assign every point to the same class - classes = np.unique(y) - C = len(classes) # number of classes - # assert np.array_equal(classes, np.arange(C)) - # assert B % C == 0 - class_nums = [sum(y == c) for c in classes] - print(class_nums) - class_indices = [np.where(y == c)[0] for c in classes] - - tmp_path = '/tmp' - no, smtk = 2, 2 - - def greedy(B, c): - print('Computing facility location submodular order...') - print(f'Calculating ordering with SMTK... part size: {class_nums[c]}, B: {B}', flush=True) - command = f'/tmp/{no}/smtk-master{smtk}/build/smraiz -sumsize {B} \ - -flnpy {tmp_path}/{no}/{smtk}-{c}.npy -pnpv -porder -ptime' - if stoch_greedy: - command += f' -stochastic-greedy -sg-epsilon {.9}' - - p = subprocess.check_output(command.split()) - s = p.decode("utf-8") - str, end = ['([', ',])'] - order = s[s.find(str) + len(str):s.rfind(end)].split(',') - order = np.asarray(order, dtype=np.int64) - greedy_time = float(s[s.find('CPU') + 4: s.find('s (User')]) - print(f'FL greedy time: {greedy_time}', flush=True) - str = 'f(Solution) = ' - F_val = float(s[s.find(str) + len(str) : s.find('Summary Solution') - 1]) - print(f'===========> f(Solution) = {F_val}') - print('time (sec) for computing facility location:', greedy_time, flush=True) - return order, greedy_time, F_val - - def get_subset_sizes(B, equal_num): - if equal_num: - # class_nums = [sum(y == c) for c in classes] - num_per_class = int(np.ceil(B / C)) * np.ones(len(classes), dtype=np.int32) - minority = class_nums < np.ceil(B / C) - if sum(minority) > 0: - extra = sum([max(0, np.ceil(B / C) - class_nums[c]) for c in classes]) - for c in classes[~minority]: - num_per_class[c] += int(np.ceil(extra / sum(minority))) - else: - num_per_class = np.int32(np.ceil(np.divide([sum(y == i) for i in classes], N) * B)) + start = time.time() + obj = FacilityLocationFunction(n=len(X), mode=mode, data=X, metric=metric, num_neighbors=num_n) + S_time = time.time() - start - return num_per_class + start = time.time() + greedyList = obj.maximize( + budget=num_per_class, + optimizer="LazyGreedy", + stopIfZeroGain=False, + stopIfNegativeGain=False, + verbose=False, + ) + order = list(map(lambda x: x[0], greedyList)) + sz = list(map(lambda x: x[1], greedyList)) + greedy_time = time.time() - start + + S = obj.sijs + order = np.asarray(order, dtype=np.int64) + sz = np.zeros(num_per_class, dtype=np.float64) - def merge_orders(order_mg_all, weights_mg_all, equal_num): - order_mg, weights_mg = [], [] - if equal_num: - props = np.rint([len(order_mg_all[i]) for i in range(len(order_mg_all))]) - else: - # merging imbalanced classes - class_ratios = np.divide([np.sum(y == i) for i in classes], N) - props = np.rint(class_ratios / np.min(class_ratios)) - print(f'Selecting with ratios {np.array(class_ratios)}') - print(f'Class proportions {np.array(props)}') - - order_mg_all = np.array(order_mg_all) - weights_mg_all = np.array(weights_mg_all) - for i in range(int(np.rint(np.max([len(order_mg_all[c]) / props[c] for c in classes])))): - for c in classes: - ndx = slice(i * int(props[c]), int(min(len(order_mg_all[c]), (i + 1) * props[c]))) - order_mg = np.append(order_mg, order_mg_all[c][ndx]) - weights_mg = np.append(weights_mg, weights_mg_all[c][ndx]) - order_mg = np.array(order_mg, dtype=np.int32) - weights_mg = np.array(weights_mg, dtype=np.float) - return order_mg, weights_mg - - def calculate_weights(order, c): - weight = np.zeros(len(order), dtype=np.float64) - center = np.argmax(D[str(c)][:, order], axis=1) - for i in range(len(order)): - weight[i] = np.sum(center == i) - return weight - - D, m = {}, 0 - similarity_times, max_similarity = [], [] - for c in classes: - print(f'Computing distances for class {c}...') - time.sleep(.1) - start = time.time() - if metric in ['', 'l2', 'l1']: - dists = sklearn.metrics.pairwise_distances(X[class_indices[c]], metric=metric, n_jobs=1) + for i in range(N): + if np.max(S[i, order]) <= 0: + continue + if weights is None: + sz[np.argmax(S[i, order])] += 1 else: - p = float(metric) - dim = class_nums[c] - dists = np.zeros((dim, dim)) - for i in range(dim): - dists[i,:] = np.power(np.sum(np.power(np.abs(X[class_indices[c][i]] - X[class_indices[c]]), p), axis=1), 1./p) - # for j in range(i+1, dim): - # dists[i,j] = np.power(np.sum(np.power(np.abs(X[class_indices[c][i]] - X[class_indices[c][j]]), p)), 1./p) - # dists[np.triu_indices(dim, 1)] = d - # dists = dists.T + dists - similarity_times.append(time.time() - start) - print(f'similarity times: {similarity_times}') - print('Computing max') - m = np.max(dists) - print(f'max: {m}') - S = m - dists - np.save(f'{tmp_path}/{no}/{smtk}-{c}', S) - D[str(c)] = S - max_similarity.append(m) - - # Ordering all the data with greedy - print(f'Greedy: selecting {class_nums} elements') - # order_in_class, greedy_times, F_vals = zip(*map(lambda c: greedy(class_nums[c], c), classes)) - # order_all = [class_indices[c][order_in_class[c]] for c in classes] - - for subset_size in [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]: - # for subset_size in [0.9, 1.0]: - B = int(N * subset_size) - num_per_class = get_subset_sizes(B, equal_num) - - # Note: for marginal gains - order_in_class, greedy_times, F_vals = zip(*map(lambda c: greedy(num_per_class[c], c), classes)) - order_all = [class_indices[c][order_in_class[c]] for c in classes] - ##### - - weights = [calculate_weights(order_in_class[c][:num_per_class[c]], c) for c in classes] - order_subset = [order_all[c][:num_per_class[c]] for c in classes] - order_merge, weights_merge = merge_orders(order_subset, weights, equal_num) - F_vals = np.divide(F_vals, class_nums) - - folder = '/tmp/covtype' - print(f'saving to {folder}_{subset_size}_{metric}_w.npz') - np.savez(f'{folder}_{subset_size}_{metric}_w', order=order_merge, weight=weights_merge, - order_time=greedy_times, similarity_time=similarity_times, F_vals=F_vals, max_dist=m) - # end for on subset sizes - # return vals + sz[np.argmax(S[i, order])] += weights[i] + sz[np.where(sz==0)] = 1 + return class_indices[order], sz, greedy_time, S_time def get_orders_and_weights(B, X, metric, smtk, no=0, stoch_greedy=0, y=None, weights=None, equal_num=False, outdir='.'): ''' @@ -526,3 +306,122 @@ def get_orders_and_weights(B, X, metric, smtk, no=0, stoch_greedy=0, y=None, wei weights_sz = [] # cluster_sizes_all[rows_selector, cluster_order].flatten(order='F') vals = order_mg, weights_mg, order_sz, weights_sz, ordering_time, similarity_time return vals + +def get_dataset(args, train=True, train_transform=True): + if args.dataset in ['cifar10', 'cifar100', 'mnist', 'svhn']: + if args.dataset == 'cifar10': + mean = (0.4914, 0.4822, 0.4465) + std = (0.2023, 0.1994, 0.2010) + elif args.dataset == 'cifar100': + mean = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343) + std = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404) + elif args.dataset == 'svhn': + mean = (0.4376821, 0.4437697, 0.47280442) + std = (0.19803012, 0.20101562, 0.19703614) + elif args.datast == 'mnist': + mean = (0.1307,) + std = (0.3081,) + else: + raise NotImplementedError + + if train and train_transform: + transform = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean, std), + ]) + else: + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean, std), + ]) + + if args.dataset == 'svhn': + if train: + dataset = torchvision.datasets.SVHN( + root=args.data_dir, split='train', + transform=transform, download=True) + else: + dataset = torchvision.datasets.SVHN( + root=args.data_dir, split='test', + transform=transform, download=True) + dataset.targets = dataset.labels + else: + dataset = torchvision.datasets.__dict__[args.dataset.upper()]( + root=args.data_dir, train=train, + transform=transform, download=True) + elif args.dataset == 'cinic10': + mean = [0.47889522, 0.47227842, 0.43047404] + std = [0.24205776, 0.23828046, 0.25874835] + if train: + path = args.data_dir + '/cinic-10/train' + else: + path = args.data_dir + '/cinic-10/test' + if train and train_transform: + transform = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean, std), + ]) + else: + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean, std), + ]) + dataset = torchvision.datasets.ImageFolder(path, + transform=transform) + else: + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + if train: + transform = transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ]) + else: + transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ]) + if args.dataset == 'imagenet': + import torch + from datasets import load_dataset + + class HFWrapper(torch.utils.data.Dataset): + def __init__(self, dataset, transform=None): + self.dataset = dataset + self.transform = transform + self.targets = dataset['label'] + + def __getitem__(self, index): + batch = self.dataset[index] + data, target = batch['image'], batch['label'] + data = data.convert("RGB") + + if self.transform is not None: + data = self.transform(data) + return data, target + + def __len__(self): + return len(self.dataset) + if train: + dataset = load_dataset("imagenet-1k", use_auth_token=True, cache_dir=args.data_dir, split="train") + else: + dataset = load_dataset("imagenet-1k", use_auth_token=True, cache_dir=args.data_dir, split="test") + dataset = HFWrapper(dataset, transform) + + elif args.dataset == 'tinyimagenet': + if train: + data_dir = os.path.join(args.data_dir, 'tiny-imagenet-200/train') + else: + data_dir = os.path.join(args.data_dir, 'tiny-imagenet-200/val') + + dataset = torchvision.datasets.ImageFolder(data_dir, transform=transform) + + return dataset