diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9413667 --- /dev/null +++ b/.gitignore @@ -0,0 +1,30 @@ +# Results +results +*.log +*.npz +image_classification/results.tar.gz +image_classification/results +image_classification/results.tsv + +# IDE & OS +.idea +.DS_Store + +# Documents +*.pdf +*.png +*.jpg +*.pptx + +# Python +*.pyc +__pycache__ + +# VIM +*.swp + +# Build +actnn/build +actnn/dist +actnn/actnn.egg-info +actnn/actnn/cpp_extension/*.so diff --git a/README.md b/README.md new file mode 100644 index 0000000..e6eca5c --- /dev/null +++ b/README.md @@ -0,0 +1,57 @@ +# ActNN : Activation Compressed Training + +## Install +- Requirements +``` +torch>=1.7.1 +torchvision>=0.8.2 +``` + +- Build +```bash +cd actnn +pip install -v -e . +``` + +## Usage +[mem_speed_benchmark/train.py](mem_speed_benchmark/train.py) is an example on using ActNN for models from torchvision. + +### Basic Usage +- Step1: Convert the model to use ActNN's layers. +```python +import actnn +model = actnn.QModule(model) +``` + +- Step2: Configure the optimization level +ActNN provides several optimization levels to control the trade-off between memory saving and computational overhead. +You can set the optimization level by +```python +# available choices are ["L0", "L1", "L2", "L3", "L4", "L5"] +actnn.set_optimization_level("L3") +``` +See [set_optimization_level](actnn/actnn/conf.py) for more details. + +### Advanced Features +- (Optional) Change the data loader +If you want to use per-sample gradient information for adaptive quantization, +you have to update the dataloader to return sample indices. +You can see `train_loader` in [mem_speed_benchmark/train.py](mem_speed_benchmark/train.py) for example. +In addition, you have to update the configurations. +```python +from actnn import config, QScheme +config.use_gradient = True +QScheme.num_samples = 1300000 # the size of training set +``` +You can find sample code in the above script. + + +## Image Classification +See [image_classification](image_classification/) + +## Sementic Segmentation +Will be added later. + +## Benchmark Memory Usage and Training Speed +See [mem_speed_benchmark](mem_speed_benchmark/) + diff --git a/actnn/actnn/__init__.py b/actnn/actnn/__init__.py new file mode 100644 index 0000000..f0ba854 --- /dev/null +++ b/actnn/actnn/__init__.py @@ -0,0 +1,10 @@ +from . import dataloader +from . import ops +from .conf import config, set_optimization_level +from .dataloader import DataLoader +from .layers import QConv1d, QConv2d, QConv3d, QConvTranspose1d, QConvTranspose2d, QConvTranspose3d, \ + QBatchNorm1d, QBatchNorm2d, QBatchNorm3d, QLinear, QReLU, QSyncBatchNorm, QMaxPool2d +from .module import QModule +from .qscheme import QScheme +from .qbnscheme import QBNScheme +from .utils import get_memory_usage, compute_tensor_bytes, exp_recorder diff --git a/actnn/actnn/_utils/__init__.py b/actnn/actnn/_utils/__init__.py new file mode 100644 index 0000000..54590f9 --- /dev/null +++ b/actnn/actnn/_utils/__init__.py @@ -0,0 +1,45 @@ +r"""Utility classes & functions for data loading. Code in this folder is mostly +used by ../dataloder.py. + +A lot of multiprocessing is used in data loading, which only supports running +functions defined in global environment (py2 can't serialize static methods). +Therefore, for code tidiness we put these functions into different files in this +folder. +""" + +import sys +import atexit + +# old private location of the ExceptionWrapper that some users rely on: +from torch._utils import ExceptionWrapper + + +IS_WINDOWS = sys.platform == "win32" + + +MP_STATUS_CHECK_INTERVAL = 5.0 +r"""Interval (in seconds) to check status of processes to avoid hanging in + multiprocessing data loading. This is mainly used in getting data from + another process, in which case we need to periodically check whether the + sender is alive to prevent hanging.""" + + +python_exit_status = False +r"""Whether Python is shutting down. This flag is guaranteed to be set before +the Python core library resources are freed, but Python may already be exiting +for some time when this is set. + +Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar +hook in Python 3.7 multiprocessing library: +https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327 +""" + + +def _set_python_exit_flag(): + global python_exit_status + python_exit_status = True + +atexit.register(_set_python_exit_flag) + + +from . import worker, signal_handling, pin_memory, collate, fetch diff --git a/actnn/actnn/_utils/collate.py b/actnn/actnn/_utils/collate.py new file mode 100644 index 0000000..afe5ade --- /dev/null +++ b/actnn/actnn/_utils/collate.py @@ -0,0 +1,86 @@ +r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to +collate samples fetched from dataset into Tensor(s). + +These **needs** to be in global scope since Py2 doesn't support serializing +static methods. +""" + +import torch +import re +from torch._six import container_abcs, string_classes, int_classes + +np_str_obj_array_pattern = re.compile(r'[SaUO]') + + +def default_convert(data): + r"""Converts each NumPy array data field into a tensor""" + elem_type = type(data) + if isinstance(data, torch.Tensor): + return data + elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ + and elem_type.__name__ != 'string_': + # array of string classes and object + if elem_type.__name__ == 'ndarray' \ + and np_str_obj_array_pattern.search(data.dtype.str) is not None: + return data + return torch.as_tensor(data) + elif isinstance(data, container_abcs.Mapping): + return {key: default_convert(data[key]) for key in data} + elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple + return elem_type(*(default_convert(d) for d in data)) + elif isinstance(data, container_abcs.Sequence) and not isinstance(data, string_classes): + return [default_convert(d) for d in data] + else: + return data + + +default_collate_err_msg_format = ( + "default_collate: batch must contain tensors, numpy arrays, numbers, " + "dicts or lists; found {}") + + +def default_collate(batch): + r"""Puts each data field into a tensor with outer dimension batch size""" + + elem = batch[0] + elem_type = type(elem) + if isinstance(elem, torch.Tensor): + out = None + if torch.utils.data.get_worker_info() is not None: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum([x.numel() for x in batch]) + storage = elem.storage()._new_shared(numel) + out = elem.new(storage) + return torch.stack(batch, 0, out=out) + elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ + and elem_type.__name__ != 'string_': + elem = batch[0] + if elem_type.__name__ == 'ndarray': + # array of string classes and object + if np_str_obj_array_pattern.search(elem.dtype.str) is not None: + raise TypeError(default_collate_err_msg_format.format(elem.dtype)) + + return default_collate([torch.as_tensor(b) for b in batch]) + elif elem.shape == (): # scalars + return torch.as_tensor(batch) + elif isinstance(elem, float): + return torch.tensor(batch, dtype=torch.float64) + elif isinstance(elem, int_classes): + return torch.tensor(batch) + elif isinstance(elem, string_classes): + return batch + elif isinstance(elem, container_abcs.Mapping): + return {key: default_collate([d[key] for d in batch]) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple + return elem_type(*(default_collate(samples) for samples in zip(*batch))) + elif isinstance(elem, container_abcs.Sequence): + # check to make sure that the elements in batch have consistent size + it = iter(batch) + elem_size = len(next(it)) + if not all(len(elem) == elem_size for elem in it): + raise RuntimeError('each element in list of batch should be of equal size') + transposed = zip(*batch) + return [default_collate(samples) for samples in transposed] + + raise TypeError(default_collate_err_msg_format.format(elem_type)) diff --git a/actnn/actnn/_utils/fetch.py b/actnn/actnn/_utils/fetch.py new file mode 100644 index 0000000..de5c690 --- /dev/null +++ b/actnn/actnn/_utils/fetch.py @@ -0,0 +1,47 @@ +r""""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch +data from an iterable-style or map-style dataset. This logic is shared in both +single- and multi-processing data loading. +""" + + +class _BaseDatasetFetcher(object): + def __init__(self, dataset, auto_collation, collate_fn, drop_last): + self.dataset = dataset + self.auto_collation = auto_collation + self.collate_fn = collate_fn + self.drop_last = drop_last + + def fetch(self, possibly_batched_index): + raise NotImplementedError() + + +class _IterableDatasetFetcher(_BaseDatasetFetcher): + def __init__(self, dataset, auto_collation, collate_fn, drop_last): + super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last) + self.dataset_iter = iter(dataset) + + def fetch(self, possibly_batched_index): + if self.auto_collation: + data = [] + for _ in possibly_batched_index: + try: + data.append(next(self.dataset_iter)) + except StopIteration: + break + if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)): + raise StopIteration + else: + data = next(self.dataset_iter) + return self.collate_fn(data) + + +class _MapDatasetFetcher(_BaseDatasetFetcher): + def __init__(self, dataset, auto_collation, collate_fn, drop_last): + super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last) + + def fetch(self, possibly_batched_index): + if self.auto_collation: + data = [self.dataset[idx] for idx in possibly_batched_index] + else: + data = self.dataset[possibly_batched_index] + return self.collate_fn(data) diff --git a/actnn/actnn/_utils/pin_memory.py b/actnn/actnn/_utils/pin_memory.py new file mode 100644 index 0000000..055c3cb --- /dev/null +++ b/actnn/actnn/_utils/pin_memory.py @@ -0,0 +1,59 @@ +r""""Contains definitions of the methods used by the _BaseDataLoaderIter to put +fetched tensors into pinned memory. + +These **needs** to be in global scope since Py2 doesn't support serializing +static methods. +""" + +import torch +from torch._six import queue, container_abcs, string_classes +from . import MP_STATUS_CHECK_INTERVAL +from torch._utils import ExceptionWrapper + + +def _pin_memory_loop(in_queue, out_queue, device_id, done_event): + # This setting is thread local, and prevents the copy in pin_memory from + # consuming all CPU cores. + torch.set_num_threads(1) + + torch.cuda.set_device(device_id) + + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the + # logic of this function. + while not done_event.is_set(): + try: + r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) + except queue.Empty: + continue + idx, data = r + if not done_event.is_set() and not isinstance(data, ExceptionWrapper): + try: + data = pin_memory(data) + except Exception: + data = ExceptionWrapper( + where="in pin memory thread for device {}".format(device_id)) + r = (idx, data) + while not done_event.is_set(): + try: + out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL) + break + except queue.Full: + continue + del r # save memory + + +def pin_memory(data): + if isinstance(data, torch.Tensor): + return data.pin_memory() + elif isinstance(data, string_classes): + return data + elif isinstance(data, container_abcs.Mapping): + return {k: pin_memory(sample) for k, sample in data.items()} + elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple + return type(data)(*(pin_memory(sample) for sample in data)) + elif isinstance(data, container_abcs.Sequence): + return [pin_memory(sample) for sample in data] + elif hasattr(data, "pin_memory"): + return data.pin_memory() + else: + return data diff --git a/actnn/actnn/_utils/signal_handling.py b/actnn/actnn/_utils/signal_handling.py new file mode 100644 index 0000000..4e57c06 --- /dev/null +++ b/actnn/actnn/_utils/signal_handling.py @@ -0,0 +1,71 @@ +r""""Signal handling for multiprocessing data loading. + +NOTE [ Signal handling in multiprocessing data loading ] + +In cases like DataLoader, if a worker process dies due to bus error/segfault +or just hang, the main process will hang waiting for data. This is difficult +to avoid on PyTorch side as it can be caused by limited shm, or other +libraries users call in the workers. In this file and `DataLoader.cpp`, we make +our best effort to provide some error message to users when such unfortunate +events happen. + +When a _BaseDataLoaderIter starts worker processes, their pids are registered in a +defined in `DataLoader.cpp`: id(_BaseDataLoaderIter) => Collection[ Worker pids ] +via `_set_worker_pids`. + +When an error happens in a worker process, the main process received a SIGCHLD, +and Python will eventually call the handler registered below +(in `_set_SIGCHLD_handler`). In the handler, the `_error_if_any_worker_fails` +call checks all registered worker pids and raise proper error message to +prevent main process from hanging waiting for data from worker. + +Additionally, at the beginning of each worker's `_utils.worker._worker_loop`, +`_set_worker_signal_handlers` is called to register critical signal handlers +(e.g., for SIGSEGV, SIGBUS, SIGFPE, SIGTERM) in C, which just prints an error +message to stderr before triggering the default handler. So a message will also +be printed from the worker process when it is killed by such signals. + +See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for the reasoning of +this signal handling design and other mechanism we implement to make our +multiprocessing data loading robust to errors. +""" + +import signal +import threading +from . import IS_WINDOWS + +# Some of the following imported functions are not used in this file, but are to +# be used `_utils.signal_handling.XXXXX`. +from torch._C import _set_worker_pids, _remove_worker_pids # noqa: F401 +from torch._C import _error_if_any_worker_fails, _set_worker_signal_handlers # noqa: F401 + +_SIGCHLD_handler_set = False +r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one +handler needs to be set for all DataLoaders in a process.""" + + +def _set_SIGCHLD_handler(): + # Windows doesn't support SIGCHLD handler + if IS_WINDOWS: + return + # can't set signal in child threads + if not isinstance(threading.current_thread(), threading._MainThread): + return + global _SIGCHLD_handler_set + if _SIGCHLD_handler_set: + return + previous_handler = signal.getsignal(signal.SIGCHLD) + if not callable(previous_handler): + # This doesn't catch default handler, but SIGCHLD default handler is a + # no-op. + previous_handler = None + + def handler(signum, frame): + # This following call uses `waitid` with WNOHANG from C side. Therefore, + # Python can still get and update the process status successfully. + _error_if_any_worker_fails() + if previous_handler is not None: + previous_handler(signum, frame) + + signal.signal(signal.SIGCHLD, handler) + _SIGCHLD_handler_set = True diff --git a/actnn/actnn/_utils/worker.py b/actnn/actnn/_utils/worker.py new file mode 100644 index 0000000..fc5f081 --- /dev/null +++ b/actnn/actnn/_utils/worker.py @@ -0,0 +1,206 @@ +r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers. + +These **needs** to be in global scope since Py2 doesn't support serializing +static methods. +""" + +import torch +import random +import os +from collections import namedtuple +from torch._six import queue +from torch._utils import ExceptionWrapper +from . import signal_handling, MP_STATUS_CHECK_INTERVAL, IS_WINDOWS + +if IS_WINDOWS: + import ctypes + from ctypes.wintypes import DWORD, BOOL, HANDLE + + # On Windows, the parent ID of the worker process remains unchanged when the manager process + # is gone, and the only way to check it through OS is to let the worker have a process handle + # of the manager and ask if the process status has changed. + class ManagerWatchdog(object): + def __init__(self): + self.manager_pid = os.getppid() + + self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) + self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD) + self.kernel32.OpenProcess.restype = HANDLE + self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD) + self.kernel32.WaitForSingleObject.restype = DWORD + + # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx + SYNCHRONIZE = 0x00100000 + self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid) + + if not self.manager_handle: + raise ctypes.WinError(ctypes.get_last_error()) + + self.manager_dead = False + + def is_alive(self): + if not self.manager_dead: + # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx + self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0 + return not self.manager_dead +else: + class ManagerWatchdog(object): # type: ignore[no-redef] + def __init__(self): + self.manager_pid = os.getppid() + self.manager_dead = False + + def is_alive(self): + if not self.manager_dead: + self.manager_dead = os.getppid() != self.manager_pid + return not self.manager_dead + +_worker_info = None + + +class WorkerInfo(object): + __initialized = False + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + self.__keys = tuple(kwargs.keys()) + self.__initialized = True + + def __setattr__(self, key, val): + if self.__initialized: + raise RuntimeError("Cannot assign attributes to {} objects".format(self.__class__.__name__)) + return super(WorkerInfo, self).__setattr__(key, val) + + def __repr__(self): + items = [] + for k in self.__keys: + items.append('{}={}'.format(k, getattr(self, k))) + return '{}({})'.format(self.__class__.__name__, ', '.join(items)) + + +def get_worker_info(): + r"""Returns the information about the current + :class:`~torch.utils.data.DataLoader` iterator worker process. + + When called in a worker, this returns an object guaranteed to have the + following attributes: + + * :attr:`id`: the current worker id. + * :attr:`num_workers`: the total number of workers. + * :attr:`seed`: the random seed set for the current worker. This value is + determined by main process RNG and the worker id. See + :class:`~torch.utils.data.DataLoader`'s documentation for more details. + * :attr:`dataset`: the copy of the dataset object in **this** process. Note + that this will be a different object in a different process than the one + in the main process. + + When called in the main process, this returns ``None``. + + .. note:: + When used in a :attr:`worker_init_fn` passed over to + :class:`~torch.utils.data.DataLoader`, this method can be useful to + set up each worker process differently, for instance, using ``worker_id`` + to configure the ``dataset`` object to only read a specific fraction of a + sharded dataset, or use ``seed`` to seed other libraries used in dataset + code (e.g., NumPy). + """ + return _worker_info + + +r"""Dummy class used to signal the end of an IterableDataset""" +_IterableDatasetStopIteration = namedtuple('_IterableDatasetStopIteration', ['worker_id']) + + +def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event, + auto_collation, collate_fn, drop_last, seed, init_fn, worker_id, + num_workers): + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the + # logic of this function. + + try: + # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal + # module's handlers are executed after Python returns from C low-level + # handlers, likely when the same fatal signal had already happened + # again. + # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers + signal_handling._set_worker_signal_handlers() + + torch.set_num_threads(1) + random.seed(seed) + torch.manual_seed(seed) + + global _worker_info + _worker_info = WorkerInfo(id=worker_id, num_workers=num_workers, + seed=seed, dataset=dataset) + + from torch.utils.data import _DatasetKind + + init_exception = None + + try: + if init_fn is not None: + init_fn(worker_id) + + fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last) + except Exception: + init_exception = ExceptionWrapper( + where="in DataLoader worker process {}".format(worker_id)) + + # When using Iterable mode, some worker can exit earlier than others due + # to the IterableDataset behaving differently for different workers. + # When such things happen, an `_IterableDatasetStopIteration` object is + # sent over to the main process with the ID of this worker, so that the + # main process won't send more tasks to this worker, and will send + # `None` to this worker to properly exit it. + # + # Note that we cannot set `done_event` from a worker as it is shared + # among all processes. Instead, we set the `iteration_end` flag to + # signify that the iterator is exhausted. When either `done_event` or + # `iteration_end` is set, we skip all processing step and just wait for + # `None`. + iteration_end = False + + watchdog = ManagerWatchdog() + + while watchdog.is_alive(): + try: + r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) + except queue.Empty: + continue + if r is None: + # Received the final signal + assert done_event.is_set() or iteration_end + break + elif done_event.is_set() or iteration_end: + # `done_event` is set. But I haven't received the final signal + # (None) yet. I will keep continuing until get it, and skip the + # processing steps. + continue + idx, index = r + if init_exception is not None: + data = init_exception + init_exception = None + else: + try: + data = fetcher.fetch(index) + except Exception as e: + if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable: + data = _IterableDatasetStopIteration(worker_id) + # Set `iteration_end` + # (1) to save future `next(...)` calls, and + # (2) to avoid sending multiple `_IterableDatasetStopIteration`s. + iteration_end = True + else: + # It is important that we don't store exc_info in a variable. + # `ExceptionWrapper` does the correct thing. + # See NOTE [ Python Traceback Reference Cycle Problem ] + data = ExceptionWrapper( + where="in DataLoader worker process {}".format(worker_id)) + data_queue.put((idx, (index, data))) + del data, idx, index, r # save memory + except KeyboardInterrupt: + # Main process will raise KeyboardInterrupt anyways. + pass + if done_event.is_set(): + data_queue.cancel_join_thread() + data_queue.close() diff --git a/actnn/actnn/conf.py b/actnn/actnn/conf.py new file mode 100644 index 0000000..7ec819d --- /dev/null +++ b/actnn/actnn/conf.py @@ -0,0 +1,65 @@ +import ast +import os + +def set_optimization_level(level): + if level == 'L0': # Do nothing + config.compress_activation = False + config.adaptive_conv_scheme = config.adaptive_bn_scheme = False + elif level == 'L1': # 4-bit conv + 32-bit bn + config.activation_compression_bits = [4] + config.adaptive_conv_scheme = config.adaptive_bn_scheme = False + config.enable_quantized_bn = False + elif level == 'L2': # 4-bit + config.activation_compression_bits = [4] + config.adaptive_conv_scheme = config.adaptive_bn_scheme = False + elif level == 'L3': # 2-bit + pass + elif level == 'L3.1': # 2-bit + light system optimization + pass + config.cudnn_benchmark_conv2d = False + config.empty_cache_threshold = 0.2 + config.pipeline_threshold = 3 * 1024**3 + elif level == 'L4': # 2-bit + swap + pass + config.swap = True + elif level == 'L5': # 2-bit + swap + defragmentation + config.swap = True + os.environ['PYTORCH_CACHE_THRESHOLD'] = '256000000' + elif level == 'swap': + config.swap = True + config.compress_activation = False + else: + raise ValueError("Invalid level: " + level) + +class QuantizationConfig: + def __init__(self): + self.compress_activation = True + self.activation_compression_bits = [2, 8, 8] + self.pergroup = True + self.perlayer = True + self.initial_bits = 8 + self.stochastic = True + self.training = True + self.group_size = 256 + self.use_gradient = False + self.adaptive_conv_scheme = True + self.adaptive_bn_scheme = True + self.simulate = False + self.enable_quantized_bn = True + + # Memory management flag + self.empty_cache_threshold = None + self.pipeline_threshold = None + self.cudnn_benchmark_conv2d = True + self.swap = False + + # Debug related flag + self.debug_memory_model = ast.literal_eval(os.environ.get('DEBUG_MEM', "False")) + self.debug_speed = ast.literal_eval(os.environ.get('DEBUG_SPEED', "False")) + self.debug_memory_op_forward = False + self.debug_memory_op_backward = False + self.debug_remove_bn = False + self.debug_remove_relu = False + +config = QuantizationConfig() + diff --git a/actnn/actnn/cpp_extension/backward_func.cc b/actnn/actnn/cpp_extension/backward_func.cc new file mode 100644 index 0000000..88dc657 --- /dev/null +++ b/actnn/actnn/cpp_extension/backward_func.cc @@ -0,0 +1,91 @@ +/* + * Use pytorch c++ extension to export c++ functions to python + */ + +#include +#include +#include + +namespace at { +namespace native { + +// Copied from +// https://github.com/pytorch/pytorch/blob/8deb4fe809ca956276e8d6edaa184de7118be58f/aten/src/ATen/native/layer_norm.h#L11 +std::tuple prepare_layer_norm_inputs( + const Tensor& input, + IntArrayRef normalized_shape, + const Tensor& weight /* optional */, + const Tensor& bias /* optional */) { + + const int normalized_ndim = normalized_shape.size(); + TORCH_CHECK( + normalized_ndim >= 1, + "Expected normalized_shape to be at least 1-dimensional, i.e., ", + "containing at least one element, but got normalized_shape = ", + normalized_shape); + TORCH_CHECK( + !weight.defined() || weight.sizes().equals(normalized_shape), + "Expected weight to be of same shape as normalized_shape, but got ", + "weight of shape ", + weight.sizes(), + " and normalized_shape = ", + normalized_shape); + TORCH_CHECK( + !bias.defined() || bias.sizes().equals(normalized_shape), + "Expected bias to be of same shape as normalized_shape, but got ", + "bias of shape ", + bias.sizes(), + " and normalized_shape = ", + normalized_shape); + + const auto input_shape = input.sizes(); + const auto input_ndim = input.dim(); + + if (input_ndim < normalized_ndim || + !input_shape.slice(input_ndim - normalized_ndim) + .equals(normalized_shape)) { + std::stringstream ss; + ss << "Given normalized_shape=" << normalized_shape + << ", expected input with shape [*"; + for (auto size : normalized_shape) { + ss << ", " << size; + } + ss << "], but got input of size" << input_shape; + AT_ERROR(ss.str()); + } + + const int axis = input_ndim - normalized_ndim; + const int64_t M = std::accumulate( + input_shape.cbegin(), + input_shape.cbegin() + axis, + 1LL, + std::multiplies()); + const int64_t N = std::accumulate( + input_shape.cbegin() + axis, + input_shape.cend(), + 1LL, + std::multiplies()); + + const auto& X = input.is_contiguous() ? input : input.contiguous(); + const auto& gamma = weight.is_contiguous() ? weight : weight.contiguous(); + const auto& beta = bias.is_contiguous() ? bias : bias.contiguous(); + + return std::make_tuple(X, gamma, beta, M, N); +} + + +} // namespace native +} // namespace at + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("cudnn_convolution_backward", &at::cudnn_convolution_backward); + m.def("cudnn_convolution_transpose_backward", &at::cudnn_convolution_transpose_backward); + m.def("prepare_layer_norm_inputs", &at::native::prepare_layer_norm_inputs); + m.def("layer_norm_cuda", &at::native::layer_norm_cuda); + m.def("layer_norm_backward_cuda", &at::native::layer_norm_backward_cuda); + m.def("cudnn_batch_norm", &at::native::cudnn_batch_norm); + m.def("cudnn_batch_norm_backward", &at::native::cudnn_batch_norm_backward); + m.def("native_batch_norm", &at::native_batch_norm); + m.def("native_batch_norm_backward", &at::native_batch_norm_backward); +} diff --git a/actnn/actnn/cpp_extension/calc_precision.cc b/actnn/actnn/cpp_extension/calc_precision.cc new file mode 100644 index 0000000..7b17727 --- /dev/null +++ b/actnn/actnn/cpp_extension/calc_precision.cc @@ -0,0 +1,130 @@ +#include +#include +#include +#include +#include +#include + +// Greedy algorithm +torch::Tensor calc_precision(torch::Tensor b, torch::Tensor C, torch::Tensor w, double target) { + TORCH_CHECK(b.device().is_cpu(), "b must be a CPU tensor!"); + TORCH_CHECK(b.is_contiguous(), "b must be contiguous!"); + TORCH_CHECK(C.device().is_cpu(), "C must be a CPU tensor!"); + TORCH_CHECK(C.is_contiguous(), "C must be contiguous!"); + TORCH_CHECK(w.device().is_cpu(), "w must be a CPU tensor!"); + TORCH_CHECK(w.is_contiguous(), "w must be contiguous!"); + + // min \sum_i C_i / (2^b_i - 1)^2, s.t., \sum_i b_i = N b + std::priority_queue> q; + + auto *b_data = b.data_ptr(); + auto *C_data = C.data_ptr(); + auto *w_data = w.data_ptr(); + + auto get_obj = [&](float C, int b) { + int coeff_1 = ((1 << b) - 1) * ((1 << b) - 1); + int coeff_2 = ((1 << (b-1)) - 1) * ((1 << (b-1)) - 1); + return C * (1.0 / coeff_1 - 1.0 / coeff_2); // negative + }; + + int N = b.size(0); + double b_sum = 0; + for (int i = 0; i < N; i++) { + auto delta = get_obj(C_data[i], b_data[i]) / w_data[i]; + q.push(std::make_pair(delta, i)); + b_sum += b_data[i] * w_data[i]; + } + + while (b_sum > target) { // Pick up the smallest increment (largest decrement) + assert(!q.empty()); + auto i = q.top().second; + q.pop(); + b_data[i] -= 1; + b_sum -= w_data[i]; + if (b_data[i] > 1) { + auto delta = get_obj(C_data[i], b_data[i]) / w_data[i]; + q.push(std::make_pair(delta, i)); + } + } + return b; +} + +struct State { + float obj; + int p, b; +}; + +// Dynamic programming +std::pair calc_precision_dp(torch::Tensor A, torch::Tensor C, int max_b, int target, int states) { + using namespace std; + + TORCH_CHECK(A.device().is_cpu(), "A must be a CPU tensor!"); + TORCH_CHECK(A.is_contiguous(), "A must be contiguous!"); + TORCH_CHECK(C.device().is_cpu(), "C must be a CPU tensor!"); + TORCH_CHECK(C.is_contiguous(), "C must be contiguous!"); + // min \sum_i (1-p_i)/p_i A_i + C_i / (p_i B_i^2), + // s.t. \sum_i p_i b_i = N b + // where B_i = 2^b_i - 1, and p_i takes ``states'' discrete states. + + // We solve with dynamic programming, where + // f[i, b] is the minimum objective function using the first i terms and b/s bits, + // where i \in [0, N] and b \in [0, N * b * states] + // the time complexity is O(N^2bs) states and O(bs) transitions + // O((Nbs)^2) in total, where N=128, b=2, and s=2, (Nbs)^2 = 262144 + + int N = A.size(0); + auto *A_data = A.data_ptr(); + auto *C_data = C.data_ptr(); + int total_states = target * N * states; + + // Initialize + std::vector> f(N+1); + for (auto &v: f) { + v.resize(total_states + 1); + for (auto &state: v) + state.obj = 1e20; + } + f[0][0].obj = 0; +// cout << "Initialized " << total_states << endl; + + for (int i = 1; i <= N; i++) { + // Moving from f[i-1] to f[i] + for (int b = 0; b < total_states; b++) { + auto &old_state = f[i-1][b]; + + for (int b0 = 1; b0 <= max_b; b0++) + for (int p = 1; p <= states; p++) + if (b + b0 * p <= total_states) { + auto &new_state = f[i][b + b0 * p]; + float p0 = (float)p / states; + float B = (1< 0; i--) { + auto &state = f[i][current_state]; + b_vec[i-1] = state.b; + p_vec[i-1] = (float)state.p / states; + current_state -= state.b * state.p; + } + TORCH_CHECK(current_state==0, "DP Failed: no path to initial state!"); + + return std::make_pair(b_vec, p_vec); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("calc_precision", &calc_precision, "calc_precision"); + m.def("calc_precision_dp", &calc_precision_dp, "calc_precision_dp"); +} diff --git a/actnn/actnn/cpp_extension/ext_common.h b/actnn/actnn/cpp_extension/ext_common.h new file mode 100644 index 0000000..79d3eca --- /dev/null +++ b/actnn/actnn/cpp_extension/ext_common.h @@ -0,0 +1,28 @@ +// Helper for type check +#define CHECK_CUDA_TENSOR_DIM_TYPE(name, n_dim, type) \ + TORCH_CHECK(name.device().is_cuda(), #name " must be a CUDA tensor!"); \ + TORCH_CHECK(name.is_contiguous(), #name " must be contiguous!"); \ + TORCH_CHECK(name.dim() == n_dim, "The dimension of " #name " is not correct!"); \ + TORCH_CHECK(name.dtype() == type, "The type of " #name " is not correct!"); \ + +// Helper for type check +#define CHECK_CUDA_TENSOR_TYPE(name, type) \ + TORCH_CHECK(name.device().is_cuda(), #name " must be a CUDA tensor!"); \ + TORCH_CHECK(name.is_contiguous(), #name " must be contiguous!"); \ + TORCH_CHECK(name.dtype() == type, "The type of " #name " is not correct!"); \ + +// Helper for type check +#define CHECK_CUDA_TENSOR_FLOAT(name) \ + TORCH_CHECK(name.device().is_cuda(), #name " must be a CUDA tensor!"); \ + TORCH_CHECK(name.is_contiguous(), #name " must be contiguous!"); \ + TORCH_CHECK(name.dtype() == torch::kFloat32 || name.dtype() == torch::kFloat16, \ + "The type of " #name " is not correct!"); \ + +// Helper for type check +#define CHECK_CUDA_TENSOR_DIM_FLOAT(name, n_dim) \ + TORCH_CHECK(name.device().is_cuda(), #name " must be a CUDA tensor!"); \ + TORCH_CHECK(name.is_contiguous(), #name " must be contiguous!"); \ + TORCH_CHECK(name.dim() == n_dim, "The dimension of " #name " is not correct!"); \ + TORCH_CHECK(name.dtype() == torch::kFloat32 || name.dtype() == torch::kFloat16, \ + "The type of " #name " is not correct!"); \ + diff --git a/actnn/actnn/cpp_extension/minimax.cc b/actnn/actnn/cpp_extension/minimax.cc new file mode 100644 index 0000000..33d3b17 --- /dev/null +++ b/actnn/actnn/cpp_extension/minimax.cc @@ -0,0 +1,16 @@ +#include + +#include "ext_common.h" + +std::pair minimax_cuda(torch::Tensor data); + +std::pair minimax(torch::Tensor data) { + CHECK_CUDA_TENSOR_FLOAT(data); + + return minimax_cuda(data); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("minimax", &minimax); +} diff --git a/actnn/actnn/cpp_extension/minimax_cuda_kernel.cu b/actnn/actnn/cpp_extension/minimax_cuda_kernel.cu new file mode 100644 index 0000000..8d0ba0e --- /dev/null +++ b/actnn/actnn/cpp_extension/minimax_cuda_kernel.cu @@ -0,0 +1,89 @@ +#include + +#include +#include + +using torch::Tensor; + + +__device__ __inline__ c10::Half __shfl_down_sync(const unsigned mask, const c10::Half var, + const unsigned int delta, const int width) { + __half var_ = var; + return __shfl_down_sync(mask, var_, delta, width); +} + + +__device__ __inline__ c10::Half __shfl_sync(const unsigned mask, const c10::Half var, + const unsigned int delta, const int width) { + __half var_ = var; + return __shfl_sync(mask, var_, delta, width); +} + + +template +__global__ void minimax_cuda_kernel(const scalar_t* __restrict__ data, + scalar_t* __restrict__ min, + scalar_t* __restrict__ max, + int N, + int D) { + scalar_t max_val, min_val; + max_val = -1e30; + min_val = 1e30; + + for (int k1_outer = 0; k1_outer < D / 32; ++k1_outer) { + max_val = std::max(max_val, data[blockIdx.x * D + k1_outer * 32 + threadIdx.x]); + min_val = std::min(min_val, data[blockIdx.x * D + k1_outer * 32 + threadIdx.x]); + } + + unsigned int mask; + scalar_t max_val_t, min_val_t; + mask = __activemask(); + + max_val_t = __shfl_down_sync(mask, max_val, 16, 32); + max_val = std::max(max_val, max_val_t); + max_val_t = __shfl_down_sync(mask, max_val, 8, 32); + max_val = std::max(max_val, max_val_t); + max_val_t = __shfl_down_sync(mask, max_val, 4, 32); + max_val = std::max(max_val, max_val_t); + max_val_t = __shfl_down_sync(mask, max_val, 2, 32); + max_val = std::max(max_val, max_val_t); + max_val_t = __shfl_down_sync(mask, max_val, 1, 32); + max_val = std::max(max_val, max_val_t); + max_val = __shfl_sync(mask, max_val, 0, 32); + max[blockIdx.x] = max_val; + + min_val_t = __shfl_down_sync(mask, min_val, 16, 32); + min_val = std::min(min_val, min_val_t); + min_val_t = __shfl_down_sync(mask, min_val, 8, 32); + min_val = std::min(min_val, min_val_t); + min_val_t = __shfl_down_sync(mask, min_val, 4, 32); + min_val = std::min(min_val, min_val_t); + min_val_t = __shfl_down_sync(mask, min_val, 2, 32); + min_val = std::min(min_val, min_val_t); + min_val_t = __shfl_down_sync(mask, min_val, 1, 32); + min_val = std::min(min_val, min_val_t); + min_val = __shfl_sync(mask, min_val, 0, 32); + min[blockIdx.x] = min_val; +} + + +std::pair minimax_cuda(torch::Tensor data) { + int N = data.size(0); + int D = data.size(1); + + auto options = torch::TensorOptions().dtype(data.dtype()).device(data.device()); + Tensor min = torch::empty({N,}, options); + Tensor max = torch::empty({N,}, options); + + int blocks = N; + int threads = 32; + TORCH_CHECK(D % 32 == 0 && D > 32); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(data.scalar_type(), "minimax_cuda", ([&] { + minimax_cuda_kernel<<>>( + data.data_ptr(), min.data_ptr(), max.data_ptr(), + N, D); + })); + + return std::make_pair(min, max); +} diff --git a/actnn/actnn/cpp_extension/quantization.cc b/actnn/actnn/cpp_extension/quantization.cc new file mode 100644 index 0000000..93ccbb1 --- /dev/null +++ b/actnn/actnn/cpp_extension/quantization.cc @@ -0,0 +1,177 @@ +/* + * Cuda operators for quantization and mixed-precision packing + */ + +#include +#include + +#include "ext_common.h" + +using torch::autograd::Function; +using torch::autograd::AutogradContext; +using torch::autograd::tensor_list; +using torch::Tensor; +using torch::IntArrayRef; + +// Declarations for functions in ext_quantization_cuda_kernel.cu +// Pack and unpack +std::pair pack_mixed_precision_cuda( + Tensor data, Tensor min, Tensor max, Tensor bits, bool stochastic); +Tensor unpack_mixed_precision_cuda( + Tensor data, Tensor bits, Tensor scale, Tensor min, int N, int num_groups, int group_size); +std::pair pack_single_precision_cuda( + Tensor data, Tensor min, Tensor max, int bits, bool stochastic); +Tensor unpack_single_precision_cuda( + Tensor data, int bits, Tensor scale, Tensor min, int N, int num_groups, int group_size); + +// ActQuantizedReLU +std::pair act_quantized_relu_forward_cuda(Tensor data); +Tensor act_quantized_relu_backward_cuda(Tensor grad_output, Tensor mask); + +// ActQuantizedMaxPool2d +std::pair act_quantized_max_pool2d_forward_cuda(Tensor input, + IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, + bool ceil_mode, bool return_indices); +Tensor act_quantized_max_pool2d_backward_cuda(Tensor grad_output, Tensor max_indices, + IntArrayRef input_shape, + IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, + bool ceil_mode, bool return_indices); + + +// Pack/Unpack mixed precision +std::pair pack_mixed_precision(Tensor data, + Tensor min, + Tensor max, + Tensor bits, + bool stochastic) { + CHECK_CUDA_TENSOR_DIM_FLOAT(data, 3); + CHECK_CUDA_TENSOR_DIM_FLOAT(min, 3); + CHECK_CUDA_TENSOR_DIM_FLOAT(max, 3); + CHECK_CUDA_TENSOR_DIM_TYPE(bits, 1, torch::kInt32); + + return pack_mixed_precision_cuda(data, min, max, bits, stochastic); +} + +Tensor unpack_mixed_precision(Tensor data, + Tensor bits, + Tensor scale, + Tensor min, + int N, + int num_groups, + int group_size) { + CHECK_CUDA_TENSOR_DIM_TYPE(data, 1, torch::kInt32); + CHECK_CUDA_TENSOR_DIM_TYPE(bits, 1, torch::kInt32); + CHECK_CUDA_TENSOR_DIM_FLOAT(scale, 3); + CHECK_CUDA_TENSOR_DIM_FLOAT(min, 3); + + return unpack_mixed_precision_cuda(data, bits, scale, min, + N, num_groups, group_size); +} + + +// Pack/Unpack single precision +std::pair pack_single_precision(Tensor data, + Tensor min, + Tensor max, + int bits, + bool stochastic) { + CHECK_CUDA_TENSOR_DIM_FLOAT(data, 3); + CHECK_CUDA_TENSOR_DIM_FLOAT(min, 3); + CHECK_CUDA_TENSOR_DIM_FLOAT(max, 3); + + return pack_single_precision_cuda(data, min, max, bits, stochastic); +} + +Tensor unpack_single_precision(Tensor data, + int bits, + Tensor scale, + Tensor min, + int N, + int num_groups, + int group_size) { + CHECK_CUDA_TENSOR_DIM_TYPE(data, 1, torch::kInt8); + CHECK_CUDA_TENSOR_DIM_FLOAT(scale, 3); + CHECK_CUDA_TENSOR_DIM_FLOAT(min, 3); + + return unpack_single_precision_cuda(data, bits, scale, min, + N, num_groups, group_size); +} + + +// Activation quantized relu: use compressed bit stream to store activation +class ActQuantizedReLU : public Function { + public: + static Tensor forward(AutogradContext *ctx, Tensor input) { + Tensor output, mask; + std::tie(output, mask) = act_quantized_relu_forward_cuda(input); + ctx->save_for_backward({mask}); + return output; + } + + static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + return {act_quantized_relu_backward_cuda(grad_outputs[0], saved[0])}; + } +}; + +Tensor act_quantized_relu(Tensor input) { + CHECK_CUDA_TENSOR_FLOAT(input); + return ActQuantizedReLU::apply(input); +} + + +// Activation quantized max_pool2d: use compressed bit stream to store activation +class ActQuantizedMaxPool2d : public Function { + public: + static Tensor forward(AutogradContext *ctx, Tensor input, IntArrayRef kernel_size, + IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, bool return_indices) { + TORCH_CHECK(kernel_size.size() == 2); + TORCH_CHECK(stride.size() == 2); + TORCH_CHECK(padding.size() == 2); + TORCH_CHECK(dilation.size() == 2); + TORCH_CHECK(ceil_mode == false); + TORCH_CHECK(return_indices == false); + TORCH_CHECK(kernel_size[0] * kernel_size[1] < 16); + + Tensor output, max_indices; + std::tie(output, max_indices) = act_quantized_max_pool2d_forward_cuda(input, kernel_size, stride, padding, + dilation, ceil_mode, return_indices); + ctx->save_for_backward({max_indices}); + ctx->saved_data["input_shape"] = input.sizes(); + ctx->saved_data["kernel_size"] = kernel_size; + ctx->saved_data["stride"] = stride; + ctx->saved_data["padding"] = padding; + ctx->saved_data["dilation"] = dilation; + ctx->saved_data["ceil_mode"] = ceil_mode; + ctx->saved_data["return_indices"] = return_indices; + return output; + } + + static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + return {act_quantized_max_pool2d_backward_cuda( + grad_outputs[0], saved[0], + IntArrayRef(ctx->saved_data["input_shape"].toIntVector()), + IntArrayRef(ctx->saved_data["kernel_size"].toIntVector()), + IntArrayRef(ctx->saved_data["stride"].toIntVector()), + IntArrayRef(ctx->saved_data["padding"].toIntVector()), + IntArrayRef(ctx->saved_data["dilation"].toIntVector()), + ctx->saved_data["ceil_mode"].toBool(),ctx->saved_data["return_indices"].toBool()), + Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor()}; + } +}; + +Tensor act_quantized_max_pool2d(Tensor input, IntArrayRef kernel_size, + IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, bool return_indices) { + CHECK_CUDA_TENSOR_FLOAT(input); + return ActQuantizedMaxPool2d::apply(input, kernel_size, stride, padding, dilation, ceil_mode, return_indices); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("pack_mixed_precision", &pack_mixed_precision); + m.def("unpack_mixed_precision", &unpack_mixed_precision); + m.def("pack_single_precision", &pack_single_precision); + m.def("unpack_single_precision", &unpack_single_precision); + m.def("act_quantized_relu", &act_quantized_relu); + m.def("act_quantized_max_pool2d", &act_quantized_max_pool2d); +} diff --git a/actnn/actnn/cpp_extension/quantization_cuda_kernel.cu b/actnn/actnn/cpp_extension/quantization_cuda_kernel.cu new file mode 100644 index 0000000..44492dc --- /dev/null +++ b/actnn/actnn/cpp_extension/quantization_cuda_kernel.cu @@ -0,0 +1,613 @@ +/* + * Cuda kernels for quantization and mixed-precision packing + */ + +#include +#include +#include + +#include +#include +#include + + +using torch::IntArrayRef; +using torch::Tensor; + +/****************************************/ +/****** Pack/Unpack Mixed Precision *****/ +/****************************************/ +template +__global__ void compute_scale_mixed_precision_kernel(const int32_t* __restrict__ bits, + const scalar_t* __restrict__ min, + const scalar_t* __restrict__ max, + scalar_t* __restrict__ scale, + int N, + int num_groups) { + int id = blockIdx.x * blockDim.x + threadIdx.x; + + if (id < N * num_groups) { + scale[id] = ((scalar_t)((1 << bits[id / num_groups]) - 1)) / (max[id] - min[id] + 2e-6); + } +} + + +template +__global__ void pack_mixed_precision_kernel(const int32_t* __restrict__ bits, + const int32_t* __restrict__ prefix_sum, + const scalar_t* __restrict__ data, + const scalar_t* __restrict__ scale, + const scalar_t* __restrict__ min, + int32_t* __restrict__ packed, + std::pair seeds, + int N, + int num_groups, + int group_size) { + extern __shared__ int packed_shared[]; + + const int n = blockIdx.y; + const int group_id = blockIdx.x; + const int d = threadIdx.x; + const int id = (n * num_groups + group_id) * group_size + d; + const int shared_len = group_size * bits[n] / (sizeof(int32_t) * 8); + + if (threadIdx.x * 2 < shared_len) { + reinterpret_cast(packed_shared)[threadIdx.x] = make_int2(0, 0); + } + + curandStatePhilox4_32_10_t state; + curand_init(seeds.first, id, seeds.second, &state); + const float noise = curand_uniform(&state); + + const int val = __float2int_rn(fmax((data[id] - min[n * num_groups + group_id]) * scale[n * num_groups + group_id] + noise - 0.5, 0.0f)); + const int offset = d * bits[n]; + + __syncthreads(); + for (int i = 0; i < bits[n]; i++) { + atomicOr(packed_shared + (offset + i) % shared_len, (1 & (val >> i)) << ((offset + i) / shared_len)); + } + __syncthreads(); + + if (threadIdx.x * 2 < shared_len) { + const int64_t global_offset = \ + ((int64_t)(n == 0 ? 0 : prefix_sum[n-1]) * num_groups * group_size + bits[n] * group_id * group_size) / (sizeof(int32_t) * 8); + reinterpret_cast(packed)[global_offset/2 + threadIdx.x] = \ + reinterpret_cast(packed_shared)[threadIdx.x]; + } +} + +// Pack float16/32 data into int32 bit stream +std::pair pack_mixed_precision_cuda(Tensor data, + Tensor min, + Tensor max, + Tensor bits, + bool stochastic) { + int N = data.size(0); + int num_groups = data.size(1); + int group_size = data.size(2); + + int bits_per_int = sizeof(int32_t) * 8; + + // Compute total bits + Tensor prefix_sum = torch::cumsum(bits, 0, torch::kInt32); + int64_t total_bits = ((int64_t) prefix_sum[-1].item()) * num_groups * group_size; + auto options = torch::TensorOptions().dtype(torch::kInt32).device(data.device()); + Tensor packed = torch::empty({(total_bits + bits_per_int - 1) / bits_per_int,}, options); + + // Compute scale + options = torch::TensorOptions().dtype(data.dtype()).device(data.device()); + Tensor scale = torch::empty({N, num_groups, 1}, options); + int threads = 256; + int blocks = (N * num_groups + threads - 1) / threads; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(scale.scalar_type(), "compute_scale_mixed_precision", ([&] { + compute_scale_mixed_precision_kernel<<>>( + bits.data_ptr(), min.data_ptr(), max.data_ptr(), + scale.data_ptr(), N, num_groups); + })); + + // Random number generator + auto gen = at::check_generator(at::cuda::detail::getDefaultCUDAGenerator()); + std::pair rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_engine_inputs(threads); + } + TORCH_CHECK(stochastic); + + // Pack + int max_bit = torch::max(bits).item(); + dim3 block_dim(num_groups, N, 1); + dim3 thread_dim(group_size, 1, 1); + TORCH_CHECK(group_size % bits_per_int == 0); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(data.scalar_type(), "pack_mixed_precision", ([&] { + pack_mixed_precision_kernel<<>>( + bits.data_ptr(), prefix_sum.data_ptr(), + data.data_ptr(), + scale.data_ptr(), min.data_ptr(), + packed.data_ptr(), + rng_engine_inputs, + N, num_groups, group_size); + })); + + return std::make_pair(packed, scale); +} + +// Unpack int32 bit stream to float16/32 data +template +__global__ void unpack_mixed_precision_kernel(const int32_t* __restrict__ bits, + const int32_t* __restrict__ prefix_sum, + const int32_t* __restrict__ data, + const scalar_t* __restrict__ scale, + const scalar_t* __restrict__ min, + scalar_t* __restrict__ unpacked, + int N, + int num_groups, + int group_size) { + const int n = blockIdx.y; + const int group_id = blockIdx.x; + const int d = threadIdx.x; + const int id = (n * num_groups + group_id) * group_size + d; + const int shared_len = group_size * bits[n] / 32; + + const int64_t global_offset = \ + ((int64_t)(n == 0 ? 0 : prefix_sum[n-1]) * num_groups * group_size + bits[n] * group_id * group_size) / 32; + const int block_offset = d * bits[n]; + + int val = 0; + for (int i = 0; i < bits[n]; i++) { + val |= (1 & (data[global_offset + (block_offset + i) % shared_len] >> ((block_offset + i) / shared_len))) << i; + } + + unpacked[id] = ((scalar_t)val) / scale[n * num_groups + group_id] + min[n * num_groups + group_id]; +} + +// Unpack int32 bit stream to float16/32 data +Tensor unpack_mixed_precision_cuda(Tensor data, + Tensor bits, + Tensor scale, + Tensor min, + int N, + int num_groups, + int group_size) { + Tensor prefix_sum = torch::cumsum(bits, 0, torch::kInt32); + + auto options = torch::TensorOptions().dtype(scale.dtype()).device(data.device()); + Tensor unpacked = torch::empty({N, num_groups, group_size}, options); + + dim3 block_dim(num_groups, N, 1); + dim3 thread_dim(group_size, 1, 1); + TORCH_CHECK(group_size % 32 == 0); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(scale.scalar_type(), "unpack_mixed_precision", ([&] { + unpack_mixed_precision_kernel<<>>( + bits.data_ptr(), prefix_sum.data_ptr(), + data.data_ptr(), + scale.data_ptr(), min.data_ptr(), + unpacked.data_ptr(), + N, num_groups, group_size); + })); + + return unpacked; +} + +/****************************************/ +/***** Pack/Unpack Single Precision *****/ +/****************************************/ +template +__global__ void compute_scale_single_precision_kernel(int32_t bits, + const scalar_t* __restrict__ min, + const scalar_t* __restrict__ max, + scalar_t* __restrict__ scale, + int N, + int num_groups) { + int id = blockIdx.x * blockDim.x + threadIdx.x; + + if (id < N * num_groups) { + scale[id] = ((scalar_t)((1 << bits) - 1)) / (max[id] - min[id] + 2e-6); + } +} + +// Pack float16/32 data into int8 bit stream +template +__global__ void pack_single_precision_kernel(int32_t bits, + const scalar_t* __restrict__ data, + const scalar_t* __restrict__ scale, + const scalar_t* __restrict__ min, + int8_t* __restrict__ packed, + std::pair seeds, + int N, + int num_groups, + int group_size) { + const int no = blockIdx.y; + const int group_id = blockIdx.x; + const int d = threadIdx.x; + const int work_per_thread = 8 / bits; + const int64_t global_thread_id = (int64_t)(no * num_groups + group_id) * group_size + d; + + curandStatePhilox4_32_10_t state; + curand_init(seeds.first, global_thread_id, seeds.second, &state); + + uint8_t local_packed = 0; + for (int ni = 0; ni < work_per_thread; ni++) { + const int n = no * work_per_thread + ni; + + if (boundary_check && n >= N) { break; } + + const int64_t id = (int64_t)(n * num_groups + group_id) * group_size + d; + const float noise = curand_uniform(&state); + const int32_t val = __float2int_rn(fmax((data[id] - min[n * num_groups + group_id]) * scale[n * num_groups + group_id] + noise - 0.5, 0.0f)); + local_packed |= (val << (ni * bits)); + } + + packed[global_thread_id] = local_packed; +} + +// Pack float16/32 data into int8 bit stream +std::pair pack_single_precision_cuda(Tensor data, + Tensor min, + Tensor max, + int bits, + bool stochastic) { + int N = data.size(0); + int num_groups = data.size(1); + int group_size = data.size(2); + + // Compute total bits + int work_per_thread = 8 / bits; + TORCH_CHECK(8 % bits == 0); + + int N_round = N + (work_per_thread - N % work_per_thread) % work_per_thread; + int64_t total_bits = ((int64_t)bits) * (N_round * num_groups * group_size); + auto options = torch::TensorOptions().dtype(torch::kInt8).device(data.device()); + Tensor packed = torch::empty({(total_bits + 8) / 8,}, options); + + // Compute scale + options = torch::TensorOptions().dtype(data.dtype()).device(data.device()); + Tensor scale = torch::empty({N, num_groups, 1}, options); + int threads = 256; + int blocks = (N * num_groups + threads - 1) / threads; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(scale.scalar_type(), "compute_scale_single_precision", ([&] { + compute_scale_single_precision_kernel<<>>( + bits, min.data_ptr(), max.data_ptr(), + scale.data_ptr(), N, num_groups); + })); + + // Random number generator + auto gen = at::check_generator(at::cuda::detail::getDefaultCUDAGenerator()); + std::pair rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_engine_inputs(threads * work_per_thread); + } + TORCH_CHECK(stochastic); + + // Pack + dim3 block_dim(num_groups, (N + work_per_thread - 1) / work_per_thread, 1); + dim3 thread_dim(group_size, 1, 1); + + if (N % work_per_thread == 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(data.scalar_type(), "pack_single_precision", ([&] { + pack_single_precision_kernel<<>>( + bits, + data.data_ptr(), + scale.data_ptr(), min.data_ptr(), + packed.data_ptr(), + rng_engine_inputs, + N, num_groups, group_size); + })); + } else { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(data.scalar_type(), "pack_single_precision", ([&] { + pack_single_precision_kernel<<>>( + bits, + data.data_ptr(), + scale.data_ptr(), min.data_ptr(), + packed.data_ptr(), + rng_engine_inputs, + N, num_groups, group_size); + })); + } + + return std::make_pair(packed, scale); +} + +// Unpack int32 bit stream to float16/32 data +template +__global__ void unpack_single_precision_kernel(int32_t bits, + const int8_t* __restrict__ data, + const scalar_t* __restrict__ scale, + const scalar_t* __restrict__ min, + scalar_t* __restrict__ unpacked, + int N, + int num_groups, + int group_size) { + const int no = blockIdx.y; + const int group_id = blockIdx.x; + const int d = threadIdx.x; + const int64_t global_thread_id = (int64_t)(no * num_groups + group_id) * group_size + d; + + int work_per_thread = 8 / bits; + + uint8_t local_packed = data[global_thread_id]; + int mask = ((1 << bits) - 1); + for (int ni = 0; ni < work_per_thread; ni++) { + const int n = no * work_per_thread + ni; + + if (boundary_check && n >= N) { break; } + + const int val = (local_packed >> (ni * bits)) & mask; + const int64_t id = (int64_t)(n * num_groups + group_id) * group_size + d; + unpacked[id] = ((scalar_t)val) / scale[n * num_groups + group_id] + min[n * num_groups + group_id]; + } +} + +// Unpack int32 bit stream to float16/32 data +Tensor unpack_single_precision_cuda(Tensor data, + int bits, + Tensor scale, + Tensor min, + int N, + int num_groups, + int group_size) { + auto options = torch::TensorOptions().dtype(scale.dtype()).device(data.device()); + Tensor unpacked = torch::empty({N, num_groups, group_size}, options); + + int work_per_thread = 8 / bits; + TORCH_CHECK(8 % bits == 0); + + // Unpack + dim3 block_dim(num_groups, (N + work_per_thread - 1) / work_per_thread, 1); + dim3 thread_dim(group_size, 1, 1); + + if (N % work_per_thread == 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(scale.scalar_type(), "unpack_single_precision", ([&] { + unpack_single_precision_kernel<<>>( + bits, + data.data_ptr(), + scale.data_ptr(), min.data_ptr(), + unpacked.data_ptr(), + N, num_groups, group_size); + })); + } else { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(scale.scalar_type(), "unpack_single_precision", ([&] { + unpack_single_precision_kernel<<>>( + bits, + data.data_ptr(), + scale.data_ptr(), min.data_ptr(), + unpacked.data_ptr(), + N, num_groups, group_size); + })); + } + + return unpacked; +} + + +/****************************************/ +/********** Act Quantized ReLU **********/ +/****************************************/ +#define ACT_QUANTIZED_RELU_NUM_THREADS 512 +// Unpack int32 bit stream to float16/32 data +template +__global__ void act_quantized_relu_forward_kernel(const scalar_t* __restrict__ data, + int32_t* __restrict__ mask, + scalar_t* __restrict__ output, + int N, + int mask_len) { + const int id = blockIdx.x * blockDim.x + threadIdx.x; + const int global_offset = blockIdx.x * blockDim.x / (sizeof(int32_t) * 8); + const int shared_len = ACT_QUANTIZED_RELU_NUM_THREADS / (sizeof(int32_t) * 8); + __shared__ int mask_shared[ACT_QUANTIZED_RELU_NUM_THREADS / (sizeof(int32_t) * 8)]; + + if (threadIdx.x * 2 < shared_len) { + reinterpret_cast(mask_shared)[threadIdx.x] = make_int2(0, 0); + } + + if (id < N) { + bool bit = data[id] > 0; + if (bit) { + output[id] = data[id]; + } else { + output[id] = 0.0; + } + + __syncthreads(); + atomicOr(mask_shared + threadIdx.x % shared_len, bit << (threadIdx.x / shared_len)); + __syncthreads(); + } + + if (threadIdx.x * 2 < shared_len) { + reinterpret_cast(mask)[global_offset / 2 + threadIdx.x] = reinterpret_cast(mask_shared)[threadIdx.x]; + } +} + +std::pair act_quantized_relu_forward_cuda(Tensor data) { + int n_elements = 1; + for (size_t i = 0; i < data.dim(); ++i) { + n_elements *= data.size(i); + } + + auto options = torch::TensorOptions().dtype(torch::kInt32).device(data.device()); + int mask_len = (n_elements + sizeof(int32_t) * 8 - 1) / (sizeof(int32_t) * 8); + Tensor mask = torch::empty({mask_len}, options); + Tensor output = torch::empty_like(data); + + int threads = ACT_QUANTIZED_RELU_NUM_THREADS; + int blocks = (n_elements + threads - 1) / threads; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(data.scalar_type(), "act_quantized_relu_forward", ([&] { + act_quantized_relu_forward_kernel<<>>( + data.data_ptr(), mask.data_ptr(), output.data_ptr(), + n_elements, mask_len); + })); + + return std::make_pair(output, mask); +} + +template +__global__ void act_quantized_relu_backward_kernel(const scalar_t* __restrict__ grad_output, + int32_t* __restrict__ mask, + scalar_t* __restrict__ grad_input, + int N) { + int id = blockIdx.x * blockDim.x + threadIdx.x; + const int global_offset = blockIdx.x * blockDim.x / (sizeof(int32_t) * 8); + const int shared_len = ACT_QUANTIZED_RELU_NUM_THREADS / (sizeof(int32_t) * 8); + + if (id < N) { + bool bit = (mask[global_offset + threadIdx.x % shared_len] >> (threadIdx.x / shared_len)) & 1; + if (bit) { + grad_input[id] = grad_output[id]; + } else { + grad_input[id] = 0.0; + } + } +} + + +Tensor act_quantized_relu_backward_cuda(Tensor grad_output, Tensor mask) { + int n_elements = 1; + for (size_t i = 0; i < grad_output.dim(); ++i) { + n_elements *= grad_output.size(i); + } + + int threads = ACT_QUANTIZED_RELU_NUM_THREADS; + int blocks = (n_elements + threads - 1) / threads; + + Tensor grad_input = torch::empty_like(grad_output); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_output.scalar_type(), "act_quantized_relu_backward", ([&] { + act_quantized_relu_backward_kernel<<>>( + grad_output.data_ptr(), mask.data_ptr(), grad_input.data_ptr(), + n_elements); + })); + + return grad_input; +} + + +/****************************************/ +/******** Act Quantized MaxPool2d *******/ +/****************************************/ +#define ACT_QUANTIZED_MAX_POOL2D_NUM_THREADS 256 +template +__global__ void act_quantized_max_pool2d_forward_kernel(const scalar_t* __restrict__ input, + scalar_t* __restrict__ output, + int8_t* __restrict__ max_indices, + int n_elements, + int N, int C, int H, int W, int H_out, int W_out, + int KH, int KW, int SH, int SW, int PH, int PW, + int DH, int DW) { + int id = blockIdx.x * blockDim.x + threadIdx.x; + + if (id < n_elements) { + int nc = id / (H_out * W_out); + int h = id / W_out % H_out; + int w = id % W_out; + + int h_base = h * SH - PH; + int h_start = std::max(h_base, 0); + int h_end = std::min(h_base + KH, H); + int w_base = w * SW - PW; + int w_start = std::max(w_base, 0); + int w_end = std::min(w_base + KW, W); + + scalar_t v = -1e10; + int8_t index; + for (int i = h_start; i < h_end; i++) { + for (int j = w_start; j < w_end; j++) { + if (input[nc * (H * W) + i * W + j] > v) { + v = input[nc * (H * W) + i * W + j]; + index = (i - h_base) * KW + j - w_base; + } + } + } + + output[id] = v; + max_indices[id] = index; + } +} + +std::pair act_quantized_max_pool2d_forward_cuda(Tensor input, + IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, + bool ceil_mode, bool return_indices) { + int N = input.size(0); + int C = input.size(1); + int H = input.size(2); + int W = input.size(3); + int H_out = (H + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1; + int W_out = (W + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1; + auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device()); + Tensor output = torch::empty({N, C, H_out, W_out}, options); + options = torch::TensorOptions().dtype(torch::kInt8).device(input.device()); + Tensor max_indices = torch::empty({N, C, H_out, W_out}, options); + + int threads = ACT_QUANTIZED_MAX_POOL2D_NUM_THREADS; + int n_elements = N * C * H_out * W_out; + int blocks = (n_elements + threads - 1) / threads; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "act_quantized_max_pool2d_forward", ([&] { + act_quantized_max_pool2d_forward_kernel<<>>( + input.data_ptr(), output.data_ptr(), max_indices.data_ptr(), n_elements, + N, C, H, W, H_out, W_out, kernel_size[0], kernel_size[1], stride[0], stride[1], + padding[0], padding[1], dilation[0], dilation[1]); + })); + + return std::make_pair(output, max_indices); +} + +template +__global__ void act_quantized_max_pool2d_backward_kernel(const scalar_t* __restrict__ grad_output, + int8_t* __restrict__ max_indices, + scalar_t* __restrict__ grad_input, + int n_elements, + int N, int C, int H, int W, int H_out, int W_out, + int KH, int KW, int SH, int SW, int PH, int PW, + int DH, int DW) { + int id = blockIdx.x * blockDim.x + threadIdx.x; + + if (id < n_elements) { + int nc = id / (H_out * W_out); + int h = id / W_out % H_out; + int w = id % W_out; + + int h_base = h * SH - PH; + int w_base = w * SW - PW; + int8_t index = max_indices[id]; + int h_offset = index / KW; + int w_offset = index % KW; + + atomicAdd(grad_input + (nc * H * W) + (h_base + h_offset) * W + (w_base + w_offset), grad_output[id]); + } +} + +Tensor act_quantized_max_pool2d_backward_cuda(Tensor grad_output, Tensor max_indices, + IntArrayRef input_shape, + IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, + bool ceil_mode, bool return_indices) { + auto options = torch::TensorOptions().dtype(grad_output.dtype()).device(grad_output.device()); + Tensor grad_input = torch::zeros(input_shape, options); + + int N = grad_output.size(0); + int C = grad_output.size(1); + int H_out = grad_output.size(2); + int W_out = grad_output.size(3); + int H = input_shape[2]; + int W = input_shape[3]; + + int threads = ACT_QUANTIZED_MAX_POOL2D_NUM_THREADS; + int n_elements = N * C * H_out * W_out; + int blocks = (n_elements + threads - 1) / threads; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_output.scalar_type(), "act_quantized_max_pool2d_backward", ([&] { + act_quantized_max_pool2d_backward_kernel<<>>( + grad_output.data_ptr(), max_indices.data_ptr(), grad_input.data_ptr(), + n_elements, + N, C, H, W, H_out, W_out, kernel_size[0], kernel_size[1], stride[0], stride[1], + padding[0], padding[1], dilation[0], dilation[1]); + })); + + return grad_input; +} diff --git a/actnn/actnn/dataloader.py b/actnn/actnn/dataloader.py new file mode 100644 index 0000000..a212db0 --- /dev/null +++ b/actnn/actnn/dataloader.py @@ -0,0 +1,963 @@ +r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter + +To support these two classes, in `./_utils` we define many utility methods and +functions to be run in multiprocessing. E.g., the data loading worker loop is +in `./_utils/worker.py`. +""" + +import threading +import itertools +import warnings + +import multiprocessing as python_multiprocessing +import torch +import torch.multiprocessing as multiprocessing +from torch._utils import ExceptionWrapper +from torch._six import queue, string_classes + +from torch.utils.data import IterableDataset, Sampler, SequentialSampler, RandomSampler, BatchSampler +from . import _utils + + +get_worker_info = _utils.worker.get_worker_info + +# This function used to be defined in this file. However, it was moved to +# _utils/collate.py. Although it is rather hard to access this from user land +# (one has to explicitly directly `import torch.utils.data.dataloader`), there +# probably is user code out there using it. This aliasing maintains BC in this +# aspect. +default_collate = _utils.collate.default_collate + + +class _DatasetKind(object): + Map = 0 + Iterable = 1 + + @staticmethod + def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last): + if kind == _DatasetKind.Map: + return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last) + else: + return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last) + + +class _InfiniteConstantSampler(Sampler): + r"""Analogous to ``itertools.repeat(None, None)``. + Used as sampler for :class:`~torch.utils.data.IterableDataset`. + + Arguments: + data_source (Dataset): dataset to sample from + """ + + def __init__(self): + super(_InfiniteConstantSampler, self).__init__(None) + + def __iter__(self): + while True: + yield None + + +class DataLoader(object): + r""" + Data loader. Combines a dataset and a sampler, and provides an iterable over + the given dataset. + + The :class:`~torch.utils.data.DataLoader` supports both map-style and + iterable-style datasets with single- or multi-process loading, customizing + loading order and optional automatic batching (collation) and memory pinning. + + See :py:mod:`torch.utils.data` documentation page for more details. + + Arguments: + dataset (Dataset): dataset from which to load the data. + batch_size (int, optional): how many samples per batch to load + (default: ``1``). + shuffle (bool, optional): set to ``True`` to have the data reshuffled + at every epoch (default: ``False``). + sampler (Sampler, optional): defines the strategy to draw samples from + the dataset. If specified, :attr:`shuffle` must be ``False``. + batch_sampler (Sampler, optional): like :attr:`sampler`, but returns a batch of + indices at a time. Mutually exclusive with :attr:`batch_size`, + :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`. + num_workers (int, optional): how many subprocesses to use for data + loading. ``0`` means that the data will be loaded in the main process. + (default: ``0``) + collate_fn (callable, optional): merges a list of samples to form a + mini-batch of Tensor(s). Used when using batched loading from a + map-style dataset. + pin_memory (bool, optional): If ``True``, the data loader will copy Tensors + into CUDA pinned memory before returning them. If your data elements + are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type, + see the example below. + drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, + if the dataset size is not divisible by the batch size. If ``False`` and + the size of dataset is not divisible by the batch size, then the last batch + will be smaller. (default: ``False``) + timeout (numeric, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative. (default: ``0``) + worker_init_fn (callable, optional): If not ``None``, this will be called on each + worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as + input, after seeding and before data loading. (default: ``None``) + + + .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn` + cannot be an unpicklable object, e.g., a lambda function. See + :ref:`multiprocessing-best-practices` on more details related + to multiprocessing in PyTorch. + + .. note:: ``len(dataloader)`` heuristic is based on the length of the sampler used. + When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`, + ``len(dataset)`` (if implemented) is returned instead, regardless + of multi-process loading configurations, because PyTorch trust + user :attr:`dataset` code in correctly handling multi-process + loading to avoid duplicate data. See `Dataset Types`_ for more + details on these two types of datasets and how + :class:`~torch.utils.data.IterableDataset` interacts with `Multi-process data loading`_. + """ + + __initialized = False + + def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, + batch_sampler=None, num_workers=0, collate_fn=None, + pin_memory=False, drop_last=False, timeout=0, + worker_init_fn=None, multiprocessing_context=None): + torch._C._log_api_usage_once("python.data_loader") + + if num_workers < 0: + raise ValueError('num_workers option should be non-negative; ' + 'use num_workers=0 to disable multiprocessing.') + + if timeout < 0: + raise ValueError('timeout option should be non-negative') + + self.dataset = dataset + self.num_workers = num_workers + self.pin_memory = pin_memory + self.timeout = timeout + self.worker_init_fn = worker_init_fn + self.multiprocessing_context = multiprocessing_context + + # Arg-check dataset related before checking samplers because we want to + # tell users that iterable-style datasets are incompatible with custom + # samplers first, so that they don't learn that this combo doesn't work + # after spending time fixing the custom sampler errors. + if isinstance(dataset, IterableDataset): + self._dataset_kind = _DatasetKind.Iterable + # NOTE [ Custom Samplers and `IterableDataset` ] + # + # `IterableDataset` does not support custom `batch_sampler` or + # `sampler` since the key is irrelevant (unless we support + # generator-style dataset one day...). + # + # For `sampler`, we always create a dummy sampler. This is an + # infinite sampler even when the dataset may have an implemented + # finite `__len__` because in multi-process data loading, naive + # settings will return duplicated data (which may be desired), and + # thus using a sampler with length matching that of dataset will + # cause data lost (you may have duplicates of the first couple + # batches, but never see anything afterwards). Therefore, + # `Iterabledataset` always uses an infinite sampler, an instance of + # `_InfiniteConstantSampler` defined above. + # + # A custom `batch_sampler` essentially only controls the batch size. + # However, it is unclear how useful it would be since an iterable-style + # dataset can handle that within itself. Moreover, it is pointless + # in multi-process data loading as the assignment order of batches + # to workers is an implementation detail so users can not control + # how to batchify each worker's iterable. Thus, we disable this + # option. If this turns out to be useful in future, we can re-enable + # this, and support custom samplers that specify the assignments to + # specific workers. + if shuffle is not False: + raise ValueError( + "DataLoader with IterableDataset: expected unspecified " + "shuffle option, but got shuffle={}".format(shuffle)) + elif sampler is not None: + # See NOTE [ Custom Samplers and IterableDataset ] + raise ValueError( + "DataLoader with IterableDataset: expected unspecified " + "sampler option, but got sampler={}".format(sampler)) + elif batch_sampler is not None: + # See NOTE [ Custom Samplers and IterableDataset ] + raise ValueError( + "DataLoader with IterableDataset: expected unspecified " + "batch_sampler option, but got batch_sampler={}".format(batch_sampler)) + else: + self._dataset_kind = _DatasetKind.Map + + if sampler is not None and shuffle: + raise ValueError('sampler option is mutually exclusive with ' + 'shuffle') + + if batch_sampler is not None: + # auto_collation with custom batch_sampler + if batch_size != 1 or shuffle or sampler is not None or drop_last: + raise ValueError('batch_sampler option is mutually exclusive ' + 'with batch_size, shuffle, sampler, and ' + 'drop_last') + batch_size = None + drop_last = False + elif batch_size is None: + # no auto_collation + if shuffle or drop_last: + raise ValueError('batch_size=None option disables auto-batching ' + 'and is mutually exclusive with ' + 'shuffle, and drop_last') + + if sampler is None: # give default samplers + if self._dataset_kind == _DatasetKind.Iterable: + # See NOTE [ Custom Samplers and IterableDataset ] + sampler = _InfiniteConstantSampler() + else: # map-style + if shuffle: + sampler = RandomSampler(dataset) + else: + sampler = SequentialSampler(dataset) + + if batch_size is not None and batch_sampler is None: + # auto_collation without custom batch_sampler + batch_sampler = BatchSampler(sampler, batch_size, drop_last) + + self.batch_size = batch_size + self.drop_last = drop_last + self.sampler = sampler + self.batch_sampler = batch_sampler + + if collate_fn is None: + if self._auto_collation: + collate_fn = _utils.collate.default_collate + else: + collate_fn = _utils.collate.default_convert + + self.collate_fn = collate_fn + self.__initialized = True + self._IterableDataset_len_called = None # See NOTE [ IterableDataset and __len__ ] + + @property + def multiprocessing_context(self): + return self.__multiprocessing_context + + @multiprocessing_context.setter + def multiprocessing_context(self, multiprocessing_context): + if multiprocessing_context is not None: + if self.num_workers > 0: + if not multiprocessing._supports_context: + raise ValueError('multiprocessing_context relies on Python >= 3.4, with ' + 'support for different start methods') + + if isinstance(multiprocessing_context, string_classes): + valid_start_methods = multiprocessing.get_all_start_methods() + if multiprocessing_context not in valid_start_methods: + raise ValueError( + ('multiprocessing_context option ' + 'should specify a valid start method in {}, but got ' + 'multiprocessing_context={}').format(valid_start_methods, multiprocessing_context)) + multiprocessing_context = multiprocessing.get_context(multiprocessing_context) + + if not isinstance(multiprocessing_context, python_multiprocessing.context.BaseContext): + raise ValueError(('multiprocessing_context option should be a valid context ' + 'object or a string specifying the start method, but got ' + 'multiprocessing_context={}').format(multiprocessing_context)) + else: + raise ValueError(('multiprocessing_context can only be used with ' + 'multi-process loading (num_workers > 0), but got ' + 'num_workers={}').format(self.num_workers)) + + self.__multiprocessing_context = multiprocessing_context + + def __setattr__(self, attr, val): + if self.__initialized and attr in ('batch_size', 'batch_sampler', 'sampler', 'drop_last', 'dataset'): + raise ValueError('{} attribute should not be set after {} is ' + 'initialized'.format(attr, self.__class__.__name__)) + + super(DataLoader, self).__setattr__(attr, val) + + def __iter__(self): + # return _SingleProcessDataLoaderIter(self) + if self.num_workers == 0: + return _SingleProcessDataLoaderIter(self) + else: + return _MultiProcessingDataLoaderIter(self) + + @property + def _auto_collation(self): + return self.batch_sampler is not None + + @property + def _index_sampler(self): + # The actual sampler used for generating indices for `_DatasetFetcher` + # (see _utils/fetch.py) to read data at each time. This would be + # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise. + # We can't change `.sampler` and `.batch_sampler` attributes for BC + # reasons. + if self._auto_collation: + return self.batch_sampler + else: + return self.sampler + + def __len__(self): + if self._dataset_kind == _DatasetKind.Iterable: + # NOTE [ IterableDataset and __len__ ] + # + # For `IterableDataset`, `__len__` could be inaccurate when one naively + # does multi-processing data loading, since the samples will be duplicated. + # However, no real use case should be actually using that behavior, so + # it should count as a user error. We should generally trust user + # code to do the proper thing (e.g., configure each replica differently + # in `__iter__`), and give us the correct `__len__` if they choose to + # implement it (this will still throw if the dataset does not implement + # a `__len__`). + # + # To provide a further warning, we track if `__len__` was called on the + # `DataLoader`, save the returned value in `self._len_called`, and warn + # if the iterator ends up yielding more than this number of samples. + length = self._IterableDataset_len_called = len(self.dataset) + return length + else: + return len(self._index_sampler) + + +class _BaseDataLoaderIter(object): + def __init__(self, loader): + self.loader = loader + self._dataset = loader.dataset + self._dataset_kind = loader._dataset_kind + self._IterableDataset_len_called = loader._IterableDataset_len_called + self._auto_collation = loader._auto_collation + self._drop_last = loader.drop_last + self._index_sampler = loader._index_sampler + self._num_workers = loader.num_workers + self._pin_memory = loader.pin_memory and torch.cuda.is_available() + self._timeout = loader.timeout + self._collate_fn = loader.collate_fn + self._sampler_iter = iter(self._index_sampler) + self._base_seed = torch.empty((), dtype=torch.int64).random_().item() + self._num_yielded = 0 + + def __iter__(self): + return self + + def _next_index(self): + return next(self._sampler_iter) + + def _next_data(self): + raise NotImplementedError + + def __next__(self): + data = self._next_data() + self._num_yielded += 1 + if self._dataset_kind == _DatasetKind.Iterable and \ + self._IterableDataset_len_called is not None and \ + self._num_yielded > self._IterableDataset_len_called: + warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} " + "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called, + self._num_yielded) + if self._num_workers > 0: + warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the " + "IterableDataset replica at each worker. Please see " + "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.") + warnings.warn(warn_msg) + return data + + next = __next__ # Python 2 compatibility + + def __len__(self): + return len(self._index_sampler) + + def __getstate__(self): + # TODO: add limited pickling support for sharing an iterator + # across multiple threads for HOGWILD. + # Probably the best way to do this is by moving the sample pushing + # to a separate thread and then just sharing the data queue + # but signalling the end is tricky without a non-blocking API + raise NotImplementedError("{} cannot be pickled", self.__class__.__name__) + + +class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): + def __init__(self, loader): + super(_SingleProcessDataLoaderIter, self).__init__(loader) + assert self._timeout == 0 + assert self._num_workers == 0 + + self._dataset_fetcher = _DatasetKind.create_fetcher( + self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last) + + def _next_data(self): + index = self._next_index() # may raise StopIteration + data = self._dataset_fetcher.fetch(index) # may raise StopIteration + if self._pin_memory: + data = _utils.pin_memory.pin_memory(data) + return index, data + + +class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): + r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" + + # NOTE [ Data Loader Multiprocessing Shutdown Logic ] + # + # Preliminary: + # + # Our data model looks like this (queues are indicated with curly brackets): + # + # main process || + # | || + # {index_queue} || + # | || + # worker processes || DATA + # | || + # {worker_result_queue} || FLOW + # | || + # pin_memory_thread of main process || DIRECTION + # | || + # {data_queue} || + # | || + # data output \/ + # + # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if + # `pin_memory=False`. + # + # + # Terminating multiprocessing logic requires very careful design. In + # particular, we need to make sure that + # + # 1. The iterator gracefully exits the workers when its last reference is + # gone or it is depleted. + # + # In this case, the workers should be gracefully exited because the + # main process may still need to continue to run, and we want cleaning + # up code in the workers to be executed (e.g., releasing GPU memory). + # Naturally, we implement the shutdown logic in `__del__` of + # DataLoaderIterator. + # + # We delay the discussion on the logic in this case until later. + # + # 2. The iterator exits the workers when the loader process and/or worker + # processes exits normally or with error. + # + # We set all workers and `pin_memory_thread` to have `daemon=True`. + # + # You may ask, why can't we make the workers non-daemonic, and + # gracefully exit using the same logic as we have in `__del__` when the + # iterator gets deleted (see 1 above)? + # + # First of all, `__del__` is **not** guaranteed to be called when + # interpreter exits. Even if it is called, by the time it executes, + # many Python core library resources may alreay be freed, and even + # simple things like acquiring an internal lock of a queue may hang. + # Therefore, in this case, we actually need to prevent `__del__` from + # being executed, and rely on the automatic termination of daemonic + # children. Thus, we register an `atexit` hook that sets a global flag + # `_utils.python_exit_status`. Since `atexit` hooks are executed in the + # reverse order of registration, we are guaranteed that this flag is + # set before library resources we use are freed. (Hooks freeing those + # resources are registered at importing the Python core libraries at + # the top of this file.) So in `__del__`, we check if + # `_utils.python_exit_status` is set or `None` (freed), and perform + # no-op if so. + # + # Another problem with `__del__` is also related to the library cleanup + # calls. When a process ends, it shuts the all its daemonic children + # down with a SIGTERM (instead of joining them without a timeout). + # Simiarly for threads, but by a different mechanism. This fact, + # together with a few implementation details of multiprocessing, forces + # us to make workers daemonic. All of our problems arise when a + # DataLoader is used in a subprocess, and are caused by multiprocessing + # code which looks more or less like this: + # + # try: + # your_function_using_a_dataloader() + # finally: + # multiprocessing.util._exit_function() + # + # The joining/termination mentioned above happens inside + # `_exit_function()`. Now, if `your_function_using_a_dataloader()` + # throws, the stack trace stored in the exception will prevent the + # frame which uses `DataLoaderIter` to be freed. If the frame has any + # reference to the `DataLoaderIter` (e.g., in a method of the iter), + # its `__del__`, which starts the shutdown procedure, will not be + # called. That, in turn, means that workers aren't notified. Attempting + # to join in `_exit_function` will then result in a hang. + # + # For context, `_exit_function` is also registered as an `atexit` call. + # So it is unclear to me (@ssnl) why this is needed in a finally block. + # The code dates back to 2008 and there is no comment on the original + # PEP 371 or patch https://bugs.python.org/issue3050 (containing both + # the finally block and the `atexit` registration) that explains this. + # + # Another choice is to just shutdown workers with logic in 1 above + # whenever we see an error in `next`. This isn't ideal because + # a. It prevents users from using try-catch to resume data loading. + # b. It doesn't prevent hanging if users have references to the + # iterator. + # + # 3. All processes exit if any of them die unexpectedly by fatal signals. + # + # As shown above, the workers are set as daemonic children of the main + # process. However, automatic cleaning-up of such child processes only + # happens if the parent process exits gracefully (e.g., not via fatal + # signals like SIGKILL). So we must ensure that each process will exit + # even the process that should send/receive data to/from it were + # killed, i.e., + # + # a. A process won't hang when getting from a queue. + # + # Even with carefully designed data dependencies (i.e., a `put()` + # always corresponding to a `get()`), hanging on `get()` can still + # happen when data in queue is corrupted (e.g., due to + # `cancel_join_thread` or unexpected exit). + # + # For child exit, we set a timeout whenever we try to get data + # from `data_queue`, and check the workers' status on each timeout + # and error. + # See `_DataLoaderiter._get_batch()` and + # `_DataLoaderiter._try_get_data()` for details. + # + # Additionally, for child exit on non-Windows platforms, we also + # register a SIGCHLD handler (which is supported on Windows) on + # the main process, which checks if any of the workers fail in the + # (Python) handler. This is more efficient and faster in detecting + # worker failures, compared to only using the above mechanism. + # See `DataLoader.cpp` and `_utils/signal_handling.py` for details. + # + # For `.get()` calls where the sender(s) is not the workers, we + # guard them with timeouts, and check the status of the sender + # when timeout happens: + # + in the workers, the `_utils.worker.ManagerWatchdog` class + # checks the status of the main process. + # + if `pin_memory=True`, when getting from `pin_memory_thread`, + # check `pin_memory_thread` status periodically until `.get()` + # returns or see that `pin_memory_thread` died. + # + # b. A process won't hang when putting into a queue; + # + # We use `mp.Queue` which has a separate background thread to put + # objects from an unbounded buffer array. The background thread is + # daemonic and usually automatically joined when the process + # exits. + # + # However, in case that the receiver has ended abruptly while + # reading from the pipe, the join will hang forever. Therefore, + # for both `worker_result_queue` (worker -> main process/pin_memory_thread) + # and each `index_queue` (main process -> worker), we use + # `q.cancel_join_thread()` in sender process before any `q.put` to + # prevent this automatic join. + # + # Moreover, having all queues called `cancel_join_thread` makes + # implementing graceful shutdown logic in `__del__` much easier. + # It won't need to get from any queue, which would also need to be + # guarded by periodic status checks. + # + # Nonetheless, `cancel_join_thread` must only be called when the + # queue is **not** going to be read from or write into by another + # process, because it may hold onto a lock or leave corrupted data + # in the queue, leading other readers/writers to hang. + # + # `pin_memory_thread`'s `data_queue` is a `queue.Queue` that does + # a blocking `put` if the queue is full. So there is no above + # problem, but we do need to wrap the `put` in a loop that breaks + # not only upon success, but also when the main process stops + # reading, i.e., is shutting down. + # + # + # Now let's get back to 1: + # how we gracefully exit the workers when the last reference to the + # iterator is gone. + # + # To achieve this, we implement the following logic along with the design + # choices mentioned above: + # + # `workers_done_event`: + # A `multiprocessing.Event` shared among the main process and all worker + # processes. This is used to signal the workers that the iterator is + # shutting down. After it is set, they will not send processed data to + # queues anymore, and only wait for the final `None` before exiting. + # `done_event` isn't strictly needed. I.e., we can just check for `None` + # from the input queue, but it allows us to skip wasting resources + # processing data if we are already shutting down. + # + # `pin_memory_thread_done_event`: + # A `threading.Event` for a similar purpose to that of + # `workers_done_event`, but is for the `pin_memory_thread`. The reason + # that separate events are needed is that `pin_memory_thread` reads from + # the output queue of the workers. But the workers, upon seeing that + # `workers_done_event` is set, only wants to see the final `None`, and is + # not required to flush all data in the output queue (e.g., it may call + # `cancel_join_thread` on that queue if its `IterableDataset` iterator + # happens to exhaust coincidentally, which is out of the control of the + # main process). Thus, since we will exit `pin_memory_thread` before the + # workers (see below), two separete events are used. + # + # NOTE: In short, the protocol is that the main process will set these + # `done_event`s and then the corresponding processes/threads a `None`, + # and that they may exit at any time after receiving the `None`. + # + # NOTE: Using `None` as the final signal is valid, since normal data will + # always be a 2-tuple with the 1st element being the index of the data + # transferred (different from dataset index/key), and the 2nd being + # either the dataset key or the data sample (depending on which part + # of the data model the queue is at). + # + # [ worker processes ] + # While loader process is alive: + # Get from `index_queue`. + # If get anything else, + # Check `workers_done_event`. + # If set, continue to next iteration + # i.e., keep getting until see the `None`, then exit. + # Otherwise, process data: + # If is fetching from an `IterableDataset` and the iterator + # is exhausted, send an `_IterableDatasetStopIteration` + # object to signal iteration end. The main process, upon + # receiving such an object, will send `None` to this + # worker and not use the corresponding `index_queue` + # anymore. + # If timed out, + # No matter `workers_done_event` is set (still need to see `None`) + # or not, must continue to next iteration. + # (outside loop) + # If `workers_done_event` is set, (this can be False with `IterableDataset`) + # `data_queue.cancel_join_thread()`. (Everything is ending here: + # main process won't read from it; + # other workers will also call + # `cancel_join_thread`.) + # + # [ pin_memory_thread ] + # # No need to check main thread. If this thread is alive, the main loader + # # thread must be alive, because this thread is set as daemonic. + # While `pin_memory_thread_done_event` is not set: + # Get from `index_queue`. + # If timed out, continue to get in the next iteration. + # Otherwise, process data. + # While `pin_memory_thread_done_event` is not set: + # Put processed data to `data_queue` (a `queue.Queue` with blocking put) + # If timed out, continue to put in the next iteration. + # Otherwise, break, i.e., continuing to the out loop. + # + # NOTE: we don't check the status of the main thread because + # 1. if the process is killed by fatal signal, `pin_memory_thread` + # ends. + # 2. in other cases, either the cleaning-up in __del__ or the + # automatic exit of daemonic thread will take care of it. + # This won't busy-wait either because `.get(timeout)` does not + # busy-wait. + # + # [ main process ] + # In the DataLoader Iter's `__del__` + # b. Exit `pin_memory_thread` + # i. Set `pin_memory_thread_done_event`. + # ii Put `None` in `worker_result_queue`. + # iii. Join the `pin_memory_thread`. + # iv. `worker_result_queue.cancel_join_thread()`. + # + # c. Exit the workers. + # i. Set `workers_done_event`. + # ii. Put `None` in each worker's `index_queue`. + # iii. Join the workers. + # iv. Call `.cancel_join_thread()` on each worker's `index_queue`. + # + # NOTE: (c) is better placed after (b) because it may leave corrupted + # data in `worker_result_queue`, which `pin_memory_thread` + # reads from, in which case the `pin_memory_thread` can only + # happen at timeing out, which is slow. Nonetheless, same thing + # happens if a worker is killed by signal at unfortunate times, + # but in other cases, we are better off having a non-corrupted + # `worker_result_queue` for `pin_memory_thread`. + # + # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b) + # can be omitted + # + # NB: `done_event`s isn't strictly needed. E.g., we can just check for + # `None` from `index_queue`, but it allows us to skip wasting resources + # processing indices already in `index_queue` if we are already shutting + # down. + + def __init__(self, loader): + super(_MultiProcessingDataLoaderIter, self).__init__(loader) + + assert self._num_workers > 0 + + if loader.multiprocessing_context is None: + multiprocessing_context = multiprocessing + else: + multiprocessing_context = loader.multiprocessing_context + + self._worker_init_fn = loader.worker_init_fn + self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers)) + self._worker_result_queue = multiprocessing_context.Queue() + self._worker_pids_set = False + self._shutdown = False + self._send_idx = 0 # idx of the next task to be sent to workers + self._rcvd_idx = 0 # idx of the next task to be returned in __next__ + # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx). + # map: task idx => - (worker_id,) if data isn't fetched (outstanding) + # \ (worker_id, data) if data is already fetched (out-of-order) + self._task_info = {} + self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1) + self._workers_done_event = multiprocessing_context.Event() + + self._index_queues = [] + self._workers = [] + # A list of booleans representing whether each worker still has work to + # do, i.e., not having exhausted its iterable dataset object. It always + # contains all `True`s if not using an iterable-style dataset + # (i.e., if kind != Iterable). + self._workers_status = [] + for i in range(self._num_workers): + index_queue = multiprocessing_context.Queue() + # index_queue.cancel_join_thread() + w = multiprocessing_context.Process( + target=_utils.worker._worker_loop, + args=(self._dataset_kind, self._dataset, index_queue, + self._worker_result_queue, self._workers_done_event, + self._auto_collation, self._collate_fn, self._drop_last, + self._base_seed + i, self._worker_init_fn, i, self._num_workers)) + w.daemon = True + # NB: Process.start() actually take some time as it needs to + # start a process and pass the arguments over via a pipe. + # Therefore, we only add a worker to self._workers list after + # it started, so that we do not call .join() if program dies + # before it starts, and __del__ tries to join but will get: + # AssertionError: can only join a started process. + w.start() + self._index_queues.append(index_queue) + self._workers.append(w) + self._workers_status.append(True) + + if self._pin_memory: + self._pin_memory_thread_done_event = threading.Event() + self._data_queue = queue.Queue() + pin_memory_thread = threading.Thread( + target=_utils.pin_memory._pin_memory_loop, + args=(self._worker_result_queue, self._data_queue, + torch.cuda.current_device(), + self._pin_memory_thread_done_event)) + pin_memory_thread.daemon = True + pin_memory_thread.start() + # Similar to workers (see comment above), we only register + # pin_memory_thread once it is started. + self._pin_memory_thread = pin_memory_thread + else: + self._data_queue = self._worker_result_queue + + _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) + _utils.signal_handling._set_SIGCHLD_handler() + self._worker_pids_set = True + + # prime the prefetch loop + for _ in range(2 * self._num_workers): + self._try_put_index() + + def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): + # Tries to fetch data from `self._data_queue` once for a given timeout. + # This can also be used as inner loop of fetching without timeout, with + # the sender status as the loop condition. + # + # This raises a `RuntimeError` if any worker died expectedly. This error + # can come from either the SIGCHLD handler in `_utils/signal_handling.py` + # (only for non-Windows platforms), or the manual check below on errors + # and timeouts. + # + # Returns a 2-tuple: + # (bool: whether successfully get data, any: data if successful else None) + try: + data = self._data_queue.get(timeout=timeout) + return (True, data) + except Exception as e: + # At timeout and error, we manually check whether any worker has + # failed. Note that this is the only mechanism for Windows to detect + # worker failures. + failed_workers = [] + for worker_id, w in enumerate(self._workers): + if self._workers_status[worker_id] and not w.is_alive(): + failed_workers.append(w) + self._shutdown_worker(worker_id) + if len(failed_workers) > 0: + pids_str = ', '.join(str(w.pid) for w in failed_workers) + raise RuntimeError('DataLoader qworker (pid(s) {}) exited unexpectedly'.format(pids_str)) + if isinstance(e, queue.Empty): + return (False, None) + raise + + def _get_data(self): + # Fetches data from `self._data_queue`. + # + # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds, + # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)` + # in a loop. This is the only mechanism to detect worker failures for + # Windows. For other platforms, a SIGCHLD handler is also used for + # worker failure detection. + # + # If `pin_memory=True`, we also need check if `pin_memory_thread` had + # died at timeouts. + if self._timeout > 0: + success, data = self._try_get_data(self._timeout) + if success: + return data + else: + raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout)) + elif self._pin_memory: + while self._pin_memory_thread.is_alive(): + success, data = self._try_get_data() + if success: + return data + else: + # while condition is false, i.e., pin_memory_thread died. + raise RuntimeError('Pin memory thread exited unexpectedly') + # In this case, `self._data_queue` is a `queue.Queue`,. But we don't + # need to call `.task_done()` because we don't use `.join()`. + else: + while True: + success, data = self._try_get_data() + if success: + return data + + def _next_data(self): + while True: + # If the worker responsible for `self._rcvd_idx` has already ended + # and was unable to fulfill this task (due to exhausting an `IterableDataset`), + # we try to advance `self._rcvd_idx` to find the next valid index. + # + # This part needs to run in the loop because both the `self._get_data()` + # call and `_IterableDatasetStopIteration` check below can mark + # extra worker(s) as dead. + while self._rcvd_idx < self._send_idx: + info = self._task_info[self._rcvd_idx] + worker_id = info[0] + if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active + break + del self._task_info[self._rcvd_idx] + self._rcvd_idx += 1 + else: + # no valid `self._rcvd_idx` is found (i.e., didn't break) + self._shutdown_workers() + raise StopIteration + + # Now `self._rcvd_idx` is the batch index we want to fetch + + # Check if the next sample has already been generated + if len(self._task_info[self._rcvd_idx]) == 2: + data = self._task_info.pop(self._rcvd_idx)[1] + return data[0], self._process_data(data[1]) + + assert not self._shutdown and self._tasks_outstanding > 0 + idx, data = self._get_data() + self._tasks_outstanding -= 1 + + if self._dataset_kind == _DatasetKind.Iterable: + # Check for _IterableDatasetStopIteration + if isinstance(data, _utils.worker._IterableDatasetStopIteration): + self._shutdown_worker(data.worker_id) + self._try_put_index() + continue + + if idx != self._rcvd_idx: + # store out-of-order samples + self._task_info[idx] += (data,) + else: + del self._task_info[idx] + return data[0], self._process_data(data[1]) + + def _try_put_index(self): + assert self._tasks_outstanding < 2 * self._num_workers + try: + index = self._next_index() + except StopIteration: + return + for _ in range(self._num_workers): # find the next active worker, if any + worker_queue_idx = next(self._worker_queue_idx_cycle) + if self._workers_status[worker_queue_idx]: + break + else: + # not found (i.e., didn't break) + return + + self._index_queues[worker_queue_idx].put((self._send_idx, index)) + self._task_info[self._send_idx] = (worker_queue_idx,) + self._tasks_outstanding += 1 + self._send_idx += 1 + + def _process_data(self, data): + self._rcvd_idx += 1 + self._try_put_index() + if isinstance(data, ExceptionWrapper): + data.reraise() + return data + + def _shutdown_worker(self, worker_id): + # Mark a worker as having finished its work and dead, e.g., due to + # exhausting an `IterableDataset`. This should be used only when this + # `_MultiProcessingDataLoaderIter` is going to continue running. + + assert self._workers_status[worker_id] + + # Signal termination to that specific worker. + q = self._index_queues[worker_id] + # Indicate that no more data will be put on this queue by the current + # process. + q.put(None) + + # Note that we don't actually join the worker here, nor do we remove the + # worker's pid from C side struct because (1) joining may be slow, and + # (2) since we don't join, the worker may still raise error, and we + # prefer capturing those, rather than ignoring them, even though they + # are raised after the worker has finished its job. + # Joinning is deferred to `_shutdown_workers`, which it is called when + # all workers finish their jobs (e.g., `IterableDataset` replicas) or + # when this iterator is garbage collected. + self._workers_status[worker_id] = False + + def _shutdown_workers(self): + # Called when shutting down this `_MultiProcessingDataLoaderIter`. + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on + # the logic of this function. + python_exit_status = _utils.python_exit_status + if python_exit_status is True or python_exit_status is None: + # See (2) of the note. If Python is shutting down, do no-op. + return + # Normal exit when last reference is gone / iterator is depleted. + # See (1) and the second half of the note. + if not self._shutdown: + self._shutdown = True + try: + # Exit `pin_memory_thread` first because exiting workers may leave + # corrupted data in `worker_result_queue` which `pin_memory_thread` + # reads from. + if hasattr(self, '_pin_memory_thread'): + # Use hasattr in case error happens before we set the attribute. + self._pin_memory_thread_done_event.set() + # Send something to pin_memory_thread in case it is waiting + # so that it can wake up and check `pin_memory_thread_done_event` + self._worker_result_queue.put((None, None)) + self._pin_memory_thread.join() + self._worker_result_queue.close() + + # Exit workers now. + self._workers_done_event.set() + for worker_id in range(len(self._workers)): + # Get number of workers from `len(self._workers)` instead of + # `self._num_workers` in case we error before starting all + # workers. + if self._workers_status[worker_id]: + self._shutdown_worker(worker_id) + for w in self._workers: + w.join() + for q in self._index_queues: + q.cancel_join_thread() + q.close() + finally: + # Even though all this function does is putting into queues that + # we have called `cancel_join_thread` on, weird things can + # happen when a worker is killed by a signal, e.g., hanging in + # `Event.set()`. So we need to guard this with SIGCHLD handler, + # and remove pids from the C side data structure only at the + # end. + # + # FIXME: Unfortunately, for Windows, we are missing a worker + # error detection mechanism here in this function, as it + # doesn't provide a SIGCHLD handler. + if self._worker_pids_set: + _utils.signal_handling._remove_worker_pids(id(self)) + self._worker_pids_set = False + + def __del__(self): + self._shutdown_workers() diff --git a/actnn/actnn/layers.py b/actnn/actnn/layers.py new file mode 100644 index 0000000..e6e34c0 --- /dev/null +++ b/actnn/actnn/layers.py @@ -0,0 +1,502 @@ +# The code is compatible with PyTorch 1.6/1.7 +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed +from torch import Tensor +from torch.nn.modules.pooling import _size_2_t, _single, _pair, _triple, _MaxPoolNd, _AvgPoolNd + +from actnn.qscheme import QScheme +from actnn.qbnscheme import QBNScheme +from actnn.conf import config +from actnn.ops import linear, batch_norm, conv1d, conv2d, conv3d, sync_batch_norm +from actnn.ops import conv_transpose1d, conv_transpose2d, conv_transpose3d +import actnn.cpp_extension.quantization as ext_quantization + + +class QConv1d(nn.Conv1d): + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', group=0): + super(QConv1d, self).__init__(in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode) + if isinstance(kernel_size, int): + num_locations = kernel_size + else: + num_locations = kernel_size[0] + + if config.adaptive_conv_scheme: + self.scheme = QScheme(self, num_locations=num_locations, group=group, depthwise_groups=groups) + else: + self.scheme = None + + def forward(self, input): + if config.training: + if self.padding_mode != 'zeros': + return conv1d.apply(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + self.weight, self.bias, self.stride, + _single(0), self.dilation, self.groups, self.scheme) + return conv1d.apply(input, self.weight, self.bias, self.stride, + self.padding, self.dilation, self.groups, self.scheme) + else: + return super(QConv1d, self).forward(input) + + +class QConv2d(nn.Conv2d): + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', group=0): + super(QConv2d, self).__init__(in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode) + if isinstance(kernel_size, int): + num_locations = kernel_size ** 2 + else: + num_locations = kernel_size[0] * kernel_size[1] + + if config.adaptive_conv_scheme: + self.scheme = QScheme(self, num_locations=num_locations, group=group, depthwise_groups=groups) + else: + self.scheme = None + + def forward(self, input): + if config.training: + if self.padding_mode != 'zeros': + return conv2d.apply(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + self.weight, self.bias, self.stride, + _pair(0), self.dilation, self.groups, self.scheme) + return conv2d.apply(input, self.weight, self.bias, self.stride, + self.padding, self.dilation, self.groups, self.scheme) + else: + return super(QConv2d, self).forward(input) + + +class QConv3d(nn.Conv3d): + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', group=0): + super(QConv3d, self).__init__(in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode) + if isinstance(kernel_size, int): + num_locations = kernel_size ** 3 + else: + num_locations = kernel_size[0] * kernel_size[1] * kernel_size[2] + + if config.adaptive_conv_scheme: + self.scheme = QScheme(self, num_locations=num_locations, group=group, depthwise_groups=groups) + else: + self.scheme = None + + def forward(self, input): + if config.training: + if self.padding_mode != 'zeros': + return conv3d.apply(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + self.weight, self.bias, self.stride, + _triple(0), self.dilation, self.groups, self.scheme) + return conv3d.apply(input, self.weight, self.bias, self.stride, + self.padding, self.dilation, self.groups, self.scheme) + else: + return super(QConv3d, self).forward(input) + + +class QConvTranspose1d(nn.ConvTranspose1d): + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, output_padding=0, groups=1, + bias=True, dilation=1, padding_mode='zeros', group=0): + super(QConvTranspose1d, self).__init__(in_channels, out_channels, kernel_size, stride, + padding, output_padding, groups, bias, dilation, padding_mode) + if isinstance(kernel_size, int): + num_locations = kernel_size + else: + num_locations = kernel_size[0] + + if config.adaptive_conv_scheme: + self.scheme = QScheme(self, num_locations=num_locations, group=group, depthwise_groups=groups) + else: + self.scheme = None + + def forward(self, input, output_size=None): + if config.training: + if self.padding_mode != 'zeros': + raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d') + + output_padding = self._output_padding( + input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) # type: ignore + + return conv_transpose1d.apply( + input, self.weight, self.bias, self.stride, self.padding, + output_padding, self.groups, self.dilation, self.scheme) + else: + return super(QConvTranspose1d, self).forward(input, output_size) + + +class QConvTranspose2d(nn.ConvTranspose2d): + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, output_padding=0, groups=1, + bias=True, dilation=1, padding_mode='zeros', group=0): + super(QConvTranspose2d, self).__init__(in_channels, out_channels, kernel_size, stride, + padding, output_padding, groups, bias, dilation, padding_mode) + if isinstance(kernel_size, int): + num_locations = kernel_size ** 2 + else: + num_locations = kernel_size[0] * kernel_size[1] + + if config.adaptive_conv_scheme: + self.scheme = QScheme(self, num_locations=num_locations, group=group, depthwise_groups=groups) + else: + self.scheme = None + + def forward(self, input, output_size=None): + if config.training: + if self.padding_mode != 'zeros': + raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d') + + output_padding = self._output_padding( + input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) # type: ignore + + return conv_transpose2d.apply( + input, self.weight, self.bias, self.stride, self.padding, + output_padding, self.groups, self.dilation, self.scheme) + else: + return super(QConvTranspose2d, self).forward(input, output_size) + + +class QConvTranspose3d(nn.ConvTranspose3d): + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, output_padding=0, groups=1, + bias=True, dilation=1, padding_mode='zeros', group=0): + super(QConvTranspose3d, self).__init__(in_channels, out_channels, kernel_size, stride, + padding, output_padding, groups, bias, dilation, padding_mode) + if isinstance(kernel_size, int): + num_locations = kernel_size ** 3 + else: + num_locations = kernel_size[0] * kernel_size[1] * kernel_size[2] + + if config.adaptive_conv_scheme: + self.scheme = QScheme(self, num_locations=num_locations, group=group, depthwise_groups=groups) + else: + self.scheme = None + + def forward(self, input, output_size=None): + if config.training: + if self.padding_mode != 'zeros': + raise ValueError('Only `zeros` padding mode is supported for ConvTranspose3d') + + output_padding = self._output_padding( + input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) # type: ignore + + return conv_transpose3d.apply( + input, self.weight, self.bias, self.stride, self.padding, + output_padding, self.groups, self.dilation, self.scheme) + else: + return super(QConvTranspose3d, self).forward(input, output_size) + + +class QLinear(nn.Linear): + num_layers = 0 + + def __init__(self, input_features, output_features, bias=True, group=0): + super(QLinear, self).__init__(input_features, output_features, bias) + if config.adaptive_conv_scheme: + self.scheme = QScheme(self, group=group) + else: + self.scheme = None + + def forward(self, input): + if config.training: + return linear.apply(input, self.weight, self.bias, self.scheme) + else: + return super(QLinear, self).forward(input) + + +class QBatchNorm1d(nn.BatchNorm1d): + def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, group=0): + super(QBatchNorm1d, self).__init__(num_features, eps, momentum, affine, track_running_stats) + if config.adaptive_bn_scheme: + self.scheme = QBNScheme(group=group) + else: + self.scheme = None + + def forward(self, input): + if not config.training: + return super(QBatchNorm1d, self).forward(input) + + self._check_input_dim(input) + + # exponential_average_factor is set to self.momentum + # (when it is available) only so that it gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + # TODO: if statement only here to tell the jit to skip emitting this when it is None + if self.num_batches_tracked is not None: + self.num_batches_tracked = self.num_batches_tracked + 1 + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + """ Decide whether the mini-batch stats should be used for normalization rather than the buffers. + Mini-batch stats are used in training mode, and in eval mode when buffers are None. + """ + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and (self.running_var is None) + + """Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be + passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are + used for normalization (i.e. in eval mode when buffers are not None). + """ + return batch_norm.apply( + input, + # If buffers are not to be tracked, ensure that they won't be updated + self.running_mean if not self.training or self.track_running_stats else None, + self.running_var if not self.training or self.track_running_stats else None, + self.weight, self.bias, bn_training, exponential_average_factor, self.eps, self.scheme) + + +class QBatchNorm2d(nn.BatchNorm2d): + def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, group=0): + super(QBatchNorm2d, self).__init__(num_features, eps, momentum, affine, track_running_stats) + if config.adaptive_bn_scheme: + self.scheme = QBNScheme(group=group) + else: + self.scheme = None + + def forward(self, input): + if not config.training: + return super(QBatchNorm2d, self).forward(input) + + self._check_input_dim(input) + + # exponential_average_factor is set to self.momentum + # (when it is available) only so that it gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + # TODO: if statement only here to tell the jit to skip emitting this when it is None + if self.num_batches_tracked is not None: + self.num_batches_tracked = self.num_batches_tracked + 1 + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + """ Decide whether the mini-batch stats should be used for normalization rather than the buffers. + Mini-batch stats are used in training mode, and in eval mode when buffers are None. + """ + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and (self.running_var is None) + + """Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be + passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are + used for normalization (i.e. in eval mode when buffers are not None). + """ + return batch_norm.apply( + input, + # If buffers are not to be tracked, ensure that they won't be updated + self.running_mean if not self.training or self.track_running_stats else None, + self.running_var if not self.training or self.track_running_stats else None, + self.weight, self.bias, bn_training, exponential_average_factor, self.eps, self.scheme) + + +class QBatchNorm3d(nn.BatchNorm3d): + def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, group=0): + super(QBatchNorm3d, self).__init__(num_features, eps, momentum, affine, track_running_stats) + if config.adaptive_bn_scheme: + self.scheme = QBNScheme(group=group) + else: + self.scheme = None + + def forward(self, input): + if not config.training: + return super(QBatchNorm3d, self).forward(input) + + self._check_input_dim(input) + + # exponential_average_factor is set to self.momentum + # (when it is available) only so that it gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + # TODO: if statement only here to tell the jit to skip emitting this when it is None + if self.num_batches_tracked is not None: + self.num_batches_tracked = self.num_batches_tracked + 1 + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + """ Decide whether the mini-batch stats should be used for normalization rather than the buffers. + Mini-batch stats are used in training mode, and in eval mode when buffers are None. + """ + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and (self.running_var is None) + + """Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be + passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are + used for normalization (i.e. in eval mode when buffers are not None). + """ + return batch_norm.apply( + input, + # If buffers are not to be tracked, ensure that they won't be updated + self.running_mean if not self.training or self.track_running_stats else None, + self.running_var if not self.training or self.track_running_stats else None, + self.weight, self.bias, bn_training, exponential_average_factor, self.eps, self.scheme) + + +class QReLU(nn.Module): + def __init__(self, inplace=False): + super().__init__() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return ext_quantization.act_quantized_relu(input) + + +class QSyncBatchNorm(nn.SyncBatchNorm): + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True, + process_group=None, + group=0 + ) -> None: + super(QSyncBatchNorm, self).__init__(num_features, eps, momentum, affine, track_running_stats, process_group) + if config.adaptive_bn_scheme: + self.scheme = QBNScheme(group=group) + else: + self.scheme = None + + def forward(self, input): + # currently only GPU input is supported + if not input.is_cuda: + raise ValueError('SyncBatchNorm expected input tensor to be on GPU') + + self._check_input_dim(input) + + # exponential_average_factor is set to self.momentum + # (when it is available) only so that it gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + assert self.num_batches_tracked is not None + self.num_batches_tracked = self.num_batches_tracked + 1 + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / self.num_batches_tracked.item() + else: # use exponential moving average + exponential_average_factor = self.momentum + + r""" + Decide whether the mini-batch stats should be used for normalization rather than the buffers. + Mini-batch stats are used in training mode, and in eval mode when buffers are None. + """ + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and (self.running_var is None) + + r""" + Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be + passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are + used for normalization (i.e. in eval mode when buffers are not None). + """ + # If buffers are not to be tracked, ensure that they won't be updated + assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor) + assert self.running_var is None or isinstance(self.running_var, torch.Tensor) + running_mean = self.running_mean if not self.training or self.track_running_stats else None + running_var = self.running_var if not self.training or self.track_running_stats else None + + need_sync = bn_training + if need_sync: + process_group = torch.distributed.group.WORLD + if self.process_group: + process_group = self.process_group + world_size = torch.distributed.get_world_size(process_group) + need_sync = world_size > 1 + + # fallback to framework BN when synchronization is not necessary + if not need_sync: + return batch_norm().apply( + input, running_mean, running_var, self.weight, self.bias, + bn_training, exponential_average_factor, self.eps, self.scheme) + else: + if not self.ddp_gpu_size: + raise AttributeError('SyncBatchNorm is only supported within torch.nn.parallel.DistributedDataParallel') + + assert bn_training + return sync_batch_norm().apply( + input, self.weight, self.bias, running_mean, running_var, + self.eps, exponential_average_factor, process_group, world_size, self.scheme) + + +class QMaxPool2d(_MaxPoolNd): + kernel_size: _size_2_t + stride: _size_2_t + padding: _size_2_t + dilation: _size_2_t + + def __init__(self, kernel_size, stride=None, padding=0, dilation=1, + return_indices=False, ceil_mode=False): + super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode) + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + + def forward(self, input): + return ext_quantization.act_quantized_max_pool2d( + input, self.kernel_size, self.stride, + self.padding, self.dilation, self.ceil_mode, + self.return_indices) + + +class QAvgPool2d(_AvgPoolNd): + __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad', 'divisor_override'] + + kernel_size: _size_2_t + stride: _size_2_t + padding: _size_2_t + ceil_mode: bool + count_include_pad: bool + + def __init__(self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, padding: _size_2_t = 0, + ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: bool = None) -> None: + super().__init__() + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride if (stride is not None) else kernel_size) + self.padding = _pair(padding) + self.ceil_mode = ceil_mode + self.count_include_pad = count_include_pad + self.divisor_override = divisor_override + + def forward(self, input: Tensor) -> Tensor: + #return F.avg_pool2d(input, self.kernel_size, self.stride, + # self.padding, self.ceil_mode, self.count_include_pad, self.divisor_override) + # TODO: implement cuda kernel for this + return ext_quantization.act_quantized_max_pool2d( + input, self.kernel_size, self.stride, + self.padding, (1, 1), self.ceil_mode, + False) + diff --git a/actnn/actnn/module.py b/actnn/actnn/module.py new file mode 100644 index 0000000..c7496c7 --- /dev/null +++ b/actnn/actnn/module.py @@ -0,0 +1,99 @@ +from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict +from collections import OrderedDict + +import torch +import torch.nn as nn +from torch import Tensor, device, dtype + +from actnn.layers import QConv1d, QConv2d, QConv3d, QConvTranspose1d, QConvTranspose2d, QConvTranspose3d, \ + QBatchNorm1d, QBatchNorm2d, QBatchNorm3d, QSyncBatchNorm, \ + QReLU, QLinear, QMaxPool2d, QAvgPool2d +from actnn.conf import config + + +class QModule(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + QModule.convert_layers(model) + + @staticmethod + def convert_layers(module): + for name, child in module.named_children(): + # Do not convert layers that are already quantized + if isinstance(child, (QConv1d, QConv2d, QConv3d, QConvTranspose1d, QConvTranspose2d, QConvTranspose3d, + QBatchNorm1d, QBatchNorm2d, QBatchNorm3d, QSyncBatchNorm, + QReLU, QLinear, QMaxPool2d, QAvgPool2d)): + continue + + if isinstance(child, nn.Conv1d): + setattr(module, name, QConv1d(child.in_channels, child.out_channels, + child.kernel_size, child.stride, child.padding, child.dilation, + child.groups, child.bias is not None, child.padding_mode)) + elif isinstance(child, nn.Conv2d): + setattr(module, name, QConv2d(child.in_channels, child.out_channels, + child.kernel_size, child.stride, child.padding, child.dilation, + child.groups, child.bias is not None, child.padding_mode)) + elif isinstance(child, nn.Conv3d): + setattr(module, name, QConv3d(child.in_channels, child.out_channels, + child.kernel_size, child.stride, child.padding, child.dilation, + child.groups, child.bias is not None, child.padding_mode)) + elif isinstance(child, nn.ConvTranspose1d): + setattr(module, name, QConvTranspose1d(child.in_channels, child.out_channels, + child.kernel_size, child.stride, child.padding, child.output_padding, + child.groups, child.bias, child.dilation, child.padding_mode)) + elif isinstance(child, nn.ConvTranspose2d): + setattr(module, name, QConvTranspose2d(child.in_channels, child.out_channels, + child.kernel_size, child.stride, child.padding, child.output_padding, + child.groups, child.bias, child.dilation, child.padding_mode)) + elif isinstance(child, nn.ConvTranspose3d): + setattr(module, name, QConvTranspose3d(child.in_channels, child.out_channels, + child.kernel_size, child.stride, child.padding, child.output_padding, + child.groups, child.bias, child.dilation, child.padding_mode)) + elif isinstance(child, nn.BatchNorm1d) and config.enable_quantized_bn: + setattr(module, name, QBatchNorm1d(child.num_features, child.eps, child.momentum, + child.affine, child.track_running_stats)) + elif isinstance(child, nn.BatchNorm2d) and config.enable_quantized_bn: + setattr(module, name, QBatchNorm2d(child.num_features, child.eps, child.momentum, + child.affine, child.track_running_stats)) + elif isinstance(child, nn.BatchNorm3d) and config.enable_quantized_bn: + setattr(module, name, QBatchNorm3d(child.num_features, child.eps, child.momentum, + child.affine, child.track_running_stats)) + elif isinstance(child, nn.Linear): + setattr(module, name, QLinear(child.in_features, child.out_features, + child.bias is not None)) + elif isinstance(child, nn.ReLU): + setattr(module, name, QReLU()) + elif isinstance(child, nn.MaxPool2d): + setattr(module, name, QMaxPool2d(child.kernel_size, child.stride, + child.padding, child.dilation, child.return_indices, child.ceil_mode)) + elif isinstance(child, nn.AvgPool2d): + setattr(module, name, QAvgPool2d(child.kernel_size, child.stride, child.padding, + child.ceil_mode, child.count_include_pad, child.divisor_override)) + else: + QModule.convert_layers(child) + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def train(self, mode: bool = True): + config.training = mode + return super().train(mode) + + def eval(self): + config.training = False + return super().eval() + + def load_state_dict(self, state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]], + strict: bool = True): + # remove the prefix "model." added by this wrapper + new_state_dict = OrderedDict([("model." + k, v) for k, v in state_dict.items()]) + return super().load_state_dict(new_state_dict, strict) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + ret = super().state_dict(destination, prefix, keep_vars) + + # remove the prefix "model." added by this wrapper + ret = OrderedDict([(k[6:], v) for k, v in ret.items()]) + return ret + diff --git a/actnn/actnn/ops.py b/actnn/actnn/ops.py new file mode 100644 index 0000000..a0e5e41 --- /dev/null +++ b/actnn/actnn/ops.py @@ -0,0 +1,618 @@ +from collections import namedtuple +import os +import time + +import numpy as np +import torch +from torch.autograd.function import Function +import torch.distributed as dist +import torch.nn.functional as F +from torch.nn.modules.utils import _single, _pair, _triple +from torch.utils.cpp_extension import load + +from actnn.conf import config +from actnn.utils import get_memory_usage, compute_tensor_bytes, empty_cache, swap_to_cpu +import actnn.cpp_extension.quantization as ext_quantization +import actnn.cpp_extension.minimax as ext_minimax +import actnn.cpp_extension.backward_func as ext_backward_func + +QParams = namedtuple('QParams', ['range', 'zero_point', 'num_bits']) + + +def quantize_and_pack(data, bits, mn, mx): + if config.simulate: + N = data.shape[0] + output = data # N, groups, group_dim + + if isinstance(bits, int): # Handle the case when config.adaptive_scheme is False + bits = torch.ones(N, dtype=torch.int32, device='cuda') * bits + + B = (2 ** bits - 1).view(N, 1, 1) + mn = mn - 1e-6 + mx = mx + 1e-6 + scale = B / (mx - mn) # N, groups, 1 + output = (output - mn) * scale + + if config.stochastic: + noise = output.new(output.shape).uniform_(-0.5, 0.5) + output.add_(noise) + + output = F.relu(output) + output = torch.min(output, B.float()).round_().int() + else: + # Pack to bitstream + if isinstance(bits, int): + pack_func = ext_quantization.pack_single_precision + else: + pack_func = ext_quantization.pack_mixed_precision + output, scale = pack_func(data, mn, mx, bits, config.stochastic) + if config.swap: + output = swap_to_cpu(output) + + return output, scale + + +def dequantize_and_unpack(data, shape, bits, scale, mn): + if config.simulate: + data = data / scale + mn + else: + if config.swap: + data = data.cuda(non_blocking=True) + + # Pad to group_size + N = shape[0] + num_features = int(np.prod(shape[1:])) + group_size = config.group_size + num_features = (num_features + (group_size - num_features % group_size) % group_size) + + # Unpack bitstream + if isinstance(bits, int): + unpack_func = ext_quantization.unpack_single_precision + else: + unpack_func = ext_quantization.unpack_mixed_precision + data = unpack_func(data, bits, scale, mn, N, num_features // group_size, group_size) + return data + + +def no_scheme_compute_quantization_bits(input): + N = input.shape[0] + D = input.shape[1] + input_flatten = input.view(N, -1) + num_features = input_flatten.shape[1] + num_pixels = num_features // D + + # Compute min, max by groups + if num_features % config.group_size != 0: + # Padding + new_num_features = (num_features // config.group_size + 1) * config.group_size + delta = new_num_features - num_features + input_flatten = torch.cat([input_flatten, + torch.zeros([N, delta], dtype=input.dtype, device=input.device)], 1) + + input_groups = input_flatten.view(-1, config.group_size) + mn, mx = ext_minimax.minimax(input_groups) + + b = config.activation_compression_bits[0] + return input_groups.view(N, -1, config.group_size), b, mn.view(N, -1, 1), mx.view(N, -1, 1) + + +def quantize_activation(input, scheme): + if not config.compress_activation: + if config.swap: + input = swap_to_cpu(input) + + return input, None, None, None + + N = input.shape[0] + if scheme: + input_groups, q_bits, q_min, mx = scheme.compute_quantization_bits(input) + else: + input_groups, q_bits, q_min, mx = no_scheme_compute_quantization_bits(input) + + q_input, q_scale = quantize_and_pack(input_groups, q_bits, q_min, mx) + + # TODO convert q_bits to int8 + if input.dtype == torch.float32: + return q_input, q_bits, q_scale.to(torch.bfloat16), q_min.to(torch.bfloat16) + else: + return q_input, q_bits, q_scale, q_min + + +def dequantize_activation(quantized, q_input_shape): + if not config.compress_activation: + ret = quantized[0] + if config.swap: + ret = ret.cuda(non_blocking=True) + return ret + + q_input, q_bits, q_scale, q_min = quantized + if q_scale.dtype == torch.bfloat16: + q_scale = q_scale.to(torch.float32) + q_min = q_min.to(torch.float32) + input = dequantize_and_unpack(q_input, q_input_shape, q_bits, q_scale, q_min) + + # Remove padding + N = q_input_shape[0] + num_features = np.prod(q_input_shape[1:]) + input = input.view(N, -1)[:, :num_features] + input = input.view(*q_input_shape) + return input + +conv2d_layer_ct = 0 +bn_layer_ct = 0 +total_act_mem = 0 + +class convnd(Function): + @staticmethod + def run_forward(n, forward_op, ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, scheme=None): + # if not ctx.needs_input_grad[1]: + # assert not ctx.needs_input_grad[0] and not ctx.needs_input_grad[1] + # return F.conv2d(input, weight, bias, stride, padding, dilation, groups) + quantized = quantize_activation(input, scheme) + + ctx.scheme = scheme + ctx.saved = quantized, weight, bias + ctx.other_args = (input.shape, stride, padding, dilation, groups) + + empty_cache(config.empty_cache_threshold) + + if config.debug_memory_op_forward: + global conv2d_layer_ct, total_act_mem + print("========== conv%dd forward %d ==========" % (d, conv2d_layer_ct)) + get_memory_usage(True) + conv2d_layer_ct += 1 + total_act_mem += compute_tensor_bytes(quantized) + print("Act mem: %.2f MB" % (total_act_mem / 1024 ** 2)) + + return forward_op(input, weight, bias, stride, padding, dilation, groups) + + @staticmethod + def run_backward(n, ctx, grad_output, bias_reduce_dims, aug): + # if not ctx.needs_input_grad[1]: + # assert not ctx.needs_input_grad[0] and not ctx.needs_input_grad[1] + # return None, None, None, None, None, None, None, None + if ctx.scheme: + ctx.scheme.set_scale(grad_output) + + q_input_shape, stride, padding, dilation, groups = ctx.other_args + padding = aug(padding) + stride = aug(stride) + dilation = aug(dilation) + + quantized, weight, bias = ctx.saved + input = dequantize_activation(quantized, q_input_shape) + del quantized, ctx.saved + + empty_cache(config.empty_cache_threshold) + + if config.debug_memory_op_backward: + global conv2d_layer_ct + print("========== conv%dd backward %d ==========" % (n, conv2d_layer_ct)) + get_memory_usage(True) + conv2d_layer_ct += 1 + print("WS: %.2f MB" % (compute_tensor_bytes([grad_output, input, input]) / 1024 ** 2)) + + use_pipeline = False + if config.pipeline_threshold: + ws_mem = compute_tensor_bytes([grad_output, input, input]) + if (ws_mem > config.pipeline_threshold and + ctx.needs_input_grad[1] and ctx.needs_input_grad[0]): + use_pipeline = True + + if use_pipeline: + micro_batch_size = (ws_mem + config.pipeline_threshold) // config.pipeline_threshold + raw_input = input + raw_grad_output = grad_output + input = torch.chunk(input, micro_batch_size) + grad_output = torch.chunk(grad_output, micro_batch_size) + grad_weight = None + + for i in range(micro_batch_size): + input[i][:], grad_weight_tmp = ext_backward_func.cudnn_convolution_backward( + input[i], grad_output[i], weight, padding, stride, dilation, groups, + config.cudnn_benchmark_conv2d, False, False, + [ctx.needs_input_grad[0], ctx.needs_input_grad[1]]) + if grad_weight is None: + grad_weight = grad_weight_tmp + else: + grad_weight += grad_weight_tmp + grad_input = raw_input + grad_output = raw_grad_output + else: + grad_input, grad_weight = ext_backward_func.cudnn_convolution_backward( + input, grad_output, weight, padding, stride, dilation, groups, + config.cudnn_benchmark_conv2d, False, False, + [ctx.needs_input_grad[0], ctx.needs_input_grad[1]]) + + if bias is not None and ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(bias_reduce_dims) + else: + grad_bias = None + + if ctx.scheme: + ctx.scheme.if_allocate_perlayer() + return grad_input, grad_weight, grad_bias, None, None, None, None, None + + +class conv1d(Function): + @staticmethod + def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, scheme=None): + return convnd.run_forward(1, F.conv1d, ctx, input, weight, bias, stride, padding, dilation, groups, scheme) + + @staticmethod + def backward(ctx, grad_output): + return convnd.run_backward(1, ctx, grad_output, [0, 2], _single) + + +class conv2d(Function): + @staticmethod + def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, scheme=None): + return convnd.run_forward(2, F.conv2d, ctx, input, weight, bias, stride, padding, dilation, groups, scheme) + + @staticmethod + def backward(ctx, grad_output): + return convnd.run_backward(2, ctx, grad_output, [0, 2, 3], _pair) + + +class conv3d(Function): + @staticmethod + def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, scheme=None): + return convnd.run_forward(3, F.conv3d, ctx, input, weight, bias, stride, padding, dilation, groups, scheme) + + @staticmethod + def backward(ctx, grad_output): + return convnd.run_backward(3, ctx, grad_output, [0, 2, 3, 4], _triple) + + +class conv_transposend(Function): + @staticmethod + def run_forward(n, forward_op, ctx, input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1, scheme=None): + quantized = quantize_activation(input, scheme) + + ctx.scheme = scheme + ctx.saved = quantized, weight, bias + ctx.other_args = (input.shape, stride, padding, output_padding, dilation, groups) + + empty_cache(config.empty_cache_threshold) + + if config.debug_memory_op_forward: + global conv2d_layer_ct, total_act_mem + print("========== conv%dd_transpose forward %d ==========" % (n, conv2d_layer_ct)) + get_memory_usage(True) + conv2d_layer_ct += 1 + total_act_mem += compute_tensor_bytes(quantized) + print("Act mem: %.2f MB" % (total_act_mem / 1024 ** 2)) + + return forward_op(input, weight, bias, stride, padding, output_padding, groups, dilation) + + @staticmethod + def run_backward(n, ctx, grad_output, bias_reduce_dims, aug): + if ctx.scheme: + ctx.scheme.set_scale(grad_output) + + q_input_shape, stride, padding, output_padding, dilation, groups = ctx.other_args + padding = aug(padding) + output_padding = aug(output_padding) + stride = aug(stride) + dilation = aug(dilation) + + quantized, weight, bias = ctx.saved + input = dequantize_activation(quantized, q_input_shape) + del quantized, ctx.saved + + empty_cache(config.empty_cache_threshold) + + if config.debug_memory_op_backward: + global conv2d_layer_ct + print("========== conv%dd_transpose backward %d ==========" % (n, conv2d_layer_ct)) + get_memory_usage(True) + conv2d_layer_ct += 1 + print("WS: %.2f MB" % (compute_tensor_bytes([grad_output, input, input]) / 1024 ** 2)) + + use_pipeline = False + if config.pipeline_threshold: + ws_mem = compute_tensor_bytes([grad_output, input, input]) + if (ws_mem > config.pipeline_threshold and + ctx.needs_input_grad[1] and ctx.needs_input_grad[0]): + use_pipeline = True + + if use_pipeline: + micro_batch_size = (ws_mem + config.pipeline_threshold) // config.pipeline_threshold + raw_input = input + raw_grad_output = grad_output + input = torch.chunk(input, micro_batch_size) + grad_output = torch.chunk(grad_output, micro_batch_size) + grad_weight = None + + for i in range(micro_batch_size): + input[i][:], grad_weight_tmp = ext_backward_func.cudnn_convolution_transpose_backward( + input[i], grad_output[i], weight, padding, output_padding, stride, dilation, groups, + config.cudnn_benchmark_conv2d, False, False, + [ctx.needs_input_grad[0], ctx.needs_input_grad[1]]) + if grad_weight is None: + grad_weight = grad_weight_tmp + else: + grad_weight += grad_weight_tmp + grad_input = raw_input + grad_output = raw_grad_output + else: + grad_input, grad_weight = ext_backward_func.cudnn_convolution_transpose_backward( + input, grad_output, weight, padding, output_padding, stride, dilation, groups, + config.cudnn_benchmark_conv2d, False, False, [ctx.needs_input_grad[0], ctx.needs_input_grad[1]]) + + if bias is not None and ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(bias_reduce_dims) + else: + grad_bias = None + + if ctx.scheme: + ctx.scheme.if_allocate_perlayer() + return grad_input, grad_weight, grad_bias, None, None, None, None, None, None + + +class conv_transpose1d(Function): + @staticmethod + def forward(ctx, input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1, scheme=None): + return conv_transposend.run_forward(1, F.conv_transpose1d, ctx, input, weight, bias, stride, + padding, output_padding, groups, dilation, scheme) + + @staticmethod + def backward(ctx, grad_output): + return conv_transposend.run_backward(1, ctx, grad_output, [0, 2], _single) + + +class conv_transpose2d(Function): + @staticmethod + def forward(ctx, input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1, scheme=None): + return conv_transposend.run_forward(2, F.conv_transpose2d, ctx, input, weight, bias, stride, + padding, output_padding, groups, dilation, scheme) + + @staticmethod + def backward(ctx, grad_output): + return conv_transposend.run_backward(2, ctx, grad_output, [0, 2, 3], _pair) + + +class conv_transpose3d(Function): + @staticmethod + def forward(ctx, input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1, scheme=None): + return conv_transposend.run_forward(3, F.conv_transpose3d, ctx, input, weight, bias, stride, + padding, output_padding, groups, dilation, scheme) + + @staticmethod + def backward(ctx, grad_output): + return conv_transposend.run_backward(3, ctx, grad_output, [0, 2, 3, 4], _triple) + + +class linear(Function): + @staticmethod + def forward(ctx, input, weight, bias=None, scheme=None): + quantized = quantize_activation(input, scheme) + + empty_cache(config.empty_cache_threshold) + + ctx.scheme = scheme + ctx.saved = quantized, weight, bias + ctx.other_args = input.shape + + return F.linear(input, weight, bias) + + @staticmethod + def backward(ctx, grad_output): + if ctx.scheme: + ctx.scheme.set_scale(grad_output) + + quantized, weight, bias = ctx.saved + q_input_shape = ctx.other_args + + input = dequantize_activation(quantized, q_input_shape) + del quantized, ctx.saved + + empty_cache(config.empty_cache_threshold) + + # TODO: the following implementation might not be optimal + grad_input = grad_output.mm(weight) + grad_weight = grad_output.t().mm(input) + if bias is not None: + grad_bias = grad_output.sum(0) + else: + grad_bias = None + + if ctx.scheme: + ctx.scheme.if_allocate_perlayer() + return grad_input, grad_weight, grad_bias, None + + +class batch_norm(Function): + @staticmethod + def forward(ctx, input, running_mean, running_var, weight, bias, + training, exponential_average_factor, eps, scheme): + # if not ctx.needs_input_grad[3]: + # assert not ctx.needs_input_grad[0] and not ctx.needs_input_grad[4] + # return ext_backward_func.cudnn_batch_norm( + # input, weight, bias, running_mean, running_var, training, exponential_average_factor, eps)[0] + quantized = quantize_activation(input, scheme) + + empty_cache(config.empty_cache_threshold) + + if config.debug_memory_op_forward: + global bn_layer_ct, total_act_mem + print("========== bn forward %d ==========" % bn_layer_ct) + get_memory_usage(True) + bn_layer_ct += 1 + total_act_mem += compute_tensor_bytes(quantized) + print("Act mem: %.2f MB" % (total_act_mem / 1024 ** 2)) + + if training: + output, save_mean, save_var, reserve = ext_backward_func.cudnn_batch_norm( + input, weight, bias, running_mean, running_var, training, exponential_average_factor, eps) + else: + output, save_mean, save_var = ext_backward_func.native_batch_norm( + input, weight, bias, running_mean, running_var, training, exponential_average_factor, eps) + reserve = None + + ctx.scheme = scheme + ctx.other_args = input.shape + ctx.saved = (quantized, weight, running_mean, running_var, save_mean, save_var, training, eps, reserve) + + return output + + @staticmethod + def backward(ctx, grad_output): + # if not ctx.needs_input_grad[3]: + # assert not ctx.needs_input_grad[0] and not ctx.needs_input_grad[4] + # return None, None, None, None, None, None, None, None, None + quantized, weight, running_mean, running_var, save_mean, save_var, training, eps, reserve = ctx.saved + + q_input_shape = ctx.other_args + + input = dequantize_activation(quantized, q_input_shape) + del quantized, ctx.saved + + empty_cache(config.empty_cache_threshold) + + if config.debug_memory_op_backward: + global bn_layer_ct + print("========== bn backward %d ==========" % bn_layer_ct) + get_memory_usage(True) + bn_layer_ct += 1 + + if training: + input = input.contiguous() + grad_input, grad_weight, grad_bias = ext_backward_func.cudnn_batch_norm_backward( + input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps, reserve) + else: + grad_input, grad_weight, grad_bias = ext_backward_func.native_batch_norm_backward( + grad_output, input, weight, running_mean, running_var, save_mean, save_var, training, eps, + [ctx.needs_input_grad[0], ctx.needs_input_grad[3], ctx.needs_input_grad[4]] + ) + + if ctx.scheme: + ctx.scheme.if_allocate_perlayer() + return grad_input, None, None, grad_weight, grad_bias, None, None, None, None + + +class sync_batch_norm(Function): + @staticmethod + def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size, scheme): + input = input.contiguous() + + count = torch.empty(1, + dtype=running_mean.dtype, + device=input.device).fill_(input.numel() // input.size(1)) + + # calculate mean/invstd for input. + mean, invstd = torch.batch_norm_stats(input, eps) + + num_channels = input.shape[1] + # C, C, 1 -> (2C + 1) + combined = torch.cat([mean, invstd, count], dim=0) + # world_size * (2C + 1) + combined_list = [ + torch.empty_like(combined) for k in range(world_size) + ] + # Use allgather instead of allreduce since I don't trust in-place operations .. + dist.all_gather(combined_list, combined, process_group, async_op=False) + combined = torch.stack(combined_list, dim=0) + # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 + mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) + + size = count_all.view(-1).long().sum() + if size == 1: + raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size)) + + # calculate global mean & invstd + mean, invstd = torch.batch_norm_gather_stats_with_counts( + input, + mean_all, + invstd_all, + running_mean, + running_var, + momentum, + eps, + count_all.view(-1) + ) + + quantized = quantize_activation(input, scheme) + self.saved = quantized + self.save_for_backward(weight, mean, invstd, count_all) + self.scheme = scheme + self.other_args = input.shape + self.process_group = process_group + + # apply element-wise normalization + return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps) + + @staticmethod + def backward(self, grad_output): + grad_output = grad_output.contiguous() + + quantized = self.saved + q_input_shape = self.other_args + saved_input = dequantize_activation(quantized, q_input_shape) + del quantized, self.saved + + weight, mean, invstd, count_tensor = self.saved_tensors + grad_input = grad_weight = grad_bias = None + process_group = self.process_group + + # calculate local stats as well as grad_weight / grad_bias + sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce( + grad_output, + saved_input, + mean, + invstd, + weight, + self.needs_input_grad[0], + self.needs_input_grad[1], + self.needs_input_grad[2] + ) + + if self.needs_input_grad[0]: + # synchronizing stats used to calculate input gradient. + # TODO: move div_ into batch_norm_backward_elemt kernel + num_channels = sum_dy.shape[0] + combined = torch.cat([sum_dy, sum_dy_xmu], dim=0) + torch.distributed.all_reduce( + combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False) + sum_dy, sum_dy_xmu = torch.split(combined, num_channels) + + divisor = count_tensor.sum() + mean_dy = sum_dy / divisor + mean_dy_xmu = sum_dy_xmu / divisor + # backward pass for gradient calculation + grad_input = torch.batch_norm_backward_elemt( + grad_output, + saved_input, + mean, + invstd, + weight, + mean_dy, + mean_dy_xmu + ) + + # synchronizing of grad_weight / grad_bias is not needed as distributed + # training would handle all reduce. + if weight is None or not self.needs_input_grad[1]: + grad_weight = None + + if weight is None or not self.needs_input_grad[2]: + grad_bias = None + + if self.scheme: + self.scheme.if_allocate_perlayer() + return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None + + +class adaptive_avg_pool2d(Function): + @staticmethod + def forward(ctx, input, output_size): + assert output_size == (1, 1) + ctx.saved = input.shape + return torch.mean(input, dim=[2, 3], keepdim=True) + + @staticmethod + def backward(ctx, grad_output): + input_shape = ctx.saved + repeat_size = [int(x / y) for x, y in zip(input_shape, grad_output.shape)] + return grad_output.repeat(repeat_size) / np.prod(repeat_size), None + diff --git a/actnn/actnn/qbnscheme.py b/actnn/actnn/qbnscheme.py new file mode 100644 index 0000000..359f577 --- /dev/null +++ b/actnn/actnn/qbnscheme.py @@ -0,0 +1,56 @@ +import torch + +from actnn.conf import config +from actnn.qscheme import QScheme +import actnn.cpp_extension.minimax as ext_minimax +import actnn.cpp_extension.calc_precision as ext_calc_precision + + +class QBNScheme(QScheme): + layers = [] + + def __init__(self, group=0): + self.initial_bits = config.initial_bits + self.bits = config.activation_compression_bits[group] + QBNScheme.layers.append(self) + if len(QScheme.layers) > 0: + self.prev_linear = QScheme.layers[-1] + else: + self.prev_linear = None + + def compute_quantization_bits(self, input): + N = input.shape[0] + D = input.shape[1] + input_flatten = input.view(N, -1) + num_features = input_flatten.shape[1] + num_pixels = num_features // D + + # Compute min, max by groups + if num_features % config.group_size != 0: + # Padding + new_num_features = (num_features // config.group_size + 1) * config.group_size + delta = new_num_features - num_features + input_flatten = torch.cat([input_flatten, + torch.zeros([N, delta], dtype=input.dtype, device=input.device)], 1) + + input_groups = input_flatten.view(-1, config.group_size) + mn, mx = ext_minimax.minimax(input_groups) + if not config.pergroup: # No per group quantization + mn = torch.ones_like(mn) * mn.min() + mx = torch.ones_like(mx) * mx.max() + + # Average range over pixels [N] + Range_sqr = torch.norm((mx - mn).view(N, -1), dim=1).square() * (config.group_size / num_pixels) + + # greedy + C = Range_sqr.to(torch.float32).cpu() + b = torch.ones(N, dtype=torch.int32) * self.initial_bits + w = torch.ones(N, dtype=torch.int32) + b = ext_calc_precision.calc_precision(b, C, w, int(self.bits * N)) + + return input_groups.view(N, -1, config.group_size), b.cuda(), mn.view(N, -1, 1), mx.view(N, -1, 1) + + @staticmethod + def allocate_perlayer(): + for layer in QBNScheme.layers: + layer.bits = layer.prev_linear.bits diff --git a/actnn/actnn/qscheme.py b/actnn/actnn/qscheme.py new file mode 100644 index 0000000..cdcc230 --- /dev/null +++ b/actnn/actnn/qscheme.py @@ -0,0 +1,143 @@ +import torch + +import actnn +from actnn.conf import config +import actnn.cpp_extension.minimax as ext_minimax +import actnn.cpp_extension.calc_precision as ext_calc_precision + +class QScheme(object): + num_samples = 1 + num_layers = 0 + batch = None + update_scale = True + layers = [] + + def __init__(self, layer, group=0, num_locations=1, depthwise_groups=1): + self.initial_bits = config.initial_bits + self.bits = config.activation_compression_bits[group] + if config.use_gradient: + assert QScheme.num_samples > 1 + self.scales = torch.zeros(QScheme.num_samples) + else: + self.scales = torch.tensor([0.0]) + QScheme.layers.append(self) + self.C = None + self.dim = None + self.num_locations = num_locations # Kernel size + self.depthwise_groups = depthwise_groups # Depthwise separable conv + self.layer = layer + self.group = group + + # debug + self.name = 'layer_{}'.format(QScheme.num_layers) + QScheme.num_layers += 1 + + def get_scale(self): + if config.use_gradient: + assert QScheme.batch is not None + scale = self.scales[QScheme.batch].clone() + avg_scale = scale.mean() + scale[scale == 0] = avg_scale + 1e-9 + return scale + else: + return self.scales + + def set_scale(self, grad): + if QScheme.update_scale: + if config.use_gradient: + assert QScheme.batch is not None + scale = grad.view(grad.shape[0], -1).norm(dim=1).square().cpu() + self.scales[QScheme.batch] = self.scales[QScheme.batch] * 0.5 + scale * 0.5 + else: + scale = grad.view(grad.shape[0], -1).norm(dim=1).square() + self.scales = scale.mean() + + def compute_quantization_bits(self, input): + N = input.shape[0] + D = input.shape[1] + input_flatten = input.view(N, -1) + num_features = input_flatten.shape[1] + num_pixels = num_features // D + + # Compute min, max by groups + if num_features % config.group_size != 0: + # Padding + new_num_features = (num_features // config.group_size + 1) * config.group_size + delta = new_num_features - num_features + input_flatten = torch.cat([input_flatten, + torch.zeros([N, delta], dtype=input.dtype, device=input.device)], 1) + + input_groups = input_flatten.view(-1, config.group_size) # [-1, group_size] + mn, mx = ext_minimax.minimax(input_groups) + if not config.pergroup: # No per group quantization + mn = torch.ones_like(mn) * mn.min() + mx = torch.ones_like(mx) * mx.max() + + # Average range over pixels G * ||R_n||^2 / I + Range_sqr = torch.norm((mx - mn).view(N, -1), dim=1).square() * (config.group_size / num_pixels) + + # greedy + grad_sum = self.get_scale().cuda() + C = (self.num_locations / 4 / self.depthwise_groups * Range_sqr * grad_sum)\ + .to(torch.float32).cpu() + b = torch.ones(N, dtype=torch.int32) * self.initial_bits + w = torch.ones(N, dtype=torch.int32) + b = ext_calc_precision.calc_precision(b, C, w, int(self.bits * N)) # N + + self.C = C + self.dim = input.numel() // N + self.b = b + + return input_groups.view(N, -1, config.group_size), b.cuda(), mn.view(N, -1, 1), mx.view(N, -1, 1) + + @staticmethod + def allocate_perlayer(): + num_groups = len(config.activation_compression_bits) + for g in range(num_groups): + layers = [layer for layer in QScheme.layers if layer.group == g] + L = len(layers) + + if config.activation_compression_bits[g] == config.initial_bits: + C = torch.tensor([layer.C.sum() for layer in layers]) + w = torch.tensor([layer.dim for layer in layers], dtype=torch.int) + total_bits = w.sum() * config.activation_compression_bits[g] + b = torch.ones(L, dtype=torch.int32) * 8 + b = ext_calc_precision.calc_precision(b, C, w, total_bits) + + for i in range(L): + layers[i].bits = layers[i].initial_bits = b[i] + else: + Cs = [layer.C for layer in layers] + C = torch.cat(Cs, 0) + + N = Cs[0].shape[0] + + # TODO ??? + Ws = [torch.ones(N, dtype=torch.int32) * layer.dim for layer in layers] + # Ws = [torch.ones(N, dtype=torch.int32) for layer in layers] + w = torch.cat(Ws, 0) + + total_bits = w.sum() * config.activation_compression_bits[g] + b = torch.ones(N * L, dtype=torch.int32) * config.initial_bits + b = ext_calc_precision.calc_precision(b, C, w, total_bits) + for i in range(L): + bs = b[i*N : (i+1)*N] + layers[i].bits = bs.float().mean() + + def if_allocate_perlayer(self): + if not config.perlayer: + return + for layer in QScheme.layers: + if layer.C is None: + return + + first_layer = None + for layer in QScheme.layers: + if layer.layer.weight.requires_grad: + first_layer = layer + + # If myself is the last layer, then reallocate bits per layer + if config.compress_activation and config.training: + if self == first_layer: + QScheme.allocate_perlayer() + actnn.QBNScheme.allocate_perlayer() diff --git a/actnn/actnn/utils.py b/actnn/actnn/utils.py new file mode 100644 index 0000000..14c3e9b --- /dev/null +++ b/actnn/actnn/utils.py @@ -0,0 +1,81 @@ +import os +from collections import OrderedDict +import json + +import torch +import numpy as np +import json + + +def swap_to_cpu(tensor): + tensor_cpu = torch.empty(tensor.shape, dtype=tensor.dtype, device='cpu', pin_memory=True) + tensor_cpu.copy_(tensor, non_blocking=True) + return tensor_cpu + + +def get_memory_usage(print_info=False): + """Get accurate gpu memory usage by querying torch runtime""" + allocated = torch.cuda.memory_allocated(0) + reserved = torch.cuda.memory_reserved(0) + if print_info: + print("allocated: %.2f MB" % (allocated / 1024 / 1024), flush=True) + print("reserved: %.2f MB" % (reserved / 1024 / 1024), flush=True) + return allocated + + +def compute_tensor_bytes(tensors): + """Compute the bytes used by a list of tensors""" + if not isinstance(tensors, (list, tuple)): + tensors = [tensors] + + ret = 0 + for x in tensors: + if x.dtype in [torch.float32, torch.int]: + ret += np.prod(x.size()) * 4 + elif x.dtype in [torch.bfloat16, torch.float16, torch.int16]: + ret += np.prod(x.size()) * 2 + elif x.dtype in [torch.int8]: + ret += np.prod(x.size()) * 2 + + return ret + + +def empty_cache(ratio): + if ratio is None: + return + allocated = torch.cuda.memory_allocated(0) + reserved = torch.cuda.memory_reserved(0) + if reserved > 0 and allocated / reserved < ratio: + torch.cuda.empty_cache() + + +def disable_cache_allocator(): + os.environ['PYTORCH_NO_CUDA_MEMORY_CACHING'] = '1' + + +def enable_cache_allocator(): + del os.environ['PYTORCH_NO_CUDA_MEMORY_CACHING'] + + +class GlobalExpRecorder: + def __init__(self): + self.val_dict = OrderedDict() + + def record(self, key, value, float_round=6): + if isinstance(value, (np.int32, np.int64)): + value = int(value) + if isinstance(value, (float, np.float32, np.float64)): + value = round(value, float_round) + + self.val_dict[key] = value + + def dump(self, filename): + with open(filename, "a") as fout: + fout.write(json.dumps(self.val_dict) + '\n') + print("Save exp results to %s" % filename) + + def clear(): + pass + +exp_recorder = GlobalExpRecorder() + diff --git a/actnn/setup.py b/actnn/setup.py new file mode 100644 index 0000000..535fe5f --- /dev/null +++ b/actnn/setup.py @@ -0,0 +1,25 @@ +from setuptools import setup, Extension, find_packages +from torch.utils import cpp_extension + +setup(name='actnn', + ext_modules=[ + cpp_extension.CUDAExtension( + 'actnn.cpp_extension.calc_precision', + ['actnn/cpp_extension/calc_precision.cc'] + ), + cpp_extension.CUDAExtension( + 'actnn.cpp_extension.minimax', + ['actnn/cpp_extension/minimax.cc', 'actnn/cpp_extension/minimax_cuda_kernel.cu'] + ), + cpp_extension.CUDAExtension( + 'actnn.cpp_extension.backward_func', + ['actnn/cpp_extension/backward_func.cc'] + ), + cpp_extension.CUDAExtension( + 'actnn.cpp_extension.quantization', + ['actnn/cpp_extension/quantization.cc', 'actnn/cpp_extension/quantization_cuda_kernel.cu'] + ), + ], + cmdclass={'build_ext': cpp_extension.BuildExtension}, + packages=find_packages() +) diff --git a/image_classification/Dockerfile b/image_classification/Dockerfile new file mode 100644 index 0000000..455179d --- /dev/null +++ b/image_classification/Dockerfile @@ -0,0 +1,8 @@ +FROM nvcr.io/nvidia/pytorch:19.05-py3 + +RUN git clone https://github.com/NVIDIA/apex \ + && cd apex \ + && pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . + +ADD . /workspace/rn50 +WORKDIR /workspace/rn50 diff --git a/image_classification/INSTALL.md b/image_classification/INSTALL.md new file mode 100644 index 0000000..947dd0d --- /dev/null +++ b/image_classification/INSTALL.md @@ -0,0 +1,12 @@ +INSTALL +==== + +```` +cd quantizers +python setup.py install +mkdir results +./train resnet18 # Exact +./train resnet18 "-c quantize --qa=True --qw=True --qg=True --persample=True --hadamard=True" # Our quantized algorithm +```` + +If the GPU memory is large, multiply `batch-size`, `lr`, and `warmup` by 2. For distributed setting, multiply `batch-size`, `lr` and `warmup` by `number of GPUs / 4`. \ No newline at end of file diff --git a/image_classification/LICENSE b/image_classification/LICENSE new file mode 100644 index 0000000..5e037a5 --- /dev/null +++ b/image_classification/LICENSE @@ -0,0 +1,11 @@ +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/image_classification/README.md b/image_classification/README.md new file mode 100644 index 0000000..383aa15 --- /dev/null +++ b/image_classification/README.md @@ -0,0 +1,48 @@ +# Image Classficiation + +## Requirement +- Put the ImageNet dataset to `~/imagenet` +- Install apex +```bash +git clone https://github.com/NVIDIA/apex +cd apex +pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ +``` + +## Train resnet56 on cifar10 +``` +mkdir -p results/tmp +python3 main.py --dataset cifar10 --arch preact_resnet56 --epochs 200 --num-classes 10 \ + -j 0 --weight-decay 1e-4 --batch-size 128 --label-smoothing 0 \ + --lr 0.1 --momentum 0.9 --warmup 4 \ + -c quantize --ca=True --cabits=2 --ibits=8 --calg pl \ + --workspace results/tmp --gather-checkpoints ~/data/cifar10 +``` + +## Train resnet50 on imagenet +``` +./dist-train 1 0 127.0.0.1 1 resnet50 \ + "-c quantize --ca=True --cabits=2 --ibits=8 --calg pl"\ + tmp ~/imagenet 256 +``` + + +## Check gradient variance +Download model checkpoints +``` +wget https://people.eecs.berkeley.edu/~jianfei/results.tar.gz +tar xzvf results.tar.gz +mkdir results/tmp +``` + +### Cifar 100 +``` +python3 main.py --dataset cifar10 --arch preact_resnet56 --epochs 200 --num-classes 100 -j 0 --weight-decay 1e-4 --batch-size 128 --label-smoothing 0 \ + -c quantize --ca=True --cabits=2 --ibits=8 --calg pl \ + --workspace results/tmp --evaluate --training-only \ + --resume results/cifar100/checkpoint-10.pth.tar --resume2 results/cifar100/checkpoint-10.pth.tar ~/data/cifar100 +``` + +| *quantize config* | *Overall Var* | *Val Top1* | +|--------|----------|---------| +| -c quantize --ca=True --cabits=2 --ibits=8 --calg pl | 0.03805697709321976 | | diff --git a/image_classification/__init__.py b/image_classification/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_classification/dist-test b/image_classification/dist-test new file mode 100755 index 0000000..41db26e --- /dev/null +++ b/image_classification/dist-test @@ -0,0 +1,19 @@ +NNODES=$1 +NODE_RANK=$2 +MASTER_ADDR=$3 +NPROC_PER_NODE=$4 +ARCH=$5 +ARGS=$6 +ID=$7 +DIR=$8 + +BATCH_SIZE=50 +LR=0.3 +WARMUP=4 + +mkdir results/$ID +python ./multiproc.py --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port 29501 --nproc_per_node $NPROC_PER_NODE \ +./main.py --arch $ARCH --num-classes 1000 --gather-checkpoints --workspace results/$ID --batch-size $BATCH_SIZE --lr $LR --gather-checkpoints --warmup $WARMUP $ARGS \ +--evaluate --training-only \ +--resume results/imagenet/checkpoint-10.pth.tar \ +--resume2 results/imagenet/checkpoint-10.pth.tar $DIR diff --git a/image_classification/dist-train b/image_classification/dist-train new file mode 100755 index 0000000..003a67c --- /dev/null +++ b/image_classification/dist-train @@ -0,0 +1,15 @@ +NNODES=$1 +NODE_RANK=$2 +MASTER_ADDR=$3 +NPROC_PER_NODE=$4 +ARCH=$5 +ARGS=$6 +ID=$7 +DIR=$8 + +BATCH_SIZE=${9:-32} +LR=0.256 +WARMUP=4 + +mkdir results/$ID +python3 ./multiproc.py --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --nproc_per_node $NPROC_PER_NODE ./main.py --arch $ARCH --gather-checkpoints --workspace results/$ID --batch-size $BATCH_SIZE --lr $LR --warmup $WARMUP $ARGS $DIR diff --git a/image_classification/exp_mem_speed.py b/image_classification/exp_mem_speed.py new file mode 100644 index 0000000..269cbb9 --- /dev/null +++ b/image_classification/exp_mem_speed.py @@ -0,0 +1,102 @@ +import os +import json +import argparse + +def run_cmd(cmd): + print(cmd) + return os.system(cmd) + + +alg_to_config = { + "exact": "-c fanin", + "quantize": "-c quantize --ca=True --cabits=2 --ibits=8 --calg pl", +} + +network_to_batch_size = { + "preact_resnet56": [64, 32000], + "preact_resnet1001": [64, 2048], + "resnet50": [64, 1024], + "resnet152": [32, 1024], +} + +network_to_command = { + "preact_resnet56": "python3 main.py --dataset cifar10 --arch preact_resnet56 " + "--epochs 200 --num-classes 10 -j 0 --weight-decay 1e-4 --batch-size BS " + "--training-only --label-smoothing 0 CONFIG " + "--workspace results/tmp --gather-checkpoints ~/data/cifar10", + "preact_resnet1001": "python3 main.py --dataset cifar10 --arch preact_resnet1001 " + "--epochs 200 --num-classes 10 -j 0 --weight-decay 1e-4 --batch-size BS " + "--training-only --label-smoothing 0 CONFIG " + "--workspace results/tmp --gather-checkpoints ~/data/cifar10", + "resnet50": "bash dist-train 1 0 127.0.0.1 1 resnet50 'CONFIG' tmp ~/imagenet BS", + "resnet152": "bash dist-train 1 0 127.0.0.1 1 resnet152 'CONFIG' tmp ~/imagenet BS", +} + + +def run_benchmark(network, alg, batch_size, debug_mem=False, debug_speed=False): + os.environ['DEBUG_MEM'] = str(debug_mem) + os.environ['DEBUG_SPEED'] = str(debug_speed) + cmd = network_to_command[network] + cmd = cmd.replace("BS", f"{batch_size}").replace("CONFIG", alg_to_config[alg]) + return run_cmd(cmd) + + +def binary_search_max_batch(network, alg, low, high): + ret = 0 + + while low <= high: + mid = low + (high - low) // 2 + success = run_benchmark(network, alg, mid, debug_speed=True) == 0 + if success: + ret = mid + low = mid + 1 + else: + high = mid - 1 + + return ret + + +def get_ips(network, alg, batch_size): + run_benchmark(network, alg, batch_size, debug_speed=True) + line = list(open("speed_results.tsv").readlines())[-1] + return json.loads(line)['ips'] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--mode", type=str, + choices=['linear_scan', 'binary_search'], + default='linear_scan') + args = parser.parse_args() + + #networks = ['preact_resnet1001', 'resnet152', 'resnet50', 'preact_resnet56'] + #algs = ['exact', 'quantize'] + + networks = ['resnet50'] + algs = ['quantize'] + batch_sizes = [1000] + + if args.mode == 'linear_scan': + for network in networks: + for alg in algs: + for batch_size in (batch_sizes or network_to_batch_size[network]): + if run_benchmark(network, alg, batch_size, debug_mem=False, debug_speed=True) != 0: + break + elif args.mode == 'binary_search': + for network in networks: + for alg in algs: + low, high = network_to_batch_size[network][0], network_to_batch_size[network][-1] + max_batch_size = binary_search_max_batch(network, alg, low, high) + ips = get_ips(network, alg, max_batch_size) + + out_file = "max_batch_results.tsv" + with open(out_file, "a") as fout: + val_dict = { + "network": network, + "algorithm": alg, + "max_batch_size": max_batch_size, + "ips": ips, + } + fout.write(json.dumps(val_dict) + "\n") + print(f"save results to {out_file}") + diff --git a/image_classification/image_classification/__init__.py b/image_classification/image_classification/__init__.py new file mode 100644 index 0000000..e9cc3b5 --- /dev/null +++ b/image_classification/image_classification/__init__.py @@ -0,0 +1,7 @@ +from . import logger +from . import dataloaders +from . import training +from . import utils +from . import mixup +from . import resnet +from . import smoothing diff --git a/image_classification/image_classification/dataloaders.py b/image_classification/image_classification/dataloaders.py new file mode 100644 index 0000000..f8c0468 --- /dev/null +++ b/image_classification/image_classification/dataloaders.py @@ -0,0 +1,367 @@ +import os +import torch +import numpy as np +import torchvision.datasets as datasets +import torchvision.transforms as transforms +from actnn import dataloader + +DATA_BACKEND_CHOICES = ['pytorch'] +# try: +# from nvidia.dali.plugin.pytorch import DALIClassificationIterator +# from nvidia.dali.pipeline import Pipeline +# import nvidia.dali.ops as ops +# import nvidia.dali.types as types +# DATA_BACKEND_CHOICES.append('dali-gpu') +# DATA_BACKEND_CHOICES.append('dali-cpu') +# except ImportError: +# print("Please install DALI from https://www.github.com/NVIDIA/DALI to run this example.") + + +# class HybridTrainPipe(Pipeline): +# def __init__(self, batch_size, num_threads, device_id, data_dir, crop, dali_cpu=False): +# super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed = 12 + device_id) +# if torch.distributed.is_initialized(): +# local_rank = torch.distributed.get_rank() +# world_size = torch.distributed.get_world_size() +# else: +# local_rank = 0 +# world_size = 1 +# +# self.input = ops.FileReader( +# file_root = data_dir, +# shard_id = local_rank, +# num_shards = world_size, +# random_shuffle = True) +# +# if dali_cpu: +# dali_device = "cpu" +# self.decode = ops.HostDecoderRandomCrop(device=dali_device, output_type=types.RGB, +# random_aspect_ratio=[0.75, 4./3.], +# random_area=[0.08, 1.0], +# num_attempts=100) +# else: +# dali_device = "gpu" +# # This padding sets the size of the internal nvJPEG buffers to be able to handle all images from full-sized ImageNet +# # without additional reallocations +# self.decode = ops.nvJPEGDecoderRandomCrop(device="mixed", output_type=types.RGB, device_memory_padding=211025920, host_memory_padding=140544512, +# random_aspect_ratio=[0.75, 4./3.], +# random_area=[0.08, 1.0], +# num_attempts=100) +# +# self.res = ops.Resize(device=dali_device, resize_x=crop, resize_y=crop, interp_type=types.INTERP_TRIANGULAR) +# self.cmnp = ops.CropMirrorNormalize(device = "gpu", +# output_dtype = types.FLOAT, +# output_layout = types.NCHW, +# crop = (crop, crop), +# image_type = types.RGB, +# mean = [0.485 * 255,0.456 * 255,0.406 * 255], +# std = [0.229 * 255,0.224 * 255,0.225 * 255]) +# self.coin = ops.CoinFlip(probability = 0.5) +# +# def define_graph(self): +# rng = self.coin() +# self.jpegs, self.labels = self.input(name = "Reader") +# images = self.decode(self.jpegs) +# images = self.res(images) +# output = self.cmnp(images.gpu(), mirror = rng) +# return [output, self.labels] +# +# +# class HybridValPipe(Pipeline): +# def __init__(self, batch_size, num_threads, device_id, data_dir, crop, size): +# super(HybridValPipe, self).__init__(batch_size, num_threads, device_id, seed = 12 + device_id) +# if torch.distributed.is_initialized(): +# local_rank = torch.distributed.get_rank() +# world_size = torch.distributed.get_world_size() +# else: +# local_rank = 0 +# world_size = 1 +# +# self.input = ops.FileReader( +# file_root = data_dir, +# shard_id = local_rank, +# num_shards = world_size, +# random_shuffle = False) +# +# self.decode = ops.nvJPEGDecoder(device = "mixed", output_type = types.RGB) +# self.res = ops.Resize(device = "gpu", resize_shorter = size) +# self.cmnp = ops.CropMirrorNormalize(device = "gpu", +# output_dtype = types.FLOAT, +# output_layout = types.NCHW, +# crop = (crop, crop), +# image_type = types.RGB, +# mean = [0.485 * 255,0.456 * 255,0.406 * 255], +# std = [0.229 * 255,0.224 * 255,0.225 * 255]) +# +# def define_graph(self): +# self.jpegs, self.labels = self.input(name = "Reader") +# images = self.decode(self.jpegs) +# images = self.res(images) +# output = self.cmnp(images) +# return [output, self.labels] + + +class DALIWrapper(object): + def gen_wrapper(dalipipeline, num_classes, one_hot): + for data in dalipipeline: + input = data[0]["data"] + target = data[0]["label"].squeeze().cuda().long() + if one_hot: + target = expand(num_classes, torch.float, target) + yield input, target + dalipipeline.reset() + + def __init__(self, dalipipeline, num_classes, one_hot): + self.dalipipeline = dalipipeline + self.num_classes = num_classes + self.one_hot = one_hot + + def __iter__(self): + return DALIWrapper.gen_wrapper(self.dalipipeline, self.num_classes, self.one_hot) + +def get_dali_train_loader(dali_cpu=False): + def gdtl(data_path, batch_size, num_classes, one_hot, workers=5, _worker_init_fn=None, fp16=False): + if torch.distributed.is_initialized(): + local_rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + else: + local_rank = 0 + world_size = 1 + + traindir = os.path.join(data_path, 'train') + + pipe = HybridTrainPipe(batch_size=batch_size, num_threads=workers, + device_id = local_rank, + data_dir = traindir, crop = 224, dali_cpu=dali_cpu) + + pipe.build() + train_loader = DALIClassificationIterator(pipe, size = int(pipe.epoch_size("Reader") / world_size)) + + return DALIWrapper(train_loader, num_classes, one_hot), int(pipe.epoch_size("Reader") / (world_size * batch_size)) + + return gdtl + + +def get_dali_val_loader(): + def gdvl(data_path, batch_size, num_classes, one_hot, workers=5, _worker_init_fn=None, fp16=False): + if torch.distributed.is_initialized(): + local_rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + else: + local_rank = 0 + world_size = 1 + + valdir = os.path.join(data_path, 'val') + + pipe = HybridValPipe(batch_size=batch_size, num_threads=workers, + device_id = local_rank, + data_dir = valdir, + crop = 224, size = 256) + pipe.build() + val_loader = DALIClassificationIterator(pipe, size = int(pipe.epoch_size("Reader") / world_size)) + + return DALIWrapper(val_loader, num_classes, one_hot), int(pipe.epoch_size("Reader") / (world_size * batch_size)) + return gdvl + + +def fast_collate(batch): + imgs = [img[0] for img in batch] + targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) + w = imgs[0].size[0] + h = imgs[0].size[1] + tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 ) + for i, img in enumerate(imgs): + #nump_array = np.asarray(img, dtype=np.uint8) + nump_array = np.array(np.asarray(img, dtype=np.uint8)) + tens = torch.from_numpy(nump_array) + if(nump_array.ndim < 3): + nump_array = np.expand_dims(nump_array, axis=-1) + nump_array = np.rollaxis(nump_array, 2) + + tensor[i] += torch.from_numpy(nump_array) + + return tensor, targets + + +def expand(num_classes, dtype, tensor): + e = torch.zeros(tensor.size(0), num_classes, dtype=dtype, device=torch.device('cuda')) + e = e.scatter(1, tensor.unsqueeze(1), 1.0) + return e + +class PrefetchedWrapper(object): + def prefetched_loader(loader, num_classes, fp16, one_hot): + # if num_classes == 10 or num_classes == 100: # Cifar10 + # mean = torch.tensor([0.491 * 255, 0.482 * 255, 0.447 * 255]).cuda().view(1, 3, 1, 1) + # std = torch.tensor([0.247 * 255, 0.243 * 255, 0.262 * 255]).cuda().view(1, 3, 1, 1) + # else: + mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1) + std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1) + if fp16: + mean = mean.half() + std = std.half() + + stream = torch.cuda.Stream() + first = True + + for next_indices, next_data in loader: + next_input, next_target = next_data + + with torch.cuda.stream(stream): + next_input = next_input.cuda(non_blocking=True) + next_target = next_target.cuda(non_blocking=True) + if fp16: + next_input = next_input.half() + if one_hot: + next_target = expand(num_classes, torch.half, next_target) + else: + next_input = next_input.float() + if one_hot: + next_target = expand(num_classes, torch.float, next_target) + + next_input = next_input.sub_(mean).div_(std) + + if not first: + yield input, target, indices + else: + first = False + + torch.cuda.current_stream().wait_stream(stream) + input = next_input + target = next_target + N = input.shape[0] + indices = next_indices.copy() + + yield input, target, indices + + def __init__(self, dataloader, num_classes, fp16, one_hot): + self.dataloader = dataloader + self.fp16 = fp16 + self.epoch = 0 + self.one_hot = one_hot + self.num_classes = num_classes + + def __iter__(self): + if (self.dataloader.sampler is not None and + isinstance(self.dataloader.sampler, + torch.utils.data.distributed.DistributedSampler)): + + self.dataloader.sampler.set_epoch(self.epoch) + self.epoch += 1 + return PrefetchedWrapper.prefetched_loader(self.dataloader, self.num_classes, self.fp16, self.one_hot) + +def get_pytorch_train_loader(data_path, batch_size, num_classes, one_hot, workers=5, _worker_init_fn=None, fp16=False): + traindir = os.path.join(data_path, 'train') + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + ])) + + if torch.distributed.is_initialized(): + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + else: + train_sampler = None + + train_loader = dataloader.DataLoader( + train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), + num_workers=workers, worker_init_fn=_worker_init_fn, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate, drop_last=False) + + return PrefetchedWrapper(train_loader, num_classes, fp16, one_hot), len(train_loader) + +def get_pytorch_val_loader(data_path, batch_size, num_classes, one_hot, workers=5, _worker_init_fn=None, fp16=False): + valdir = os.path.join(data_path, 'val') + val_dataset = datasets.ImageFolder( + valdir, transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + ])) + + if torch.distributed.is_initialized(): + val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) + else: + val_sampler = None + + val_loader = dataloader.DataLoader( + val_dataset, + sampler=val_sampler, + batch_size=batch_size, shuffle=False, + num_workers=workers, worker_init_fn=_worker_init_fn, pin_memory=True, + collate_fn=fast_collate) + + return PrefetchedWrapper(val_loader, num_classes, fp16, one_hot), len(val_loader) + + +def get_pytorch_train_loader_cifar10(data_path, batch_size, num_classes, one_hot, workers=5, _worker_init_fn=None, fp16=False): + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip() + ]) + if num_classes == 10: + print('Loading CIFAR10') + train_dataset = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform_train) + else: + print('Loading CIFAR100') + train_dataset = datasets.CIFAR100(root=data_path, train=True, download=True, transform=transform_train) + + if torch.distributed.is_initialized(): + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + else: + train_sampler = None + + # train_loader = torch.utils.data.DataLoader( + train_loader = dataloader.DataLoader( + train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), + num_workers=workers, worker_init_fn=_worker_init_fn, pin_memory=True, sampler=train_sampler, + collate_fn=fast_collate, drop_last=False) + + # return train_loader, len(train_loader) + return PrefetchedWrapper(train_loader, num_classes, fp16, one_hot), len(train_loader) + + +def get_pytorch_val_loader_cifar10(data_path, batch_size, num_classes, one_hot, workers=5, _worker_init_fn=None, fp16=False): + if num_classes == 10: + val_dataset = datasets.CIFAR10(root=data_path, train=False, download=True) + else: + val_dataset = datasets.CIFAR100(root=data_path, train=False, download=True) + + if torch.distributed.is_initialized(): + val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) + else: + val_sampler = None + + # val_loader = torch.utils.data.DataLoader( + val_loader = dataloader.DataLoader( + val_dataset, + sampler=val_sampler, + batch_size=batch_size, shuffle=False, + num_workers=workers, worker_init_fn=_worker_init_fn, pin_memory=True, + collate_fn=fast_collate) + + return PrefetchedWrapper(val_loader, num_classes, fp16, one_hot), len(val_loader) + + +def get_pytorch_debug_loader_cifar10(data_path, batch_size, num_classes, one_hot, workers=5, _worker_init_fn=None, fp16=False): + if num_classes == 10: + val_dataset = datasets.CIFAR10(root=data_path, train=False, download=True) + else: + val_dataset = datasets.CIFAR100(root=data_path, train=False, download=True) + n = val_dataset.data.shape[0] + n = n//batch_size * batch_size + val_dataset.data = val_dataset.data[:n] + val_dataset.targets = val_dataset.targets[:n] + + if torch.distributed.is_initialized(): + val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) + else: + val_sampler = None + + # val_loader = torch.utils.data.DataLoader( + val_loader = dataloader.DataLoader( + val_dataset, + sampler=val_sampler, + batch_size=batch_size, shuffle=False, + num_workers=workers, worker_init_fn=_worker_init_fn, pin_memory=True, + collate_fn=fast_collate) + + return PrefetchedWrapper(val_loader, num_classes, fp16, one_hot), len(val_loader) diff --git a/image_classification/image_classification/debug.py b/image_classification/image_classification/debug.py new file mode 100644 index 0000000..525b28d --- /dev/null +++ b/image_classification/image_classification/debug.py @@ -0,0 +1,213 @@ +from actnn import config, QScheme, QBNScheme +from .utils import * +import matplotlib +matplotlib.use('Agg') +from matplotlib import pyplot as plt +from tqdm import tqdm +import numpy as np +import pickle +from matplotlib.colors import LogNorm +from copy import deepcopy + + +def get_var(model_and_loss, optimizer, val_loader, num_batches=20, model_state=None): + num_samples = 3 + # print(QF.num_samples, QF.update_scale, QF.training) + model_and_loss.train() + if hasattr(model_and_loss.model, 'module'): + m = model_and_loss.model.module + else: + m = model_and_loss.model.model + + m.set_name() + weight_names = [layer.layer_name for layer in m.linear_layers] + + data_iter = enumerate(val_loader) + inputs = [] + targets = [] + indices = [] + config.compress_activation = False + QScheme.update_scale = True + + def bp(input, target): + optimizer.zero_grad() + loss, output = model_and_loss(input, target) + loss.backward() + torch.cuda.synchronize() + grad = {layer.layer_name: layer.weight.grad.detach().cpu() for layer in m.linear_layers} + return grad #, output + + def save_scale(prefix, index): + scales = [layer.scheme.scales[index] for layer in m.linear_layers] + torch.save(scales, prefix + '_scale.pt') + + # First pass + cnt = 0 + batch_grad = None + for i, (input, target, index) in tqdm(data_iter): + QScheme.batch = index + cnt += 1 + + # if i == 0: + # print(index) + # print('Old scale: ') + # scale = m.linear_layers[0].scheme.scales[index] + # print(scale) + # save_scale('old', index) + + inputs.append(input.clone().cpu()) + targets.append(target.clone().cpu()) + indices.append(index.copy()) + mean_grad = bp(input, target) + batch_grad = dict_add(batch_grad, mean_grad) + + # if i == 0: + # print('New scale: ') + # scale = m.linear_layers[0].scheme.scales[index] + # print(scale) + # save_scale('new', index) + # + # schemes = [layer.scheme for layer in m.linear_layers] + # data = [(s.input, s.output, s.grad_input, s.grad_output) + # for s in schemes] + # weights = [layer.weight for layer in m.linear_layers] + # torch.save([data, weights, output, targets], 'data.pt') + # exit(0) + + # exit(0) + if cnt == num_batches: + break + + num_batches = cnt + batch_grad = dict_mul(batch_grad, 1.0 / num_batches) + QScheme.update_scale = False + + print('=======') + print(m.linear_layers[0].scheme.scales) + print('=======') + + if model_state is not None: + model_and_loss.load_model_state(model_state) + + if config.perlayer: + config.compress_activation = True + QScheme.batch = indices[0] + grad = bp(inputs[0].cuda(), targets[0].cuda()) + QScheme.allocate_perlayer() + QBNScheme.allocate_perlayer() + + total_var = None + total_error = None + total_bias = None + sample_var = None + for i, input, target, index in tqdm(zip(range(num_batches), inputs, targets, indices)): + input = input.cuda() + target = target.cuda() + QScheme.batch = index + config.compress_activation = False + exact_grad = bp(input, target) + sample_var = dict_add(sample_var, dict_sqr(dict_minus(exact_grad, batch_grad))) + + mean_grad = None + second_momentum = None + config.compress_activation = True + for iter in range(num_samples): + grad = bp(input, target) + + mean_grad = dict_add(mean_grad, grad) + total_error = dict_add(total_error, dict_sqr(dict_minus(exact_grad, grad))) + second_momentum = dict_add(second_momentum, dict_sqr(grad)) + + mean_grad = dict_mul(mean_grad, 1.0 / num_samples) + second_momentum = dict_mul(second_momentum, 1.0 / num_samples) + + grad_bias = dict_sqr(dict_minus(mean_grad, exact_grad)) + total_bias = dict_add(total_bias, grad_bias) + + grad_var = dict_minus(second_momentum, dict_sqr(mean_grad)) + total_var = dict_add(total_var, grad_var) + + total_error = dict_mul(total_error, 1.0 / (num_samples * num_batches)) + total_bias = dict_mul(total_bias, 1.0 / num_batches) + total_var = dict_mul(total_var, 1.0 / num_batches) + + all_qg = 0 + all_b = 0 + all_s = 0 + for k in total_var: + g = (batch_grad[k]**2).sum() + sv = sample_var[k].sum() + v = total_var[k].sum() + b = total_bias[k].sum() + e = total_error[k].sum() + avg_v = v / total_var[k].numel() + + all_qg += v + all_b += b + all_s += sv + print('{}, grad norm = {}, sample var = {}, bias = {}, var = {}, avg_var = {}, error = {}'.format(k, g, sv, b, v, avg_v, e)) + + print('Overall Bias = {}, Var = {}, SampleVar = {}'.format(all_b, all_qg, all_s)) + + +def get_var_during_training(model_and_loss, optimizer, val_loader, num_batches=20): + num_samples = 3 + # print(QF.num_samples, QF.update_scale, QF.training) + model_and_loss.train() + if hasattr(model_and_loss.model, 'module'): + m = model_and_loss.model.module + else: + m = model_and_loss.model + + m.set_name() + weight_names = [layer.layer_name for layer in m.linear_layers] + + print('=======') + print(m.linear_layers[0].scheme.scales) + print('=======') + + def bp(input, target): + optimizer.zero_grad() + loss, output = model_and_loss(input, target) + loss.backward() + torch.cuda.synchronize() + grad = {layer.layer_name: layer.weight.grad.detach().cpu() for layer in m.linear_layers} + return grad + + QScheme.update_scale = False + data_iter = enumerate(val_loader) + + total_var = None + cnt = 0 + for i, (input, target, index) in tqdm(data_iter): + QScheme.batch = index + cnt += 1 + if cnt == num_batches: + break + + mean_grad = None + second_momentum = None + for iter in range(num_samples): + grad = bp(input, target) + + mean_grad = dict_add(mean_grad, grad) + second_momentum = dict_add(second_momentum, dict_sqr(grad)) + + mean_grad = dict_mul(mean_grad, 1.0 / num_samples) + second_momentum = dict_mul(second_momentum, 1.0 / num_samples) + grad_var = dict_minus(second_momentum, dict_sqr(mean_grad)) + total_var = dict_add(total_var, grad_var) + + num_batches = cnt + total_var = dict_mul(total_var, 1.0 / num_batches) + + all_qg = 0 + for k in total_var: + v = total_var[k].sum() + avg_v = v / total_var[k].numel() + + all_qg += v + print('{}, var = {}, avg_var = {}'.format(k, v, avg_v)) + + print('Overall Var = {}'.format(all_qg)) + diff --git a/image_classification/image_classification/logger.py b/image_classification/image_classification/logger.py new file mode 100644 index 0000000..0c262b9 --- /dev/null +++ b/image_classification/image_classification/logger.py @@ -0,0 +1,299 @@ +import random +import json +from collections import OrderedDict + + +class IterationMeter(object): + def __init__(self): + self.reset() + + def reset(self): + self.last = 0 + + def record(self, val, n = 1): + self.last = val + + def get_val(self): + return None + + def get_last(self): + return self.last + + +class EpochMeter(object): + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + + def record(self, val, n = 1): + self.val = val + + def get_val(self): + return self.val + + def get_last(self): + return None + + +class AverageMeter(object): + def __init__(self, ret_last=True, ret_val=True): + self.reset() + self.ret_last = ret_last + self.ret_val = ret_val + + def reset(self): + self.n = 0 + self.val = 0 + self.last = 0 + + def record(self, val, n = 1): + self.last = val + self.n += n + self.val += val * n + + def get_val(self): + if self.ret_val: + if self.n == 0: + return 0.0 + return self.val / self.n + else: + return None + + def get_last(self): + if self.ret_last: + return self.last + else: + return None + + +class RunningMeter(object): + def __init__(self, decay): + self.decay = decay + + def reset(self): + self.val = 0 + self.last = 0 + + def record(self, val, n = 1): + self.last = val + decay = 1 - ((1 - self.decay) ** n) + self.val = (1 - decay) * self.val + decay * val + + def get_val(self): + return self.val + + def get_last(self): + return self.last + + +class Logger(object): + def __init__(self, print_interval, backends, verbose=False): + self.epoch = -1 + self.iteration = -1 + self.val_iteration = -1 + self.metrics = OrderedDict() + self.backends = backends + self.print_interval = print_interval + self.verbose = verbose + + def log_run_tag(self, name, val): + for b in self.backends: + b.log_run_tag(name, val) + + def register_metric(self, metric_name, meter, log_level=0): + if self.verbose: + print("Registering metric: {}".format(metric_name)) + self.metrics[metric_name] = {'meter' : meter, 'level' : log_level} + + def log_metric(self, metric_name, val, n=1): + self.metrics[metric_name]['meter'].record(val, n=n) + + def start_iteration(self, val=False): + if val: + self.val_iteration += 1 + else: + self.iteration += 1 + + def end_iteration(self, val=False): + it = self.val_iteration if val else self.iteration + if (it % self.print_interval == 0): + for b in self.backends: + if val: + b.log_iteration_metric('val.it', it) + else: + b.log_iteration_metric('it', it) + + f = lambda l: filter(lambda m : m['level'] <= b.level) + for n, m in [(n, m) for n, m in self.metrics.items() if m['level'] <= b.level and n.startswith('val') == val]: + mv = m['meter'].get_last() + if mv is not None: + b.log_iteration_metric(n, mv) + + b.log_end_iteration() + + def start_epoch(self): + self.epoch += 1 + self.iteration = 0 + self.val_iteration = 0 + + for b in self.backends: + b.log_epoch_metric('ep', self.epoch) + + for n, m in [(n, m) for n, m in self.metrics.items() if m['level'] <= b.level]: + m['meter'].reset() + + def end_epoch(self): + for b in self.backends: + for n, m in [(n, m) for n, m in self.metrics.items() if m['level'] <= b.level]: + mv = m['meter'].get_val() + if mv is not None: + b.log_epoch_metric(n, mv) + b.log_end_epoch() + + def end(self): + for b in self.backends: + b.end() + + def iteration_generator_wrapper(self, gen, val = False): + for g in gen: + self.start_iteration(val = val) + yield g + self.end_iteration(val = val) + + def epoch_generator_wrapper(self, gen): + for g in gen: + self.start_epoch() + yield g + self.end_epoch() + + +class JsonBackend(object): + def __init__(self, filename, log_level=0): + print("Logger: ", filename) + self.level = log_level + self.filename = filename + self.json_log = OrderedDict([ + ('run' , OrderedDict()), + ('epoch', OrderedDict()), + ('iter' , OrderedDict()), + ('event', OrderedDict()), + ]) + + def log_run_tag(self, name, val): + self.json_log['run'][name] = val + + def log_end_epoch(self): + pass + + def log_end_iteration(self): + pass + + def log_epoch_metric(self, name, val): + if not name in self.json_log['epoch'].keys(): + self.json_log['epoch'][name] = [] + + self.json_log['epoch'][name].append(val) + + if name != 'ep': + if name in self.json_log['iter'].keys(): + self.json_log['iter'][name].append([]) + else: + if not 'it' in self.json_log['iter'].keys(): + self.json_log['iter']['it'] = [] + + self.json_log['iter']['it'].append([]) + + def log_iteration_metric(self, name, val): + if not (name in self.json_log['iter'].keys()): + self.json_log['iter'][name] = [[]] + + self.json_log['iter'][name][-1].append(val) + + def end(self): + print(json.dump(self.json_log, open(self.filename, 'w'))) + + +class StdOut1LBackend(object): + def __init__(self, iters, val_iters, epochs, log_level=0): + self.level = log_level + self.iteration = 0 + self.total_iterations = iters + self.total_val_iterations = val_iters + self.epoch = 0 + self.total_epochs = epochs + self.iteration_metrics = {} + self.epoch_metrics = {} + self.mode = 'train' + + def log_run_tag(self, name, val): + print("{} : {}".format(name, val)) + + def log_end_epoch(self): + print("Summary Epoch: {}/{};\t{}".format( + self.epoch, self.total_epochs, + "\t".join(["{} : {:.3f}".format(m,v) for m, v in self.epoch_metrics.items()]))) + + self.epoch_metrics = {} + + def log_end_iteration(self): + md = "Validation" if self.mode == 'val' else "" + ti = self.total_val_iterations if self.mode == 'val' else self.total_iterations + print("Epoch: {}/{} {} Iteration: {}/{};\t{}".format( + self.epoch, self.total_epochs, md, self.iteration, ti, + "\t".join(["{} : {:.3f}".format(m,v) for m, v in self.iteration_metrics.items()]))) + + self.iteration_metrics = {} + + def log_epoch_metric(self, name, value): + if name == 'ep': + self.epoch = value + self.iteration = 0 + else: + self.epoch_metrics[name] = value + + def log_iteration_metric(self, name, value): + if name == 'it' or name == 'val.it': + self.mode = 'train' if name == 'it' else 'val' + self.iteration = value + else: + self.iteration_metrics[name] = value + + def end(self): + pass + + + +class StdOutBackend(object): + def __init__(self, iters, epochs, log_level=0): + self.level = log_level + self.iteration = 0 + self.epoch = 0 + + def log_run_tag(self, name, val): + print("{} : {}".format(name, val)) + + def log_end_epoch(self): + pass + + def log_end_iteration(self): + pass + + def log_epoch_metric(self, name, value): + if name == 'ep': + self.epoch = value + self.iteration = 0 + else: + print("Summary Epoch: {}; {} = {:.3f}".format(self.epoch, name, value)) + + def log_iteration_metric(self, name, value): + if name == 'it' or name == 'val.it': + self.iteration = value + else: + print("Epoch: {} Iteration: {}; {} = {:.3f}".format(self.epoch, self.iteration, name, value)) + + def end(self): + pass + + diff --git a/image_classification/image_classification/mixup.py b/image_classification/image_classification/mixup.py new file mode 100644 index 0000000..ebdf8f1 --- /dev/null +++ b/image_classification/image_classification/mixup.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn +import numpy as np + + +def mixup(alpha, num_classes, data, target): + with torch.no_grad(): + bs = data.size(0) + c = np.random.beta(alpha, alpha) + + perm = torch.randperm(bs).cuda() + + md = c * data + (1-c) * data[perm, :] + mt = c * target + (1-c) * target[perm, :] + return md, mt + + +class MixUpWrapper(object): + def __init__(self, alpha, num_classes, dataloader): + self.alpha = alpha + self.dataloader = dataloader + self.num_classes = num_classes + + def mixup_loader(self, loader): + for input, target in loader: + i, t = mixup(self.alpha, self.num_classes, input, target) + yield i, t + + def __iter__(self): + return self.mixup_loader(self.dataloader) + + +class NLLMultiLabelSmooth(nn.Module): + def __init__(self, smoothing = 0.0): + super(NLLMultiLabelSmooth, self).__init__() + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + + def forward(self, x, target): + if self.training: + x = x.float() + target = target.float() + logprobs = torch.nn.functional.log_softmax(x, dim = -1) + + nll_loss = -logprobs * target + nll_loss = nll_loss.sum(-1) + + smooth_loss = -logprobs.mean(dim=-1) + + loss = self.confidence * nll_loss + self.smoothing * smooth_loss + + return loss.mean() + else: + return torch.nn.functional.cross_entropy(x, target) diff --git a/image_classification/image_classification/plot_variance.py b/image_classification/image_classification/plot_variance.py new file mode 100644 index 0000000..d829c1f --- /dev/null +++ b/image_classification/image_classification/plot_variance.py @@ -0,0 +1,64 @@ +import numpy as np +import matplotlib +matplotlib.use('Agg') +from matplotlib import pyplot as plt +from matplotlib.colors import LogNorm +import pickle +import seaborn as sns + +weight_names = pickle.load(open('layer_names.pkl', 'rb')) + +grads = np.load('error_profile_5.npy') +grads = np.maximum(grads, 0) +quantizers = grads.sum(1) +variances = grads.sum(0) + +# grads = np.minimum(grads, 1) +# grads *= 1000 +for i in range(grads.shape[0]): + for j in range(grads.shape[1]): + if j > i: + grads[i, j] = 0 + +fig, ax = plt.subplots(figsize=(20, 20)) +im = ax.imshow(grads, cmap='Blues', norm=LogNorm(vmin=0.01, vmax=10.0)) +ax.set_xticks(np.arange(len(weight_names))) +ax.set_yticks(np.arange(len(weight_names)+1)) +ax.set_xticklabels(weight_names) +weight_names.append('sample') +ax.set_yticklabels(weight_names) + +ax.tick_params(top=True, bottom=False, + labeltop=True, labelbottom=False) +plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", + rotation_mode="anchor") +cbar = ax.figure.colorbar(im, ax=ax) + +for i in range(grads.shape[0]): + for j in range(grads.shape[1]): + text = ax.text(j, i, int(grads[i, j]*10), + ha="center", va="center") + + +fig.savefig('variance_profile.pdf') + +fig, ax = plt.subplots(figsize=(20, 20)) +sns.barplot(x=np.arange(quantizers.shape[0]), y=quantizers, ax=ax) +ax.set_xticks(np.arange(len(weight_names)) + 0.5) +ax.set_xticklabels(weight_names) +plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", + rotation_mode="anchor") +ax.set_xlabel('quantizer') +ax.set_ylabel('variance') +fig.savefig('quantizers.pdf') + +fig, ax = plt.subplots(figsize=(20, 20)) +sns.barplot(x=np.arange(variances.shape[0]), y=variances, ax=ax) +weight_names.pop(-1) +ax.set_xticks(np.arange(len(weight_names)) + 0.5) +ax.set_xticklabels(weight_names) +plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", + rotation_mode="anchor") +ax.set_xlabel('parameter') +ax.set_ylabel('variance') +fig.savefig('parameter.pdf') diff --git a/image_classification/image_classification/preact_resnet.py b/image_classification/image_classification/preact_resnet.py new file mode 100644 index 0000000..5cbc923 --- /dev/null +++ b/image_classification/image_classification/preact_resnet.py @@ -0,0 +1,195 @@ +'''Pre-activation ResNet in PyTorch. + +Reference: +[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Identity Mappings in Deep Residual Networks. arXiv:1603.05027 +''' +import math +import torch +import torch.nn as nn +from actnn import QModule + + +class PreActBlock(nn.Module): + expansion = 1 + M = 2 + + def __init__(self, builder, inplanes, planes, stride=1, downsample=None): + super(PreActBlock, self).__init__() + self.bn1 = builder.batchnorm(inplanes) + self.relu = builder.activation() + self.conv1 = builder.conv3x3(inplanes, planes, stride) + self.bn2 = builder.batchnorm(planes, last_bn=True) + self.conv2 = builder.conv3x3(planes, planes) + self.downsample = downsample + self.stride = stride + self.debug = False + + def forward(self, x): + residual = x + + out = self.bn1(x) + out = self.relu(out) + + if self.downsample is not None: + residual = self.downsample(out) + + if self.debug: + out.retain_grad() + self.conv1_in = out + + out = self.conv1(out) + + if self.debug: + out.retain_grad() + self.conv1_out = out + + out = self.bn2(out) + out = self.relu(out) + + if self.debug: + out.retain_grad() + self.conv2_in = out + + out = self.conv2(out) + + if self.debug: + out.retain_grad() + self.conv2_out = out + + out += residual + + return out + + +class PreActBottleneck(nn.Module): + expansion = 4 + M = 3 + + def __init__(self, builder, inplanes, planes, stride=1, downsample=None): + super(PreActBottleneck, self).__init__() + self.bn1 = builder.batchnorm(inplanes) + self.relu = builder.activation() + self.conv1 = builder.conv1x1(inplanes, planes) + self.bn2 = builder.batchnorm(planes) + self.conv2 = builder.conv3x3(planes, planes, stride=stride) + self.bn3 = builder.batchnorm(planes, last_bn=True) + self.conv3 = builder.conv1x1(planes, planes*4) + self.downsample = downsample + self.stride = stride + self.debug = False + + def forward(self, x): + residual = x + + out = self.bn1(x) + out = self.relu(out) + + if self.downsample is not None: + residual = self.downsample(out) + + if self.debug: + out.retain_grad() + self.conv1_in = out + + out = self.conv1(out) + + if self.debug: + out.retain_grad() + self.conv1_out = out + + out = self.bn2(out) + out = self.relu(out) + + if self.debug: + out.retain_grad() + self.conv2_in = out + + out = self.conv2(out) + + if self.debug: + out.retain_grad() + self.conv2_out = out + + out = self.bn3(out) + out = self.relu(out) + + if self.debug: + out.retain_grad() + self.conv3_in = out + + out = self.conv3(out) + + if self.debug: + out.retain_grad() + self.conv3_out = out + + out += residual + + return out + + +class PreActResNet(nn.Module): + def __init__(self, builder, block, num_blocks, num_classes=10): + super(PreActResNet, self).__init__() + self.inplanes = 16 + self.builder = builder + self.conv1 = builder.conv3x3(3, 16) + self.layer1 = self._make_layer(block, 16, num_blocks[0]) + self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) + self.bn = builder.batchnorm(64 * block.expansion) + self.relu = builder.activation() + #self.avgpool = nn.AvgPool2d(8, stride=1) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = builder.linear(64 * block.expansion, num_classes) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + self.builder.conv1x1(self.inplanes, planes * block.expansion, stride=stride) + ) + + layers = [] + layers.append(block(self.builder, self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.builder, self.inplanes, planes)) + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.bn(x) + x = self.relu(x) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + def set_debug(self, debug): + for l in [self.layer1, self.layer2, self.layer3]: + for b in l: + b.debug = debug + + + def set_name(self): + self.linear_layers = [self.conv1] + self.conv1.layer_name = 'conv_0' + for lid, layer in enumerate([self.layer1, self.layer2, self.layer3]): + for bid, block in enumerate(layer): + for cid, convlayer in enumerate([block.conv1, block.conv2]): + convlayer.layer_name = 'conv_{}_{}_{}'.format(lid+1, bid+1, cid+1) + self.linear_layers.append(convlayer) + if block.downsample is not None: + block.downsample[0].layer_name = 'conv_{}_{}_skip'.format(lid+1, bid+1) + self.linear_layers.append(block.downsample[0]) + + self.fc.layer_name = 'fc' + self.linear_layers.append(self.fc) diff --git a/image_classification/image_classification/resnet.py b/image_classification/image_classification/resnet.py new file mode 100644 index 0000000..232a63d --- /dev/null +++ b/image_classification/image_classification/resnet.py @@ -0,0 +1,485 @@ +import torch +import torch.nn as nn +from .preact_resnet import PreActBlock, PreActBottleneck, PreActResNet + +from actnn import QConv2d, QLinear, QBatchNorm2d, QReLU, QSyncBatchNorm, QMaxPool2d, config + +__all__ = ['ResNet', 'build_resnet', 'resnet_versions', 'resnet_configs'] + +# ResNetBuilder {{{ + +class ResNetBuilder(object): + def __init__(self, version, config): + self.config = config + + self.L = sum(version['layers']) + self.M = version['block'].M + + def conv(self, kernel_size, in_planes, out_planes, stride=1): + if kernel_size == 3: + conv = self.config['conv'](in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + elif kernel_size == 1: + conv = self.config['conv'](in_planes, out_planes, kernel_size=1, stride=stride, + bias=False) + elif kernel_size == 5: + conv = self.config['conv'](in_planes, out_planes, kernel_size=5, stride=stride, + padding=2, bias=False) + elif kernel_size == 7: + conv = self.config['conv'](in_planes, out_planes, kernel_size=7, stride=stride, + padding=3, bias=False) + else: + return None + + if self.config['nonlinearity'] == 'relu': + nn.init.kaiming_normal_(conv.weight, + mode=self.config['conv_init'], + nonlinearity=self.config['nonlinearity']) + + return conv + + def conv3x3(self, in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + c = self.conv(3, in_planes, out_planes, stride=stride) + return c + + def conv1x1(self, in_planes, out_planes, stride=1): + """1x1 convolution with padding""" + c = self.conv(1, in_planes, out_planes, stride=stride) + return c + + def conv7x7(self, in_planes, out_planes, stride=1): + """7x7 convolution with padding""" + c = self.conv(7, in_planes, out_planes, stride=stride) + return c + + def conv5x5(self, in_planes, out_planes, stride=1): + """5x5 convolution with padding""" + c = self.conv(5, in_planes, out_planes, stride=stride) + return c + + def batchnorm(self, planes, last_bn=False): + if config.debug_remove_bn: + return nn.Identity() + bn = self.config['bn'](planes) + + gamma_init_val = 0 if last_bn and self.config['last_bn_0_init'] else 1 + nn.init.constant_(bn.weight, gamma_init_val) + nn.init.constant_(bn.bias, 0) + + return bn + + def max_pool2d(self, *args, **kwargs): + return self.config['max_pool2d'](*args, **kwargs) + + def linear(self, in_planes, out_planes): + return self.config['linear'](in_planes, out_planes) + + def activation(self): + if config.debug_remove_relu: + return nn.Identity() + return self.config['activation']() + +# ResNetBuilder }}} + +# BasicBlock {{{ +class BasicBlock(nn.Module): + M = 2 + expansion = 1 + + def __init__(self, builder, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = builder.conv3x3(inplanes, planes, stride) + self.bn1 = builder.batchnorm(planes) + self.relu = builder.activation() + self.conv2 = builder.conv3x3(planes, planes) + self.bn2 = builder.batchnorm(planes, last_bn=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + if self.bn1 is not None: + out = self.bn1(out) + + out = self.relu(out) + + out = self.conv2(out) + + if self.bn2 is not None: + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out +# BasicBlock }}} + +# Bottleneck {{{ +class Bottleneck(nn.Module): + M = 3 + expansion = 4 + + def __init__(self, builder, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = builder.conv1x1(inplanes, planes) + self.bn1 = builder.batchnorm(planes) + self.conv2 = builder.conv3x3(planes, planes, stride=stride) + self.bn2 = builder.batchnorm(planes) + self.conv3 = builder.conv1x1(planes, planes * self.expansion) + self.bn3 = builder.batchnorm(planes * self.expansion, last_bn=True) + self.relu = builder.activation() + self.downsample = downsample + self.stride = stride + self.debug = False + + def forward(self, x): + residual = x + + if self.debug: + x.retain_grad() + self.conv1_in = x + + out = self.conv1(x) + + if self.debug: + x.retain_grad() + self.conv1_out = out + + out = self.bn1(out) + + if self.debug: + x.retain_grad() + self.conv1_bn_out = out + + out = self.relu(out) + + if self.debug: + x.retain_grad() + self.conv1_relu_out = out + + if self.debug: + out.retain_grad() + self.conv2_in = out + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + if self.debug: + out.retain_grad() + self.conv3_in = out + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + + out = self.relu(out) + + return out +# Bottleneck }}} + +# ResNet {{{ +class ResNet(nn.Module): + def __init__(self, builder, block, layers, num_classes=1000): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = builder.conv7x7(3, 64, stride=2) + self.bn1 = builder.batchnorm(64) + self.relu = builder.activation() + self.maxpool = builder.max_pool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(builder, block, 64, layers[0]) + self.layer2 = self._make_layer(builder, block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(builder, block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(builder, block, 512, layers[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = builder.linear(512 * block.expansion, num_classes) + + def _make_layer(self, builder, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + dconv = builder.conv1x1(self.inplanes, planes * block.expansion, + stride=stride) + dbn = builder.batchnorm(planes * block.expansion) + if dbn is not None: + downsample = nn.Sequential(dconv, dbn) + else: + downsample = dconv + + layers = [] + layers.append(block(builder, self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(builder, self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + if self.bn1 is not None: + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + def set_precision(self): # Hack + self.bn1.scheme.bits = self.conv1.scheme.bits + for block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for layer in block: + layer.bn1.scheme.bits = layer.conv1.scheme.bits + layer.bn2.scheme.bits = layer.conv2.scheme.bits + layer.bn3.scheme.bits = layer.conv3.scheme.bits + layer.bn1.scheme.conv_input_norm = layer.conv1.conv_input_norm + layer.bn2.scheme.conv_input_norm = layer.conv2.conv_input_norm + layer.bn3.scheme.conv_input_norm = layer.conv3.conv_input_norm + if layer.downsample is not None: + layer.downsample[1].scheme.bits = layer.downsample[0].scheme.bits + layer.downsample[1].scheme.conv_input_norm = layer.downsample[0].conv_input_norm + + def set_debug(self, debug): + self.debug = True + for l in [self.layer1, self.layer2, self.layer3, self.layer4]: + for b in l: + b.debug = debug + + def set_name(self): + self.linear_layers = [self.conv1] + self.conv1.layer_name = 'conv_0' + for lid, layer in enumerate([self.layer1, self.layer2, self.layer3, self.layer4]): + for bid, block in enumerate(layer): + for cid, convlayer in enumerate([block.conv1, block.conv2, block.conv3]): + convlayer.layer_name = 'conv_{}_{}_{}'.format(lid+1, bid+1, cid+1) + self.linear_layers.append(convlayer) + if block.downsample is not None: + block.downsample[0].layer_name = 'conv_{}_{}_skip'.format(lid+1, bid+1) + self.linear_layers.append(block.downsample[0]) + + self.fc.layer_name = 'fc' + self.linear_layers.append(self.fc) + + +# ResNet }}} + + +# ResNet {{{ +class ResNetCifar(nn.Module): + def __init__(self, builder, block, layers, num_classes=10): + self.inplanes = 16 + super(ResNetCifar, self).__init__() + self.conv1 = builder.conv3x3(3, 16) + self.bn1 = builder.batchnorm(16) + self.relu = builder.activation() + self.layer1 = self._make_layer(builder, block, 16, layers[0]) + self.layer2 = self._make_layer(builder, block, 32, layers[1], stride=2) + self.layer3 = self._make_layer(builder, block, 64, layers[2], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = builder.linear(64 * block.expansion, num_classes) + + def _make_layer(self, builder, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + dconv = builder.conv1x1(self.inplanes, planes * block.expansion, + stride=stride) + dbn = builder.batchnorm(planes * block.expansion) + if dbn is not None: + downsample = nn.Sequential(dconv, dbn) + else: + downsample = dconv + + layers = [] + layers.append(block(builder, self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(builder, self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + if self.bn1 is not None: + x = self.bn1(x) + x = self.relu(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + def set_debug(self, debug): + self.debug = True + for l in [self.layer1, self.layer2, self.layer3]: + for b in l: + b.debug = debug + + def set_name(self): + self.linear_layers = [self.conv1] + self.conv1.layer_name = 'conv_0' + for lid, layer in enumerate([self.layer1, self.layer2, self.layer3]): + for bid, block in enumerate(layer): + for cid, convlayer in enumerate([block.conv1, block.conv2]): + convlayer.layer_name = 'conv_{}_{}_{}'.format(lid+1, bid+1, cid+1) + self.linear_layers.append(convlayer) + if block.downsample is not None: + block.downsample[0].layer_name = 'conv_{}_{}_skip'.format(lid+1, bid+1) + self.linear_layers.append(block.downsample[0]) + + self.fc.layer_name = 'fc' + self.linear_layers.append(self.fc) + + +# ResNet }}} + +resnet_configs = { + 'classic' : { + 'conv' : nn.Conv2d, + 'linear' : nn.Linear, + 'bn' : nn.BatchNorm2d, + 'max_pool2d' : nn.MaxPool2d, + 'conv_init' : 'fan_out', + 'nonlinearity' : 'relu', + 'last_bn_0_init' : False, + 'activation' : lambda: nn.ReLU(inplace=True), + 'quantize_forward': False + }, + 'fanin' : { + 'conv' : nn.Conv2d, + 'linear' : nn.Linear, + 'bn' : nn.BatchNorm2d, + 'max_pool2d' : nn.MaxPool2d, + 'conv_init' : 'fan_in', + 'nonlinearity' : 'relu', + 'last_bn_0_init' : False, + 'activation' : lambda: nn.ReLU(inplace=True), + 'quantize_forward': False + }, + 'quantize' : { + 'conv' : QConv2d, + 'linear' : QLinear, + 'bn' : QBatchNorm2d, + 'max_pool2d' : QMaxPool2d, + 'conv_init' : 'fan_in', + 'nonlinearity' : 'relu', + 'last_bn_0_init' : False, + 'activation' : QReLU, + 'quantize_forward': True + }, + 'qlinear' : { + 'conv' : QConv2d, + 'linear' : QLinear, + 'bn' : nn.BatchNorm2d, + 'max_pool2d' : QMaxPool2d, + 'conv_init' : 'fan_in', + 'nonlinearity' : 'relu', + 'last_bn_0_init' : False, + 'activation' : lambda: nn.ReLU(inplace=True), + 'quantize_forward': True + }, + 'qsyncbn': { + 'conv': QConv2d, + 'linear': QLinear, + 'bn': QSyncBatchNorm, + 'max_pool2d' : QMaxPool2d, + 'conv_init': 'fan_in', + 'nonlinearity': 'relu', + 'last_bn_0_init': False, + 'activation': lambda: nn.ReLU(inplace=True), + 'quantize_forward': True + }, +} + +resnet_versions = { + 'resnet18' : { + 'net' : ResNet, + 'block' : BasicBlock, + 'layers' : [2, 2, 2, 2], + }, + 'resnet34' : { + 'net' : ResNet, + 'block' : BasicBlock, + 'layers' : [3, 4, 6, 3], + }, + 'resnet50' : { + 'net' : ResNet, + 'block' : Bottleneck, + 'layers' : [3, 4, 6, 3], + }, + 'resnet101' : { + 'net' : ResNet, + 'block' : Bottleneck, + 'layers' : [3, 4, 23, 3], + }, + 'resnet152' : { + 'net' : ResNet, + 'block' : Bottleneck, + 'layers' : [3, 8, 36, 3], + }, + 'resnet56' : { + 'net' : ResNetCifar, + 'block' : BasicBlock, + 'layers' : [9, 9, 9], + }, + 'preact_resnet20' : { + 'net' : PreActResNet, + 'block' : PreActBlock, + 'layers' : [3, 3, 3], + }, + 'preact_resnet56' : { + 'net' : PreActResNet, + 'block' : PreActBlock, + 'layers' : [9, 9, 9], + }, + 'preact_resnet110' : { + 'net' : PreActResNet, + 'block' : PreActBlock, + 'layers' : [18, 18, 18], + }, + 'preact_resnet164' : { + 'net' : PreActResNet, + 'block' : PreActBottleneck, + 'layers' : [18, 18, 18], + }, + 'preact_resnet1001' : { + 'net' : PreActResNet, + 'block' : PreActBottleneck, + 'layers' : [111, 111, 111], + }, + } + + +def build_resnet(version, config, num_classes, model_state=None): + version = resnet_versions[version] + config = resnet_configs[config] + + builder = ResNetBuilder(version, config) + print("Version: {}".format(version)) + print("Config: {}".format(config)) + model = version['net'](builder, + version['block'], + version['layers'], + num_classes) + + return model diff --git a/image_classification/image_classification/smoothing.py b/image_classification/image_classification/smoothing.py new file mode 100644 index 0000000..99f7466 --- /dev/null +++ b/image_classification/image_classification/smoothing.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn + +class LabelSmoothing(nn.Module): + """ + NLL loss with label smoothing. + """ + def __init__(self, smoothing=0.0): + """ + Constructor for the LabelSmoothing module. + + :param smoothing: label smoothing factor + """ + super(LabelSmoothing, self).__init__() + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + + def forward(self, x, target): + logprobs = torch.nn.functional.log_softmax(x, dim=-1) + + nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) + nll_loss = nll_loss.squeeze(1) + smooth_loss = -logprobs.mean(dim=-1) + loss = self.confidence * nll_loss + self.smoothing * smooth_loss + return loss.mean() + diff --git a/image_classification/image_classification/training.py b/image_classification/image_classification/training.py new file mode 100644 index 0000000..b5fa71a --- /dev/null +++ b/image_classification/image_classification/training.py @@ -0,0 +1,446 @@ +import time +import os +import numpy as np +import torch +import torch.nn as nn +from torch.autograd import Variable +from . import logger as log +from . import resnet as models +from . import utils +from .debug import get_var, get_var_during_training +from actnn import config, QScheme, QModule, get_memory_usage, compute_tensor_bytes, exp_recorder +from copy import copy + +try: + # from apex.parallel import DistributedDataParallel as DDP + from torch.nn.parallel import DistributedDataParallel as DDP + from apex.fp16_utils import * + from apex import amp +except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") + + +MB = 1024**2 +GB = 1024**3 + +class ModelAndLoss(nn.Module): + def __init__(self, arch, num_classes, loss, pretrained_weights=None, cuda=True, fp16=False): + super(ModelAndLoss, self).__init__() + self.arch = arch + + print("=> creating model '{}'".format(arch)) + model = models.build_resnet(arch[0], arch[1], num_classes) + if arch[1] not in ['classic', 'fanin']: + print("=> convert to quantized model") + model = QModule(model) + + if pretrained_weights is not None: + print("=> using pre-trained model from a file '{}'".format(arch)) + model.load_state_dict(pretrained_weights) + + if cuda: + model = model.cuda() + if fp16: + model = network_to_half(model) + + # define loss function (criterion) and optimizer + criterion = loss() + + if cuda: + criterion = criterion.cuda() + + self.model = model + self.loss = criterion + + def forward(self, data, target): + output = self.model(data) + loss = self.loss(output, target) + + return loss, output + + def distributed(self, rank): + self.model = DDP(self.model, device_ids=[rank]) + + def load_model_state(self, state): + if not state is None: + try: + self.model.load_state_dict(state) + except: + state = {k.replace('module.', ''): state[k] for k in state} + self.model.load_state_dict(state) + + + +def get_optimizer(parameters, fp16, lr, momentum, weight_decay, + nesterov=False, + state=None, + static_loss_scale=1., dynamic_loss_scale=False, + bn_weight_decay = False): + + if bn_weight_decay: + print(" ! Weight decay applied to BN parameters ") + optimizer = torch.optim.SGD([v for n, v in parameters], lr, + momentum=momentum, + weight_decay=weight_decay, + nesterov = nesterov) + else: + print(" ! Weight decay NOT applied to BN parameters ") + bn_params = [v for n, v in parameters if 'bn' in n] + rest_params = [v for n, v in parameters if not 'bn' in n] + print(len(bn_params)) + print(len(rest_params)) + optimizer = torch.optim.SGD([{'params': bn_params, 'weight_decay' : 0}, + {'params': rest_params, 'weight_decay' : weight_decay}], + lr, + momentum=momentum, + weight_decay=weight_decay, + nesterov = nesterov) + if fp16: + optimizer = FP16_Optimizer(optimizer, + static_loss_scale=static_loss_scale, + dynamic_loss_scale=dynamic_loss_scale, + verbose=False) + + if not state is None: + optimizer.load_state_dict(state) + + return optimizer + + +def lr_policy(lr_fn, logger=None): + if logger is not None: + logger.register_metric('lr', log.IterationMeter(), log_level=1) + def _alr(optimizer, iteration, epoch): + lr = lr_fn(iteration, epoch) + + if logger is not None: + logger.log_metric('lr', lr) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + return _alr + + +def lr_step_policy(base_lr, steps, decay_factor, warmup_length, logger=None): + def _lr_fn(iteration, epoch): + if epoch < warmup_length: + lr = base_lr * (epoch + 1) / warmup_length + else: + lr = base_lr + for s in steps: + if epoch >= s: + lr *= decay_factor + return lr + + return lr_policy(_lr_fn, logger=logger) + + +def lr_linear_policy(base_lr, warmup_length, epochs, logger=None): + def _lr_fn(iteration, epoch): + if epoch < warmup_length: + lr = base_lr * (epoch + 1) / warmup_length + else: + e = epoch - warmup_length + es = epochs - warmup_length + lr = base_lr * (1-(e/es)) + return lr + + return lr_policy(_lr_fn, logger=logger) + + +def lr_cosine_policy(base_lr, warmup_length, epochs, logger=None): + def _lr_fn(iteration, epoch): + if epoch < warmup_length: + lr = base_lr * (epoch + 1) / warmup_length + else: + e = epoch - warmup_length + es = epochs - warmup_length + lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr + return lr + + return lr_policy(_lr_fn, logger=logger) + + +def lr_exponential_policy(base_lr, warmup_length, epochs, final_multiplier=0.001, logger=None): + es = epochs - warmup_length + epoch_decay = np.power(2, np.log2(final_multiplier)/es) + + def _lr_fn(iteration, epoch): + if epoch < warmup_length: + lr = base_lr * (epoch + 1) / warmup_length + else: + e = epoch - warmup_length + lr = base_lr * (epoch_decay ** e) + return lr + + return lr_policy(_lr_fn, logger=logger) + + + +def get_train_step(model_and_loss, optimizer, fp16, use_amp = False, batch_size_multiplier = 1): + def _step(input, target, optimizer_step = True): + input_var = Variable(input) + target_var = Variable(target) + + if config.debug_memory_model: + print("========== Init Data Loader ===========") + init_mem = get_memory_usage(True) + exp_recorder.record("data_loader", init_mem / GB - exp_recorder.val_dict['model_only'], 2) + + loss, output = model_and_loss(input_var, target_var) + prec1, prec5 = torch.zeros(1), torch.zeros(1) #utils.accuracy(output.data, target, topk=(1, 5)) + + if torch.distributed.is_initialized(): + reduced_loss = utils.reduce_tensor(loss.data) + #prec1 = reduce_tensor(prec1) + #prec5 = reduce_tensor(prec5) + else: + reduced_loss = loss.data + + if config.debug_memory_model: + print("========== Before Backward ===========") + before_backward = get_memory_usage(True) + act_mem = get_memory_usage() - init_mem - compute_tensor_bytes([loss, output]) + res = "Batch size: %d\tTotal Mem: %.2f MB\tAct Mem: %.2f MB" % ( + len(output), before_backward / MB, act_mem / MB) + loss.backward() + optimizer.step() + del loss + print("========== After Backward ===========") + after_backward = get_memory_usage(True) + total_mem = before_backward + (after_backward - init_mem) + res = "Batch size: %d\tTotal Mem: %.2f MB\tAct Mem: %.2f MB" % ( + len(output), total_mem / MB, act_mem / MB) + print(res) + exp_recorder.record("batch_size", len(output)) + exp_recorder.record("total", total_mem / GB, 2) + exp_recorder.record("activation", act_mem / GB, 2) + exp_recorder.dump('mem_results.tsv') + exit() + + if fp16: + optimizer.backward(loss) + elif use_amp: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + if optimizer_step: + opt = optimizer.optimizer if isinstance(optimizer, FP16_Optimizer) else optimizer + for param_group in opt.param_groups: + for param in param_group['params']: + param.grad /= batch_size_multiplier + + optimizer.step() + optimizer.zero_grad() + + torch.cuda.synchronize() + + return reduced_loss, output, prec1, prec5 + + return _step + +train_step_ct = 0 +train_max_ips = 0 +train_max_batch = 0 + +def train(train_loader, model_and_loss, optimizer, lr_scheduler, fp16, logger, epoch, use_amp=False, prof=-1, batch_size_multiplier=1, register_metrics=True): + if register_metrics and logger is not None: + logger.register_metric('train.top1', log.AverageMeter(), log_level = 0) + logger.register_metric('train.top5', log.AverageMeter(), log_level = 0) + logger.register_metric('train.loss', log.AverageMeter(), log_level = 0) + logger.register_metric('train.compute_ips', log.AverageMeter(), log_level=1) + logger.register_metric('train.total_ips', log.AverageMeter(), log_level=0) + logger.register_metric('train.data_time', log.AverageMeter(), log_level=1) + logger.register_metric('train.compute_time', log.AverageMeter(), log_level=1) + + if config.debug_memory_model: + print("========== Model Only ===========") + usage = get_memory_usage(True) + exp_recorder.record("network", model_and_loss.arch[0]) + exp_recorder.record("algorithm", 'quantize' + if model_and_loss.arch[1] == 'quantize' else 'exact') + exp_recorder.record("model_only", usage / GB, 2) + + step = get_train_step(model_and_loss, optimizer, fp16, use_amp = use_amp, batch_size_multiplier = batch_size_multiplier) + + model_and_loss.train() + print('Training mode ', config.training) + end = time.time() + + optimizer.zero_grad() + + data_iter = enumerate(train_loader) + if logger is not None: + data_iter = logger.iteration_generator_wrapper(data_iter) + + for i, (input, target, index) in data_iter: # NOTE: only needed for use_gradient + QScheme.batch = index # NOTE: only needed for use_gradient + + bs = input.size(0) + lr_scheduler(optimizer, i, epoch) + data_time = time.time() - end + + if prof > 0: + if i >= prof: + break + + optimizer_step = ((i + 1) % batch_size_multiplier) == 0 + loss, _, prec1, prec5 = step(input, target, optimizer_step = optimizer_step) + + it_time = time.time() - end + + if config.debug_speed: + global train_step_ct, train_max_ips, train_max_batch + train_max_ips = max(train_max_ips, calc_ips(bs, it_time)) + train_max_batch = max(train_max_batch, len(input)) + if train_step_ct >= 3: + res = "BatchSize: %d\tIPS: %.2f\t,Cost: %.2f ms" % ( + bs, train_max_ips, 1000.0 / train_max_ips) + print(res, flush=True) + exp_recorder.record("network", model_and_loss.arch[0]) + exp_recorder.record("algorithm", 'quantize' + if model_and_loss.arch[1] == 'quantize' else 'exact') + exp_recorder.record("batch_size", train_max_batch) + exp_recorder.record("ips", train_max_ips, 1) + exp_recorder.dump('speed_results.tsv') + exit(0) + train_step_ct += 1 + + if logger is not None: + logger.log_metric('train.top1', to_python_float(prec1)) + logger.log_metric('train.top5', to_python_float(prec5)) + logger.log_metric('train.loss', to_python_float(loss)) + logger.log_metric('train.compute_ips', calc_ips(bs, it_time - data_time)) + logger.log_metric('train.total_ips', calc_ips(bs, it_time)) + logger.log_metric('train.data_time', data_time) + logger.log_metric('train.compute_time', it_time - data_time) + + end = time.time() + # if epoch > 0 and config.perlayer: + # QScheme.allocate_perlayer() + # QBNScheme.allocate_perlayer() + + #for layer in QScheme.layers: + # print(layer.name, layer.bits) + + +def get_val_step(model_and_loss): + def _step(input, target): + input_var = Variable(input) + target_var = Variable(target) + + with torch.no_grad(): + loss, output = model_and_loss(input_var, target_var) + + prec1, prec5 = utils.accuracy(output.data, target, topk=(1, 5)) + + if torch.distributed.is_initialized(): + reduced_loss = utils.reduce_tensor(loss.data) + prec1 = utils.reduce_tensor(prec1) + prec5 = utils.reduce_tensor(prec5) + else: + reduced_loss = loss.data + + torch.cuda.synchronize() + + return reduced_loss, prec1, prec5 + + return _step + + +def validate(val_loader, model_and_loss, fp16, logger, epoch, prof=-1, register_metrics=True): + if register_metrics and logger is not None: + logger.register_metric('val.top1', log.AverageMeter(), log_level = 0) + logger.register_metric('val.top5', log.AverageMeter(), log_level = 0) + logger.register_metric('val.loss', log.AverageMeter(), log_level = 0) + logger.register_metric('val.compute_ips', log.AverageMeter(), log_level = 1) + logger.register_metric('val.total_ips', log.AverageMeter(), log_level = 1) + logger.register_metric('val.data_time', log.AverageMeter(), log_level = 1) + logger.register_metric('val.compute_time', log.AverageMeter(), log_level = 1) + + step = get_val_step(model_and_loss) + + top1 = log.AverageMeter() + # switch to evaluate mode + model_and_loss.eval() + print('Training mode ', config.training) + + end = time.time() + + data_iter = enumerate(val_loader) + if not logger is None: + data_iter = logger.iteration_generator_wrapper(data_iter, val=True) + + for i, (input, target, _) in data_iter: + bs = input.size(0) + data_time = time.time() - end + if prof > 0: + if i > prof: + break + + loss, prec1, prec5 = step(input, target) + + it_time = time.time() - end + + top1.record(to_python_float(prec1), bs) + if logger is not None: + logger.log_metric('val.top1', to_python_float(prec1)) + logger.log_metric('val.top5', to_python_float(prec5)) + logger.log_metric('val.loss', to_python_float(loss)) + logger.log_metric('val.compute_ips', calc_ips(bs, it_time - data_time)) + logger.log_metric('val.total_ips', calc_ips(bs, it_time)) + logger.log_metric('val.data_time', data_time) + logger.log_metric('val.compute_time', it_time - data_time) + + end = time.time() + + return top1.get_val() + +# Train loop {{{ +def calc_ips(batch_size, time): + world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + tbs = world_size * batch_size + return tbs/time + +def train_loop(model_and_loss, optimizer, new_optimizer, lr_scheduler, train_loader, val_loader, debug_loader, epochs, fp16, logger, + should_backup_checkpoint, use_amp=False, + batch_size_multiplier = 1, + best_prec1 = 0, start_epoch = 0, prof = -1, skip_training = False, skip_validation = False, save_checkpoints = True, checkpoint_dir='./', + model_state = None): + QScheme.update_scale = True + prec1 = -1 + + epoch_iter = range(start_epoch, epochs) + if logger is not None: + epoch_iter = logger.epoch_generator_wrapper(epoch_iter) + for epoch in epoch_iter: + print('Epoch ', epoch) + if not skip_training: + train(train_loader, model_and_loss, optimizer, lr_scheduler, fp16, logger, epoch, use_amp = use_amp, prof = prof, register_metrics=epoch==start_epoch, batch_size_multiplier=batch_size_multiplier) + + if not skip_validation: + prec1 = validate(val_loader, model_and_loss, fp16, logger, epoch, prof = prof, register_metrics=epoch==start_epoch) + + if save_checkpoints and (not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0): + if not skip_training: + is_best = prec1 > best_prec1 + best_prec1 = max(prec1, best_prec1) + + if should_backup_checkpoint(epoch): + backup_filename = 'checkpoint-{}.pth.tar'.format(epoch + 1) + else: + backup_filename = None + utils.save_checkpoint({ + 'epoch': epoch + 1, + 'arch': model_and_loss.arch, + 'state_dict': model_and_loss.model.state_dict(), + 'best_prec1': best_prec1, + 'optimizer' : optimizer.state_dict(), + }, is_best, checkpoint_dir=checkpoint_dir, backup_filename=backup_filename) + + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + logger.end() + + get_var(model_and_loss, optimizer, train_loader, 10, model_state) diff --git a/image_classification/image_classification/utils.py b/image_classification/image_classification/utils.py new file mode 100644 index 0000000..48ebd76 --- /dev/null +++ b/image_classification/image_classification/utils.py @@ -0,0 +1,94 @@ +import os +import numpy as np +import torch +import shutil +import torch.distributed as dist + + +def should_backup_checkpoint(args): + def _sbc(epoch): + return args.gather_checkpoints # and (epoch < 10 or epoch % 10 == 0) + return _sbc + + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', checkpoint_dir='./', backup_filename=None): + if (not torch.distributed.is_initialized()) or torch.distributed.get_rank() == 0: + filename = os.path.join(checkpoint_dir, filename) + print("SAVING {}".format(filename)) + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, os.path.join(checkpoint_dir, 'model_best.pth.tar')) + if backup_filename is not None: + shutil.copyfile(filename, os.path.join(checkpoint_dir, backup_filename)) + + +def timed_generator(gen): + start = time.time() + for g in gen: + end = time.time() + t = end - start + yield g, t + start = time.time() + + +def timed_function(f): + def _timed_function(*args, **kwargs): + start = time.time() + ret = f(*args, **kwargs) + return ret, time.time() - start + return _timed_function + + +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + #correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def reduce_tensor(tensor): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + rt /= torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + return rt + + +def dict_add(x, y): + if x is None: + return y + return {k: x[k] + y[k] for k in x} + + +def dict_minus(x, y): + return {k: x[k] - y[k] for k in x} + + +def dict_sqr(x): + return {k: x[k]**2 for k in x} + + +def dict_sqrt(x): + return {k: torch.sqrt(x[k]) for k in x} + + +def dict_mul(x, a): + return {k: x[k]*a for k in x} + + +def dict_clone(x): + return {k: x[k].clone() for k in x} + +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url diff --git a/image_classification/img/.gitkeep b/image_classification/img/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/image_classification/main.py b/image_classification/main.py new file mode 100644 index 0000000..bf446ba --- /dev/null +++ b/image_classification/main.py @@ -0,0 +1,378 @@ +import argparse +import random + +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +from actnn import config, QScheme, QModule + +try: + # from apex.parallel import DistributedDataParallel as DDP + from torch.nn.parallel import DistributedDataParallel as DDP + from apex.fp16_utils import * + from apex import amp +except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") + +from image_classification.smoothing import LabelSmoothing +from image_classification.mixup import NLLMultiLabelSmooth, MixUpWrapper +from image_classification.dataloaders import * +from image_classification.training import * +from image_classification.utils import * + + +def add_parser_arguments(parser): + model_names = models.resnet_versions.keys() + model_configs = models.resnet_configs.keys() + + parser.add_argument('data', metavar='DIR', + help='path to dataset') + parser.add_argument('--dataset', type=str, default='imagenet') + + parser.add_argument('--data-backend', metavar='BACKEND', default='pytorch', + choices=DATA_BACKEND_CHOICES) + + parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50', + choices=model_names, + help='model architecture: ' + + ' | '.join(model_names) + + ' (default: resnet50)') + + parser.add_argument('--model-config', '-c', metavar='CONF', default='fanin', + choices=model_configs, + help='model configs: ' + + ' | '.join(model_configs) + '(default: classic)') + + parser.add_argument('-j', '--workers', default=5, type=int, metavar='N', + help='number of data loading workers (default: 5)') + parser.add_argument('--num-classes', default=1000, type=int, metavar='N', + help='number of classes (default: 1000)') + parser.add_argument('--epochs', default=90, type=int, metavar='N', + help='number of total epochs to run') + parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') + parser.add_argument('-b', '--batch-size', default=64, type=int, + metavar='N', help='mini-batch size (default: 256) per gpu') + + parser.add_argument('--optimizer-batch-size', default=-1, type=int, + metavar='N', help='size of a total batch size, for simulating bigger batches') + + parser.add_argument('--lr', '--learning-rate', default=0.512, type=float, + metavar='LR', help='initial learning rate') + parser.add_argument('--lr-schedule', default='cosine', type=str, metavar='SCHEDULE', choices=['step','linear','cosine']) + + parser.add_argument('--warmup', default=4, type=int, + metavar='E', help='number of warmup epochs') + + parser.add_argument('--label-smoothing', default=0.1, type=float, + metavar='S', help='label smoothing') + parser.add_argument('--mixup', default=0.0, type=float, + metavar='ALPHA', help='mixup alpha') + + parser.add_argument('--momentum', default=0.875, type=float, metavar='M', + help='momentum') + parser.add_argument('--weight-decay', '--wd', default=3.0517578125e-05, type=float, + metavar='W', help='weight decay (default: 1e-4)') + parser.add_argument('--bn-weight-decay', action='store_true', + help='use weight_decay on batch normalization learnable parameters, default: false)') + parser.add_argument('--nesterov', action='store_true', + help='use nesterov momentum, default: false)') + + parser.add_argument('--print-freq', '-p', default=100, type=int, + metavar='N', help='print frequency (default: 10)') + parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') + parser.add_argument('--resume2', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') + parser.add_argument('--pretrained-weights', default='', type=str, metavar='PATH', + help='load weights from here') + + parser.add_argument('--fp16', action='store_true', + help='Run model fp16 mode.') + parser.add_argument('--static-loss-scale', type=float, default=1, + help='Static loss scale, positive power of 2 values can improve fp16 convergence.') + parser.add_argument('--dynamic-loss-scale', action='store_true', + help='Use dynamic loss scaling. If supplied, this argument supersedes ' + + '--static-loss-scale.') + parser.add_argument('--prof', type=int, default=-1, + help='Run only N iterations') + parser.add_argument('--amp', action='store_true', + help='Run model AMP (automatic mixed precision) mode.') + + parser.add_argument("--local_rank", default=0, type=int) + + parser.add_argument('--seed', default=None, type=int, + help='random seed used for np and pytorch') + + parser.add_argument('--gather-checkpoints', action='store_true', + help='Gather checkpoints throughout the training') + + parser.add_argument('--raport-file', default='raport.json', type=str, + help='file in which to store JSON experiment raport') + + parser.add_argument('--final-weights', default='model.pth.tar', type=str, + help='file in which to store final model weights') + + parser.add_argument('--evaluate', action='store_true', help='evaluate checkpoint/model') + parser.add_argument('--training-only', action='store_true', help='do not evaluate') + + parser.add_argument('--no-checkpoints', action='store_false', dest='save_checkpoints') + + parser.add_argument('--workspace', type=str, default='./') + + def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + parser.add_argument('--ca', type=str2bool, default=True, help='compress activation') + parser.add_argument('--sq', type=str2bool, default=True, help='stochastic quantization') + parser.add_argument('--cabits', type=float, default=8, help='activation number of bits') + parser.add_argument('--qat', type=int, default=8, help='quantization aware training bits') + parser.add_argument('--ibits', type=int, default=8, help='Initial precision for the allocation algorithm') + parser.add_argument('--calg', type=str, default='pl', help='Quantization algorithm, naive, pg, ps, or pl') + # parser.add_argument('--pergroup', type=str2bool, default=True, help='Per-group range') + parser.add_argument('--groupsize', type=int, default=256, help='Size for each quantization group') + # parser.add_argument('--perlayer', type=str2bool, default=True, help='Per layer quantization') + parser.add_argument('--usegradient', type=str2bool, default=True, help='Using gradient information for persample') + + +def main(args): + config.compress_activation = args.ca + config.stochastic = args.sq + if args.calg == 'naive': + config.activation_compression_bits = [int(args.cabits)] + config.pergroup = False + config.perlayer = False + config.initial_bits = int(args.cabits) + elif args.calg == 'pg': + config.activation_compression_bits = [int(args.cabits)] + config.pergroup = True + config.perlayer = False + config.initial_bits = int(args.cabits) + elif args.calg == 'ps': + config.activation_compression_bits = [args.cabits] + config.pergroup = True + config.perlayer = False + config.initial_bits = 8 + else: + config.activation_compression_bits = [args.cabits] + config.pergroup = True + config.perlayer = True + config.initial_bits = 8 + + config.qat = args.qat + config.use_gradient = args.usegradient + config.group_size = args.groupsize + + exp_start_time = time.time() + global best_prec1 + best_prec1 = 0 + + args.distributed = False + if 'WORLD_SIZE' in os.environ: + args.distributed = int(os.environ['WORLD_SIZE']) > 1 + + args.gpu = 0 + args.world_size = 1 + + if args.distributed: + args.gpu = args.local_rank % torch.cuda.device_count() + torch.cuda.set_device(args.gpu) + dist.init_process_group(backend='nccl', init_method='env://') + args.world_size = torch.distributed.get_world_size() + + if args.amp and args.fp16: + print("Please use only one of the --fp16/--amp flags") + exit(1) + + if args.seed is not None: + print("Using seed = {}".format(args.seed)) + torch.manual_seed(args.seed + args.local_rank) + torch.cuda.manual_seed(args.seed + args.local_rank) + np.random.seed(seed=args.seed + args.local_rank) + random.seed(args.seed + args.local_rank) + + def _worker_init_fn(id): + np.random.seed(seed=args.seed + args.local_rank + id) + random.seed(args.seed + args.local_rank + id) + else: + def _worker_init_fn(id): + pass + + if args.fp16: + assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." + + if args.static_loss_scale != 1.0: + if not args.fp16: + print("Warning: if --fp16 is not used, static_loss_scale will be ignored.") + + if args.optimizer_batch_size < 0: + batch_size_multiplier = 1 + else: + tbs = args.world_size * args.batch_size + if args.optimizer_batch_size % tbs != 0: + print("Warning: simulated batch size {} is not divisible by actual batch size {}".format(args.optimizer_batch_size, tbs)) + batch_size_multiplier = int(args.optimizer_batch_size/ tbs) + print("BSM: {}".format(batch_size_multiplier)) + + pretrained_weights = None + if args.pretrained_weights: + if os.path.isfile(args.pretrained_weights): + print("=> loading pretrained weights from '{}'".format(args.pretrained_weights)) + pretrained_weights = torch.load(args.pretrained_weights) + else: + print("=> no pretrained weights found at '{}'".format(args.resume)) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu)) + args.start_epoch = checkpoint['epoch'] + best_prec1 = checkpoint['best_prec1'] + model_state = checkpoint['state_dict'] + optimizer_state = checkpoint['optimizer'] + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + model_state = None + optimizer_state = None + else: + model_state = None + optimizer_state = None + + if args.resume2: + if os.path.isfile(args.resume2): + print("=> loading checkpoint '{}'".format(args.resume2)) + checkpoint2 = torch.load(args.resume2, map_location=lambda storage, loc: storage.cuda(args.gpu)) + model_state2 = checkpoint2['state_dict'] + else: + model_state2 = None + else: + model_state2 = None + + + # Create data loaders and optimizers as needed + if args.dataset == 'cifar10': + get_train_loader = get_pytorch_train_loader_cifar10 + get_val_loader = get_pytorch_val_loader_cifar10 + get_debug_loader = get_pytorch_debug_loader_cifar10 + QScheme.num_samples = 50000 # NOTE: only needed for use_gradient + elif args.data_backend == 'pytorch': + get_train_loader = get_pytorch_train_loader + get_val_loader = get_pytorch_val_loader + get_debug_loader = get_pytorch_val_loader + QScheme.num_samples = 1300000 # NOTE: only needed for use_gradient + elif args.data_backend == 'dali-gpu': + get_train_loader = get_dali_train_loader(dali_cpu=False) + get_val_loader = get_dali_val_loader() + elif args.data_backend == 'dali-cpu': + get_train_loader = get_dali_train_loader(dali_cpu=True) + get_val_loader = get_dali_val_loader() + + loss = nn.CrossEntropyLoss + if args.mixup > 0.0: + loss = lambda: NLLMultiLabelSmooth(args.label_smoothing) + elif args.label_smoothing > 0.0: + loss = lambda: LabelSmoothing(args.label_smoothing) + + model_and_loss = ModelAndLoss( + (args.arch, args.model_config), + args.num_classes, + loss, + pretrained_weights=pretrained_weights, + cuda = True, fp16 = args.fp16) + + train_loader, train_loader_len = get_train_loader(args.data, args.batch_size, args.num_classes, args.mixup > 0.0, workers=args.workers, fp16=args.fp16) + if args.mixup != 0.0: + train_loader = MixUpWrapper(args.mixup, args.num_classes, train_loader) + + val_loader, val_loader_len = get_val_loader(args.data, args.batch_size, args.num_classes, False, workers=args.workers, fp16=args.fp16) + debug_loader, debug_loader_len = get_debug_loader(args.data, args.batch_size, args.num_classes, False, workers=args.workers, fp16=args.fp16) + + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + logger = log.Logger( + args.print_freq, + [ + log.JsonBackend(os.path.join(args.workspace, args.raport_file), log_level=1), + log.StdOut1LBackend(train_loader_len, val_loader_len, args.epochs, log_level=0), + ]) + + for k, v in args.__dict__.items(): + logger.log_run_tag(k, v) + else: + logger = None + + optimizer = get_optimizer(list(model_and_loss.model.named_parameters()), + args.fp16, + args.lr, args.momentum, args.weight_decay, + nesterov = args.nesterov, + bn_weight_decay = args.bn_weight_decay, + # state=optimizer_state, + static_loss_scale = args.static_loss_scale, + dynamic_loss_scale = args.dynamic_loss_scale) + + + def new_optimizer(): + return get_optimizer(list(model_and_loss.model.named_parameters()), + args.fp16, + args.lr, args.momentum, args.weight_decay, + nesterov = args.nesterov, + bn_weight_decay = args.bn_weight_decay, + # state=optimizer_state, + static_loss_scale = args.static_loss_scale, + dynamic_loss_scale = args.dynamic_loss_scale) + + if args.lr_schedule == 'step': + lr_policy = lr_step_policy(args.lr, [30,60,80], 0.1, args.warmup, logger=logger) + elif args.lr_schedule == 'cosine': + lr_policy = lr_cosine_policy(args.lr, args.warmup, args.epochs, logger=logger) + elif args.lr_schedule == 'linear': + lr_policy = lr_linear_policy(args.lr, args.warmup, args.epochs, logger=logger) + + if args.amp: + model_and_loss, optimizer = amp.initialize( + model_and_loss, optimizer, + opt_level="O2", + loss_scale="dynamic" if args.dynamic_loss_scale else args.static_loss_scale) + + if args.distributed: + model_and_loss.distributed(args.local_rank) + + model_and_loss.load_model_state(model_state) + + print('Start epoch {}'.format(args.start_epoch)) + train_loop( + model_and_loss, optimizer, new_optimizer, + lr_policy, + train_loader, val_loader, debug_loader, args.epochs, + args.fp16, logger, should_backup_checkpoint(args), use_amp=args.amp, + batch_size_multiplier = batch_size_multiplier, + start_epoch = args.start_epoch, best_prec1 = best_prec1, prof=args.prof, + skip_training = args.evaluate, skip_validation = args.training_only, + save_checkpoints=args.save_checkpoints and not args.evaluate, checkpoint_dir=args.workspace, + model_state=model_state2) + exp_duration = time.time() - exp_start_time + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + logger.end() + print("Experiment ended") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') + + add_parser_arguments(parser) + args = parser.parse_args() + cudnn.benchmark = True + + main(args) diff --git a/image_classification/multiproc.py b/image_classification/multiproc.py new file mode 100644 index 0000000..b7e7321 --- /dev/null +++ b/image_classification/multiproc.py @@ -0,0 +1,119 @@ +import sys +import subprocess +import os +import socket +import time +from argparse import ArgumentParser, REMAINDER + +import torch + +def parse_args(): + """ + Helper function parsing the command line options + @retval ArgumentParser + """ + parser = ArgumentParser(description="PyTorch distributed training launch " + "helper utilty that will spawn up " + "multiple distributed processes") + + # Optional arguments for the launch helper + parser.add_argument("--nnodes", type=int, default=1, + help="The number of nodes to use for distributed " + "training") + parser.add_argument("--node_rank", type=int, default=0, + help="The rank of the node for multi-node distributed " + "training") + parser.add_argument("--nproc_per_node", type=int, default=1, + help="The number of processes to launch on each node, " + "for GPU training, this is recommended to be set " + "to the number of GPUs in your system so that " + "each process can be bound to a single GPU.") + parser.add_argument("--master_addr", default="127.0.0.1", type=str, + help="Master node (rank 0)'s address, should be either " + "the IP address or the hostname of node 0, for " + "single node multi-proc training, the " + "--master_addr can simply be 127.0.0.1") + parser.add_argument("--master_port", default=29500, type=int, + help="Master node (rank 0)'s free port that needs to " + "be used for communciation during distributed " + "training") + + # positional + parser.add_argument("training_script", type=str, + help="The full path to the single GPU training " + "program/script to be launched in parallel, " + "followed by all the arguments for the " + "training script") + + # rest from the training program + parser.add_argument('training_script_args', nargs=REMAINDER) + return parser.parse_args() + + +def main(): + args = parse_args() + + # world size in terms of number of processes + dist_world_size = args.nproc_per_node * args.nnodes + + # set PyTorch distributed related environmental variables + current_env = os.environ.copy() + current_env["MASTER_ADDR"] = args.master_addr + current_env["MASTER_PORT"] = str(args.master_port) + current_env["WORLD_SIZE"] = str(dist_world_size) + + processes = [] + + for local_rank in range(0, args.nproc_per_node): + # each process's rank + dist_rank = args.nproc_per_node * args.node_rank + local_rank + current_env["RANK"] = str(dist_rank) + + # spawn the processes + cmd = [sys.executable, + "-u", + args.training_script, + "--local_rank={}".format(local_rank)] + args.training_script_args + + print(cmd) + + stdout = None if local_rank == 0 else open("GPU_"+str(local_rank)+".log", "w") + + process = subprocess.Popen(cmd, env=current_env, stdout=stdout) + processes.append(process) + + try: + up = True + error = False + while up and not error: + up = False + for p in processes: + ret = p.poll() + if ret is None: + up = True + elif ret != 0: + error = True + time.sleep(1) + + if error: + for p in processes: + if p.poll() is None: + p.terminate() + exit(1) + + except KeyboardInterrupt: + for p in processes: + p.terminate() + raise + except SystemExit: + for p in processes: + p.terminate() + raise + except: + for p in processes: + p.terminate() + raise + + +if __name__ == "__main__": + main() diff --git a/image_classification/quick_test.sh b/image_classification/quick_test.sh new file mode 100755 index 0000000..63abed7 --- /dev/null +++ b/image_classification/quick_test.sh @@ -0,0 +1,30 @@ +python3 main.py --dataset cifar10 --arch preact_resnet56 --epochs 200 --num-classes 100 -j 0 --weight-decay 1e-4 --batch-size 128 --label-smoothing 0 \ + -c quantize --ca=True --cabits=2 --ibits=8 --calg pl \ + --workspace results/tmp --evaluate --training-only \ + --resume results/cifar100/checkpoint-10.pth.tar --resume2 results/cifar100/checkpoint-10.pth.tar ~/data/cifar100 + +###### Check cifar memory ##### +#for BS in 512 +#do +# python3 main.py --dataset cifar10 --arch preact_resnet56 --epochs 200 --num-classes 10 -j 0 --weight-decay 1e-4 --batch-size $BS --label-smoothing 0 \ +# -c fanin \ +# --workspace results/tmp --gather-checkpoints ~/data/cifar10 +#done + +#for BS in 28000 +#do +# python3 main.py --dataset cifar10 --arch preact_resnet56 --epochs 200 --num-classes 10 -j 0 --weight-decay 1e-4 --batch-size $BS --label-smoothing 0 \ +# -c quantize --ca=True --cabits=2 --ibits=8 --calg pl \ +# --workspace results/tmp --gather-checkpoints ~/data/cifar10 +#done + +##### Resnet exact ###### +#./dist-train 1 0 127.0.0.1 1 resnet50 \ +# "-c fanin" \ +# tmp ~/imagenet 256 + +##### Resnet quantized ###### +#./dist-train 1 0 127.0.0.1 1 resnet152 \ +# "-c quantize --ca=True --cabits=2 --ibits=8 --calg pl" \ +# tmp ~/imagenet 256 + diff --git a/image_classification/run.sh b/image_classification/run.sh new file mode 100644 index 0000000..906c898 --- /dev/null +++ b/image_classification/run.sh @@ -0,0 +1,6 @@ +import os + +for bbits in [3, 4, 5, 6, 7, 8]: + for persample in [False, True]: + for biased in [False, True]: + CUDA_VISIBLE_DEVICES=0,1 ./test_cifar 2 preact_resnet56 200 "-c quantize --bbits ${bbits} --qa False --persample=True --qw False" 29500 diff --git a/image_classification/test b/image_classification/test new file mode 100755 index 0000000..4f74f6c --- /dev/null +++ b/image_classification/test @@ -0,0 +1 @@ +python ./multiproc.py --nproc_per_node 8 ./main.py --resume $1/model_best.pth.tar --epochs 91 --evaluate --raport-file test_raport.json --workspace $1/ $2 ~/data/imagenet diff --git a/image_classification/test_cifar b/image_classification/test_cifar new file mode 100755 index 0000000..d34f5bc --- /dev/null +++ b/image_classification/test_cifar @@ -0,0 +1,22 @@ +#!/usr/bin/env python +import sys, os + +num_nodes = sys.argv[1] +batch_size = 128 // int(num_nodes) +model = sys.argv[2] +dir = sys.argv[3] +epoch = int(sys.argv[4]) +options = sys.argv[5] +port = sys.argv[6] + +cmd = 'python ./multiproc.py --master_port {port} --nproc_per_node {num_nodes} ./main.py --dataset cifar10 --arch {model} --resume results/{dir}/checkpoint-{epochs}.pth.tar --epochs {epochs_plus_one} --evaluate --raport-file test_raport.json --workspace results/{dir} --batch-size {batch_size} --label-smoothing 0 --weight-decay 1e-4 {options} ~/data/cifar10'.format( + num_nodes=num_nodes, batch_size=batch_size, model=model, epochs=epoch, epochs_plus_one=epoch+1, options=options, port=port, dir=dir) + +print(cmd) +os.system(cmd) + +#cmd = 'python compute_std.py results/{} {}'.format(model, num_nodes) +#print(cmd) +#os.system(cmd) + +# os.system("python compute_error_std.py") diff --git a/image_classification/train b/image_classification/train new file mode 100755 index 0000000..97e7b17 --- /dev/null +++ b/image_classification/train @@ -0,0 +1,2 @@ +mkdir results/$2 +python ./multiproc.py --nproc_per_node 4 ./main.py --arch $1 --gather-checkpoints --workspace results/$1 --batch-size 64 --lr 0.256 --warmup 1 $3 ~/data/imagenet diff --git a/image_classification/train_cifar b/image_classification/train_cifar new file mode 100755 index 0000000..85c9f6a --- /dev/null +++ b/image_classification/train_cifar @@ -0,0 +1,2 @@ +mkdir results/$2 +python main.py --dataset cifar10 --arch $1 --gather-checkpoints --workspace results/$2 --batch-size 128 --lr 0.1 --momentum 0.9 --label-smoothing 0 --warmup 0 --weight-decay 1e-4 --epochs 200 $3 ~/data/cifar10 diff --git a/mem_speed_benchmark/README.md b/mem_speed_benchmark/README.md new file mode 100644 index 0000000..0121e74 --- /dev/null +++ b/mem_speed_benchmark/README.md @@ -0,0 +1,31 @@ +# Benchmark Memory Usage and Training Speed on Torchvision Models + +## Prepare dataset +Put the ImageNet dataset to `~/imagenet` + +## Benchmark Memory Usage +``` +DEBUG_MEM=True python3 train.py ~/imagenet --arch ARCH -b BATCH_SIZE --alg ALGORITHM +``` + +The choices for ARCH are {resnet50, resnet152, wide_resnet101_2, densenet201} +The choices for ALGORITHM are {exact, actnn-L0, actnn-L1, actnn-L2, actnn-L3, actnn-L4, actnn-L5} + +For example, the command below run actnn-L3 on resnet50 +``` +DEBUG_MEM=True python3 train.py ~/imagenet --arch resnet50 -b 128 --alg actnn-L3 +``` + +## Benchmark Training Speed +``` +DEBUG_SPEED=True python3 train.py ~/imagenet --arch ARCH -b BATCH_SIZE --alg ALGORITHM +``` + +The choices for ARCH are {resnet50, resnet152, wide_resnet101_2, densenet201} +The choices for ALGORITHM are {exact, actnn-L0, actnn-L1, actnn-L2, actnn-L3, actnn-L4, actnn-L5} + +For example, the command below run actnn-L3 on resnet50 +``` +DEBUG_SPEED=True python3 train.py ~/imagenet --arch resnet50 -b 128 --alg actnn-L3 +``` + diff --git a/mem_speed_benchmark/exp_mem_speed.py b/mem_speed_benchmark/exp_mem_speed.py new file mode 100644 index 0000000..5a65cfc --- /dev/null +++ b/mem_speed_benchmark/exp_mem_speed.py @@ -0,0 +1,256 @@ +import argparse +import json +import os +import time + +def run_cmd(cmd): + print(cmd) + return os.system(cmd) + + +def alg_to_config(algorithm): + return "--alg %s" % algorithm + + +def network_to_command(network): + return "python3 train.py ~/imagenet --arch ARCH --b BS CONFIG".replace("ARCH", network) + + +def run_benchmark(network, alg, batch_size, debug_mem=False, debug_speed=False, input_size=None, get_macs=False): + os.environ['DEBUG_MEM'] = str(debug_mem) + os.environ['DEBUG_SPEED'] = str(debug_speed) + cmd = network_to_command(network) + cmd = cmd.replace("BS", f"{batch_size}").replace("CONFIG", alg_to_config(alg)) + + if input_size is not None: + cmd += f' --input-size {input_size}' + + if get_macs: + cmd += " --get-macs" + + ret_code = run_cmd(cmd) + + if ret_code != 0: + out_file = "speed_results.json" + with open(out_file, "a") as fout: + val_dict = { + "network": network, + "algorithm": alg, + "batch_size": batch_size, + "ips": -1, + } + fout.write(json.dumps(val_dict) + "\n") + print(f"save results to {out_file}") + + time.sleep(1) + run_cmd("nvidia-smi > /dev/null") + time.sleep(1) + return ret_code + + +def round_up(x): + return int((x + 3)// 4 * 4) + +def round_down(x): + return int(x // 4 * 4) + + +def binary_search_max_batch(network, alg, low, high): + ret = 0 + low, high= round_up(low), round_down(high) + + while low <= high: + mid = round_down(low + (high - low) // 2) + success = run_benchmark(network, alg, mid, debug_speed=True) == 0 + if success: + ret = mid + low = round_up(mid + 1) + else: + high = round_down(mid - 1) + + return ret + + +def binary_search_max_input_size(alg, low, high, network, batch_size): + ret = 0 + low, high= round_up(low), round_down(high) + + while low <= high: + mid = round_down(low + (high - low) // 2) + success = (run_benchmark(network, alg, input_size=mid, batch_size=batch_size, + debug_speed=True) == 0) + if success: + ret = mid + low = round_up(mid + 1) + else: + high = round_down(mid - 1) + + return ret + + +def binary_search_max_layer(alg, low, high, batch_size): + ret = 0 + low, high= round_up(low), round_down(high) + + while low <= high: + mid = round_down(low + (high - low) // 2) + network = "scaled_resnet_%d" % mid + success = (run_benchmark(network, alg, batch_size=batch_size, debug_speed=True) == 0) + if success: + ret = mid + low = round_up(mid + 1) + else: + high = round_down(mid - 1) + + return ret + + +def binary_search_max_width(alg, low, high, batch_size): + ret = 0 + low, high= round_up(low), round_down(high) + + while low <= high: + mid = round_down(low + (high - low) // 2) + network = "scaled_wide_resnet_%d" % mid + success = (run_benchmark(network, alg, batch_size=batch_size, debug_speed=True) == 0) + if success: + ret = mid + low = round_up(mid + 1) + else: + high = round_down(mid - 1) + + return ret + + + +def get_ips(network, alg, batch_size, input_size=None): + run_benchmark(network, alg, batch_size, input_size=input_size, debug_speed=True) + line = list(open("speed_results.json").readlines())[-1] + return json.loads(line)['ips'] + + +def get_macs(network, alg, batch_size, input_size=None): + run_benchmark(network, alg, batch_size, input_size=input_size, get_macs=True) + line = list(open("get_macs.json").readlines())[-1] + return json.loads(line) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--mode", type=str, default='linear_scan') + parser.add_argument("--retry", type=int, default=1) + args = parser.parse_args() + + if args.mode == 'linear_scan': + networks = ['resnet50', 'resnet152', 'densenet201', 'wide_resnet101_2'] + batch_sizes = list(range(32, 256, 16)) + list(range(256, 1280, 32)) + algs = ['exact', 'actnn-L0', 'actnn-L1', 'actnn-L2', 'actnn-L3', 'actnn-L3.1', 'actnn-L4', 'actnn-L5'] + #networks = ['resnet152'] + #batch_sizes = list(range(32, 256, 64)) + list(range(256, 1280, 32)) + #algs = ['exact'] + else: + algs = ['exact', 'actnn-L4', 'actnn-L3', 'actnn-L3.1'] + + if args.mode == 'linear_scan': + for network in networks: + for alg in algs: + failed = 0 + for batch_size in batch_sizes: + # do not run L5 too frequently + if alg == 'actnn-L5' and batch_size % 32 != 0: + continue + + if run_benchmark(network, alg, batch_size, debug_mem=False, debug_speed=True) != 0: + if failed >= args.retry: + break + failed += 1 + else: + failed = 0 + elif args.mode == 'binary_search_max_batch': + for network in networks: + for alg in algs: + low, high = 16, 1024 + max_batch_size = binary_search_max_batch(network, alg, low, high) + ips = get_ips(network, alg, max_batch_size) + + out_file = "max_batch_results.json" + with open(out_file, "a") as fout: + val_dict = { + "network": network, + "algorithm": alg, + "max_batch_size": max_batch_size, + "ips": ips, + "tstamp": time.time() + } + fout.write(json.dumps(val_dict) + "\n") + print(f"save results to {out_file}") + elif args.mode == 'binary_search_max_input_size': + for alg in algs: + low, high = 224, 768 + batch_size = 64 + network = 'resnet152' + max_input_size = binary_search_max_input_size(alg, low, high, network, batch_size) + ips = get_ips(network, alg, batch_size, input_size=max_input_size) + macs, params = get_macs(network, alg, batch_size, input_size=max_input_size) + + out_file = "max_input_size_results.json" + with open(out_file, "a") as fout: + val_dict = { + "network": network, + "algorithm": alg, + "max_input_size": max_input_size, + "ips": ips, + "macs": macs, + "params": params, + "TFLOPS": round(macs * ips / 1e12, 2), + "tstamp": time.time() + } + fout.write(json.dumps(val_dict) + "\n") + print(f"save results to {out_file}") + elif args.mode == 'binary_search_max_layer': + for alg in algs: + low, high = 152, 1024 + batch_size = 64 + max_layer = binary_search_max_layer(alg, low, high, batch_size) + network = 'scaled_resnet_%d' % max_layer + ips = get_ips(network, alg, batch_size) + macs, params = get_macs(network, alg, batch_size) + + out_file = "max_layer_results.json" + with open(out_file, "a") as fout: + val_dict = { + "network": network, + "algorithm": alg, + "max_layer": max_layer, + "ips": ips, + "macs": macs, + "params": params, + "TFLOPS": round(macs * ips / 1e12, 2), + "tstamp": time.time() + } + fout.write(json.dumps(val_dict) + "\n") + print(f"save results to {out_file}") + elif args.mode == 'binary_search_max_width': + for alg in algs: + low, high = 64, 512 + batch_size = 64 + max_width = binary_search_max_width(alg, low, high, batch_size=batch_size) + network = 'scaled_wide_resnet_%d' % max_width + ips = get_ips(network, alg, batch_size) + macs, params = get_macs(network, alg, batch_size) + + out_file = "max_width_results.json" + with open(out_file, "a") as fout: + val_dict = { + "network": network, + "algorithm": alg, + "max_width": max_width, + "ips": ips, + "macs": macs, + "params": params, + "TFLOPS": round(macs * ips / 1e12, 2), + "tstamp": time.time() + } + fout.write(json.dumps(val_dict) + "\n") + print(f"save results to {out_file}") + diff --git a/mem_speed_benchmark/scaled_resnet.py b/mem_speed_benchmark/scaled_resnet.py new file mode 100644 index 0000000..7c46a14 --- /dev/null +++ b/mem_speed_benchmark/scaled_resnet.py @@ -0,0 +1,17 @@ +from torchvision.models.resnet import _resnet, Bottleneck + +def scaled_resnet(name): + n_layers = int(name.split('scaled_resnet_')[1]) + assert n_layers >= 152 + added_layers = n_layers - 152 + + add_1 = int(added_layers * (8 / (8 + 36))) + add_2 = added_layers - add_1 + + return _resnet('', Bottleneck, [3, 8 + add_1, 36 + add_2, 3], False, True) + +def scaled_wide_resnet(name): + width = int(name.split('scaled_wide_resnet_')[1]) + kwargs = {'width_per_group': width} + return _resnet('', Bottleneck, [3, 8, 36, 3], False, True, **kwargs) + diff --git a/mem_speed_benchmark/train.py b/mem_speed_benchmark/train.py new file mode 100644 index 0000000..85638fd --- /dev/null +++ b/mem_speed_benchmark/train.py @@ -0,0 +1,569 @@ +""" +Modified from https://github.com/utsaslab/MONeT/blob/master/examples/imagenet.py +""" + +import argparse +import json +import os +import random +import shutil +import time +import warnings + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.optim +import torch.multiprocessing as mp +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import torchvision.models as models +import numpy as np + +from scaled_resnet import scaled_resnet, scaled_wide_resnet + +import actnn +from actnn import config, QScheme, QModule +from actnn import get_memory_usage, compute_tensor_bytes, exp_recorder +MB = 1024**2 +GB = 1024**3 + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', + help='model architecture') +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 4)') +parser.add_argument('--epochs', default=90, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=64, type=int, + metavar='N', + help='mini-batch size (default: 64), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Data Parallel or Distributed Data Parallel') +parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, + metavar='LR', help='initial learning rate', dest='lr') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') +parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') +parser.add_argument('-p', '--print-freq', default=2, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', + help='evaluate model on validation set') +parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') +parser.add_argument('--world-size', default=-1, type=int, + help='number of nodes for distributed training') +parser.add_argument('--rank', default=-1, type=int, + help='node rank for distributed training') +parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='nccl', type=str, + help='distributed backend') +parser.add_argument('--seed', default=None, type=int, + help='seed for initializing training. ') +parser.add_argument('--gpu', default=None, type=int, + help='GPU id to use.') +parser.add_argument('--multiprocessing-distributed', action='store_true', + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') +parser.add_argument('--alg', type=str, default="exact", help="Memory saving algorithm") +parser.add_argument('--input-size', type=int) +parser.add_argument('--get-macs', action='store_true') + +best_acc1 = 0 + +def set_optimization_level(args): + if args.alg == 'exact': + pass + elif args.alg.startswith('actnn-'): + actnn.set_optimization_level(args.alg[6:]) + elif args.alg == 'swap': + actnn.set_optimization_level('swap') + else: + raise ValueError("Invalid algorithm: " + args.alg) + +def main(): + args = parser.parse_args() + + set_optimization_level(args) + QScheme.num_samples = 1300000 # the size of training set + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + cudnn.deterministic = True + warnings.warn('You have chosen to seed training. ' + 'This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! ' + 'You may see unexpected behavior when restarting ' + 'from checkpoints.') + + if args.gpu is not None: + warnings.warn('You have chosen a specific GPU. This will completely ' + 'disable data parallelism.') + + if args.dist_url == "env://" and args.world_size == -1: + args.world_size = int(os.environ["WORLD_SIZE"]) + + args.distributed = args.world_size > 1 or args.multiprocessing_distributed + + ngpus_per_node = torch.cuda.device_count() + + if args.multiprocessing_distributed: + # Since we have ngpus_per_node processes per node, the total world_size + # needs to be adjusted accordingly + args.world_size = ngpus_per_node * args.world_size + # Use torch.multiprocessing.spawn to launch distributed processes: the + # main_worker process function + mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) + else: + # Simply call main_worker function + main_worker(args.gpu, ngpus_per_node, args) + + +def main_worker(gpu, ngpus_per_node, args): + global best_acc1 + args.gpu = gpu + + if args.gpu is not None: + print("Use GPU: {} for training".format(args.gpu)) + + if args.distributed: + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + if args.multiprocessing_distributed: + # For multiprocessing distributed training, rank needs to be the + # global rank among all the processes + args.rank = args.rank * ngpus_per_node + gpu + dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + + # create model + if args.pretrained: + print("=> using pre-trained model '{}'".format(args.arch)) + model = models.__dict__[args.arch](pretrained=True) + else: + print("=> creating model '{}'".format(args.arch)) + if args.arch.startswith('scaled_wide_resnet'): + model = scaled_wide_resnet(args.arch) + elif args.arch.startswith('scaled_resnet'): + model = scaled_resnet(args.arch) + else: + if args.arch in ['inception_v3']: + kwargs = {"aux_logits": False} + else: + kwargs = {} + model = models.__dict__[args.arch](**kwargs) + + if not torch.cuda.is_available(): + print('using CPU, this will be slow') + elif args.distributed: + # For multiprocessing distributed, DistributedDataParallel constructor + # should always set the single device scope, otherwise, + # DistributedDataParallel will use all available devices. + if args.gpu is not None: + torch.cuda.set_device(args.gpu) + model.cuda(args.gpu) + # When using a single GPU per process and per + # DistributedDataParallel, we need to divide the batch size + # ourselves based on the total number of GPUs we have + args.batch_size = int(args.batch_size / ngpus_per_node) + args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + else: + model.cuda() + # DistributedDataParallel will divide and allocate batch_size to all + # available GPUs if device_ids are not set + model = torch.nn.parallel.DistributedDataParallel(model) + elif args.gpu is not None: + torch.cuda.set_device(args.gpu) + model = model.cuda(args.gpu) + else: + # DataParallel will divide and allocate batch_size to all available GPUs + if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): + model.features = torch.nn.DataParallel(model.features) + model.cuda() + else: + model = torch.nn.DataParallel(model).cuda() + + if args.alg in ['swap'] or args.alg.startswith('actnn'): + print("=> convert model") + model = QModule(model) + model.cuda() + + if config.debug_memory_model: + print("========== Model Only ===========") + usage = get_memory_usage(True) + exp_recorder.record("network", args.arch) + exp_recorder.record("algorithm", args.alg) + exp_recorder.record("model_only", usage / GB, 2) + + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda(args.gpu) + + optimizer = torch.optim.SGD(model.parameters(), args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + if args.gpu is None: + checkpoint = torch.load(args.resume) + else: + # Map model to be loaded to specified single gpu. + loc = 'cuda:{}'.format(args.gpu) + checkpoint = torch.load(args.resume, map_location=loc) + args.start_epoch = checkpoint['epoch'] + best_acc1 = checkpoint['best_acc1'] + if args.gpu is not None: + # best_acc1 may be from a checkpoint from a different GPU + best_acc1 = best_acc1.to(args.gpu) + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + + # Data loading code + traindir = os.path.join(args.data, 'train') + valdir = os.path.join(args.data, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + if args.input_size is None: + if args.arch in ['inception_v3']: + input_size = 299 + else: + input_size = 224 + else: + input_size = args.input_size + + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(input_size), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + else: + train_sampler = None + + train_loader = actnn.dataloader.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=train_sampler) + + val_loader = torch.utils.data.DataLoader( + datasets.ImageFolder(valdir, transforms.Compose([ + transforms.Resize(input_size), + transforms.CenterCrop(input_size), + transforms.ToTensor(), + normalize, + ])), + batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + + if args.get_macs: + from thop import profile + from thop.vision.basic_hooks import count_convNd, count_bn, count_linear + from actnn.layers import QConv2d, QBatchNorm2d, QLinear + if isinstance(model, QModule): + QScheme.batch = [1] + model = model.model + model = model.module + input = torch.randn(1, 3, input_size, input_size).cuda() + macs, params = profile(model, inputs=(input, ), + custom_ops={QConv2d: count_convNd, QBatchNorm2d: count_bn, QLinear: count_linear}) + print(f"Macs: {macs}\t Params: {params}") + out_file = "get_macs.json" + with open(out_file, 'w') as fout: + fout.write(json.dumps([macs, params])) + print(f"save results to {out_file}") + exit() + + if args.evaluate: + validate(val_loader, model, criterion, args) + return + + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + adjust_learning_rate(optimizer, epoch, args) + + # train for one epoch + train(train_loader, model, criterion, optimizer, epoch, args) + + # evaluate on validation set + acc1 = validate(val_loader, model, criterion, args) + + # remember best acc@1 and save checkpoint + is_best = acc1 > best_acc1 + best_acc1 = max(acc1, best_acc1) + + if not args.multiprocessing_distributed or (args.multiprocessing_distributed + and args.rank % ngpus_per_node == 0): + save_checkpoint({ + 'epoch': epoch + 1, + 'arch': args.arch, + 'state_dict': model.state_dict(), + 'best_acc1': best_acc1, + 'optimizer' : optimizer.state_dict(), + }, is_best) + + +train_step_ct = 0 +train_max_batch = 0 +train_ips_list = [] + + +def train(train_loader, model, criterion, optimizer, epoch, args): + batch_time = AverageMeter('Time', ':.3f') + data_time = AverageMeter('Data', ':.3f') + losses = AverageMeter('Loss', ':.3f') + top1 = AverageMeter('Acc@1', ':.2f') + top5 = AverageMeter('Acc@5', ':.2f') + ips = AverageMeter('IPS', ':.1f') + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, losses, top1, top5, ips], + prefix="Epoch: [{}]".format(epoch)) + + # switch to train mode + model.train() + end = time.time() + for i, (indices, (images, target)) in enumerate(train_loader): + QScheme.batch = indices # NOTE: only needed for use_gradient + + # measure data loading time + data_time.update(time.time() - end) + + if config.debug_memory_model: + print("========== Init Data Loader ===========") + init_mem = get_memory_usage(True) + exp_recorder.record("data_loader", init_mem / GB - exp_recorder.val_dict['model_only'], 2) + + if args.gpu is not None: + images = images.cuda(args.gpu, non_blocking=True) + if torch.cuda.is_available(): + target = target.cuda(args.gpu, non_blocking=True) + + # compute output + output = model(images) + loss = criterion(output, target) + + if config.debug_memory_model: + print("========== Before Backward ===========") + before_backward = get_memory_usage(True) + act_mem = get_memory_usage() - init_mem - compute_tensor_bytes([loss, output]) + res = "Batch size: %d\tTotal Mem: %.2f MB\tAct Mem: %.2f MB" % ( + len(output), before_backward / MB, act_mem / MB) + optimizer.zero_grad() + loss.backward() + optimizer.step() + del loss + print("========== After Backward ===========") + after_backward = get_memory_usage(True) + total_mem = before_backward + (after_backward - init_mem) + res = "Batch size: %d\tTotal Mem: %.2f MB\tAct Mem: %.2f MB" % ( + len(output), total_mem / MB, act_mem / MB) + print(res) + exp_recorder.record("batch_size", len(output)) + exp_recorder.record("total", total_mem / GB, 2) + exp_recorder.record("activation", act_mem / GB, 2) + exp_recorder.dump('mem_results.json') + exit() + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + with torch.no_grad(): + optimizer.step() + + # measure elapsed time + bs = len(images) + batch_total_time = time.time() - end + train_ips = bs / batch_total_time + batch_time.update(batch_total_time) + ips.update(train_ips) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + if config.debug_speed: + global train_step_ct, train_ips_list, train_max_batch + train_ips_list.append(train_ips) + train_max_batch = max(train_max_batch, bs) + if train_step_ct >= 4: + train_ips = np.median(train_ips_list) + res = "BatchSize: %d\tIPS: %.2f\t,Cost: %.2f ms" % ( + bs, train_ips, 1000.0 / train_ips) + print(res, flush=True) + exp_recorder.record("network", args.arch) + exp_recorder.record("algorithm", args.alg) + exp_recorder.record("batch_size", train_max_batch) + exp_recorder.record("ips", train_ips, 2) + exp_recorder.record("tstamp", time.time(), 2) + exp_recorder.dump('speed_results.json') + exit(0) + train_step_ct += 1 + + +def validate(val_loader, model, criterion, args): + batch_time = AverageMeter('Time', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter( + len(val_loader), + [batch_time, losses, top1, top5], + prefix='Test: ') + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + end = time.time() + for i, (images, target) in enumerate(val_loader): + if args.gpu is not None: + images = images.cuda(args.gpu, non_blocking=True) + if torch.cuda.is_available(): + target = target.cuda(args.gpu, non_blocking=True) + + # compute output + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + # TODO: this should also be done with the ProgressMeter + print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' + .format(top1=top1, top5=top5)) + + return top1.avg + + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pth.tar') + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + #fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + fmtstr = '{name} {val' + self.fmt + '}' + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + + +def adjust_learning_rate(optimizer, epoch, args): + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" + lr = args.lr * (0.1 ** (epoch // 30)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + if config.debug_speed: + return [[-1]] * len(topk) + + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == '__main__': + # Run a matmul to initialize cublas first + a = torch.ones((1, 1)).cuda() + a = (a @ a).cpu() + del a + torch.cuda.empty_cache() + + main() + diff --git a/tests/dcgan.py b/tests/dcgan.py new file mode 100644 index 0000000..04d8aac --- /dev/null +++ b/tests/dcgan.py @@ -0,0 +1,736 @@ +# -*- coding: utf-8 -*- +""" +DCGAN Tutorial +============== + +**Author**: `Nathan Inkawhich `__ + +""" + +###################################################################### +# Introduction +# ------------ +# +# This tutorial will give an introduction to DCGANs through an example. We +# will train a generative adversarial network (GAN) to generate new +# celebrities after showing it pictures of many real celebrities. Most of +# the code here is from the dcgan implementation in +# `pytorch/examples `__, and this +# document will give a thorough explanation of the implementation and shed +# light on how and why this model works. But don’t worry, no prior +# knowledge of GANs is required, but it may require a first-timer to spend +# some time reasoning about what is actually happening under the hood. +# Also, for the sake of time it will help to have a GPU, or two. Lets +# start from the beginning. +# +# Generative Adversarial Networks +# ------------------------------- +# +# What is a GAN? +# ~~~~~~~~~~~~~~ +# +# GANs are a framework for teaching a DL model to capture the training +# data’s distribution so we can generate new data from that same +# distribution. GANs were invented by Ian Goodfellow in 2014 and first +# described in the paper `Generative Adversarial +# Nets `__. +# They are made of two distinct models, a *generator* and a +# *discriminator*. The job of the generator is to spawn ‘fake’ images that +# look like the training images. The job of the discriminator is to look +# at an image and output whether or not it is a real training image or a +# fake image from the generator. During training, the generator is +# constantly trying to outsmart the discriminator by generating better and +# better fakes, while the discriminator is working to become a better +# detective and correctly classify the real and fake images. The +# equilibrium of this game is when the generator is generating perfect +# fakes that look as if they came directly from the training data, and the +# discriminator is left to always guess at 50% confidence that the +# generator output is real or fake. +# +# Now, lets define some notation to be used throughout tutorial starting +# with the discriminator. Let :math:`x` be data representing an image. +# :math:`D(x)` is the discriminator network which outputs the (scalar) +# probability that :math:`x` came from training data rather than the +# generator. Here, since we are dealing with images the input to +# :math:`D(x)` is an image of CHW size 3x64x64. Intuitively, :math:`D(x)` +# should be HIGH when :math:`x` comes from training data and LOW when +# :math:`x` comes from the generator. :math:`D(x)` can also be thought of +# as a traditional binary classifier. +# +# For the generator’s notation, let :math:`z` be a latent space vector +# sampled from a standard normal distribution. :math:`G(z)` represents the +# generator function which maps the latent vector :math:`z` to data-space. +# The goal of :math:`G` is to estimate the distribution that the training +# data comes from (:math:`p_{data}`) so it can generate fake samples from +# that estimated distribution (:math:`p_g`). +# +# So, :math:`D(G(z))` is the probability (scalar) that the output of the +# generator :math:`G` is a real image. As described in `Goodfellow’s +# paper `__, +# :math:`D` and :math:`G` play a minimax game in which :math:`D` tries to +# maximize the probability it correctly classifies reals and fakes +# (:math:`logD(x)`), and :math:`G` tries to minimize the probability that +# :math:`D` will predict its outputs are fake (:math:`log(1-D(G(x)))`). +# From the paper, the GAN loss function is +# +# .. math:: \underset{G}{\text{min}} \underset{D}{\text{max}}V(D,G) = \mathbb{E}_{x\sim p_{data}(x)}\big[logD(x)\big] + \mathbb{E}_{z\sim p_{z}(z)}\big[log(1-D(G(z)))\big] +# +# In theory, the solution to this minimax game is where +# :math:`p_g = p_{data}`, and the discriminator guesses randomly if the +# inputs are real or fake. However, the convergence theory of GANs is +# still being actively researched and in reality models do not always +# train to this point. +# +# What is a DCGAN? +# ~~~~~~~~~~~~~~~~ +# +# A DCGAN is a direct extension of the GAN described above, except that it +# explicitly uses convolutional and convolutional-transpose layers in the +# discriminator and generator, respectively. It was first described by +# Radford et. al. in the paper `Unsupervised Representation Learning With +# Deep Convolutional Generative Adversarial +# Networks `__. The discriminator +# is made up of strided +# `convolution `__ +# layers, `batch +# norm `__ +# layers, and +# `LeakyReLU `__ +# activations. The input is a 3x64x64 input image and the output is a +# scalar probability that the input is from the real data distribution. +# The generator is comprised of +# `convolutional-transpose `__ +# layers, batch norm layers, and +# `ReLU `__ activations. The +# input is a latent vector, :math:`z`, that is drawn from a standard +# normal distribution and the output is a 3x64x64 RGB image. The strided +# conv-transpose layers allow the latent vector to be transformed into a +# volume with the same shape as an image. In the paper, the authors also +# give some tips about how to setup the optimizers, how to calculate the +# loss functions, and how to initialize the model weights, all of which +# will be explained in the coming sections. +# + +from __future__ import print_function +import argparse +import os +import random +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim as optim +import torch.utils.data +import torchvision.datasets as dset +import torchvision.transforms as transforms +import torchvision.utils as vutils +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import matplotlib.animation as animation +# from IPython.display import HTML +import actnn + +# Set random seed for reproducibility +manualSeed = 999 +# manualSeed = random.randint(1, 10000) # use if you want new results +print("Random Seed: ", manualSeed) +random.seed(manualSeed) +torch.manual_seed(manualSeed) + +###################################################################### +# Inputs +# ------ +# +# Let’s define some inputs for the run: +# +# - **dataroot** - the path to the root of the dataset folder. We will +# talk more about the dataset in the next section +# - **workers** - the number of worker threads for loading the data with +# the DataLoader +# - **batch_size** - the batch size used in training. The DCGAN paper +# uses a batch size of 128 +# - **image_size** - the spatial size of the images used for training. +# This implementation defaults to 64x64. If another size is desired, +# the structures of D and G must be changed. See +# `here `__ for more +# details +# - **nc** - number of color channels in the input images. For color +# images this is 3 +# - **nz** - length of latent vector +# - **ngf** - relates to the depth of feature maps carried through the +# generator +# - **ndf** - sets the depth of feature maps propagated through the +# discriminator +# - **num_epochs** - number of training epochs to run. Training for +# longer will probably lead to better results but will also take much +# longer +# - **lr** - learning rate for training. As described in the DCGAN paper, +# this number should be 0.0002 +# - **beta1** - beta1 hyperparameter for Adam optimizers. As described in +# paper, this number should be 0.5 +# - **ngpu** - number of GPUs available. If this is 0, code will run in +# CPU mode. If this number is greater than 0 it will run on that number +# of GPUs +# + +# Root directory for dataset +dataroot = "/data/jianfei/celeba" + +# Number of workers for dataloader +workers = 2 + +# Batch size during training +batch_size = 128 + +# Spatial size of training images. All images will be resized to this +# size using a transformer. +image_size = 64 + +# Number of channels in the training images. For color images this is 3 +nc = 3 + +# Size of z latent vector (i.e. size of generator input) +nz = 100 + +# Size of feature maps in generator +ngf = 64 + +# Size of feature maps in discriminator +ndf = 64 + +# Number of training epochs +num_epochs = 5 + +# Learning rate for optimizers +lr = 0.0002 + +# Beta1 hyperparam for Adam optimizers +beta1 = 0.5 + +# Number of GPUs available. Use 0 for CPU mode. +ngpu = 1 + +###################################################################### +# Data +# ---- +# +# In this tutorial we will use the `Celeb-A Faces +# dataset `__ which can +# be downloaded at the linked site, or in `Google +# Drive `__. +# The dataset will download as a file named *img_align_celeba.zip*. Once +# downloaded, create a directory named *celeba* and extract the zip file +# into that directory. Then, set the *dataroot* input for this notebook to +# the *celeba* directory you just created. The resulting directory +# structure should be: +# +# :: +# +# /path/to/celeba +# -> img_align_celeba +# -> 188242.jpg +# -> 173822.jpg +# -> 284702.jpg +# -> 537394.jpg +# ... +# +# This is an important step because we will be using the ImageFolder +# dataset class, which requires there to be subdirectories in the +# dataset’s root folder. Now, we can create the dataset, create the +# dataloader, set the device to run on, and finally visualize some of the +# training data. +# + +# We can use an image folder dataset the way we have it setup. +# Create the dataset +dataset = dset.ImageFolder(root=dataroot, + transform=transforms.Compose([ + transforms.Resize(image_size), + transforms.CenterCrop(image_size), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ])) +# Create the dataloader +dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, + shuffle=True, num_workers=workers) + +# Decide which device we want to run on +device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu") + +# Plot some training images +real_batch = next(iter(dataloader)) +plt.figure(figsize=(8, 8)) +plt.axis("off") +plt.title("Training Images") +plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(), (1, 2, 0))) + + +###################################################################### +# Implementation +# -------------- +# +# With our input parameters set and the dataset prepared, we can now get +# into the implementation. We will start with the weight initialization +# strategy, then talk about the generator, discriminator, loss functions, +# and training loop in detail. +# +# Weight Initialization +# ~~~~~~~~~~~~~~~~~~~~~ +# +# From the DCGAN paper, the authors specify that all model weights shall +# be randomly initialized from a Normal distribution with mean=0, +# stdev=0.02. The ``weights_init`` function takes an initialized model as +# input and reinitializes all convolutional, convolutional-transpose, and +# batch normalization layers to meet this criteria. This function is +# applied to the models immediately after initialization. +# + +# custom weights initialization called on netG and netD +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + + +###################################################################### +# Generator +# ~~~~~~~~~ +# +# The generator, :math:`G`, is designed to map the latent space vector +# (:math:`z`) to data-space. Since our data are images, converting +# :math:`z` to data-space means ultimately creating a RGB image with the +# same size as the training images (i.e. 3x64x64). In practice, this is +# accomplished through a series of strided two dimensional convolutional +# transpose layers, each paired with a 2d batch norm layer and a relu +# activation. The output of the generator is fed through a tanh function +# to return it to the input data range of :math:`[-1,1]`. It is worth +# noting the existence of the batch norm functions after the +# conv-transpose layers, as this is a critical contribution of the DCGAN +# paper. These layers help with the flow of gradients during training. An +# image of the generator from the DCGAN paper is shown below. +# +# .. figure:: /_static/img/dcgan_generator.png +# :alt: dcgan_generator +# +# Notice, the how the inputs we set in the input section (*nz*, *ngf*, and +# *nc*) influence the generator architecture in code. *nz* is the length +# of the z input vector, *ngf* relates to the size of the feature maps +# that are propagated through the generator, and *nc* is the number of +# channels in the output image (set to 3 for RGB images). Below is the +# code for the generator. +# + +# Generator Code + +class Generator(nn.Module): + def __init__(self, ngpu): + super(Generator, self).__init__() + self.ngpu = ngpu + self.main = nn.Sequential( + # input is Z, going into a convolution + nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), + nn.BatchNorm2d(ngf * 8), + nn.ReLU(True), + # state size. (ngf*8) x 4 x 4 + nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf * 4), + nn.ReLU(True), + # state size. (ngf*4) x 8 x 8 + nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf * 2), + nn.ReLU(True), + # state size. (ngf*2) x 16 x 16 + nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf), + nn.ReLU(True), + # state size. (ngf) x 32 x 32 + nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), + nn.Tanh() + # state size. (nc) x 64 x 64 + ) + + def forward(self, input): + return self.main(input) + + +###################################################################### +# Now, we can instantiate the generator and apply the ``weights_init`` +# function. Check out the printed model to see how the generator object is +# structured. +# + +# Create the generator +netG = Generator(ngpu) +netG = actnn.QModule(netG) +netG = netG.to(device) + +# Handle multi-gpu if desired +if (device.type == 'cuda') and (ngpu > 1): + netG = nn.DataParallel(netG, list(range(ngpu))) + +# Apply the weights_init function to randomly initialize all weights +# to mean=0, stdev=0.2. +netG.apply(weights_init) + +# Print the model +print(netG) + + +###################################################################### +# Discriminator +# ~~~~~~~~~~~~~ +# +# As mentioned, the discriminator, :math:`D`, is a binary classification +# network that takes an image as input and outputs a scalar probability +# that the input image is real (as opposed to fake). Here, :math:`D` takes +# a 3x64x64 input image, processes it through a series of Conv2d, +# BatchNorm2d, and LeakyReLU layers, and outputs the final probability +# through a Sigmoid activation function. This architecture can be extended +# with more layers if necessary for the problem, but there is significance +# to the use of the strided convolution, BatchNorm, and LeakyReLUs. The +# DCGAN paper mentions it is a good practice to use strided convolution +# rather than pooling to downsample because it lets the network learn its +# own pooling function. Also batch norm and leaky relu functions promote +# healthy gradient flow which is critical for the learning process of both +# :math:`G` and :math:`D`. +# + +######################################################################### +# Discriminator Code + +class Discriminator(nn.Module): + def __init__(self, ngpu): + super(Discriminator, self).__init__() + self.ngpu = ngpu + self.main = nn.Sequential( + # input is (nc) x 64 x 64 + nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), + nn.LeakyReLU(0.2, inplace=True), + # state size. (ndf) x 32 x 32 + nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(ndf * 2), + nn.LeakyReLU(0.2, inplace=True), + # state size. (ndf*2) x 16 x 16 + nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), + nn.BatchNorm2d(ndf * 4), + nn.LeakyReLU(0.2, inplace=True), + # state size. (ndf*4) x 8 x 8 + nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), + nn.BatchNorm2d(ndf * 8), + nn.LeakyReLU(0.2, inplace=True), + # state size. (ndf*8) x 4 x 4 + nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), + nn.Sigmoid() + ) + + def forward(self, input): + return self.main(input) + + +###################################################################### +# Now, as with the generator, we can create the discriminator, apply the +# ``weights_init`` function, and print the model’s structure. +# + +# Create the Discriminator +netD = Discriminator(ngpu) +netD = actnn.QModule(netD) +netD = netD.to(device) + +# Handle multi-gpu if desired +if (device.type == 'cuda') and (ngpu > 1): + netD = nn.DataParallel(netD, list(range(ngpu))) + +# Apply the weights_init function to randomly initialize all weights +# to mean=0, stdev=0.2. +netD.apply(weights_init) + +# Print the model +print(netD) + +###################################################################### +# Loss Functions and Optimizers +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# With :math:`D` and :math:`G` setup, we can specify how they learn +# through the loss functions and optimizers. We will use the Binary Cross +# Entropy loss +# (`BCELoss `__) +# function which is defined in PyTorch as: +# +# .. math:: \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right] +# +# Notice how this function provides the calculation of both log components +# in the objective function (i.e. :math:`log(D(x))` and +# :math:`log(1-D(G(z)))`). We can specify what part of the BCE equation to +# use with the :math:`y` input. This is accomplished in the training loop +# which is coming up soon, but it is important to understand how we can +# choose which component we wish to calculate just by changing :math:`y` +# (i.e. GT labels). +# +# Next, we define our real label as 1 and the fake label as 0. These +# labels will be used when calculating the losses of :math:`D` and +# :math:`G`, and this is also the convention used in the original GAN +# paper. Finally, we set up two separate optimizers, one for :math:`D` and +# one for :math:`G`. As specified in the DCGAN paper, both are Adam +# optimizers with learning rate 0.0002 and Beta1 = 0.5. For keeping track +# of the generator’s learning progression, we will generate a fixed batch +# of latent vectors that are drawn from a Gaussian distribution +# (i.e. fixed_noise) . In the training loop, we will periodically input +# this fixed_noise into :math:`G`, and over the iterations we will see +# images form out of the noise. +# + +# Initialize BCELoss function +criterion = nn.BCELoss() + +# Create batch of latent vectors that we will use to visualize +# the progression of the generator +fixed_noise = torch.randn(64, nz, 1, 1, device=device) + +# Establish convention for real and fake labels during training +real_label = 1. +fake_label = 0. + +# Setup Adam optimizers for both G and D +optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999)) +optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999)) + +###################################################################### +# Training +# ~~~~~~~~ +# +# Finally, now that we have all of the parts of the GAN framework defined, +# we can train it. Be mindful that training GANs is somewhat of an art +# form, as incorrect hyperparameter settings lead to mode collapse with +# little explanation of what went wrong. Here, we will closely follow +# Algorithm 1 from Goodfellow’s paper, while abiding by some of the best +# practices shown in `ganhacks `__. +# Namely, we will “construct different mini-batches for real and fake” +# images, and also adjust G’s objective function to maximize +# :math:`logD(G(z))`. Training is split up into two main parts. Part 1 +# updates the Discriminator and Part 2 updates the Generator. +# +# **Part 1 - Train the Discriminator** +# +# Recall, the goal of training the discriminator is to maximize the +# probability of correctly classifying a given input as real or fake. In +# terms of Goodfellow, we wish to “update the discriminator by ascending +# its stochastic gradient”. Practically, we want to maximize +# :math:`log(D(x)) + log(1-D(G(z)))`. Due to the separate mini-batch +# suggestion from ganhacks, we will calculate this in two steps. First, we +# will construct a batch of real samples from the training set, forward +# pass through :math:`D`, calculate the loss (:math:`log(D(x))`), then +# calculate the gradients in a backward pass. Secondly, we will construct +# a batch of fake samples with the current generator, forward pass this +# batch through :math:`D`, calculate the loss (:math:`log(1-D(G(z)))`), +# and *accumulate* the gradients with a backward pass. Now, with the +# gradients accumulated from both the all-real and all-fake batches, we +# call a step of the Discriminator’s optimizer. +# +# **Part 2 - Train the Generator** +# +# As stated in the original paper, we want to train the Generator by +# minimizing :math:`log(1-D(G(z)))` in an effort to generate better fakes. +# As mentioned, this was shown by Goodfellow to not provide sufficient +# gradients, especially early in the learning process. As a fix, we +# instead wish to maximize :math:`log(D(G(z)))`. In the code we accomplish +# this by: classifying the Generator output from Part 1 with the +# Discriminator, computing G’s loss *using real labels as GT*, computing +# G’s gradients in a backward pass, and finally updating G’s parameters +# with an optimizer step. It may seem counter-intuitive to use the real +# labels as GT labels for the loss function, but this allows us to use the +# :math:`log(x)` part of the BCELoss (rather than the :math:`log(1-x)` +# part) which is exactly what we want. +# +# Finally, we will do some statistic reporting and at the end of each +# epoch we will push our fixed_noise batch through the generator to +# visually track the progress of G’s training. The training statistics +# reported are: +# +# - **Loss_D** - discriminator loss calculated as the sum of losses for +# the all real and all fake batches (:math:`log(D(x)) + log(D(G(z)))`). +# - **Loss_G** - generator loss calculated as :math:`log(D(G(z)))` +# - **D(x)** - the average output (across the batch) of the discriminator +# for the all real batch. This should start close to 1 then +# theoretically converge to 0.5 when G gets better. Think about why +# this is. +# - **D(G(z))** - average discriminator outputs for the all fake batch. +# The first number is before D is updated and the second number is +# after D is updated. These numbers should start near 0 and converge to +# 0.5 as G gets better. Think about why this is. +# +# **Note:** This step might take a while, depending on how many epochs you +# run and if you removed some data from the dataset. +# + +# Training Loop + +# Lists to keep track of progress +img_list = [] +G_losses = [] +D_losses = [] +iters = 0 + +print("Starting Training Loop...") +# For each epoch +for epoch in range(num_epochs): + # For each batch in the dataloader + for i, data in enumerate(dataloader, 0): + + ############################ + # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) + ########################### + ## Train with all-real batch + netD.zero_grad() + # Format batch + real_cpu = data[0].to(device) + b_size = real_cpu.size(0) + label = torch.full((b_size,), real_label, dtype=torch.float, device=device) + # Forward pass real batch through D + output = netD(real_cpu).view(-1) + # Calculate loss on all-real batch + errD_real = criterion(output, label) + # Calculate gradients for D in backward pass + errD_real.backward() + D_x = output.mean().item() + + ## Train with all-fake batch + # Generate batch of latent vectors + noise = torch.randn(b_size, nz, 1, 1, device=device) + # Generate fake image batch with G + fake = netG(noise) + label.fill_(fake_label) + # Classify all fake batch with D + output = netD(fake.detach()).view(-1) + # Calculate D's loss on the all-fake batch + errD_fake = criterion(output, label) + # Calculate the gradients for this batch + errD_fake.backward() + D_G_z1 = output.mean().item() + # Add the gradients from the all-real and all-fake batches + errD = errD_real + errD_fake + # Update D + optimizerD.step() + + ############################ + # (2) Update G network: maximize log(D(G(z))) + ########################### + netG.zero_grad() + label.fill_(real_label) # fake labels are real for generator cost + # Since we just updated D, perform another forward pass of all-fake batch through D + output = netD(fake).view(-1) + # Calculate G's loss based on this output + errG = criterion(output, label) + # Calculate gradients for G + errG.backward() + D_G_z2 = output.mean().item() + # Update G + optimizerG.step() + + # Output training stats + if i % 50 == 0: + print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' + % (epoch, num_epochs, i, len(dataloader), + errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) + + # Save Losses for plotting later + G_losses.append(errG.item()) + D_losses.append(errD.item()) + + # Check how the generator is doing by saving G's output on fixed_noise + if (iters % 500 == 0) or ((epoch == num_epochs - 1) and (i == len(dataloader) - 1)): + with torch.no_grad(): + fake = netG(fixed_noise).detach().cpu() + img_list.append(vutils.make_grid(fake, padding=2, normalize=True)) + + iters += 1 + +###################################################################### +# Results +# ------- +# +# Finally, lets check out how we did. Here, we will look at three +# different results. First, we will see how D and G’s losses changed +# during training. Second, we will visualize G’s output on the fixed_noise +# batch for every epoch. And third, we will look at a batch of real data +# next to a batch of fake data from G. +# +# **Loss versus training iteration** +# +# Below is a plot of D & G’s losses versus training iterations. +# + +plt.figure(figsize=(10, 5)) +plt.title("Generator and Discriminator Loss During Training") +plt.plot(G_losses, label="G") +plt.plot(D_losses, label="D") +plt.xlabel("iterations") +plt.ylabel("Loss") +plt.legend() +plt.show() +plt.savefig('loss.png') + +###################################################################### +# **Visualization of G’s progression** +# +# Remember how we saved the generator’s output on the fixed_noise batch +# after every epoch of training. Now, we can visualize the training +# progression of G with an animation. Press the play button to start the +# animation. +# + +# %%capture +# fig = plt.figure(figsize=(8, 8)) +# plt.axis("off") +# ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list] +# ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True) + +# HTML(ani.to_jshtml()) + +###################################################################### +# **Real Images vs. Fake Images** +# +# Finally, lets take a look at some real images and fake images side by +# side. +# + +# Grab a batch of real images from the dataloader +real_batch = next(iter(dataloader)) + +# Plot the real images +plt.figure(figsize=(15, 15)) +plt.subplot(1, 2, 1) +plt.axis("off") +plt.title("Real Images") +plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(), (1, 2, 0))) +plt.savefig('real_images.png') + +# Plot the fake images from the last epoch +plt.subplot(1, 2, 2) +plt.axis("off") +plt.title("Fake Images") +plt.imshow(np.transpose(img_list[-1], (1, 2, 0))) +plt.show() +plt.savefig('fake_images.png') + +###################################################################### +# Where to Go Next +# ---------------- +# +# We have reached the end of our journey, but there are several places you +# could go from here. You could: +# +# - Train for longer to see how good the results get +# - Modify this model to take a different dataset and possibly change the +# size of the images and the model architecture +# - Check out some other cool GAN projects +# `here `__ +# - Create GANs that generate +# `music `__ +# diff --git a/tests/test_act_quantized_ops.py b/tests/test_act_quantized_ops.py new file mode 100644 index 0000000..fca9a3b --- /dev/null +++ b/tests/test_act_quantized_ops.py @@ -0,0 +1,590 @@ +"""Test the activation quantized ops""" + +import math + +import numpy as np +import torch +from torch import nn, autograd +from torch.nn import init, functional as F +from torch.autograd.function import Function + +from timeit_v2 import py_benchmark + +from actnn import QScheme, QBNScheme, config, get_memory_usage, compute_tensor_bytes +from actnn.ops import ext_backward_func, ext_quantization +from actnn.ops import conv2d as quantized_conv2d, batch_norm as quantized_batch_norm, \ + adaptive_avg_pool2d as quantized_adaptive_avg_pool2d + + +def test_relu_correctness(): + print("========== ReLU Correctness Test ==========") + + for dtype in ['float32', 'float16']: + print(f"test {dtype}...") + data_np = np.random.randn(128, 56, 56, 31).astype(dtype) + + def test_implementation(func): + data = torch.tensor(data_np).to("cuda").requires_grad_() + + output = func(data) + output.backward(torch.ones_like(output)) + + return [x.detach().cpu().numpy() for x in [output, data.grad]] + + output_ref, grad_data_ref = test_implementation(F.relu) + output_us, grad_data_us = test_implementation(ext_quantization.act_quantized_relu) + + np.testing.assert_allclose(output_ref, output_us) + np.testing.assert_allclose(grad_data_ref, grad_data_us) + + +def test_relu_memory(): + print("========== ReLU Memory Test ==========") + + for dtype in ['float32', 'float16']: + print(f"test {dtype}...") + data_np = np.random.randn(128, 56, 56, 32).astype(dtype) + + def test_implementation(func): + data = torch.tensor(data_np).to("cuda").requires_grad_() + + before = get_memory_usage() + + for i in range(10): + data = func(data) + + after = get_memory_usage() + + return after - before + + usage_ref = test_implementation(F.relu) + usage_us = test_implementation(ext_quantization.act_quantized_relu) + + print("Exact. Usage: %.2f MB" % (usage_ref / 2 ** 20)) + print("Quantized. Usage: %.2f MB" % (usage_us / 2 ** 20)) + + +def test_relu_speed(): + print("========== ReLU Speed Test ==========") + + for dtype in ['float32', 'float16']: + print(f"test {dtype}...") + + data_np = np.random.randn(256, 56, 56, 32).astype(dtype) + + def test_implementation(func): + data = torch.tensor(data_np).to("cuda").requires_grad_() + + stmt = "func(data)" + t_forward = py_benchmark(stmt, {**globals(), **locals()}, + setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()") + + output = func(data) + head = torch.ones_like(output) + stmt = "output.backward(head, retain_graph=True)" + t_backward = py_benchmark(stmt, {**globals(), **locals()}, + setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()") + + return t_forward, t_backward + + forward_ref, backward_ref = test_implementation(F.relu) + forward_us, backward_us = test_implementation(ext_quantization.act_quantized_relu) + + print("Exact. forward: %.2f ms\tbackward: %.2f ms\tsum: %.2f ms" % + (forward_ref * 1e3, backward_ref * 1e3, (forward_ref + backward_ref) * 1e3)) + print("Quantized. forward: %.2f ms\tbackward: %.2f ms\tsum: %.2f ms" % + (forward_us * 1e3, backward_us * 1e3, (forward_us + backward_us) * 1e3)) + + +def test_adaptive_avg_pool2d_correctness(): + """Test the correctness of computation results""" + # arguments and test data + N, H, W, CI, CO, kernel_size, stride, padding, dilation, groups = 4, 28, 28, 256, 256, 3, 1, 1, 1, 1 + data_np = np.random.randn(N, CI, H, W).astype('float32') + head_np = np.random.randn(N, CI, 1, 1).astype('float32') + output_size = 1, 1 + + def test_implementation(func): + torch.manual_seed(0) + data = torch.tensor(data_np).to("cuda").requires_grad_() + head = torch.tensor(head_np).to("cuda") + + output = func(data, output_size) + output.backward(head) + + return [x.detach().cpu().numpy() for x in [output, data.grad]] + + output_ref, grad_data_ref = test_implementation(F.adaptive_avg_pool2d) + output_us, grad_data_us = test_implementation(quantized_adaptive_avg_pool2d.apply) + + atol = 1e-4 + rtol = 1e-4 + print("========== AdaptiveAvgPool2d Correctness Test ==========") + np.testing.assert_allclose(output_ref, output_us, atol=atol, rtol=rtol) + np.testing.assert_allclose(grad_data_ref, grad_data_us, atol=atol, rtol=rtol) + + +def test_adaptive_avg_pool2d_memory(): + """Test the memory usage""" + # arguments and test data + N, H, W, CI = 1024, 4, 4, 1024 + data_np = np.random.randn(N, CI, H, W).astype('float32') + output_size = (1, 1) + + def test_implementation(func): + data = torch.tensor(data_np).to("cuda").requires_grad_() + output = func(data, output_size) + for i in range(10): + output = func(output, output_size) + + return get_memory_usage() - compute_tensor_bytes([data, output]) + + usage_ref = test_implementation(F.adaptive_avg_pool2d) + usage_us = test_implementation(quantized_adaptive_avg_pool2d.apply) + + print("========== AdaptiveAvgPool2d Memory Test ==========") + print("Exact. Usage: %.3f MB" % (usage_ref / 2 ** 20)) + print("Quantized. Usage: %.2f MB" % (usage_us / 2 ** 20)) + + +def test_max_pool2d_correctness(): + """Test the correctness of computation results""" + # arguments and test data + N, H, W, CI, kernel_size, stride, padding, dilation = 4, 28, 28, 8, 3, 2, 1, 1 + ceil_mode, return_indices = False, False + + print("========== MaxPool2d Correctness Test ==========") + for dtype in ['float32', 'float16']: + print(f"test {dtype}...") + data_np = np.random.randn(N, CI, H, W).astype(dtype) + + def test_implementation(func): + data = torch.tensor(data_np).to("cuda").requires_grad_() + + output = func(data, (kernel_size, kernel_size), (stride, stride), (padding, padding), + (dilation, dilation), ceil_mode, return_indices) + output.backward(torch.ones_like(output)) + + return [x.detach().cpu().numpy() for x in [output, data.grad]] + + output_ref, grad_data_ref = test_implementation(F.max_pool2d) + output_us, grad_data_us = test_implementation(ext_quantization.act_quantized_max_pool2d) + + atol = 1e-4 + rtol = 1e-4 + np.testing.assert_allclose(output_ref, output_us, atol=atol, rtol=rtol) + np.testing.assert_allclose(grad_data_ref, grad_data_us, atol=atol, rtol=rtol) + + +def test_max_pool2d_memory(): + """Test the memory usage""" + # arguments and test data + N, H, W, CI, kernel_size, stride, padding, dilation = 128, 28, 28, 8, 3, 2, 1, 1 + ceil_mode, return_indices = False, False + + print("========== MaxPool2d Memory Test ==========") + for dtype in ['float32', 'float16']: + print(f"test {dtype}...") + + data_np = np.random.randn(N, CI, H, W).astype(dtype) + + def test_implementation(func): + data = torch.tensor(data_np).to("cuda").requires_grad_() + output = func(data, (kernel_size, kernel_size), (stride, stride), (padding, padding), + (dilation, dilation), ceil_mode, return_indices) + + return get_memory_usage() - compute_tensor_bytes([output, data]) + + usage_ref = test_implementation(F.max_pool2d) + usage_us = test_implementation(ext_quantization.act_quantized_max_pool2d) + print("Exact. Usage: %.3f MB" % (usage_ref / 2 ** 20)) + print("Quantized. Usage: %.3f MB" % (usage_us / 2 ** 20)) + + +def test_max_pool2d_speed(): + """Test the correctness of computation results""" + # arguments and test data + N, H, W, CI, kernel_size, stride, padding, dilation = 128, 28, 28, 128, 3, 2, 1, 1 + ceil_mode, return_indices = False, False + + + print("========== MaxPool2d Speed Test ==========") + for dtype in ['float32', 'float16']: + print(f"test {dtype}...") + data_np = np.random.randn(N, CI, H, W).astype(dtype) + + def test_implementation(func): + data = torch.tensor(data_np).to("cuda").requires_grad_() + + stmt = "func(data, (kernel_size, kernel_size), (stride, stride), (padding, padding),"\ + "(dilation, dilation), ceil_mode, return_indices)" + t_forward = py_benchmark(stmt, {**globals(), **locals()}, + setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()") + + output = func(data, (kernel_size, kernel_size), (stride, stride), (padding, padding), + (dilation, dilation), ceil_mode, return_indices) + head = torch.ones_like(output) + + stmt = "output.backward(head, retain_graph=True)" + t_backward = py_benchmark(stmt, {**globals(), **locals()}, + setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()") + return t_forward, t_backward + + forward_ref, backward_ref = test_implementation(F.max_pool2d) + forward_us, backward_us = test_implementation(ext_quantization.act_quantized_max_pool2d) + + print("Exact. forward: %.2f ms\tbackward: %.2f ms\tsum: %.2f ms" % + (forward_ref * 1e3, backward_ref * 1e3, (forward_ref + backward_ref) * 1e3)) + print("Quantized. forward: %.2f ms\tbackward: %.2f ms\tsum: %.2f ms" % + (forward_us * 1e3, backward_us * 1e3, (forward_us + backward_us) * 1e3)) + + +def test_upsample_memory(): + """Test the memory usage""" + # arguments and test data + N, H, W, CI = 128, 28, 28, 8 + size, scale_factor, mode, align_corners = None, 2, 'bilinear', False + data_np = np.random.randn(N, CI, H, W).astype('float32') + + def test_implementation(func): + data = torch.tensor(data_np).to("cuda").requires_grad_() + output = func(data, size, scale_factor, mode, align_corners) + output = func(output, size, scale_factor, mode, align_corners) + output = func(output, size, scale_factor, mode, align_corners) + + return get_memory_usage() - compute_tensor_bytes([output, data]) + + usage_ref = test_implementation(F.interpolate) + print("========== Upsample Memory Test ==========") + print("Exact. Usage: %.3f MB" % (usage_ref / 2 ** 20)) + + +def test_bn_correctness(): + # arguments and test data + N, H, W, CI = 16, 28, 28, 256 + data_np = np.random.randn(N, CI, H, W).astype('float32') * 0.01 + running_mean_np = np.random.randn(CI).astype('float32') + running_var_np = np.random.randn(CI).astype('float32') + bn_weight_np = np.random.randn(CI).astype('float32') + bn_bias_np = np.random.randn(CI).astype('float32') + training = False + + bn_scheme = QBNScheme() + config.compress_activation = False + + def test_implementation(func): + torch.manual_seed(0) + data = torch.tensor(data_np).to("cuda").requires_grad_() + running_mean = torch.tensor(running_mean_np).to("cuda") + running_var = torch.tensor(running_var_np).to("cuda") + bn_weight = torch.tensor(bn_weight_np).to("cuda").requires_grad_() + bn_bias = torch.tensor(bn_bias_np).to("cuda").requires_grad_() + + if func == F.batch_norm: + output = func(data, running_mean, running_var, bn_weight, bn_bias, training, 0.1, 1e-5) + else: + output = func(data, running_mean, running_var, bn_weight, bn_bias, training, 0.1, 1e-5, bn_scheme) + + output.backward(torch.ones_like(output)) + + return [x.detach().cpu().numpy() for x in [output, data.grad, bn_weight.grad, bn_bias.grad]] + + output_ref, grad_data_ref, grad_weight_ref, grad_bias_ref = test_implementation(F.batch_norm) + output_us, grad_data_us, grad_weight_us, grad_bias_us = test_implementation(quantized_batch_norm.apply) + + atol = 1e-3 + rtol = 1e-3 + print("========== BN Correctness Test ==========") + np.testing.assert_allclose(output_ref, output_us, atol=atol, rtol=rtol) + np.testing.assert_allclose(grad_data_ref, grad_data_us, atol=atol, rtol=rtol) + np.testing.assert_allclose(grad_weight_ref, grad_weight_us, atol=atol, rtol=rtol) + np.testing.assert_allclose(grad_bias_ref, grad_bias_us, atol=atol, rtol=rtol) + + +def test_conv2d_correctness(): + """Test the correctness of computation results""" + # arguments and test data + N, H, W, CI, CO, kernel_size, stride, padding, dilation, groups = 4, 28, 28, 256, 256, 3, 1, 1, 1, 1 + + print("========== Conv2d Correctness Test ==========") + for dtype in ['float32', 'float16']: + print(f"test {dtype}...") + data_np = np.random.randn(N, CI, H, W).astype(dtype) + weight_np = np.random.randn(CO, CI // groups, kernel_size, kernel_size).astype(dtype) + bias_np = np.random.randn(CO).astype(dtype) + + def test_implementation(func, scheme): + torch.manual_seed(0) + data = torch.tensor(data_np).to("cuda").requires_grad_() + weight = torch.tensor(weight_np).to("cuda").requires_grad_() + bias = torch.tensor(bias_np).to("cuda").requires_grad_() + + output = func(data, weight, bias, stride, padding, dilation, groups, scheme) + output.backward(torch.ones_like(output)) + + return [x.detach().cpu().numpy() for x in [output, data.grad, weight.grad, bias.grad]] + + config.activation_compression_bits = [16] + config.initial_bits = 16 + config.perlayer = False + config.use_gradient = False + scheme = QScheme(None) + + config.simulate = True + output_ref, grad_data_ref, grad_weight_ref, grad_bias_ref = test_implementation(quantized_conv2d.apply, scheme) + config.simulate = False + output_us, grad_data_us, grad_weight_us, grad_bias_us = test_implementation(quantized_conv2d.apply, scheme) + + atol = 1e-2 + rtol = 1e-2 + assert output_ref.dtype == output_us.dtype + np.testing.assert_allclose(output_ref, output_us, atol=atol, rtol=rtol) + np.testing.assert_allclose(grad_data_ref, grad_data_us, atol=atol, rtol=rtol) + np.testing.assert_allclose(grad_weight_ref, grad_weight_us, atol=atol, rtol=rtol) + np.testing.assert_allclose(grad_bias_ref, grad_bias_us, atol=atol, rtol=rtol) + + +def test_conv2d_correctness_per_group_only(): + """Test the correctness of computation results + + NOTE: This test will fail on large shapes or low bits. + To make this test pass, we should disable stochastic noise. + """ + + # arguments and test data + N, H, W, CI, CO, kernel_size, stride, padding, dilation, groups = 2, 16, 16, 4, 4, 1, 1, 1, 1, 1 + + print("========== Conv2d Correctness Test (per group only) ==========") + for dtype in ['float32', 'float16']: + print(f"test {dtype}...") + + data_np = np.random.randn(N, CI, H, W).astype(dtype) + weight_np = np.random.randn(CO, CI // groups, kernel_size, kernel_size).astype(dtype) + bias_np = np.random.randn(CO).astype(dtype) + + def test_implementation(func, scheme): + torch.manual_seed(0) + data = torch.tensor(data_np).to("cuda").requires_grad_() + weight = torch.tensor(weight_np).to("cuda").requires_grad_() + bias = torch.tensor(bias_np).to("cuda").requires_grad_() + + output = func(data, weight, bias, stride, padding, dilation, groups, scheme) + output.backward(torch.ones_like(output)) + + return [x.detach().cpu().numpy() for x in [output, data.grad, weight.grad, bias.grad]] + + config.activation_compression_bits = [8] + config.perlayer = False + config.use_gradient = False + + config.simulate = True + output_ref, grad_data_ref, grad_weight_ref, grad_bias_ref = test_implementation(quantized_conv2d.apply, None) + config.simulate = False + output_us, grad_data_us, grad_weight_us, grad_bias_us = test_implementation(quantized_conv2d.apply, None) + + atol = 1e-1 + rtol = 1e-1 + assert output_ref.dtype == output_us.dtype + np.testing.assert_allclose(output_ref, output_us, atol=atol, rtol=rtol) + np.testing.assert_allclose(grad_data_ref, grad_data_us, atol=atol, rtol=rtol) + np.testing.assert_allclose(grad_weight_ref, grad_weight_us, atol=atol, rtol=rtol) + np.testing.assert_allclose(grad_bias_ref, grad_bias_us, atol=atol, rtol=rtol) + + +def test_conv2d_speed(): + """Test the speed of convolution layer""" + + + # arguments and test data + N, H, W, CI, CO, kernel_size, stride, padding, dilation, groups = 128, 28, 28, 256, 256, 3, 1, 1, 1, 1 + + print("========== Conv2d Speed Test ==========") + + for dtype in ['float32', 'float16']: + print(f"test {dtype}...") + data_np = np.random.randn(N, CI, H, W).astype(dtype) + weight_np = np.random.randn(CO, CI // groups, kernel_size, kernel_size).astype(dtype) + bias_np = np.random.randn(CO).astype(dtype) + + scheme = QScheme(None) + + def test_implementation(func, scheme): + data = torch.tensor(data_np).to("cuda").requires_grad_() + weight = torch.tensor(weight_np).to("cuda").requires_grad_() + bias = torch.tensor(bias_np).to("cuda").requires_grad_() + + if func == quantized_conv2d.apply: + output = func(data, weight, bias, stride, padding, dilation, groups, scheme) + stmt = "func(data, weight, bias, stride, padding, dilation, groups, scheme)" + else: + output = func(data, weight, bias, stride, padding, dilation, groups) + stmt = "func(data, weight, bias, stride, padding, dilation, groups)" + + t_forward = py_benchmark(stmt, {**globals(), **locals()}, + setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()") + + head = torch.ones_like(output) + stmt = "output.backward(head, retain_graph=True)" + t_backward = py_benchmark(stmt, {**globals(), **locals()}, + setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()") + + return t_forward, t_backward + + config.activation_compression_bits = [16] + config.initial_bits = 16 + config.perlayer = False + config.use_gradient = False + config.simulate = False + scheme = QScheme(None) + + forward_ref, backward_ref = test_implementation(F.conv2d, None) + forward_us, backward_us = test_implementation(quantized_conv2d.apply, scheme) + + print("Exact. forward: %.2f ms\tbackward: %.2f ms\tsum: %.2f ms" % + (forward_ref * 1e3, backward_ref * 1e3, (forward_ref + backward_ref) * 1e3)) + print("Quantized. forward: %.2f ms\tbackward: %.2f ms\tsum: %.2f ms" % + (forward_us * 1e3, backward_us * 1e3, (forward_us + backward_us) * 1e3)) + + +def test_conv2d_memory_analytical(): + """Compute the memory of activation analytically""" + + # arguments and test data + N, H, W, CI, CO, kernel_size, stride, padding, dilation, groups = 256, 28, 28, 256, 256, 3, 1, 1, 1, 1 + data_np = np.random.randn(N, CI, H, W).astype('float32') + weight_np = np.random.randn(CO, CI // groups, kernel_size, kernel_size).astype('float32') + bias_np = np.random.randn(CO).astype('float32') + running_mean = np.zeros((CO,), dtype='float32') + running_var = np.ones((CO,), dtype='float32') + bn_weight = np.random.randn(CO).astype('float32') + bn_bias = np.random.randn(CO).astype('float32') + + scheme = QScheme(num_locations=kernel_size**2) + bn_scheme = QBNScheme() + + def test_implementation(conv_func, relu_func, bn_func, n_layers=10): + data = torch.tensor(data_np).to("cuda") + + # allocate input and weights + data = torch.tensor(data_np).to("cuda").requires_grad_(False) + weights = [] + running_means = [] + running_vars = [] + bn_weights = [] + bn_biass = [] + for i in range(n_layers): + weights.append(torch.tensor(weight_np).to("cuda").requires_grad_()) + running_means.append(torch.tensor(running_mean).to("cuda")) + running_vars.append(torch.tensor(running_var).to("cuda")) + bn_weights.append(torch.tensor(bn_weight).to("cuda").requires_grad_()) + bn_biass.append(torch.tensor(bn_bias).to("cuda").requires_grad_()) + + before_size = get_memory_usage(False) + + # forward n convolution layers + output = data + for i in range(n_layers): + if conv_func == quantized_conv2d.apply: + output = conv_func(output, weights[i], None, stride, padding, dilation, groups, scheme) + output = bn_func(output, running_means[i], running_vars[i], bn_weights[i], bn_biass[i], True, 0.1, 1e-5, bn_scheme) + else: + output = conv_func(output, weights[i], None, stride, padding, dilation, groups) + output = bn_func(output, running_means[i], running_vars[i], bn_weights[i], bn_biass[i], True, 0.1, 1e-5) + output = relu_func(output) + + output = output.sum() + + after_size = get_memory_usage(False) + output_size = compute_tensor_bytes(output) + + return after_size / 1024**2, (after_size - before_size - output_size) / 1024**2 + + total_size_ref, act_size_ref = test_implementation(F.conv2d, lambda x: F.relu(x, inplace=True), F.batch_norm) + config.simulate = True + total_size_sim, act_size_sim = test_implementation(quantized_conv2d.apply, + ext_quantization.act_quantized_relu, quantized_batch_norm.apply) + config.simulate = False + total_size_us, act_size_us = test_implementation(quantized_conv2d.apply, + ext_quantization.act_quantized_relu, quantized_batch_norm.apply) + + print("========== Conv2d Activation Memory Test (bits = %d) ==========" % (config.activation_compression_bits)) + print("Exact. Total: %7.2f MB\tAct: %7.2f MB" % (total_size_ref, act_size_ref)) + print("Simulation. Total: %7.2f MB\tAct: %7.2f MB" % (total_size_sim, act_size_sim)) + print("Quantized. Total: %7.2f MB\tAct: %7.2f MB" % (total_size_us, act_size_us)) + + +def test_conv2d_memory_max_batch_size(): + """Find the maximum batch size by gradually increasing the batch size until hitting Out-of-memory error""" + + for device in ["cuda"]: + def test_implementation(func, n_layers, batch_sizes): + def run_batch_size(batch_size): + N, H, W, CI, CO, kernel_size, stride, padding, dilation, groups = batch_size, 28, 28, 256, 256, 3, 1, 1, 1, 1 + data_np = np.random.uniform(size=(N, CI, H, W)).astype('float32') + weight_np = np.random.uniform(size=(CO, CI // groups, kernel_size, kernel_size)).astype('float32') + bias_np = np.random.uniform(size=(CO,)).astype('float32') + + # allocate input and weights + data = torch.tensor(data_np).to("cuda").requires_grad_(False) + weights = [] + for i in range(n_layers): + weight = torch.tensor(weight_np).to("cuda").requires_grad_() + weights.append(weight) + + before_size = get_memory_usage(False) + + # forward n convolution layers + output = data + for i in range(n_layers): + output = func(output, weights[i], None, stride, padding, dilation, groups) + output = output.sum() + + after_size = get_memory_usage(False) + output_size = compute_tensor_bytes(output) + + return after_size / 1024**2, (after_size - before_size - output_size) / 1024**2 + + # try gradually increased batch sizes + try: + for i, batch_size in enumerate(batch_sizes): + total_size_ref, act_size_ref = run_batch_size(batch_size) + print("batch_size: %4d\t" % batch_size, end="") + print("total_memory: %7.2f MB\tact_memory: %7.2f MB" % (total_size_ref, act_size_ref)) + except RuntimeError: + pass + finally: + print("Maximum batch size: %d" % (batch_sizes[i-1])) + + print("========== Conv2d Batch Size Test ==========") + print("---> Exact") + test_implementation(F.conv2d, n_layers=50, batch_sizes=[100, 200, 250, 300, 350, 400, 450, 500, 1000]) + print("---> Quantized") + test_implementation(act_quantized_conv2d.apply, n_layers=50, batch_sizes=[100, 200, 250, 500, 1000, 2200, 2300, 2400, 3000, 4000]) + + +if __name__ == "__main__": + test_relu_correctness() + test_relu_memory() + test_relu_speed() + + #test_adaptive_avg_pool2d_correctness() + #test_adaptive_avg_pool2d_memory() + + #test_max_pool2d_correctness() + #test_max_pool2d_memory() + #test_max_pool2d_speed() + + #test_upsample_memory() + + #test_bn_correctness() + + test_conv2d_correctness() + #test_conv2d_correctness_per_group_only() + + #test_conv2d_speed() + + #config.activation_compression_bits = 2 + #test_conv2d_memory_analytical() + + #config.activation_compression_bits = 2 + #test_conv2d_memory_max_batch_size() diff --git a/tests/test_backward_func.py b/tests/test_backward_func.py new file mode 100644 index 0000000..d5a5bd9 --- /dev/null +++ b/tests/test_backward_func.py @@ -0,0 +1,396 @@ +"""Test calling c++ backward func from Python""" +import math + +import numpy as np +import torch +from torch import nn, autograd +from torch.nn import init, functional as F +from torch.nn.modules.utils import _single, _pair, _triple + +from actnn.cpp_extension.backward_func import (cudnn_convolution_backward, + cudnn_convolution_transpose_backward) + + +class conv1d_explicit_backward(autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + stride = (1, stride) + padding = (0, padding) + dilation = (1, dilation) + + ctx.save_for_backward(input, weight, bias) + ctx.other_args = (stride, padding, dilation, groups) + + input = input.unsqueeze(2) + weight = weight.unsqueeze(2) + out = F.conv2d(input, weight, bias, stride, padding, dilation, groups) + return out.squeeze(2) + + @staticmethod + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + + input = input.unsqueeze(2) + weight = weight.unsqueeze(2) + grad_output = grad_output.unsqueeze(2) + + stride, padding, dilation, groups = ctx.other_args + padding = _pair(padding) + stride = _pair(stride) + dilation = _pair(dilation) + + grad_input, grad_weight = cudnn_convolution_backward( + input, grad_output, weight, padding, stride, dilation, groups, + False, False, False, [ctx.needs_input_grad[0], ctx.needs_input_grad[1]]) + + if bias is not None and ctx.needs_input_grad[2]: + grad_bias = grad_output.sum([0, 2, 3]) + else: + grad_bias = None + + grad_input = grad_input.squeeze(2) + grad_weight = grad_weight.squeeze(2) + + return grad_input, grad_weight, grad_bias, None, None, None, None + + +class conv2d_explicit_backward(autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + ctx.save_for_backward(input, weight, bias) + ctx.other_args = (stride, padding, dilation, groups) + return F.conv2d(input, weight, bias, stride, padding, dilation, groups) + + @staticmethod + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + stride, padding, dilation, groups = ctx.other_args + padding = _pair(padding) + stride = _pair(stride) + dilation = _pair(dilation) + + grad_input, grad_weight = cudnn_convolution_backward( + input, grad_output, weight, padding, stride, dilation, groups, + False, False, False, [ctx.needs_input_grad[0], ctx.needs_input_grad[1]]) + + if bias is not None and ctx.needs_input_grad[2]: + grad_bias = grad_output.sum([0, 2, 3]) + else: + grad_bias = None + + return grad_input, grad_weight, grad_bias, None, None, None, None + + +class conv3d_explicit_backward(autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + ctx.save_for_backward(input, weight, bias) + ctx.other_args = (stride, padding, dilation, groups) + return F.conv3d(input, weight, bias, stride, padding, dilation, groups) + + @staticmethod + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + stride, padding, dilation, groups = ctx.other_args + padding = _triple(padding) + stride = _triple(stride) + dilation = _triple(dilation) + + grad_input, grad_weight = cudnn_convolution_backward( + input, grad_output, weight, padding, stride, dilation, groups, + False, False, False, [ctx.needs_input_grad[0], ctx.needs_input_grad[1]]) + + if bias is not None and ctx.needs_input_grad[2]: + grad_bias = grad_output.sum([0, 2, 3, 4]) + else: + grad_bias = None + + return grad_input, grad_weight, grad_bias, None, None, None, None + + +class conv_transpose1d_explicit_backward(autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias=None, stride=1, padding=0, output_padding=0, dilation=1, groups=1): + stride = (1, stride) + padding = (0, padding) + dilation = (1, dilation) + + ctx.save_for_backward(input, weight, bias) + ctx.other_args = (stride, padding, output_padding, dilation, groups) + + input = input.unsqueeze(2) + weight = weight.unsqueeze(2) + out = F.conv_transpose2d(input, weight, bias, stride, padding, output_padding, groups, dilation) + return out.squeeze(2) + + @staticmethod + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + stride, padding, output_padding, dilation, groups = ctx.other_args + + input = input.unsqueeze(2) + weight = weight.unsqueeze(2) + grad_output = grad_output.unsqueeze(2) + + padding = _pair(padding) + output_padding = _pair(output_padding) + stride = _pair(stride) + dilation = _pair(dilation) + + grad_input, grad_weight = cudnn_convolution_transpose_backward( + input, grad_output, weight, padding, output_padding, stride, dilation, groups, + False, False, False, [ctx.needs_input_grad[0], ctx.needs_input_grad[1]]) + + if bias is not None and ctx.needs_input_grad[2]: + grad_bias = grad_output.sum([0, 2, 3]) + else: + grad_bias = None + + grad_input = grad_input.squeeze(2) + grad_weight = grad_weight.squeeze(2) + + return grad_input, grad_weight, grad_bias, None, None, None, None, None + + + +class conv_transpose2d_explicit_backward(autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias=None, stride=1, padding=0, output_padding=0, dilation=1, groups=1): + ctx.save_for_backward(input, weight, bias) + ctx.other_args = (stride, padding, output_padding, dilation, groups) + return F.conv_transpose2d(input, weight, bias, stride, padding, output_padding, groups, dilation) + + @staticmethod + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + stride, padding, output_padding, dilation, groups = ctx.other_args + padding = _pair(padding) + output_padding = _pair(output_padding) + stride = _pair(stride) + dilation = _pair(dilation) + + grad_input, grad_weight = cudnn_convolution_transpose_backward( + input, grad_output, weight, padding, output_padding, stride, dilation, groups, + False, False, False, [ctx.needs_input_grad[0], ctx.needs_input_grad[1]]) + + if bias is not None and ctx.needs_input_grad[2]: + grad_bias = grad_output.sum([0, 2, 3]) + else: + grad_bias = None + + return grad_input, grad_weight, grad_bias, None, None, None, None, None + + +class conv_transpose3d_explicit_backward(autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias=None, stride=1, padding=0, output_padding=0, dilation=1, groups=1): + ctx.save_for_backward(input, weight, bias) + ctx.other_args = (stride, padding, output_padding, dilation, groups) + return F.conv_transpose3d(input, weight, bias, stride, padding, output_padding, groups, dilation) + + @staticmethod + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + stride, padding, output_padding, dilation, groups = ctx.other_args + padding = _triple(padding) + output_padding = _triple(output_padding) + stride = _triple(stride) + dilation = _triple(dilation) + + grad_input, grad_weight = cudnn_convolution_transpose_backward( + input, grad_output, weight, padding, output_padding, stride, dilation, groups, + False, False, False, [ctx.needs_input_grad[0], ctx.needs_input_grad[1]]) + + if bias is not None and ctx.needs_input_grad[2]: + grad_bias = grad_output.sum([0, 2, 3, 4]) + else: + grad_bias = None + + return grad_input, grad_weight, grad_bias, None, None, None, None, None + + +def test_conv1d_correctness(): + # arguments and test data + N, H, CI, CO, kernel_size, stride, padding, dilation, groups = 4, 28, 64, 128, 3, 1, 1, 1, 1 + data_np = np.random.randn(N, CI, H).astype('float32') + weight_np = np.random.randn(CO, CI // groups, kernel_size).astype('float32') + bias_np = np.random.rand(CO).astype('float32') + + + def test_implementation(func): + data = torch.tensor(data_np).to('cuda').requires_grad_() + weight = torch.tensor(weight_np).to('cuda').requires_grad_() + bias = torch.tensor(bias_np).to('cuda').requires_grad_() + + output = func(data, weight, bias, stride, padding, dilation, groups) + output.backward(torch.ones_like(output)) + + return [x.detach().cpu().numpy() for x in [output, data.grad, weight.grad, bias.grad]] + + output_ref, grad_data_ref, grad_weight_ref, grad_bias_ref = test_implementation(F.conv1d) + output_us, grad_data_us, grad_weight_us, grad_bias_us = test_implementation(conv1d_explicit_backward.apply) + + atol = 1e-5 + print("========== Conv1d Correctness Test ==========") + np.testing.assert_allclose(output_ref, output_us, atol=atol) + np.testing.assert_allclose(grad_data_ref, grad_data_us, atol=atol) + np.testing.assert_allclose(grad_weight_ref, grad_weight_us, atol=atol) + np.testing.assert_allclose(grad_bias_ref, grad_bias_us, atol=atol) + + +def test_conv2d_correctness(): + # arguments and test data + N, H, W, CI, CO, kernel_size, stride, padding, dilation, groups = 4, 28, 28, 64, 128, 3, 1, 1, 1, 1 + data_np = np.random.randn(N, CI, H, W).astype('float32') + weight_np = np.random.randn(CO, CI // groups, kernel_size, kernel_size).astype('float32') + bias_np = np.random.rand(CO).astype('float32') + + + def test_implementation(func): + data = torch.tensor(data_np).to('cuda').requires_grad_() + weight = torch.tensor(weight_np).to('cuda').requires_grad_() + bias = torch.tensor(bias_np).to('cuda').requires_grad_() + + output = func(data, weight, bias, stride, padding, dilation, groups) + output.backward(torch.ones_like(output)) + + return [x.detach().cpu().numpy() for x in [output, data.grad, weight.grad, bias.grad]] + + output_ref, grad_data_ref, grad_weight_ref, grad_bias_ref = test_implementation(F.conv2d) + output_us, grad_data_us, grad_weight_us, grad_bias_us = test_implementation(conv2d_explicit_backward.apply) + + atol = 1e-5 + print("========== Conv2d Correctness Test ==========") + np.testing.assert_allclose(output_ref, output_us, atol=atol) + np.testing.assert_allclose(grad_data_ref, grad_data_us, atol=atol) + np.testing.assert_allclose(grad_weight_ref, grad_weight_us, atol=atol) + np.testing.assert_allclose(grad_bias_ref, grad_bias_us, atol=atol) + + +def test_conv3d_correctness(): + # arguments and test data + N, D, H, W, CI, CO, kernel_size, stride, padding, dilation, groups = 4, 16, 28, 28, 64, 128, 3, 1, 1, 1, 1 + data_np = np.random.randn(N, CI, D, H, W).astype('float32') + weight_np = np.random.randn(CO, CI // groups, kernel_size, kernel_size, kernel_size).astype('float32') + bias_np = np.random.rand(CO).astype('float32') + + def test_implementation(func): + data = torch.tensor(data_np).to('cuda').requires_grad_() + weight = torch.tensor(weight_np).to('cuda').requires_grad_() + bias = torch.tensor(bias_np).to('cuda').requires_grad_() + + output = func(data, weight, bias, stride, padding, dilation, groups) + output.backward(torch.ones_like(output)) + + return [x.detach().cpu().numpy() for x in [output, data.grad, weight.grad, bias.grad]] + + output_ref, grad_data_ref, grad_weight_ref, grad_bias_ref = test_implementation(F.conv3d) + output_us, grad_data_us, grad_weight_us, grad_bias_us = test_implementation(conv3d_explicit_backward.apply) + + atol = 5e-4 + print("========== Conv3d Correctness Test ==========") + np.testing.assert_allclose(output_ref, output_us, atol=atol) + np.testing.assert_allclose(grad_data_ref, grad_data_us, atol=atol) + np.testing.assert_allclose(grad_weight_ref, grad_weight_us, atol=atol) + np.testing.assert_allclose(grad_bias_ref, grad_bias_us, atol=atol) + + +def test_conv1d_transpose_correctness(): + # arguments and test data + N, H, CI, CO, kernel_size, stride, padding, output_padding, dilation, groups =\ + 4, 28, 64, 128, 3, 1, 1, 0, 1, 1 + data_np = np.random.randn(N, CI, H).astype('float32') + weight_np = np.random.randn(CI, CO // groups, kernel_size).astype('float32') + bias_np = np.random.rand(CO).astype('float32') + + + def test_implementation(func): + data = torch.tensor(data_np).to('cuda').requires_grad_() + weight = torch.tensor(weight_np).to('cuda').requires_grad_() + bias = torch.tensor(bias_np).to('cuda').requires_grad_() + + output = func(data, weight, bias, stride, padding, output_padding, dilation, groups) + output.backward(torch.ones_like(output)) + + return [x.detach().cpu().numpy() for x in [output, data.grad, weight.grad, bias.grad]] + + output_ref, grad_data_ref, grad_weight_ref, grad_bias_ref = test_implementation(F.conv_transpose1d) + output_us, grad_data_us, grad_weight_us, grad_bias_us = test_implementation(conv_transpose1d_explicit_backward.apply) + + atol = 2e-4 + print("========== Conv1dTranspose Correctness Test ==========") + np.testing.assert_allclose(output_ref, output_us, atol=atol) + np.testing.assert_allclose(grad_data_ref, grad_data_us, atol=atol) + np.testing.assert_allclose(grad_weight_ref, grad_weight_us, atol=atol) + np.testing.assert_allclose(grad_bias_ref, grad_bias_us, atol=atol) + + +def test_conv2d_transpose_correctness(): + # arguments and test data + N, H, W, CI, CO, kernel_size, stride, padding, output_padding, dilation, groups =\ + 4, 28, 28, 64, 128, 3, 1, 1, 0, 1, 1 + data_np = np.random.randn(N, CI, H, W).astype('float32') + weight_np = np.random.randn(CI, CO // groups, kernel_size, kernel_size).astype('float32') + bias_np = np.random.rand(CO).astype('float32') + + + def test_implementation(func): + data = torch.tensor(data_np).to('cuda').requires_grad_() + weight = torch.tensor(weight_np).to('cuda').requires_grad_() + bias = torch.tensor(bias_np).to('cuda').requires_grad_() + + output = func(data, weight, bias, stride, padding, output_padding, dilation, groups) + output.backward(torch.ones_like(output)) + + return [x.detach().cpu().numpy() for x in [output, data.grad, weight.grad, bias.grad]] + + output_ref, grad_data_ref, grad_weight_ref, grad_bias_ref = test_implementation(F.conv_transpose2d) + output_us, grad_data_us, grad_weight_us, grad_bias_us = test_implementation(conv_transpose2d_explicit_backward.apply) + + atol = 2e-4 + print("========== Conv2dTranspose Correctness Test ==========") + np.testing.assert_allclose(output_ref, output_us, atol=atol) + np.testing.assert_allclose(grad_data_ref, grad_data_us, atol=atol) + np.testing.assert_allclose(grad_weight_ref, grad_weight_us, atol=atol) + np.testing.assert_allclose(grad_bias_ref, grad_bias_us, atol=atol) + + +def test_conv3d_transpose_correctness(): + # arguments and test data + N, D, H, W, CI, CO, kernel_size, stride, padding, output_padding, dilation, groups =\ + 4, 16, 28, 28, 64, 128, 3, 1, 1, 0, 1, 1 + data_np = np.random.randn(N, CI, D, H, W).astype('float32') + weight_np = np.random.randn(CI, CO // groups, kernel_size, kernel_size, kernel_size).astype('float32') + bias_np = np.random.rand(CO).astype('float32') + + + def test_implementation(func): + data = torch.tensor(data_np).to('cuda').requires_grad_() + weight = torch.tensor(weight_np).to('cuda').requires_grad_() + bias = torch.tensor(bias_np).to('cuda').requires_grad_() + + output = func(data, weight, bias, stride, padding, output_padding, dilation, groups) + output.backward(torch.ones_like(output)) + + return [x.detach().cpu().numpy() for x in [output, data.grad, weight.grad, bias.grad]] + + output_ref, grad_data_ref, grad_weight_ref, grad_bias_ref = test_implementation(F.conv_transpose3d) + output_us, grad_data_us, grad_weight_us, grad_bias_us = test_implementation(conv_transpose3d_explicit_backward.apply) + + atol = 2e-4 + print("========== Conv3dTranspose Correctness Test ==========") + np.testing.assert_allclose(output_ref, output_us, atol=atol) + np.testing.assert_allclose(grad_data_ref, grad_data_us, atol=atol) + np.testing.assert_allclose(grad_weight_ref, grad_weight_us, atol=atol) + np.testing.assert_allclose(grad_bias_ref, grad_bias_us, atol=atol) + + +if __name__ == "__main__": + test_conv1d_correctness() + test_conv2d_correctness() + test_conv3d_correctness() + + test_conv1d_transpose_correctness() + test_conv2d_transpose_correctness() + test_conv3d_transpose_correctness() + diff --git a/tests/test_conv.py b/tests/test_conv.py new file mode 100644 index 0000000..7670425 --- /dev/null +++ b/tests/test_conv.py @@ -0,0 +1,76 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from actnn import config, QConv1d, QConv2d, QConv3d, QConvTranspose2d, QConvTranspose3d + +torch.manual_seed(0) + + +def test(layer, qlayer, x, y): + with torch.no_grad(): + qlayer.weight.copy_(layer.weight) + qlayer.bias.copy_(layer.bias) + + # print(qlayer.weight.shape) + # print(x.shape, y) + ce = nn.CrossEntropyLoss().cuda() + + def get_grad(model): + pred = model(x) + pred = F.relu(pred) + pred = pred.view(pred.shape[0], pred.shape[1], -1).mean(2) + loss = ce(pred, y) + model.weight.grad = None + model.bias.grad = None + loss.backward() + return model.weight.grad.cpu().numpy() + + true_grad = get_grad(layer) + grads = [] + for i in range(10): + grads.append(get_grad(qlayer)) + + grads = np.stack(grads, 0) + grad_mean = grads.mean(0) + grad_std = grads.std(0) + + bias = np.linalg.norm(grad_mean - true_grad) + print('Grad = {}, Bias = {}, Std = {}'.format(np.linalg.norm(true_grad), bias, np.linalg.norm(grad_std))) + + +config.activation_compression_bits = [2] +# config.perlayer = False +# config.initial_bits = 2 +# config.pergroup = False +in_channels = 100 +out_channels = 4 +kernel_size = 3 +stride = 2 +groups = 2 + +# layer = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, groups=groups).cuda() +# qlayer = QConv1d(in_channels, out_channels, kernel_size, stride=stride, groups=groups).cuda() +# x = torch.rand([10, in_channels, 2000]).cuda() +# y = torch.empty(10, dtype=torch.long).random_(4).cuda() +# test(layer, qlayer, x, y) +# +# layer = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, groups=groups).cuda() +# qlayer = QConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, groups=groups).cuda() +# x = torch.rand([10, in_channels, 160, 200]).cuda() +# y = torch.empty(10, dtype=torch.long).random_(4).cuda() +# test(layer, qlayer, x, y) +# +# layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, groups=groups).cuda() +# qlayer = QConv2d(in_channels, out_channels, kernel_size, stride=stride, groups=groups).cuda() +# test(layer, qlayer, x, y) + +layer = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, groups=groups).cuda() +qlayer = QConv3d(in_channels, out_channels, kernel_size, stride=stride, groups=groups).cuda() +x = torch.rand([10, in_channels, 10, 12, 8]).cuda() +y = torch.empty(10, dtype=torch.long).random_(4).cuda() +test(layer, qlayer, x, y) + +layer = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, groups=groups).cuda() +qlayer = QConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, groups=groups).cuda() +test(layer, qlayer, x, y) diff --git a/tests/test_static.py b/tests/test_static.py new file mode 100644 index 0000000..f6371db --- /dev/null +++ b/tests/test_static.py @@ -0,0 +1 @@ +print('%dd%d' % (1, 2)) \ No newline at end of file diff --git a/tests/timeit_v2.py b/tests/timeit_v2.py new file mode 100644 index 0000000..e39bcb6 --- /dev/null +++ b/tests/timeit_v2.py @@ -0,0 +1,257 @@ +# timeit_v2.py: Copied from the default library with the following two modifiations +# 1. Add 'finish' argument to timeit for calling cuda synchronization. +# 2. Add accurate measurment utility function py_benchmark + +"""Tool for measuring execution time of small code snippets. + +This module avoids a number of common traps for measuring execution +times. See also Tim Peters' introduction to the Algorithms chapter in +the Python Cookbook, published by O'Reilly. + +Library usage: see the Timer class. + +Command line usage: + python timeit.py [-n N] [-r N] [-s S] [-p] [-h] [--] [statement] + +Options: + -n/--number N: how many times to execute 'statement' (default: see below) + -r/--repeat N: how many times to repeat the timer (default 5) + -s/--setup S: statement to be executed once initially (default 'pass'). + Execution time of this setup statement is NOT timed. + -p/--process: use time.process_time() (default is time.perf_counter()) + -v/--verbose: print raw timing results; repeat for more digits precision + -u/--unit: set the output time unit (nsec, usec, msec, or sec) + -h/--help: print this usage message and exit + --: separate options from statement, use when statement starts with - + statement: statement to be timed (default 'pass') + +A multi-line statement may be given by specifying each line as a +separate argument; indented lines are possible by enclosing an +argument in quotes and using leading spaces. Multiple -s options are +treated similarly. + +If -n is not given, a suitable number of loops is calculated by trying +successive powers of 10 until the total time is at least 0.2 seconds. + +Note: there is a certain baseline overhead associated with executing a +pass statement. It differs between versions. The code here doesn't try +to hide it, but you should be aware of it. The baseline overhead can be +measured by invoking the program without arguments. + +Classes: + + Timer + +Functions: + + timeit(string, string) -> float + repeat(string, string) -> list + default_timer() -> float +""" + +import gc +import sys +import time +import itertools + +__all__ = ["Timer", "timeit", "repeat", "default_timer"] + +dummy_src_name = "" +default_number = 1000000 +default_repeat = 5 +default_timer = time.perf_counter + +_globals = globals + +# Don't change the indentation of the template; the reindent() calls +# in Timer.__init__() depend on setup being indented 4 spaces and stmt +# being indented 8 spaces. +template = """ +def inner(_it, _timer{init}): + {setup} + _t0 = _timer() + for _i in _it: + {stmt} + {finish} + _t1 = _timer() + return _t1 - _t0 +""" + +def reindent(src, indent): + """Helper to reindent a multi-line statement.""" + return src.replace("\n", "\n" + " "*indent) + +class Timer: + """Class for timing execution speed of small code snippets. + + The constructor takes a statement to be timed, an additional + statement used for setup, and a timer function. Both statements + default to 'pass'; the timer function is platform-dependent (see + module doc string). If 'globals' is specified, the code will be + executed within that namespace (as opposed to inside timeit's + namespace). + + To measure the execution time of the first statement, use the + timeit() method. The repeat() method is a convenience to call + timeit() multiple times and return a list of results. + + The statements may contain newlines, as long as they don't contain + multi-line string literals. + """ + + def __init__(self, stmt="pass", setup="pass", finish='pass', timer=default_timer, + globals=None): + """Constructor. See class doc string.""" + self.timer = timer + local_ns = {} + global_ns = _globals() if globals is None else globals + init = '' + if isinstance(setup, str): + # Check that the code can be compiled outside a function + compile(setup, dummy_src_name, "exec") + stmtprefix = setup + '\n' + setup = reindent(setup, 4) + elif callable(setup): + local_ns['_setup'] = setup + init += ', _setup=_setup' + stmtprefix = '' + setup = '_setup()' + else: + raise ValueError("setup is neither a string nor callable") + if isinstance(stmt, str): + # Check that the code can be compiled outside a function + compile(stmtprefix + stmt, dummy_src_name, "exec") + stmt = reindent(stmt, 8) + elif callable(stmt): + local_ns['_stmt'] = stmt + init += ', _stmt=_stmt' + stmt = '_stmt()' + else: + raise ValueError("stmt is neither a string nor callable") + + assert isinstance(finish, str) + compile(setup + '\n' + stmt + '\n' + finish, dummy_src_name, 'exec') + finish = reindent(finish, 4) + + src = template.format(stmt=stmt, setup=setup, init=init, finish=finish) + self.src = src # Save for traceback display + code = compile(src, dummy_src_name, "exec") + exec(code, global_ns, local_ns) + self.inner = local_ns["inner"] + + def print_exc(self, file=None): + """Helper to print a traceback from the timed code. + + Typical use: + + t = Timer(...) # outside the try/except + try: + t.timeit(...) # or t.repeat(...) + except: + t.print_exc() + + The advantage over the standard traceback is that source lines + in the compiled template will be displayed. + + The optional file argument directs where the traceback is + sent; it defaults to sys.stderr. + """ + import linecache, traceback + if self.src is not None: + linecache.cache[dummy_src_name] = (len(self.src), + None, + self.src.split("\n"), + dummy_src_name) + # else the source is already stored somewhere else + + traceback.print_exc(file=file) + + def timeit(self, number=default_number): + """Time 'number' executions of the main statement. + + To be precise, this executes the setup statement once, and + then returns the time it takes to execute the main statement + a number of times, as a float measured in seconds. The + argument is the number of times through the loop, defaulting + to one million. The main statement, the setup statement and + the timer function to be used are passed to the constructor. + """ + it = itertools.repeat(None, number) + gcold = gc.isenabled() + gc.disable() + try: + timing = self.inner(it, self.timer) + finally: + if gcold: + gc.enable() + return timing + + def repeat(self, repeat=default_repeat, number=default_number): + """Call timeit() a few times. + + This is a convenience function that calls the timeit() + repeatedly, returning a list of results. The first argument + specifies how many times to call timeit(), defaulting to 5; + the second argument specifies the timer argument, defaulting + to one million. + + Note: it's tempting to calculate mean and standard deviation + from the result vector and report these. However, this is not + very useful. In a typical case, the lowest value gives a + lower bound for how fast your machine can run the given code + snippet; higher values in the result vector are typically not + caused by variability in Python's speed, but by other + processes interfering with your timing accuracy. So the min() + of the result is probably the only number you should be + interested in. After that, you should look at the entire + vector and apply common sense rather than statistics. + """ + r = [] + for i in range(repeat): + t = self.timeit(number) + r.append(t) + return r + + def autorange(self, callback=None): + """Return the number of loops and time taken so that total time >= 0.2. + + Calls the timeit method with increasing numbers from the sequence + 1, 2, 5, 10, 20, 50, ... until the time taken is at least 0.2 + second. Returns (number, time_taken). + + If *callback* is given and is not None, it will be called after + each trial with two arguments: ``callback(number, time_taken)``. + """ + i = 1 + while True: + for j in 1, 2, 5: + number = i * j + time_taken = self.timeit(number) + if callback: + callback(number, time_taken) + if time_taken >= 0.2: + return (number, time_taken) + i *= 10 + +def timeit(stmt="pass", setup="pass", finish='pass', timer=default_timer, + number=default_number, globals=None): + """Convenience function to create Timer object and call timeit method.""" + return Timer(stmt, setup, finish, timer, globals).timeit(number) + +def repeat(stmt="pass", setup="pass", finish='pass', timer=default_timer, + repeat=default_repeat, number=default_number, globals=None): + """Convenience function to create Timer object and call repeat method.""" + return Timer(stmt, setup, finish, timer, globals).repeat(repeat, number) + +def py_benchmark(stmt, context, min_repeat_second=1, setup='pass', finish='pass'): + total_time = 0 + number = 10 + + eval(stmt, context) # warmup + total_time = timeit(stmt=stmt, setup=setup, finish=finish, number=number, globals=context) + while total_time < min_repeat_second: + number = int(number * (min_repeat_second / total_time)) + 1 + total_time = timeit(stmt=stmt, setup=setup, finish=finish, number=number, globals=context) + + return total_time / number + diff --git a/tests/trigger_error.py b/tests/trigger_error.py new file mode 100644 index 0000000..821c4bd --- /dev/null +++ b/tests/trigger_error.py @@ -0,0 +1,26 @@ +"""Trigger the autograd error""" +import torch +from torch import nn, autograd + +class identity(autograd.Function): + @staticmethod + def forward(ctx, data): + # correct + #ctx.save_for_backward(data) + + # correct + #ctx.save_for_backward(data + 1) + + # RuntimeError: No grad accumulator for a saved leaf! + ctx.save_for_backward(data.view((1, -1))) + return data + + @staticmethod + def backward(ctx, data): + print(ctx.saved_tensors) + return data + + +a = torch.ones((10,)).requires_grad_() +b = identity.apply(a).sum() +b.backward()